Commit 8f1c3850 authored by David S. Miller's avatar David S. Miller

Merge branch 'tls-rx-refactor-part-3'

Jakub Kicinski says:

====================
tls: rx: random refactoring part 3

TLS Rx refactoring. Part 3 of 3. This set is mostly around rx_list
and async processing. The last two patches are minor optimizations.
A couple of features to follow.
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents f45ba67e a4ae58cd
...@@ -188,18 +188,13 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err) ...@@ -188,18 +188,13 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
tls_err_abort(skb->sk, err); tls_err_abort(skb->sk, err);
} else { } else {
struct strp_msg *rxm = strp_msg(skb); struct strp_msg *rxm = strp_msg(skb);
int pad;
pad = padding_length(prot, skb); /* No TLS 1.3 support with async crypto */
if (pad < 0) { WARN_ON(prot->tail_size);
ctx->async_wait.err = pad;
tls_err_abort(skb->sk, pad);
} else {
rxm->full_len -= pad;
rxm->offset += prot->prepend_size; rxm->offset += prot->prepend_size;
rxm->full_len -= prot->overhead_size; rxm->full_len -= prot->overhead_size;
} }
}
/* After using skb->sk to propagate sk through crypto async callback /* After using skb->sk to propagate sk through crypto async callback
* we need to NULL it again. * we need to NULL it again.
...@@ -232,7 +227,7 @@ static int tls_do_decryption(struct sock *sk, ...@@ -232,7 +227,7 @@ static int tls_do_decryption(struct sock *sk,
char *iv_recv, char *iv_recv,
size_t data_len, size_t data_len,
struct aead_request *aead_req, struct aead_request *aead_req,
bool async) struct tls_decrypt_arg *darg)
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_prot_info *prot = &tls_ctx->prot_info;
...@@ -245,7 +240,7 @@ static int tls_do_decryption(struct sock *sk, ...@@ -245,7 +240,7 @@ static int tls_do_decryption(struct sock *sk,
data_len + prot->tag_size, data_len + prot->tag_size,
(u8 *)iv_recv); (u8 *)iv_recv);
if (async) { if (darg->async) {
/* Using skb->sk to push sk through to crypto async callback /* Using skb->sk to push sk through to crypto async callback
* handler. This allows propagating errors up to the socket * handler. This allows propagating errors up to the socket
* if needed. It _must_ be cleared in the async handler * if needed. It _must_ be cleared in the async handler
...@@ -265,14 +260,15 @@ static int tls_do_decryption(struct sock *sk, ...@@ -265,14 +260,15 @@ static int tls_do_decryption(struct sock *sk,
ret = crypto_aead_decrypt(aead_req); ret = crypto_aead_decrypt(aead_req);
if (ret == -EINPROGRESS) { if (ret == -EINPROGRESS) {
if (async) if (darg->async)
return ret; return 0;
ret = crypto_wait_req(ret, &ctx->async_wait); ret = crypto_wait_req(ret, &ctx->async_wait);
} }
darg->async = false;
if (async) if (ret == -EBADMSG)
atomic_dec(&ctx->decrypt_pending); TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR);
return ret; return ret;
} }
...@@ -1456,7 +1452,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, ...@@ -1456,7 +1452,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv); aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
mem_size = aead_size + (nsg * sizeof(struct scatterlist)); mem_size = aead_size + (nsg * sizeof(struct scatterlist));
mem_size = mem_size + prot->aad_size; mem_size = mem_size + prot->aad_size;
mem_size = mem_size + crypto_aead_ivsize(ctx->aead_recv); mem_size = mem_size + MAX_IV_SIZE;
/* Allocate a single block of memory which contains /* Allocate a single block of memory which contains
* aead_req || sgin[] || sgout[] || aad || iv. * aead_req || sgin[] || sgout[] || aad || iv.
...@@ -1486,6 +1482,11 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, ...@@ -1486,6 +1482,11 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
} }
/* Prepare IV */ /* Prepare IV */
if (prot->version == TLS_1_3_VERSION ||
prot->cipher_type == TLS_CIPHER_CHACHA20_POLY1305) {
memcpy(iv + iv_offset, tls_ctx->rx.iv,
prot->iv_size + prot->salt_size);
} else {
err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE, err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
iv + iv_offset + prot->salt_size, iv + iv_offset + prot->salt_size,
prot->iv_size); prot->iv_size);
...@@ -1493,13 +1494,8 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, ...@@ -1493,13 +1494,8 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
kfree(mem); kfree(mem);
return err; return err;
} }
if (prot->version == TLS_1_3_VERSION ||
prot->cipher_type == TLS_CIPHER_CHACHA20_POLY1305)
memcpy(iv + iv_offset, tls_ctx->rx.iv,
prot->iv_size + prot->salt_size);
else
memcpy(iv + iv_offset, tls_ctx->rx.iv, prot->salt_size); memcpy(iv + iv_offset, tls_ctx->rx.iv, prot->salt_size);
}
xor_iv_with_seq(prot, iv + iv_offset, tls_ctx->rx.rec_seq); xor_iv_with_seq(prot, iv + iv_offset, tls_ctx->rx.rec_seq);
/* Prepare AAD */ /* Prepare AAD */
...@@ -1542,9 +1538,9 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, ...@@ -1542,9 +1538,9 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
/* Prepare and submit AEAD request */ /* Prepare and submit AEAD request */
err = tls_do_decryption(sk, skb, sgin, sgout, iv, err = tls_do_decryption(sk, skb, sgin, sgout, iv,
data_len, aead_req, darg->async); data_len, aead_req, darg);
if (err == -EINPROGRESS) if (darg->async)
return err; return 0;
/* Release the pages in case iov was mapped to pages */ /* Release the pages in case iov was mapped to pages */
for (; pages > 0; pages--) for (; pages > 0; pages--)
...@@ -1581,13 +1577,10 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, ...@@ -1581,13 +1577,10 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
} }
err = decrypt_internal(sk, skb, dest, NULL, darg); err = decrypt_internal(sk, skb, dest, NULL, darg);
if (err < 0) { if (err < 0)
if (err == -EINPROGRESS)
tls_advance_record_sn(sk, prot, &tls_ctx->rx);
else if (err == -EBADMSG)
TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR);
return err; return err;
} if (darg->async)
goto decrypt_next;
decrypt_done: decrypt_done:
pad = padding_length(prot, skb); pad = padding_length(prot, skb);
...@@ -1597,8 +1590,9 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, ...@@ -1597,8 +1590,9 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
rxm->full_len -= pad; rxm->full_len -= pad;
rxm->offset += prot->prepend_size; rxm->offset += prot->prepend_size;
rxm->full_len -= prot->overhead_size; rxm->full_len -= prot->overhead_size;
tls_advance_record_sn(sk, prot, &tls_ctx->rx);
tlm->decrypted = 1; tlm->decrypted = 1;
decrypt_next:
tls_advance_record_sn(sk, prot, &tls_ctx->rx);
return 0; return 0;
} }
...@@ -1658,7 +1652,7 @@ static int process_rx_list(struct tls_sw_context_rx *ctx, ...@@ -1658,7 +1652,7 @@ static int process_rx_list(struct tls_sw_context_rx *ctx,
err = tls_record_content_type(msg, tlm, control); err = tls_record_content_type(msg, tlm, control);
if (err <= 0) if (err <= 0)
return err; goto out;
if (skip < rxm->full_len) if (skip < rxm->full_len)
break; break;
...@@ -1676,13 +1670,13 @@ static int process_rx_list(struct tls_sw_context_rx *ctx, ...@@ -1676,13 +1670,13 @@ static int process_rx_list(struct tls_sw_context_rx *ctx,
err = tls_record_content_type(msg, tlm, control); err = tls_record_content_type(msg, tlm, control);
if (err <= 0) if (err <= 0)
return err; goto out;
if (!zc || (rxm->full_len - skip) > len) { if (!zc || (rxm->full_len - skip) > len) {
err = skb_copy_datagram_msg(skb, rxm->offset + skip, err = skb_copy_datagram_msg(skb, rxm->offset + skip,
msg, chunk); msg, chunk);
if (err < 0) if (err < 0)
return err; goto out;
} }
len = len - chunk; len = len - chunk;
...@@ -1709,14 +1703,16 @@ static int process_rx_list(struct tls_sw_context_rx *ctx, ...@@ -1709,14 +1703,16 @@ static int process_rx_list(struct tls_sw_context_rx *ctx,
next_skb = skb_peek_next(skb, &ctx->rx_list); next_skb = skb_peek_next(skb, &ctx->rx_list);
if (!is_peek) { if (!is_peek) {
skb_unlink(skb, &ctx->rx_list); __skb_unlink(skb, &ctx->rx_list);
consume_skb(skb); consume_skb(skb);
} }
skb = next_skb; skb = next_skb;
} }
err = 0;
return copied; out:
return copied ? : err;
} }
int tls_sw_recvmsg(struct sock *sk, int tls_sw_recvmsg(struct sock *sk,
...@@ -1750,12 +1746,15 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1750,12 +1746,15 @@ int tls_sw_recvmsg(struct sock *sk,
lock_sock(sk); lock_sock(sk);
bpf_strp_enabled = sk_psock_strp_enabled(psock); bpf_strp_enabled = sk_psock_strp_enabled(psock);
/* If crypto failed the connection is broken */
err = ctx->async_wait.err;
if (err)
goto end;
/* Process pending decrypted records. It must be non-zero-copy */ /* Process pending decrypted records. It must be non-zero-copy */
err = process_rx_list(ctx, msg, &control, 0, len, false, is_peek); err = process_rx_list(ctx, msg, &control, 0, len, false, is_peek);
if (err < 0) { if (err < 0)
tls_err_abort(sk, err);
goto end; goto end;
}
copied = err; copied = err;
if (len <= copied) if (len <= copied)
...@@ -1775,14 +1774,10 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1775,14 +1774,10 @@ int tls_sw_recvmsg(struct sock *sk,
skb = tls_wait_data(sk, psock, flags & MSG_DONTWAIT, timeo, &err); skb = tls_wait_data(sk, psock, flags & MSG_DONTWAIT, timeo, &err);
if (!skb) { if (!skb) {
if (psock) { if (psock) {
int ret = sk_msg_recvmsg(sk, psock, msg, len, chunk = sk_msg_recvmsg(sk, psock, msg, len,
flags); flags);
if (chunk > 0)
if (ret > 0) { goto leave_on_list;
decrypted += ret;
len -= ret;
continue;
}
} }
goto recv_end; goto recv_end;
} }
...@@ -1803,13 +1798,12 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1803,13 +1798,12 @@ int tls_sw_recvmsg(struct sock *sk,
darg.async = false; darg.async = false;
err = decrypt_skb_update(sk, skb, &msg->msg_iter, &darg); err = decrypt_skb_update(sk, skb, &msg->msg_iter, &darg);
if (err < 0 && err != -EINPROGRESS) { if (err < 0) {
tls_err_abort(sk, -EBADMSG); tls_err_abort(sk, -EBADMSG);
goto recv_end; goto recv_end;
} }
if (err == -EINPROGRESS) async |= darg.async;
async = true;
/* If the type of records being processed is not known yet, /* If the type of records being processed is not known yet,
* set it to record type just dequeued. If it is already known, * set it to record type just dequeued. If it is already known,
...@@ -1824,7 +1818,7 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1824,7 +1818,7 @@ int tls_sw_recvmsg(struct sock *sk,
ctx->recv_pkt = NULL; ctx->recv_pkt = NULL;
__strp_unpause(&ctx->strp); __strp_unpause(&ctx->strp);
skb_queue_tail(&ctx->rx_list, skb); __skb_queue_tail(&ctx->rx_list, skb);
if (async) { if (async) {
/* TLS 1.2-only, to_decrypt must be text length */ /* TLS 1.2-only, to_decrypt must be text length */
...@@ -1845,7 +1839,7 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1845,7 +1839,7 @@ int tls_sw_recvmsg(struct sock *sk,
if (err != __SK_PASS) { if (err != __SK_PASS) {
rxm->offset = rxm->offset + rxm->full_len; rxm->offset = rxm->offset + rxm->full_len;
rxm->full_len = 0; rxm->full_len = 0;
skb_unlink(skb, &ctx->rx_list); __skb_unlink(skb, &ctx->rx_list);
if (err == __SK_DROP) if (err == __SK_DROP)
consume_skb(skb); consume_skb(skb);
continue; continue;
...@@ -1873,7 +1867,7 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1873,7 +1867,7 @@ int tls_sw_recvmsg(struct sock *sk,
decrypted += chunk; decrypted += chunk;
len -= chunk; len -= chunk;
skb_unlink(skb, &ctx->rx_list); __skb_unlink(skb, &ctx->rx_list);
consume_skb(skb); consume_skb(skb);
/* Return full control message to userspace before trying /* Return full control message to userspace before trying
...@@ -1886,7 +1880,7 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1886,7 +1880,7 @@ int tls_sw_recvmsg(struct sock *sk,
recv_end: recv_end:
if (async) { if (async) {
int pending; int ret, pending;
/* Wait for all previously submitted records to be decrypted */ /* Wait for all previously submitted records to be decrypted */
spin_lock_bh(&ctx->decrypt_compl_lock); spin_lock_bh(&ctx->decrypt_compl_lock);
...@@ -1894,11 +1888,10 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1894,11 +1888,10 @@ int tls_sw_recvmsg(struct sock *sk,
pending = atomic_read(&ctx->decrypt_pending); pending = atomic_read(&ctx->decrypt_pending);
spin_unlock_bh(&ctx->decrypt_compl_lock); spin_unlock_bh(&ctx->decrypt_compl_lock);
if (pending) { if (pending) {
err = crypto_wait_req(-EINPROGRESS, &ctx->async_wait); ret = crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
if (err) { if (ret) {
/* one of async decrypt failed */ if (err >= 0 || err == -EINPROGRESS)
tls_err_abort(sk, err); err = ret;
copied = 0;
decrypted = 0; decrypted = 0;
goto end; goto end;
} }
...@@ -1911,11 +1904,7 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1911,11 +1904,7 @@ int tls_sw_recvmsg(struct sock *sk,
else else
err = process_rx_list(ctx, msg, &control, 0, err = process_rx_list(ctx, msg, &control, 0,
decrypted, true, is_peek); decrypted, true, is_peek);
if (err < 0) { decrypted = max(err, 0);
tls_err_abort(sk, err);
copied = 0;
goto end;
}
} }
copied += decrypted; copied += decrypted;
...@@ -2173,7 +2162,7 @@ void tls_sw_release_resources_rx(struct sock *sk) ...@@ -2173,7 +2162,7 @@ void tls_sw_release_resources_rx(struct sock *sk)
if (ctx->aead_recv) { if (ctx->aead_recv) {
kfree_skb(ctx->recv_pkt); kfree_skb(ctx->recv_pkt);
ctx->recv_pkt = NULL; ctx->recv_pkt = NULL;
skb_queue_purge(&ctx->rx_list); __skb_queue_purge(&ctx->rx_list);
crypto_free_aead(ctx->aead_recv); crypto_free_aead(ctx->aead_recv);
strp_stop(&ctx->strp); strp_stop(&ctx->strp);
/* If tls_sw_strparser_arm() was not called (cleanup paths) /* If tls_sw_strparser_arm() was not called (cleanup paths)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment