Commit 2897041e authored by David S. Miller's avatar David S. Miller

Merge branch 'tls-fixes'

Jakub Kicinski says:

====================
tls: rx: strp: fix inline crypto offload

The local strparser version I added to TLS does not preserve
decryption status, which breaks inline crypto (NIC offload).
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 7e01c7f7 74836ec8
......@@ -1587,6 +1587,16 @@ static inline void skb_copy_hash(struct sk_buff *to, const struct sk_buff *from)
to->l4_hash = from->l4_hash;
};
static inline int skb_cmp_decrypted(const struct sk_buff *skb1,
const struct sk_buff *skb2)
{
#ifdef CONFIG_TLS_DEVICE
return skb2->decrypted - skb1->decrypted;
#else
return 0;
#endif
}
static inline void skb_copy_decrypted(struct sk_buff *to,
const struct sk_buff *from)
{
......
......@@ -126,6 +126,7 @@ struct tls_strparser {
u32 mark : 8;
u32 stopped : 1;
u32 copy_mode : 1;
u32 mixed_decrypted : 1;
u32 msg_ready : 1;
struct strp_msg stm;
......
......@@ -167,6 +167,11 @@ static inline bool tls_strp_msg_ready(struct tls_sw_context_rx *ctx)
return ctx->strp.msg_ready;
}
static inline bool tls_strp_msg_mixed_decrypted(struct tls_sw_context_rx *ctx)
{
return ctx->strp.mixed_decrypted;
}
#ifdef CONFIG_TLS_DEVICE
int tls_device_init(void);
void tls_device_cleanup(void);
......
......@@ -1007,20 +1007,14 @@ int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx)
struct tls_sw_context_rx *sw_ctx = tls_sw_ctx_rx(tls_ctx);
struct sk_buff *skb = tls_strp_msg(sw_ctx);
struct strp_msg *rxm = strp_msg(skb);
int is_decrypted = skb->decrypted;
int is_encrypted = !is_decrypted;
struct sk_buff *skb_iter;
int left;
left = rxm->full_len - skb->len;
/* Check if all the data is decrypted already */
skb_iter = skb_shinfo(skb)->frag_list;
while (skb_iter && left > 0) {
is_decrypted &= skb_iter->decrypted;
is_encrypted &= !skb_iter->decrypted;
left -= skb_iter->len;
skb_iter = skb_iter->next;
int is_decrypted, is_encrypted;
if (!tls_strp_msg_mixed_decrypted(sw_ctx)) {
is_decrypted = skb->decrypted;
is_encrypted = !is_decrypted;
} else {
is_decrypted = 0;
is_encrypted = 0;
}
trace_tls_device_decrypted(sk, tcp_sk(sk)->copied_seq - rxm->full_len,
......
......@@ -29,34 +29,50 @@ static void tls_strp_anchor_free(struct tls_strparser *strp)
struct skb_shared_info *shinfo = skb_shinfo(strp->anchor);
DEBUG_NET_WARN_ON_ONCE(atomic_read(&shinfo->dataref) != 1);
if (!strp->copy_mode)
shinfo->frag_list = NULL;
consume_skb(strp->anchor);
strp->anchor = NULL;
}
/* Create a new skb with the contents of input copied to its page frags */
static struct sk_buff *tls_strp_msg_make_copy(struct tls_strparser *strp)
static struct sk_buff *
tls_strp_skb_copy(struct tls_strparser *strp, struct sk_buff *in_skb,
int offset, int len)
{
struct strp_msg *rxm;
struct sk_buff *skb;
int i, err, offset;
int i, err;
skb = alloc_skb_with_frags(0, strp->stm.full_len, TLS_PAGE_ORDER,
skb = alloc_skb_with_frags(0, len, TLS_PAGE_ORDER,
&err, strp->sk->sk_allocation);
if (!skb)
return NULL;
offset = strp->stm.offset;
for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
skb_frag_t *frag = &skb_shinfo(skb)->frags[i];
WARN_ON_ONCE(skb_copy_bits(strp->anchor, offset,
WARN_ON_ONCE(skb_copy_bits(in_skb, offset,
skb_frag_address(frag),
skb_frag_size(frag)));
offset += skb_frag_size(frag);
}
skb_copy_header(skb, strp->anchor);
skb->len = len;
skb->data_len = len;
skb_copy_header(skb, in_skb);
return skb;
}
/* Create a new skb with the contents of input copied to its page frags */
static struct sk_buff *tls_strp_msg_make_copy(struct tls_strparser *strp)
{
struct strp_msg *rxm;
struct sk_buff *skb;
skb = tls_strp_skb_copy(strp, strp->anchor, strp->stm.offset,
strp->stm.full_len);
if (!skb)
return NULL;
rxm = strp_msg(skb);
rxm->offset = 0;
return skb;
......@@ -180,22 +196,22 @@ static void tls_strp_flush_anchor_copy(struct tls_strparser *strp)
for (i = 0; i < shinfo->nr_frags; i++)
__skb_frag_unref(&shinfo->frags[i], false);
shinfo->nr_frags = 0;
if (strp->copy_mode) {
kfree_skb_list(shinfo->frag_list);
shinfo->frag_list = NULL;
}
strp->copy_mode = 0;
strp->mixed_decrypted = 0;
}
static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb,
unsigned int offset, size_t in_len)
static int tls_strp_copyin_frag(struct tls_strparser *strp, struct sk_buff *skb,
struct sk_buff *in_skb, unsigned int offset,
size_t in_len)
{
struct tls_strparser *strp = (struct tls_strparser *)desc->arg.data;
struct sk_buff *skb;
skb_frag_t *frag;
size_t len, chunk;
skb_frag_t *frag;
int sz;
if (strp->msg_ready)
return 0;
skb = strp->anchor;
frag = &skb_shinfo(skb)->frags[skb->len / PAGE_SIZE];
len = in_len;
......@@ -208,19 +224,26 @@ static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb,
skb_frag_size(frag),
chunk));
sz = tls_rx_msg_size(strp, strp->anchor);
if (sz < 0) {
desc->error = sz;
return 0;
}
/* We may have over-read, sz == 0 is guaranteed under-read */
if (sz > 0)
chunk = min_t(size_t, chunk, sz - skb->len);
skb->len += chunk;
skb->data_len += chunk;
skb_frag_size_add(frag, chunk);
sz = tls_rx_msg_size(strp, skb);
if (sz < 0)
return sz;
/* We may have over-read, sz == 0 is guaranteed under-read */
if (unlikely(sz && sz < skb->len)) {
int over = skb->len - sz;
WARN_ON_ONCE(over > chunk);
skb->len -= over;
skb->data_len -= over;
skb_frag_size_add(frag, -over);
chunk -= over;
}
frag++;
len -= chunk;
offset += chunk;
......@@ -247,15 +270,99 @@ static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb,
offset += chunk;
}
if (strp->stm.full_len == skb->len) {
read_done:
return in_len - len;
}
static int tls_strp_copyin_skb(struct tls_strparser *strp, struct sk_buff *skb,
struct sk_buff *in_skb, unsigned int offset,
size_t in_len)
{
struct sk_buff *nskb, *first, *last;
struct skb_shared_info *shinfo;
size_t chunk;
int sz;
if (strp->stm.full_len)
chunk = strp->stm.full_len - skb->len;
else
chunk = TLS_MAX_PAYLOAD_SIZE + PAGE_SIZE;
chunk = min(chunk, in_len);
nskb = tls_strp_skb_copy(strp, in_skb, offset, chunk);
if (!nskb)
return -ENOMEM;
shinfo = skb_shinfo(skb);
if (!shinfo->frag_list) {
shinfo->frag_list = nskb;
nskb->prev = nskb;
} else {
first = shinfo->frag_list;
last = first->prev;
last->next = nskb;
first->prev = nskb;
}
skb->len += chunk;
skb->data_len += chunk;
if (!strp->stm.full_len) {
sz = tls_rx_msg_size(strp, skb);
if (sz < 0)
return sz;
/* We may have over-read, sz == 0 is guaranteed under-read */
if (unlikely(sz && sz < skb->len)) {
int over = skb->len - sz;
WARN_ON_ONCE(over > chunk);
skb->len -= over;
skb->data_len -= over;
__pskb_trim(nskb, nskb->len - over);
chunk -= over;
}
strp->stm.full_len = sz;
}
return chunk;
}
static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb,
unsigned int offset, size_t in_len)
{
struct tls_strparser *strp = (struct tls_strparser *)desc->arg.data;
struct sk_buff *skb;
int ret;
if (strp->msg_ready)
return 0;
skb = strp->anchor;
if (!skb->len)
skb_copy_decrypted(skb, in_skb);
else
strp->mixed_decrypted |= !!skb_cmp_decrypted(skb, in_skb);
if (IS_ENABLED(CONFIG_TLS_DEVICE) && strp->mixed_decrypted)
ret = tls_strp_copyin_skb(strp, skb, in_skb, offset, in_len);
else
ret = tls_strp_copyin_frag(strp, skb, in_skb, offset, in_len);
if (ret < 0) {
desc->error = ret;
ret = 0;
}
if (strp->stm.full_len && strp->stm.full_len == skb->len) {
desc->count = 0;
strp->msg_ready = 1;
tls_rx_msg_ready(strp);
}
read_done:
return in_len - len;
return ret;
}
static int tls_strp_read_copyin(struct tls_strparser *strp)
......@@ -315,15 +422,19 @@ static int tls_strp_read_copy(struct tls_strparser *strp, bool qshort)
return 0;
}
static bool tls_strp_check_no_dup(struct tls_strparser *strp)
static bool tls_strp_check_queue_ok(struct tls_strparser *strp)
{
unsigned int len = strp->stm.offset + strp->stm.full_len;
struct sk_buff *skb;
struct sk_buff *first, *skb;
u32 seq;
skb = skb_shinfo(strp->anchor)->frag_list;
seq = TCP_SKB_CB(skb)->seq;
first = skb_shinfo(strp->anchor)->frag_list;
skb = first;
seq = TCP_SKB_CB(first)->seq;
/* Make sure there's no duplicate data in the queue,
* and the decrypted status matches.
*/
while (skb->len < len) {
seq += skb->len;
len -= skb->len;
......@@ -331,6 +442,8 @@ static bool tls_strp_check_no_dup(struct tls_strparser *strp)
if (TCP_SKB_CB(skb)->seq != seq)
return false;
if (skb_cmp_decrypted(first, skb))
return false;
}
return true;
......@@ -411,7 +524,7 @@ static int tls_strp_read_sock(struct tls_strparser *strp)
return tls_strp_read_copy(strp, true);
}
if (!tls_strp_check_no_dup(strp))
if (!tls_strp_check_queue_ok(strp))
return tls_strp_read_copy(strp, false);
strp->msg_ready = 1;
......
......@@ -2304,10 +2304,14 @@ static void tls_data_ready(struct sock *sk)
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
struct sk_psock *psock;
gfp_t alloc_save;
trace_sk_data_ready(sk);
alloc_save = sk->sk_allocation;
sk->sk_allocation = GFP_ATOMIC;
tls_strp_data_ready(&ctx->strp);
sk->sk_allocation = alloc_save;
psock = sk_psock_get(sk);
if (psock) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment