Commit 15a7dea7 authored by Jakub Kicinski's avatar Jakub Kicinski Committed by David S. Miller

net/tls: use RCU protection on icsk->icsk_ulp_data

We need to make sure context does not get freed while diag
code is interrogating it. Free struct tls_context with
kfree_rcu().

We add the __rcu annotation directly in icsk, and cast it
away in the datapath accessor. Presumably all ULPs will
do a similar thing.
Signed-off-by: default avatarJakub Kicinski <jakub.kicinski@netronome.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent ed6e8103
...@@ -97,7 +97,7 @@ struct inet_connection_sock { ...@@ -97,7 +97,7 @@ struct inet_connection_sock {
const struct tcp_congestion_ops *icsk_ca_ops; const struct tcp_congestion_ops *icsk_ca_ops;
const struct inet_connection_sock_af_ops *icsk_af_ops; const struct inet_connection_sock_af_ops *icsk_af_ops;
const struct tcp_ulp_ops *icsk_ulp_ops; const struct tcp_ulp_ops *icsk_ulp_ops;
void *icsk_ulp_data; void __rcu *icsk_ulp_data;
void (*icsk_clean_acked)(struct sock *sk, u32 acked_seq); void (*icsk_clean_acked)(struct sock *sk, u32 acked_seq);
struct hlist_node icsk_listen_portaddr_node; struct hlist_node icsk_listen_portaddr_node;
unsigned int (*icsk_sync_mss)(struct sock *sk, u32 pmtu); unsigned int (*icsk_sync_mss)(struct sock *sk, u32 pmtu);
......
...@@ -41,6 +41,7 @@ ...@@ -41,6 +41,7 @@
#include <linux/tcp.h> #include <linux/tcp.h>
#include <linux/skmsg.h> #include <linux/skmsg.h>
#include <linux/netdevice.h> #include <linux/netdevice.h>
#include <linux/rcupdate.h>
#include <net/tcp.h> #include <net/tcp.h>
#include <net/strparser.h> #include <net/strparser.h>
...@@ -290,6 +291,7 @@ struct tls_context { ...@@ -290,6 +291,7 @@ struct tls_context {
struct list_head list; struct list_head list;
refcount_t refcount; refcount_t refcount;
struct rcu_head rcu;
}; };
enum tls_offload_ctx_dir { enum tls_offload_ctx_dir {
...@@ -348,7 +350,7 @@ struct tls_offload_context_rx { ...@@ -348,7 +350,7 @@ struct tls_offload_context_rx {
#define TLS_OFFLOAD_CONTEXT_SIZE_RX \ #define TLS_OFFLOAD_CONTEXT_SIZE_RX \
(sizeof(struct tls_offload_context_rx) + TLS_DRIVER_STATE_SIZE_RX) (sizeof(struct tls_offload_context_rx) + TLS_DRIVER_STATE_SIZE_RX)
void tls_ctx_free(struct tls_context *ctx); void tls_ctx_free(struct sock *sk, struct tls_context *ctx);
int wait_on_pending_writer(struct sock *sk, long *timeo); int wait_on_pending_writer(struct sock *sk, long *timeo);
int tls_sk_query(struct sock *sk, int optname, char __user *optval, int tls_sk_query(struct sock *sk, int optname, char __user *optval,
int __user *optlen); int __user *optlen);
...@@ -467,7 +469,10 @@ static inline struct tls_context *tls_get_ctx(const struct sock *sk) ...@@ -467,7 +469,10 @@ static inline struct tls_context *tls_get_ctx(const struct sock *sk)
{ {
struct inet_connection_sock *icsk = inet_csk(sk); struct inet_connection_sock *icsk = inet_csk(sk);
return icsk->icsk_ulp_data; /* Use RCU on icsk_ulp_data only for sock diag code,
* TLS data path doesn't need rcu_dereference().
*/
return (__force void *)icsk->icsk_ulp_data;
} }
static inline void tls_advance_record_sn(struct sock *sk, static inline void tls_advance_record_sn(struct sock *sk,
......
...@@ -345,7 +345,7 @@ static int sock_map_update_common(struct bpf_map *map, u32 idx, ...@@ -345,7 +345,7 @@ static int sock_map_update_common(struct bpf_map *map, u32 idx,
return -EINVAL; return -EINVAL;
if (unlikely(idx >= map->max_entries)) if (unlikely(idx >= map->max_entries))
return -E2BIG; return -E2BIG;
if (unlikely(icsk->icsk_ulp_data)) if (unlikely(rcu_access_pointer(icsk->icsk_ulp_data)))
return -EINVAL; return -EINVAL;
link = sk_psock_init_link(); link = sk_psock_init_link();
......
...@@ -61,7 +61,7 @@ static void tls_device_free_ctx(struct tls_context *ctx) ...@@ -61,7 +61,7 @@ static void tls_device_free_ctx(struct tls_context *ctx)
if (ctx->rx_conf == TLS_HW) if (ctx->rx_conf == TLS_HW)
kfree(tls_offload_ctx_rx(ctx)); kfree(tls_offload_ctx_rx(ctx));
tls_ctx_free(ctx); tls_ctx_free(NULL, ctx);
} }
static void tls_device_gc_task(struct work_struct *work) static void tls_device_gc_task(struct work_struct *work)
......
...@@ -251,14 +251,26 @@ static void tls_write_space(struct sock *sk) ...@@ -251,14 +251,26 @@ static void tls_write_space(struct sock *sk)
ctx->sk_write_space(sk); ctx->sk_write_space(sk);
} }
void tls_ctx_free(struct tls_context *ctx) /**
* tls_ctx_free() - free TLS ULP context
* @sk: socket to with @ctx is attached
* @ctx: TLS context structure
*
* Free TLS context. If @sk is %NULL caller guarantees that the socket
* to which @ctx was attached has no outstanding references.
*/
void tls_ctx_free(struct sock *sk, struct tls_context *ctx)
{ {
if (!ctx) if (!ctx)
return; return;
memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send)); memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send));
memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv)); memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv));
kfree(ctx);
if (sk)
kfree_rcu(ctx, rcu);
else
kfree(ctx);
} }
static void tls_sk_proto_cleanup(struct sock *sk, static void tls_sk_proto_cleanup(struct sock *sk,
...@@ -306,7 +318,7 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) ...@@ -306,7 +318,7 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
write_lock_bh(&sk->sk_callback_lock); write_lock_bh(&sk->sk_callback_lock);
if (free_ctx) if (free_ctx)
icsk->icsk_ulp_data = NULL; rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
sk->sk_prot = ctx->sk_proto; sk->sk_prot = ctx->sk_proto;
if (sk->sk_write_space == tls_write_space) if (sk->sk_write_space == tls_write_space)
sk->sk_write_space = ctx->sk_write_space; sk->sk_write_space = ctx->sk_write_space;
...@@ -321,7 +333,7 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) ...@@ -321,7 +333,7 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
ctx->sk_proto_close(sk, timeout); ctx->sk_proto_close(sk, timeout);
if (free_ctx) if (free_ctx)
tls_ctx_free(ctx); tls_ctx_free(sk, ctx);
} }
static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval, static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval,
...@@ -610,7 +622,7 @@ static struct tls_context *create_ctx(struct sock *sk) ...@@ -610,7 +622,7 @@ static struct tls_context *create_ctx(struct sock *sk)
if (!ctx) if (!ctx)
return NULL; return NULL;
icsk->icsk_ulp_data = ctx; rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
ctx->setsockopt = sk->sk_prot->setsockopt; ctx->setsockopt = sk->sk_prot->setsockopt;
ctx->getsockopt = sk->sk_prot->getsockopt; ctx->getsockopt = sk->sk_prot->getsockopt;
ctx->sk_proto_close = sk->sk_prot->close; ctx->sk_proto_close = sk->sk_prot->close;
...@@ -651,8 +663,8 @@ static void tls_hw_sk_destruct(struct sock *sk) ...@@ -651,8 +663,8 @@ static void tls_hw_sk_destruct(struct sock *sk)
ctx->sk_destruct(sk); ctx->sk_destruct(sk);
/* Free ctx */ /* Free ctx */
tls_ctx_free(ctx); rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
icsk->icsk_ulp_data = NULL; tls_ctx_free(sk, ctx);
} }
static int tls_hw_prot(struct sock *sk) static int tls_hw_prot(struct sock *sk)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment