Commit f66de3ee authored by Boris Pismenny's avatar Boris Pismenny Committed by David S. Miller

net/tls: Split conf to rx + tx

In TLS inline crypto, we can have one direction in software
and another in hardware. Thus, we split the TLS configuration to separate
structures for receive and transmit.
Signed-off-by: default avatarBoris Pismenny <borisp@mellanox.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 2342a851
...@@ -83,21 +83,10 @@ struct tls_device { ...@@ -83,21 +83,10 @@ struct tls_device {
void (*unhash)(struct tls_device *device, struct sock *sk); void (*unhash)(struct tls_device *device, struct sock *sk);
}; };
struct tls_sw_context { struct tls_sw_context_tx {
struct crypto_aead *aead_send; struct crypto_aead *aead_send;
struct crypto_aead *aead_recv;
struct crypto_wait async_wait; struct crypto_wait async_wait;
/* Receive context */
struct strparser strp;
void (*saved_data_ready)(struct sock *sk);
unsigned int (*sk_poll)(struct file *file, struct socket *sock,
struct poll_table_struct *wait);
struct sk_buff *recv_pkt;
u8 control;
bool decrypted;
/* Sending context */
char aad_space[TLS_AAD_SPACE_SIZE]; char aad_space[TLS_AAD_SPACE_SIZE];
unsigned int sg_plaintext_size; unsigned int sg_plaintext_size;
...@@ -114,6 +103,19 @@ struct tls_sw_context { ...@@ -114,6 +103,19 @@ struct tls_sw_context {
struct scatterlist sg_aead_out[2]; struct scatterlist sg_aead_out[2];
}; };
struct tls_sw_context_rx {
struct crypto_aead *aead_recv;
struct crypto_wait async_wait;
struct strparser strp;
void (*saved_data_ready)(struct sock *sk);
unsigned int (*sk_poll)(struct file *file, struct socket *sock,
struct poll_table_struct *wait);
struct sk_buff *recv_pkt;
u8 control;
bool decrypted;
};
enum { enum {
TLS_PENDING_CLOSED_RECORD TLS_PENDING_CLOSED_RECORD
}; };
...@@ -138,9 +140,15 @@ struct tls_context { ...@@ -138,9 +140,15 @@ struct tls_context {
struct tls12_crypto_info_aes_gcm_128 crypto_recv_aes_gcm_128; struct tls12_crypto_info_aes_gcm_128 crypto_recv_aes_gcm_128;
}; };
void *priv_ctx; struct list_head list;
struct net_device *netdev;
refcount_t refcount;
void *priv_ctx_tx;
void *priv_ctx_rx;
u8 conf:3; u8 tx_conf:3;
u8 rx_conf:3;
struct cipher_context tx; struct cipher_context tx;
struct cipher_context rx; struct cipher_context rx;
...@@ -177,7 +185,8 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size); ...@@ -177,7 +185,8 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
int tls_sw_sendpage(struct sock *sk, struct page *page, int tls_sw_sendpage(struct sock *sk, struct page *page,
int offset, size_t size, int flags); int offset, size_t size, int flags);
void tls_sw_close(struct sock *sk, long timeout); void tls_sw_close(struct sock *sk, long timeout);
void tls_sw_free_resources(struct sock *sk); void tls_sw_free_resources_tx(struct sock *sk);
void tls_sw_free_resources_rx(struct sock *sk);
int tls_sw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int tls_sw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
int nonblock, int flags, int *addr_len); int nonblock, int flags, int *addr_len);
unsigned int tls_sw_poll(struct file *file, struct socket *sock, unsigned int tls_sw_poll(struct file *file, struct socket *sock,
...@@ -297,16 +306,22 @@ static inline struct tls_context *tls_get_ctx(const struct sock *sk) ...@@ -297,16 +306,22 @@ static inline struct tls_context *tls_get_ctx(const struct sock *sk)
return icsk->icsk_ulp_data; return icsk->icsk_ulp_data;
} }
static inline struct tls_sw_context *tls_sw_ctx( static inline struct tls_sw_context_rx *tls_sw_ctx_rx(
const struct tls_context *tls_ctx)
{
return (struct tls_sw_context_rx *)tls_ctx->priv_ctx_rx;
}
static inline struct tls_sw_context_tx *tls_sw_ctx_tx(
const struct tls_context *tls_ctx) const struct tls_context *tls_ctx)
{ {
return (struct tls_sw_context *)tls_ctx->priv_ctx; return (struct tls_sw_context_tx *)tls_ctx->priv_ctx_tx;
} }
static inline struct tls_offload_context *tls_offload_ctx( static inline struct tls_offload_context *tls_offload_ctx(
const struct tls_context *tls_ctx) const struct tls_context *tls_ctx)
{ {
return (struct tls_offload_context *)tls_ctx->priv_ctx; return (struct tls_offload_context *)tls_ctx->priv_ctx_tx;
} }
int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg, int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg,
......
...@@ -51,12 +51,9 @@ enum { ...@@ -51,12 +51,9 @@ enum {
TLSV6, TLSV6,
TLS_NUM_PROTS, TLS_NUM_PROTS,
}; };
enum { enum {
TLS_BASE, TLS_BASE,
TLS_SW_TX, TLS_SW,
TLS_SW_RX,
TLS_SW_RXTX,
TLS_HW_RECORD, TLS_HW_RECORD,
TLS_NUM_CONFIG, TLS_NUM_CONFIG,
}; };
...@@ -65,14 +62,14 @@ static struct proto *saved_tcpv6_prot; ...@@ -65,14 +62,14 @@ static struct proto *saved_tcpv6_prot;
static DEFINE_MUTEX(tcpv6_prot_mutex); static DEFINE_MUTEX(tcpv6_prot_mutex);
static LIST_HEAD(device_list); static LIST_HEAD(device_list);
static DEFINE_MUTEX(device_mutex); static DEFINE_MUTEX(device_mutex);
static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG]; static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
static struct proto_ops tls_sw_proto_ops; static struct proto_ops tls_sw_proto_ops;
static inline void update_sk_prot(struct sock *sk, struct tls_context *ctx) static void update_sk_prot(struct sock *sk, struct tls_context *ctx)
{ {
int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
sk->sk_prot = &tls_prots[ip_ver][ctx->conf]; sk->sk_prot = &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf];
} }
int wait_on_pending_writer(struct sock *sk, long *timeo) int wait_on_pending_writer(struct sock *sk, long *timeo)
...@@ -245,10 +242,10 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) ...@@ -245,10 +242,10 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
lock_sock(sk); lock_sock(sk);
sk_proto_close = ctx->sk_proto_close; sk_proto_close = ctx->sk_proto_close;
if (ctx->conf == TLS_HW_RECORD) if (ctx->tx_conf == TLS_HW_RECORD && ctx->rx_conf == TLS_HW_RECORD)
goto skip_tx_cleanup; goto skip_tx_cleanup;
if (ctx->conf == TLS_BASE) { if (ctx->tx_conf == TLS_BASE && ctx->rx_conf == TLS_BASE) {
kfree(ctx); kfree(ctx);
ctx = NULL; ctx = NULL;
goto skip_tx_cleanup; goto skip_tx_cleanup;
...@@ -270,15 +267,17 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) ...@@ -270,15 +267,17 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
} }
} }
kfree(ctx->tx.rec_seq); /* We need these for tls_sw_fallback handling of other packets */
kfree(ctx->tx.iv); if (ctx->tx_conf == TLS_SW) {
kfree(ctx->rx.rec_seq); kfree(ctx->tx.rec_seq);
kfree(ctx->rx.iv); kfree(ctx->tx.iv);
tls_sw_free_resources_tx(sk);
}
if (ctx->conf == TLS_SW_TX || if (ctx->rx_conf == TLS_SW) {
ctx->conf == TLS_SW_RX || kfree(ctx->rx.rec_seq);
ctx->conf == TLS_SW_RXTX) { kfree(ctx->rx.iv);
tls_sw_free_resources(sk); tls_sw_free_resources_rx(sk);
} }
skip_tx_cleanup: skip_tx_cleanup:
...@@ -287,7 +286,8 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) ...@@ -287,7 +286,8 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
/* free ctx for TLS_HW_RECORD, used by tcp_set_state /* free ctx for TLS_HW_RECORD, used by tcp_set_state
* for sk->sk_prot->unhash [tls_hw_unhash] * for sk->sk_prot->unhash [tls_hw_unhash]
*/ */
if (ctx && ctx->conf == TLS_HW_RECORD) if (ctx && ctx->tx_conf == TLS_HW_RECORD &&
ctx->rx_conf == TLS_HW_RECORD)
kfree(ctx); kfree(ctx);
} }
...@@ -441,25 +441,21 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval, ...@@ -441,25 +441,21 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
goto err_crypto_info; goto err_crypto_info;
} }
/* currently SW is default, we will have ethtool in future */
if (tx) { if (tx) {
rc = tls_set_sw_offload(sk, ctx, 1); rc = tls_set_sw_offload(sk, ctx, 1);
if (ctx->conf == TLS_SW_RX) conf = TLS_SW;
conf = TLS_SW_RXTX;
else
conf = TLS_SW_TX;
} else { } else {
rc = tls_set_sw_offload(sk, ctx, 0); rc = tls_set_sw_offload(sk, ctx, 0);
if (ctx->conf == TLS_SW_TX) conf = TLS_SW;
conf = TLS_SW_RXTX;
else
conf = TLS_SW_RX;
} }
if (rc) if (rc)
goto err_crypto_info; goto err_crypto_info;
ctx->conf = conf; if (tx)
ctx->tx_conf = conf;
else
ctx->rx_conf = conf;
update_sk_prot(sk, ctx); update_sk_prot(sk, ctx);
if (tx) { if (tx) {
ctx->sk_write_space = sk->sk_write_space; ctx->sk_write_space = sk->sk_write_space;
...@@ -535,7 +531,8 @@ static int tls_hw_prot(struct sock *sk) ...@@ -535,7 +531,8 @@ static int tls_hw_prot(struct sock *sk)
ctx->hash = sk->sk_prot->hash; ctx->hash = sk->sk_prot->hash;
ctx->unhash = sk->sk_prot->unhash; ctx->unhash = sk->sk_prot->unhash;
ctx->sk_proto_close = sk->sk_prot->close; ctx->sk_proto_close = sk->sk_prot->close;
ctx->conf = TLS_HW_RECORD; ctx->rx_conf = TLS_HW_RECORD;
ctx->tx_conf = TLS_HW_RECORD;
update_sk_prot(sk, ctx); update_sk_prot(sk, ctx);
rc = 1; rc = 1;
break; break;
...@@ -579,29 +576,30 @@ static int tls_hw_hash(struct sock *sk) ...@@ -579,29 +576,30 @@ static int tls_hw_hash(struct sock *sk)
return err; return err;
} }
static void build_protos(struct proto *prot, struct proto *base) static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
struct proto *base)
{ {
prot[TLS_BASE] = *base; prot[TLS_BASE][TLS_BASE] = *base;
prot[TLS_BASE].setsockopt = tls_setsockopt; prot[TLS_BASE][TLS_BASE].setsockopt = tls_setsockopt;
prot[TLS_BASE].getsockopt = tls_getsockopt; prot[TLS_BASE][TLS_BASE].getsockopt = tls_getsockopt;
prot[TLS_BASE].close = tls_sk_proto_close; prot[TLS_BASE][TLS_BASE].close = tls_sk_proto_close;
prot[TLS_SW_TX] = prot[TLS_BASE]; prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
prot[TLS_SW_TX].sendmsg = tls_sw_sendmsg; prot[TLS_SW][TLS_BASE].sendmsg = tls_sw_sendmsg;
prot[TLS_SW_TX].sendpage = tls_sw_sendpage; prot[TLS_SW][TLS_BASE].sendpage = tls_sw_sendpage;
prot[TLS_SW_RX] = prot[TLS_BASE]; prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE];
prot[TLS_SW_RX].recvmsg = tls_sw_recvmsg; prot[TLS_BASE][TLS_SW].recvmsg = tls_sw_recvmsg;
prot[TLS_SW_RX].close = tls_sk_proto_close; prot[TLS_BASE][TLS_SW].close = tls_sk_proto_close;
prot[TLS_SW_RXTX] = prot[TLS_SW_TX]; prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE];
prot[TLS_SW_RXTX].recvmsg = tls_sw_recvmsg; prot[TLS_SW][TLS_SW].recvmsg = tls_sw_recvmsg;
prot[TLS_SW_RXTX].close = tls_sk_proto_close; prot[TLS_SW][TLS_SW].close = tls_sk_proto_close;
prot[TLS_HW_RECORD] = *base; prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
prot[TLS_HW_RECORD].hash = tls_hw_hash; prot[TLS_HW_RECORD][TLS_HW_RECORD].hash = tls_hw_hash;
prot[TLS_HW_RECORD].unhash = tls_hw_unhash; prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash = tls_hw_unhash;
prot[TLS_HW_RECORD].close = tls_sk_proto_close; prot[TLS_HW_RECORD][TLS_HW_RECORD].close = tls_sk_proto_close;
} }
static int tls_init(struct sock *sk) static int tls_init(struct sock *sk)
...@@ -643,7 +641,8 @@ static int tls_init(struct sock *sk) ...@@ -643,7 +641,8 @@ static int tls_init(struct sock *sk)
mutex_unlock(&tcpv6_prot_mutex); mutex_unlock(&tcpv6_prot_mutex);
} }
ctx->conf = TLS_BASE; ctx->tx_conf = TLS_BASE;
ctx->rx_conf = TLS_BASE;
update_sk_prot(sk, ctx); update_sk_prot(sk, ctx);
out: out:
return rc; return rc;
......
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment