Commit 2897041e authored by David S. Miller's avatar David S. Miller

Merge branch 'tls-fixes'

Jakub Kicinski says:

====================
tls: rx: strp: fix inline crypto offload

The local strparser version I added to TLS does not preserve
decryption status, which breaks inline crypto (NIC offload).
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 7e01c7f7 74836ec8
...@@ -1587,6 +1587,16 @@ static inline void skb_copy_hash(struct sk_buff *to, const struct sk_buff *from) ...@@ -1587,6 +1587,16 @@ static inline void skb_copy_hash(struct sk_buff *to, const struct sk_buff *from)
to->l4_hash = from->l4_hash; to->l4_hash = from->l4_hash;
}; };
static inline int skb_cmp_decrypted(const struct sk_buff *skb1,
const struct sk_buff *skb2)
{
#ifdef CONFIG_TLS_DEVICE
return skb2->decrypted - skb1->decrypted;
#else
return 0;
#endif
}
static inline void skb_copy_decrypted(struct sk_buff *to, static inline void skb_copy_decrypted(struct sk_buff *to,
const struct sk_buff *from) const struct sk_buff *from)
{ {
......
...@@ -126,6 +126,7 @@ struct tls_strparser { ...@@ -126,6 +126,7 @@ struct tls_strparser {
u32 mark : 8; u32 mark : 8;
u32 stopped : 1; u32 stopped : 1;
u32 copy_mode : 1; u32 copy_mode : 1;
u32 mixed_decrypted : 1;
u32 msg_ready : 1; u32 msg_ready : 1;
struct strp_msg stm; struct strp_msg stm;
......
...@@ -167,6 +167,11 @@ static inline bool tls_strp_msg_ready(struct tls_sw_context_rx *ctx) ...@@ -167,6 +167,11 @@ static inline bool tls_strp_msg_ready(struct tls_sw_context_rx *ctx)
return ctx->strp.msg_ready; return ctx->strp.msg_ready;
} }
static inline bool tls_strp_msg_mixed_decrypted(struct tls_sw_context_rx *ctx)
{
return ctx->strp.mixed_decrypted;
}
#ifdef CONFIG_TLS_DEVICE #ifdef CONFIG_TLS_DEVICE
int tls_device_init(void); int tls_device_init(void);
void tls_device_cleanup(void); void tls_device_cleanup(void);
......
...@@ -1007,20 +1007,14 @@ int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx) ...@@ -1007,20 +1007,14 @@ int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx)
struct tls_sw_context_rx *sw_ctx = tls_sw_ctx_rx(tls_ctx); struct tls_sw_context_rx *sw_ctx = tls_sw_ctx_rx(tls_ctx);
struct sk_buff *skb = tls_strp_msg(sw_ctx); struct sk_buff *skb = tls_strp_msg(sw_ctx);
struct strp_msg *rxm = strp_msg(skb); struct strp_msg *rxm = strp_msg(skb);
int is_decrypted = skb->decrypted; int is_decrypted, is_encrypted;
int is_encrypted = !is_decrypted;
struct sk_buff *skb_iter; if (!tls_strp_msg_mixed_decrypted(sw_ctx)) {
int left; is_decrypted = skb->decrypted;
is_encrypted = !is_decrypted;
left = rxm->full_len - skb->len; } else {
/* Check if all the data is decrypted already */ is_decrypted = 0;
skb_iter = skb_shinfo(skb)->frag_list; is_encrypted = 0;
while (skb_iter && left > 0) {
is_decrypted &= skb_iter->decrypted;
is_encrypted &= !skb_iter->decrypted;
left -= skb_iter->len;
skb_iter = skb_iter->next;
} }
trace_tls_device_decrypted(sk, tcp_sk(sk)->copied_seq - rxm->full_len, trace_tls_device_decrypted(sk, tcp_sk(sk)->copied_seq - rxm->full_len,
......
...@@ -29,34 +29,50 @@ static void tls_strp_anchor_free(struct tls_strparser *strp) ...@@ -29,34 +29,50 @@ static void tls_strp_anchor_free(struct tls_strparser *strp)
struct skb_shared_info *shinfo = skb_shinfo(strp->anchor); struct skb_shared_info *shinfo = skb_shinfo(strp->anchor);
DEBUG_NET_WARN_ON_ONCE(atomic_read(&shinfo->dataref) != 1); DEBUG_NET_WARN_ON_ONCE(atomic_read(&shinfo->dataref) != 1);
shinfo->frag_list = NULL; if (!strp->copy_mode)
shinfo->frag_list = NULL;
consume_skb(strp->anchor); consume_skb(strp->anchor);
strp->anchor = NULL; strp->anchor = NULL;
} }
/* Create a new skb with the contents of input copied to its page frags */ static struct sk_buff *
static struct sk_buff *tls_strp_msg_make_copy(struct tls_strparser *strp) tls_strp_skb_copy(struct tls_strparser *strp, struct sk_buff *in_skb,
int offset, int len)
{ {
struct strp_msg *rxm;
struct sk_buff *skb; struct sk_buff *skb;
int i, err, offset; int i, err;
skb = alloc_skb_with_frags(0, strp->stm.full_len, TLS_PAGE_ORDER, skb = alloc_skb_with_frags(0, len, TLS_PAGE_ORDER,
&err, strp->sk->sk_allocation); &err, strp->sk->sk_allocation);
if (!skb) if (!skb)
return NULL; return NULL;
offset = strp->stm.offset;
for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) { for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
skb_frag_t *frag = &skb_shinfo(skb)->frags[i]; skb_frag_t *frag = &skb_shinfo(skb)->frags[i];
WARN_ON_ONCE(skb_copy_bits(strp->anchor, offset, WARN_ON_ONCE(skb_copy_bits(in_skb, offset,
skb_frag_address(frag), skb_frag_address(frag),
skb_frag_size(frag))); skb_frag_size(frag)));
offset += skb_frag_size(frag); offset += skb_frag_size(frag);
} }
skb_copy_header(skb, strp->anchor); skb->len = len;
skb->data_len = len;
skb_copy_header(skb, in_skb);
return skb;
}
/* Create a new skb with the contents of input copied to its page frags */
static struct sk_buff *tls_strp_msg_make_copy(struct tls_strparser *strp)
{
struct strp_msg *rxm;
struct sk_buff *skb;
skb = tls_strp_skb_copy(strp, strp->anchor, strp->stm.offset,
strp->stm.full_len);
if (!skb)
return NULL;
rxm = strp_msg(skb); rxm = strp_msg(skb);
rxm->offset = 0; rxm->offset = 0;
return skb; return skb;
...@@ -180,22 +196,22 @@ static void tls_strp_flush_anchor_copy(struct tls_strparser *strp) ...@@ -180,22 +196,22 @@ static void tls_strp_flush_anchor_copy(struct tls_strparser *strp)
for (i = 0; i < shinfo->nr_frags; i++) for (i = 0; i < shinfo->nr_frags; i++)
__skb_frag_unref(&shinfo->frags[i], false); __skb_frag_unref(&shinfo->frags[i], false);
shinfo->nr_frags = 0; shinfo->nr_frags = 0;
if (strp->copy_mode) {
kfree_skb_list(shinfo->frag_list);
shinfo->frag_list = NULL;
}
strp->copy_mode = 0; strp->copy_mode = 0;
strp->mixed_decrypted = 0;
} }
static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb, static int tls_strp_copyin_frag(struct tls_strparser *strp, struct sk_buff *skb,
unsigned int offset, size_t in_len) struct sk_buff *in_skb, unsigned int offset,
size_t in_len)
{ {
struct tls_strparser *strp = (struct tls_strparser *)desc->arg.data;
struct sk_buff *skb;
skb_frag_t *frag;
size_t len, chunk; size_t len, chunk;
skb_frag_t *frag;
int sz; int sz;
if (strp->msg_ready)
return 0;
skb = strp->anchor;
frag = &skb_shinfo(skb)->frags[skb->len / PAGE_SIZE]; frag = &skb_shinfo(skb)->frags[skb->len / PAGE_SIZE];
len = in_len; len = in_len;
...@@ -208,19 +224,26 @@ static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb, ...@@ -208,19 +224,26 @@ static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb,
skb_frag_size(frag), skb_frag_size(frag),
chunk)); chunk));
sz = tls_rx_msg_size(strp, strp->anchor);
if (sz < 0) {
desc->error = sz;
return 0;
}
/* We may have over-read, sz == 0 is guaranteed under-read */
if (sz > 0)
chunk = min_t(size_t, chunk, sz - skb->len);
skb->len += chunk; skb->len += chunk;
skb->data_len += chunk; skb->data_len += chunk;
skb_frag_size_add(frag, chunk); skb_frag_size_add(frag, chunk);
sz = tls_rx_msg_size(strp, skb);
if (sz < 0)
return sz;
/* We may have over-read, sz == 0 is guaranteed under-read */
if (unlikely(sz && sz < skb->len)) {
int over = skb->len - sz;
WARN_ON_ONCE(over > chunk);
skb->len -= over;
skb->data_len -= over;
skb_frag_size_add(frag, -over);
chunk -= over;
}
frag++; frag++;
len -= chunk; len -= chunk;
offset += chunk; offset += chunk;
...@@ -247,15 +270,99 @@ static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb, ...@@ -247,15 +270,99 @@ static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb,
offset += chunk; offset += chunk;
} }
if (strp->stm.full_len == skb->len) { read_done:
return in_len - len;
}
static int tls_strp_copyin_skb(struct tls_strparser *strp, struct sk_buff *skb,
struct sk_buff *in_skb, unsigned int offset,
size_t in_len)
{
struct sk_buff *nskb, *first, *last;
struct skb_shared_info *shinfo;
size_t chunk;
int sz;
if (strp->stm.full_len)
chunk = strp->stm.full_len - skb->len;
else
chunk = TLS_MAX_PAYLOAD_SIZE + PAGE_SIZE;
chunk = min(chunk, in_len);
nskb = tls_strp_skb_copy(strp, in_skb, offset, chunk);
if (!nskb)
return -ENOMEM;
shinfo = skb_shinfo(skb);
if (!shinfo->frag_list) {
shinfo->frag_list = nskb;
nskb->prev = nskb;
} else {
first = shinfo->frag_list;
last = first->prev;
last->next = nskb;
first->prev = nskb;
}
skb->len += chunk;
skb->data_len += chunk;
if (!strp->stm.full_len) {
sz = tls_rx_msg_size(strp, skb);
if (sz < 0)
return sz;
/* We may have over-read, sz == 0 is guaranteed under-read */
if (unlikely(sz && sz < skb->len)) {
int over = skb->len - sz;
WARN_ON_ONCE(over > chunk);
skb->len -= over;
skb->data_len -= over;
__pskb_trim(nskb, nskb->len - over);
chunk -= over;
}
strp->stm.full_len = sz;
}
return chunk;
}
static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb,
unsigned int offset, size_t in_len)
{
struct tls_strparser *strp = (struct tls_strparser *)desc->arg.data;
struct sk_buff *skb;
int ret;
if (strp->msg_ready)
return 0;
skb = strp->anchor;
if (!skb->len)
skb_copy_decrypted(skb, in_skb);
else
strp->mixed_decrypted |= !!skb_cmp_decrypted(skb, in_skb);
if (IS_ENABLED(CONFIG_TLS_DEVICE) && strp->mixed_decrypted)
ret = tls_strp_copyin_skb(strp, skb, in_skb, offset, in_len);
else
ret = tls_strp_copyin_frag(strp, skb, in_skb, offset, in_len);
if (ret < 0) {
desc->error = ret;
ret = 0;
}
if (strp->stm.full_len && strp->stm.full_len == skb->len) {
desc->count = 0; desc->count = 0;
strp->msg_ready = 1; strp->msg_ready = 1;
tls_rx_msg_ready(strp); tls_rx_msg_ready(strp);
} }
read_done: return ret;
return in_len - len;
} }
static int tls_strp_read_copyin(struct tls_strparser *strp) static int tls_strp_read_copyin(struct tls_strparser *strp)
...@@ -315,15 +422,19 @@ static int tls_strp_read_copy(struct tls_strparser *strp, bool qshort) ...@@ -315,15 +422,19 @@ static int tls_strp_read_copy(struct tls_strparser *strp, bool qshort)
return 0; return 0;
} }
static bool tls_strp_check_no_dup(struct tls_strparser *strp) static bool tls_strp_check_queue_ok(struct tls_strparser *strp)
{ {
unsigned int len = strp->stm.offset + strp->stm.full_len; unsigned int len = strp->stm.offset + strp->stm.full_len;
struct sk_buff *skb; struct sk_buff *first, *skb;
u32 seq; u32 seq;
skb = skb_shinfo(strp->anchor)->frag_list; first = skb_shinfo(strp->anchor)->frag_list;
seq = TCP_SKB_CB(skb)->seq; skb = first;
seq = TCP_SKB_CB(first)->seq;
/* Make sure there's no duplicate data in the queue,
* and the decrypted status matches.
*/
while (skb->len < len) { while (skb->len < len) {
seq += skb->len; seq += skb->len;
len -= skb->len; len -= skb->len;
...@@ -331,6 +442,8 @@ static bool tls_strp_check_no_dup(struct tls_strparser *strp) ...@@ -331,6 +442,8 @@ static bool tls_strp_check_no_dup(struct tls_strparser *strp)
if (TCP_SKB_CB(skb)->seq != seq) if (TCP_SKB_CB(skb)->seq != seq)
return false; return false;
if (skb_cmp_decrypted(first, skb))
return false;
} }
return true; return true;
...@@ -411,7 +524,7 @@ static int tls_strp_read_sock(struct tls_strparser *strp) ...@@ -411,7 +524,7 @@ static int tls_strp_read_sock(struct tls_strparser *strp)
return tls_strp_read_copy(strp, true); return tls_strp_read_copy(strp, true);
} }
if (!tls_strp_check_no_dup(strp)) if (!tls_strp_check_queue_ok(strp))
return tls_strp_read_copy(strp, false); return tls_strp_read_copy(strp, false);
strp->msg_ready = 1; strp->msg_ready = 1;
......
...@@ -2304,10 +2304,14 @@ static void tls_data_ready(struct sock *sk) ...@@ -2304,10 +2304,14 @@ static void tls_data_ready(struct sock *sk)
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
struct sk_psock *psock; struct sk_psock *psock;
gfp_t alloc_save;
trace_sk_data_ready(sk); trace_sk_data_ready(sk);
alloc_save = sk->sk_allocation;
sk->sk_allocation = GFP_ATOMIC;
tls_strp_data_ready(&ctx->strp); tls_strp_data_ready(&ctx->strp);
sk->sk_allocation = alloc_save;
psock = sk_psock_get(sk); psock = sk_psock_get(sk);
if (psock) { if (psock) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment