Commit 7b219da4 authored by Lorenz Bauer's avatar Lorenz Bauer Committed by Alexei Starovoitov

net: sk_msg: Simplify sk_psock initialization

Initializing psock->sk_proto and other saved callbacks is only
done in sk_psock_update_proto, after sk_psock_init has returned.
The logic for this is difficult to follow, and needlessly complex.

Instead, initialize psock->sk_proto whenever we allocate a new
psock. Additionally, assert the following invariants:

* The SK has no ULP: ULP does it's own finagling of sk->sk_prot
* sk_user_data is unused: we need it to store sk_psock

Protect our access to sk_user_data with sk_callback_lock, which
is what other users like reuseport arrays, etc. do.

The result is that an sk_psock is always fully initialized, and
that psock->sk_proto is always the "original" struct proto.
The latter allows us to use psock->sk_proto when initializing
IPv6 TCP / UDP callbacks for sockmap.
Signed-off-by: default avatarLorenz Bauer <lmb@cloudflare.com>
Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
Acked-by: default avatarJohn Fastabend <john.fastabend@gmail.com>
Link: https://lore.kernel.org/bpf/20200821102948.21918-2-lmb@cloudflare.com
parent dca5612f
...@@ -340,23 +340,6 @@ static inline void sk_psock_update_proto(struct sock *sk, ...@@ -340,23 +340,6 @@ static inline void sk_psock_update_proto(struct sock *sk,
struct sk_psock *psock, struct sk_psock *psock,
struct proto *ops) struct proto *ops)
{ {
/* Initialize saved callbacks and original proto only once, since this
* function may be called multiple times for a psock, e.g. when
* psock->progs.msg_parser is updated.
*
* Since we've not installed the new proto, psock is not yet in use and
* we can initialize it without synchronization.
*/
if (!psock->sk_proto) {
struct proto *orig = READ_ONCE(sk->sk_prot);
psock->saved_unhash = orig->unhash;
psock->saved_close = orig->close;
psock->saved_write_space = sk->sk_write_space;
psock->sk_proto = orig;
}
/* Pairs with lockless read in sk_clone_lock() */ /* Pairs with lockless read in sk_clone_lock() */
WRITE_ONCE(sk->sk_prot, ops); WRITE_ONCE(sk->sk_prot, ops);
} }
......
...@@ -494,14 +494,34 @@ static void sk_psock_backlog(struct work_struct *work) ...@@ -494,14 +494,34 @@ static void sk_psock_backlog(struct work_struct *work)
struct sk_psock *sk_psock_init(struct sock *sk, int node) struct sk_psock *sk_psock_init(struct sock *sk, int node)
{ {
struct sk_psock *psock = kzalloc_node(sizeof(*psock), struct sk_psock *psock;
GFP_ATOMIC | __GFP_NOWARN, struct proto *prot;
node);
if (!psock) write_lock_bh(&sk->sk_callback_lock);
return NULL;
if (inet_csk_has_ulp(sk)) {
psock = ERR_PTR(-EINVAL);
goto out;
}
if (sk->sk_user_data) {
psock = ERR_PTR(-EBUSY);
goto out;
}
psock = kzalloc_node(sizeof(*psock), GFP_ATOMIC | __GFP_NOWARN, node);
if (!psock) {
psock = ERR_PTR(-ENOMEM);
goto out;
}
prot = READ_ONCE(sk->sk_prot);
psock->sk = sk; psock->sk = sk;
psock->eval = __SK_NONE; psock->eval = __SK_NONE;
psock->sk_proto = prot;
psock->saved_unhash = prot->unhash;
psock->saved_close = prot->close;
psock->saved_write_space = sk->sk_write_space;
INIT_LIST_HEAD(&psock->link); INIT_LIST_HEAD(&psock->link);
spin_lock_init(&psock->link_lock); spin_lock_init(&psock->link_lock);
...@@ -516,6 +536,8 @@ struct sk_psock *sk_psock_init(struct sock *sk, int node) ...@@ -516,6 +536,8 @@ struct sk_psock *sk_psock_init(struct sock *sk, int node)
rcu_assign_sk_user_data_nocopy(sk, psock); rcu_assign_sk_user_data_nocopy(sk, psock);
sock_hold(sk); sock_hold(sk);
out:
write_unlock_bh(&sk->sk_callback_lock);
return psock; return psock;
} }
EXPORT_SYMBOL_GPL(sk_psock_init); EXPORT_SYMBOL_GPL(sk_psock_init);
......
...@@ -184,8 +184,6 @@ static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock) ...@@ -184,8 +184,6 @@ static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock)
{ {
struct proto *prot; struct proto *prot;
sock_owned_by_me(sk);
switch (sk->sk_type) { switch (sk->sk_type) {
case SOCK_STREAM: case SOCK_STREAM:
prot = tcp_bpf_get_proto(sk, psock); prot = tcp_bpf_get_proto(sk, psock);
...@@ -272,8 +270,8 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs, ...@@ -272,8 +270,8 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
} }
} else { } else {
psock = sk_psock_init(sk, map->numa_node); psock = sk_psock_init(sk, map->numa_node);
if (!psock) { if (IS_ERR(psock)) {
ret = -ENOMEM; ret = PTR_ERR(psock);
goto out_progs; goto out_progs;
} }
} }
...@@ -322,8 +320,8 @@ static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk) ...@@ -322,8 +320,8 @@ static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk)
if (!psock) { if (!psock) {
psock = sk_psock_init(sk, map->numa_node); psock = sk_psock_init(sk, map->numa_node);
if (!psock) if (IS_ERR(psock))
return -ENOMEM; return PTR_ERR(psock);
} }
ret = sock_map_init_proto(sk, psock); ret = sock_map_init_proto(sk, psock);
...@@ -478,8 +476,6 @@ static int sock_map_update_common(struct bpf_map *map, u32 idx, ...@@ -478,8 +476,6 @@ static int sock_map_update_common(struct bpf_map *map, u32 idx,
return -EINVAL; return -EINVAL;
if (unlikely(idx >= map->max_entries)) if (unlikely(idx >= map->max_entries))
return -E2BIG; return -E2BIG;
if (inet_csk_has_ulp(sk))
return -EINVAL;
link = sk_psock_init_link(); link = sk_psock_init_link();
if (!link) if (!link)
...@@ -855,8 +851,6 @@ static int sock_hash_update_common(struct bpf_map *map, void *key, ...@@ -855,8 +851,6 @@ static int sock_hash_update_common(struct bpf_map *map, void *key,
WARN_ON_ONCE(!rcu_read_lock_held()); WARN_ON_ONCE(!rcu_read_lock_held());
if (unlikely(flags > BPF_EXIST)) if (unlikely(flags > BPF_EXIST))
return -EINVAL; return -EINVAL;
if (inet_csk_has_ulp(sk))
return -EINVAL;
link = sk_psock_init_link(); link = sk_psock_init_link();
if (!link) if (!link)
......
...@@ -567,10 +567,9 @@ static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS], ...@@ -567,10 +567,9 @@ static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS],
prot[TCP_BPF_TX].sendpage = tcp_bpf_sendpage; prot[TCP_BPF_TX].sendpage = tcp_bpf_sendpage;
} }
static void tcp_bpf_check_v6_needs_rebuild(struct sock *sk, struct proto *ops) static void tcp_bpf_check_v6_needs_rebuild(struct proto *ops)
{ {
if (sk->sk_family == AF_INET6 && if (unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) {
unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) {
spin_lock_bh(&tcpv6_prot_lock); spin_lock_bh(&tcpv6_prot_lock);
if (likely(ops != tcpv6_prot_saved)) { if (likely(ops != tcpv6_prot_saved)) {
tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops); tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops);
...@@ -603,13 +602,11 @@ struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock) ...@@ -603,13 +602,11 @@ struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE; int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE;
if (!psock->sk_proto) { if (sk->sk_family == AF_INET6) {
struct proto *ops = READ_ONCE(sk->sk_prot); if (tcp_bpf_assert_proto_ops(psock->sk_proto))
if (tcp_bpf_assert_proto_ops(ops))
return ERR_PTR(-EINVAL); return ERR_PTR(-EINVAL);
tcp_bpf_check_v6_needs_rebuild(sk, ops); tcp_bpf_check_v6_needs_rebuild(psock->sk_proto);
} }
return &tcp_bpf_prots[family][config]; return &tcp_bpf_prots[family][config];
......
...@@ -22,10 +22,9 @@ static void udp_bpf_rebuild_protos(struct proto *prot, const struct proto *base) ...@@ -22,10 +22,9 @@ static void udp_bpf_rebuild_protos(struct proto *prot, const struct proto *base)
prot->close = sock_map_close; prot->close = sock_map_close;
} }
static void udp_bpf_check_v6_needs_rebuild(struct sock *sk, struct proto *ops) static void udp_bpf_check_v6_needs_rebuild(struct proto *ops)
{ {
if (sk->sk_family == AF_INET6 && if (unlikely(ops != smp_load_acquire(&udpv6_prot_saved))) {
unlikely(ops != smp_load_acquire(&udpv6_prot_saved))) {
spin_lock_bh(&udpv6_prot_lock); spin_lock_bh(&udpv6_prot_lock);
if (likely(ops != udpv6_prot_saved)) { if (likely(ops != udpv6_prot_saved)) {
udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV6], ops); udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV6], ops);
...@@ -46,8 +45,8 @@ struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock) ...@@ -46,8 +45,8 @@ struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
{ {
int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6; int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6;
if (!psock->sk_proto) if (sk->sk_family == AF_INET6)
udp_bpf_check_v6_needs_rebuild(sk, READ_ONCE(sk->sk_prot)); udp_bpf_check_v6_needs_rebuild(psock->sk_proto);
return &udp_bpf_prots[family]; return &udp_bpf_prots[family];
} }
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment