Commit d1f66ac6 authored by David S. Miller's avatar David S. Miller

Merge branch 'tls-rx-refactor-part-1'

Jakub Kicinski says:

====================
tls: rx: random refactoring part 1

TLS Rx refactoring. Part 1 of 3. A couple of features to follow.
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents dc2e0617 71471ca3
...@@ -70,6 +70,10 @@ struct sk_skb_cb { ...@@ -70,6 +70,10 @@ struct sk_skb_cb {
* when dst_reg == src_reg. * when dst_reg == src_reg.
*/ */
u64 temp_reg; u64 temp_reg;
struct tls_msg {
u8 control;
u8 decrypted;
} tls;
}; };
static inline struct strp_msg *strp_msg(struct sk_buff *skb) static inline struct strp_msg *strp_msg(struct sk_buff *skb)
......
...@@ -64,6 +64,7 @@ ...@@ -64,6 +64,7 @@
#define TLS_AAD_SPACE_SIZE 13 #define TLS_AAD_SPACE_SIZE 13
#define MAX_IV_SIZE 16 #define MAX_IV_SIZE 16
#define TLS_TAG_SIZE 16
#define TLS_MAX_REC_SEQ_SIZE 8 #define TLS_MAX_REC_SEQ_SIZE 8
/* For CCM mode, the full 16-bytes of IV is made of '4' fields of given sizes. /* For CCM mode, the full 16-bytes of IV is made of '4' fields of given sizes.
...@@ -117,11 +118,6 @@ struct tls_rec { ...@@ -117,11 +118,6 @@ struct tls_rec {
u8 aead_req_ctx[]; u8 aead_req_ctx[];
}; };
struct tls_msg {
struct strp_msg rxm;
u8 control;
};
struct tx_work { struct tx_work {
struct delayed_work work; struct delayed_work work;
struct sock *sk; struct sock *sk;
...@@ -152,9 +148,7 @@ struct tls_sw_context_rx { ...@@ -152,9 +148,7 @@ struct tls_sw_context_rx {
void (*saved_data_ready)(struct sock *sk); void (*saved_data_ready)(struct sock *sk);
struct sk_buff *recv_pkt; struct sk_buff *recv_pkt;
u8 control;
u8 async_capable:1; u8 async_capable:1;
u8 decrypted:1;
atomic_t decrypt_pending; atomic_t decrypt_pending;
/* protect crypto_wait with decrypt_pending*/ /* protect crypto_wait with decrypt_pending*/
spinlock_t decrypt_compl_lock; spinlock_t decrypt_compl_lock;
...@@ -411,7 +405,9 @@ void tls_free_partial_record(struct sock *sk, struct tls_context *ctx); ...@@ -411,7 +405,9 @@ void tls_free_partial_record(struct sock *sk, struct tls_context *ctx);
static inline struct tls_msg *tls_msg(struct sk_buff *skb) static inline struct tls_msg *tls_msg(struct sk_buff *skb)
{ {
return (struct tls_msg *)strp_msg(skb); struct sk_skb_cb *scb = (struct sk_skb_cb *)skb->cb;
return &scb->tls;
} }
static inline bool tls_is_partially_sent_record(struct tls_context *ctx) static inline bool tls_is_partially_sent_record(struct tls_context *ctx)
......
...@@ -962,11 +962,9 @@ int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx, ...@@ -962,11 +962,9 @@ int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx,
tls_ctx->rx.rec_seq, rxm->full_len, tls_ctx->rx.rec_seq, rxm->full_len,
is_encrypted, is_decrypted); is_encrypted, is_decrypted);
ctx->sw.decrypted |= is_decrypted;
if (unlikely(test_bit(TLS_RX_DEV_DEGRADED, &tls_ctx->flags))) { if (unlikely(test_bit(TLS_RX_DEV_DEGRADED, &tls_ctx->flags))) {
if (likely(is_encrypted || is_decrypted)) if (likely(is_encrypted || is_decrypted))
return 0; return is_decrypted;
/* After tls_device_down disables the offload, the next SKB will /* After tls_device_down disables the offload, the next SKB will
* likely have initial fragments decrypted, and final ones not * likely have initial fragments decrypted, and final ones not
...@@ -981,7 +979,7 @@ int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx, ...@@ -981,7 +979,7 @@ int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx,
*/ */
if (is_decrypted) { if (is_decrypted) {
ctx->resync_nh_reset = 1; ctx->resync_nh_reset = 1;
return 0; return is_decrypted;
} }
if (is_encrypted) { if (is_encrypted) {
tls_device_core_ctrl_rx_resync(tls_ctx, ctx, sk, skb); tls_device_core_ctrl_rx_resync(tls_ctx, ctx, sk, skb);
......
...@@ -128,32 +128,31 @@ static int skb_nsg(struct sk_buff *skb, int offset, int len) ...@@ -128,32 +128,31 @@ static int skb_nsg(struct sk_buff *skb, int offset, int len)
return __skb_nsg(skb, offset, len, 0); return __skb_nsg(skb, offset, len, 0);
} }
static int padding_length(struct tls_sw_context_rx *ctx, static int padding_length(struct tls_prot_info *prot, struct sk_buff *skb)
struct tls_prot_info *prot, struct sk_buff *skb)
{ {
struct strp_msg *rxm = strp_msg(skb); struct strp_msg *rxm = strp_msg(skb);
struct tls_msg *tlm = tls_msg(skb);
int sub = 0; int sub = 0;
/* Determine zero-padding length */ /* Determine zero-padding length */
if (prot->version == TLS_1_3_VERSION) { if (prot->version == TLS_1_3_VERSION) {
int offset = rxm->full_len - TLS_TAG_SIZE - 1;
char content_type = 0; char content_type = 0;
int err; int err;
int back = 17;
while (content_type == 0) { while (content_type == 0) {
if (back > rxm->full_len - prot->prepend_size) if (offset < prot->prepend_size)
return -EBADMSG; return -EBADMSG;
err = skb_copy_bits(skb, err = skb_copy_bits(skb, rxm->offset + offset,
rxm->offset + rxm->full_len - back,
&content_type, 1); &content_type, 1);
if (err) if (err)
return err; return err;
if (content_type) if (content_type)
break; break;
sub++; sub++;
back++; offset--;
} }
ctx->control = content_type; tlm->control = content_type;
} }
return sub; return sub;
} }
...@@ -187,7 +186,7 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err) ...@@ -187,7 +186,7 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
struct strp_msg *rxm = strp_msg(skb); struct strp_msg *rxm = strp_msg(skb);
int pad; int pad;
pad = padding_length(ctx, prot, skb); pad = padding_length(prot, skb);
if (pad < 0) { if (pad < 0) {
ctx->async_wait.err = pad; ctx->async_wait.err = pad;
tls_err_abort(skb->sk, pad); tls_err_abort(skb->sk, pad);
...@@ -1421,6 +1420,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, ...@@ -1421,6 +1420,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_prot_info *prot = &tls_ctx->prot_info;
struct strp_msg *rxm = strp_msg(skb); struct strp_msg *rxm = strp_msg(skb);
struct tls_msg *tlm = tls_msg(skb);
int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0; int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0;
struct aead_request *aead_req; struct aead_request *aead_req;
struct sk_buff *unused; struct sk_buff *unused;
...@@ -1505,7 +1505,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, ...@@ -1505,7 +1505,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
/* Prepare AAD */ /* Prepare AAD */
tls_make_aad(aad, rxm->full_len - prot->overhead_size + tls_make_aad(aad, rxm->full_len - prot->overhead_size +
prot->tail_size, prot->tail_size,
tls_ctx->rx.rec_seq, ctx->control, prot); tls_ctx->rx.rec_seq, tlm->control, prot);
/* Prepare sgin */ /* Prepare sgin */
sg_init_table(sgin, n_sgin); sg_init_table(sgin, n_sgin);
...@@ -1561,36 +1561,38 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, ...@@ -1561,36 +1561,38 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
bool async) bool async)
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_prot_info *prot = &tls_ctx->prot_info;
struct strp_msg *rxm = strp_msg(skb); struct strp_msg *rxm = strp_msg(skb);
int pad, err = 0; struct tls_msg *tlm = tls_msg(skb);
int pad, err;
if (tlm->decrypted) {
*zc = false;
return 0;
}
if (!ctx->decrypted) {
if (tls_ctx->rx_conf == TLS_HW) { if (tls_ctx->rx_conf == TLS_HW) {
err = tls_device_decrypted(sk, tls_ctx, skb, rxm); err = tls_device_decrypted(sk, tls_ctx, skb, rxm);
if (err < 0) if (err < 0)
return err; return err;
if (err > 0) {
tlm->decrypted = 1;
*zc = false;
goto decrypt_done;
}
} }
/* Still not decrypted after tls_device */ err = decrypt_internal(sk, skb, dest, NULL, chunk, zc, async);
if (!ctx->decrypted) {
err = decrypt_internal(sk, skb, dest, NULL, chunk, zc,
async);
if (err < 0) { if (err < 0) {
if (err == -EINPROGRESS) if (err == -EINPROGRESS)
tls_advance_record_sn(sk, prot, tls_advance_record_sn(sk, prot, &tls_ctx->rx);
&tls_ctx->rx);
else if (err == -EBADMSG) else if (err == -EBADMSG)
TLS_INC_STATS(sock_net(sk), TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR);
LINUX_MIB_TLSDECRYPTERROR);
return err; return err;
} }
} else {
*zc = false;
}
pad = padding_length(ctx, prot, skb); decrypt_done:
pad = padding_length(prot, skb);
if (pad < 0) if (pad < 0)
return pad; return pad;
...@@ -1598,13 +1600,9 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, ...@@ -1598,13 +1600,9 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
rxm->offset += prot->prepend_size; rxm->offset += prot->prepend_size;
rxm->full_len -= prot->overhead_size; rxm->full_len -= prot->overhead_size;
tls_advance_record_sn(sk, prot, &tls_ctx->rx); tls_advance_record_sn(sk, prot, &tls_ctx->rx);
ctx->decrypted = 1; tlm->decrypted = 1;
ctx->saved_data_ready(sk);
} else {
*zc = false;
}
return err; return 0;
} }
int decrypt_skb(struct sock *sk, struct sk_buff *skb, int decrypt_skb(struct sock *sk, struct sk_buff *skb,
...@@ -1760,6 +1758,7 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1760,6 +1758,7 @@ int tls_sw_recvmsg(struct sock *sk,
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_prot_info *prot = &tls_ctx->prot_info;
struct sk_psock *psock; struct sk_psock *psock;
int num_async, pending;
unsigned char control = 0; unsigned char control = 0;
ssize_t decrypted = 0; ssize_t decrypted = 0;
struct strp_msg *rxm; struct strp_msg *rxm;
...@@ -1772,8 +1771,6 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1772,8 +1771,6 @@ int tls_sw_recvmsg(struct sock *sk,
bool is_kvec = iov_iter_is_kvec(&msg->msg_iter); bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
bool is_peek = flags & MSG_PEEK; bool is_peek = flags & MSG_PEEK;
bool bpf_strp_enabled; bool bpf_strp_enabled;
int num_async = 0;
int pending;
flags |= nonblock; flags |= nonblock;
...@@ -1790,17 +1787,18 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1790,17 +1787,18 @@ int tls_sw_recvmsg(struct sock *sk,
if (err < 0) { if (err < 0) {
tls_err_abort(sk, err); tls_err_abort(sk, err);
goto end; goto end;
} else {
copied = err;
} }
copied = err;
if (len <= copied) if (len <= copied)
goto recv_end; goto end;
target = sock_rcvlowat(sk, flags & MSG_WAITALL, len); target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
len = len - copied; len = len - copied;
timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
decrypted = 0;
num_async = 0;
while (len && (decrypted + copied < target || ctx->recv_pkt)) { while (len && (decrypted + copied < target || ctx->recv_pkt)) {
bool retain_skb = false; bool retain_skb = false;
bool zc = false; bool zc = false;
...@@ -1822,26 +1820,21 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1822,26 +1820,21 @@ int tls_sw_recvmsg(struct sock *sk,
} }
} }
goto recv_end; goto recv_end;
} else {
tlm = tls_msg(skb);
if (prot->version == TLS_1_3_VERSION)
tlm->control = 0;
else
tlm->control = ctx->control;
} }
rxm = strp_msg(skb); rxm = strp_msg(skb);
tlm = tls_msg(skb);
to_decrypt = rxm->full_len - prot->overhead_size; to_decrypt = rxm->full_len - prot->overhead_size;
if (to_decrypt <= len && !is_kvec && !is_peek && if (to_decrypt <= len && !is_kvec && !is_peek &&
ctx->control == TLS_RECORD_TYPE_DATA && tlm->control == TLS_RECORD_TYPE_DATA &&
prot->version != TLS_1_3_VERSION && prot->version != TLS_1_3_VERSION &&
!bpf_strp_enabled) !bpf_strp_enabled)
zc = true; zc = true;
/* Do not use async mode if record is non-data */ /* Do not use async mode if record is non-data */
if (ctx->control == TLS_RECORD_TYPE_DATA && !bpf_strp_enabled) if (tlm->control == TLS_RECORD_TYPE_DATA && !bpf_strp_enabled)
async_capable = ctx->async_capable; async_capable = ctx->async_capable;
else else
async_capable = false; async_capable = false;
...@@ -1856,8 +1849,6 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1856,8 +1849,6 @@ int tls_sw_recvmsg(struct sock *sk,
if (err == -EINPROGRESS) { if (err == -EINPROGRESS) {
async = true; async = true;
num_async++; num_async++;
} else if (prot->version == TLS_1_3_VERSION) {
tlm->control = ctx->control;
} }
/* If the type of records being processed is not known yet, /* If the type of records being processed is not known yet,
...@@ -2005,6 +1996,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, ...@@ -2005,6 +1996,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
struct strp_msg *rxm = NULL; struct strp_msg *rxm = NULL;
struct sock *sk = sock->sk; struct sock *sk = sock->sk;
struct tls_msg *tlm;
struct sk_buff *skb; struct sk_buff *skb;
ssize_t copied = 0; ssize_t copied = 0;
bool from_queue; bool from_queue;
...@@ -2033,14 +2025,15 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, ...@@ -2033,14 +2025,15 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
} }
} }
rxm = strp_msg(skb);
tlm = tls_msg(skb);
/* splice does not support reading control messages */ /* splice does not support reading control messages */
if (ctx->control != TLS_RECORD_TYPE_DATA) { if (tlm->control != TLS_RECORD_TYPE_DATA) {
err = -EINVAL; err = -EINVAL;
goto splice_read_end; goto splice_read_end;
} }
rxm = strp_msg(skb);
chunk = min_t(unsigned int, rxm->full_len, len); chunk = min_t(unsigned int, rxm->full_len, len);
copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags); copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags);
if (copied < 0) if (copied < 0)
...@@ -2084,10 +2077,10 @@ bool tls_sw_sock_is_readable(struct sock *sk) ...@@ -2084,10 +2077,10 @@ bool tls_sw_sock_is_readable(struct sock *sk)
static int tls_read_size(struct strparser *strp, struct sk_buff *skb) static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
{ {
struct tls_context *tls_ctx = tls_get_ctx(strp->sk); struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_prot_info *prot = &tls_ctx->prot_info;
char header[TLS_HEADER_SIZE + MAX_IV_SIZE]; char header[TLS_HEADER_SIZE + MAX_IV_SIZE];
struct strp_msg *rxm = strp_msg(skb); struct strp_msg *rxm = strp_msg(skb);
struct tls_msg *tlm = tls_msg(skb);
size_t cipher_overhead; size_t cipher_overhead;
size_t data_len = 0; size_t data_len = 0;
int ret; int ret;
...@@ -2104,11 +2097,11 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb) ...@@ -2104,11 +2097,11 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
/* Linearize header to local buffer */ /* Linearize header to local buffer */
ret = skb_copy_bits(skb, rxm->offset, header, prot->prepend_size); ret = skb_copy_bits(skb, rxm->offset, header, prot->prepend_size);
if (ret < 0) if (ret < 0)
goto read_failure; goto read_failure;
ctx->control = header[0]; tlm->decrypted = 0;
tlm->control = header[0];
data_len = ((header[4] & 0xFF) | (header[3] << 8)); data_len = ((header[4] & 0xFF) | (header[3] << 8));
...@@ -2149,8 +2142,6 @@ static void tls_queue(struct strparser *strp, struct sk_buff *skb) ...@@ -2149,8 +2142,6 @@ static void tls_queue(struct strparser *strp, struct sk_buff *skb)
struct tls_context *tls_ctx = tls_get_ctx(strp->sk); struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
ctx->decrypted = 0;
ctx->recv_pkt = skb; ctx->recv_pkt = skb;
strp_pause(strp); strp_pause(strp);
...@@ -2501,7 +2492,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) ...@@ -2501,7 +2492,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
/* Sanity-check the sizes for stack allocations. */ /* Sanity-check the sizes for stack allocations. */
if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE || if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE ||
rec_seq_size > TLS_MAX_REC_SEQ_SIZE) { rec_seq_size > TLS_MAX_REC_SEQ_SIZE || tag_size != TLS_TAG_SIZE) {
rc = -EINVAL; rc = -EINVAL;
goto free_priv; goto free_priv;
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment