Commit dafb67f3 authored by Boris Pismenny's avatar Boris Pismenny Committed by David S. Miller

tls: Split decrypt_skb to two functions

Previously, decrypt_skb also updated the TLS context.
Now, decrypt_skb only decrypts the payload using the current context,
while decrypt_skb_update also updates the state.

Later, in the tls_device Rx flow, we will use decrypt_skb directly.
Signed-off-by: default avatarBoris Pismenny <borisp@mellanox.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent d80a1b9d
...@@ -390,6 +390,8 @@ int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg, ...@@ -390,6 +390,8 @@ int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg,
unsigned char *record_type); unsigned char *record_type);
void tls_register_device(struct tls_device *device); void tls_register_device(struct tls_device *device);
void tls_unregister_device(struct tls_device *device); void tls_unregister_device(struct tls_device *device);
int decrypt_skb(struct sock *sk, struct sk_buff *skb,
struct scatterlist *sgout);
struct sk_buff *tls_validate_xmit_skb(struct sock *sk, struct sk_buff *tls_validate_xmit_skb(struct sock *sk,
struct net_device *dev, struct net_device *dev,
......
...@@ -53,7 +53,6 @@ static int tls_do_decryption(struct sock *sk, ...@@ -53,7 +53,6 @@ static int tls_do_decryption(struct sock *sk,
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
struct strp_msg *rxm = strp_msg(skb);
struct aead_request *aead_req; struct aead_request *aead_req;
int ret; int ret;
...@@ -71,18 +70,6 @@ static int tls_do_decryption(struct sock *sk, ...@@ -71,18 +70,6 @@ static int tls_do_decryption(struct sock *sk,
ret = crypto_wait_req(crypto_aead_decrypt(aead_req), &ctx->async_wait); ret = crypto_wait_req(crypto_aead_decrypt(aead_req), &ctx->async_wait);
if (ret < 0)
goto out;
rxm->offset += tls_ctx->rx.prepend_size;
rxm->full_len -= tls_ctx->rx.overhead_size;
tls_advance_record_sn(sk, &tls_ctx->rx);
ctx->decrypted = true;
ctx->saved_data_ready(sk);
out:
aead_request_free(aead_req); aead_request_free(aead_req);
return ret; return ret;
} }
...@@ -666,7 +653,28 @@ static struct sk_buff *tls_wait_data(struct sock *sk, int flags, ...@@ -666,7 +653,28 @@ static struct sk_buff *tls_wait_data(struct sock *sk, int flags,
return skb; return skb;
} }
static int decrypt_skb(struct sock *sk, struct sk_buff *skb, static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
struct scatterlist *sgout)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
struct strp_msg *rxm = strp_msg(skb);
int err = 0;
err = decrypt_skb(sk, skb, sgout);
if (err < 0)
return err;
rxm->offset += tls_ctx->rx.prepend_size;
rxm->full_len -= tls_ctx->rx.overhead_size;
tls_advance_record_sn(sk, &tls_ctx->rx);
ctx->decrypted = true;
ctx->saved_data_ready(sk);
return err;
}
int decrypt_skb(struct sock *sk, struct sk_buff *skb,
struct scatterlist *sgout) struct scatterlist *sgout)
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
...@@ -812,7 +820,7 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -812,7 +820,7 @@ int tls_sw_recvmsg(struct sock *sk,
if (err < 0) if (err < 0)
goto fallback_to_reg_recv; goto fallback_to_reg_recv;
err = decrypt_skb(sk, skb, sgin); err = decrypt_skb_update(sk, skb, sgin);
for (; pages > 0; pages--) for (; pages > 0; pages--)
put_page(sg_page(&sgin[pages])); put_page(sg_page(&sgin[pages]));
if (err < 0) { if (err < 0) {
...@@ -821,7 +829,7 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -821,7 +829,7 @@ int tls_sw_recvmsg(struct sock *sk,
} }
} else { } else {
fallback_to_reg_recv: fallback_to_reg_recv:
err = decrypt_skb(sk, skb, NULL); err = decrypt_skb_update(sk, skb, NULL);
if (err < 0) { if (err < 0) {
tls_err_abort(sk, EBADMSG); tls_err_abort(sk, EBADMSG);
goto recv_end; goto recv_end;
...@@ -892,7 +900,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, ...@@ -892,7 +900,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
} }
if (!ctx->decrypted) { if (!ctx->decrypted) {
err = decrypt_skb(sk, skb, NULL); err = decrypt_skb_update(sk, skb, NULL);
if (err < 0) { if (err < 0) {
tls_err_abort(sk, EBADMSG); tls_err_abort(sk, EBADMSG);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment