Commit d19da360 authored by Lorenz Bauer's avatar Lorenz Bauer Committed by Daniel Borkmann

bpf: tcp: Move assertions into tcp_bpf_get_proto

We need to ensure that sk->sk_prot uses certain callbacks, so that
code that directly calls e.g. tcp_sendmsg in certain corner cases
works. To avoid spurious asserts, we must to do this only if
sk_psock_update_proto has not yet been called. The same invariants
apply for tcp_bpf_check_v6_needs_rebuild, so move the call as well.

Doing so allows us to merge tcp_bpf_init and tcp_bpf_reinit.
Signed-off-by: default avatarLorenz Bauer <lmb@cloudflare.com>
Signed-off-by: default avatarDaniel Borkmann <daniel@iogearbox.net>
Reviewed-by: default avatarJakub Sitnicki <jakub@cloudflare.com>
Acked-by: default avatarJohn Fastabend <john.fastabend@gmail.com>
Link: https://lore.kernel.org/bpf/20200309111243.6982-4-lmb@cloudflare.com
parent 1a2e2013
...@@ -2196,7 +2196,6 @@ struct sk_msg; ...@@ -2196,7 +2196,6 @@ struct sk_msg;
struct sk_psock; struct sk_psock;
int tcp_bpf_init(struct sock *sk); int tcp_bpf_init(struct sock *sk);
void tcp_bpf_reinit(struct sock *sk);
int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, u32 bytes, int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, u32 bytes,
int flags); int flags);
int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
......
...@@ -145,8 +145,8 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs, ...@@ -145,8 +145,8 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
struct sock *sk) struct sock *sk)
{ {
struct bpf_prog *msg_parser, *skb_parser, *skb_verdict; struct bpf_prog *msg_parser, *skb_parser, *skb_verdict;
bool skb_progs, sk_psock_is_new = false;
struct sk_psock *psock; struct sk_psock *psock;
bool skb_progs;
int ret; int ret;
skb_verdict = READ_ONCE(progs->skb_verdict); skb_verdict = READ_ONCE(progs->skb_verdict);
...@@ -191,18 +191,14 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs, ...@@ -191,18 +191,14 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
ret = -ENOMEM; ret = -ENOMEM;
goto out_progs; goto out_progs;
} }
sk_psock_is_new = true;
} }
if (msg_parser) if (msg_parser)
psock_set_prog(&psock->progs.msg_parser, msg_parser); psock_set_prog(&psock->progs.msg_parser, msg_parser);
if (sk_psock_is_new) {
ret = tcp_bpf_init(sk); ret = tcp_bpf_init(sk);
if (ret < 0) if (ret < 0)
goto out_drop; goto out_drop;
} else {
tcp_bpf_reinit(sk);
}
write_lock_bh(&sk->sk_callback_lock); write_lock_bh(&sk->sk_callback_lock);
if (skb_progs && !psock->parser.enabled) { if (skb_progs && !psock->parser.enabled) {
...@@ -239,15 +235,12 @@ static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk) ...@@ -239,15 +235,12 @@ static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk)
if (IS_ERR(psock)) if (IS_ERR(psock))
return PTR_ERR(psock); return PTR_ERR(psock);
if (psock) { if (!psock) {
tcp_bpf_reinit(sk); psock = sk_psock_init(sk, map->numa_node);
return 0; if (!psock)
return -ENOMEM;
} }
psock = sk_psock_init(sk, map->numa_node);
if (!psock)
return -ENOMEM;
ret = tcp_bpf_init(sk); ret = tcp_bpf_init(sk);
if (ret < 0) if (ret < 0)
sk_psock_put(sk, psock); sk_psock_put(sk, psock);
......
...@@ -629,14 +629,6 @@ static int __init tcp_bpf_v4_build_proto(void) ...@@ -629,14 +629,6 @@ static int __init tcp_bpf_v4_build_proto(void)
} }
core_initcall(tcp_bpf_v4_build_proto); core_initcall(tcp_bpf_v4_build_proto);
static void tcp_bpf_update_sk_prot(struct sock *sk, struct sk_psock *psock)
{
int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE;
sk_psock_update_proto(sk, psock, &tcp_bpf_prots[family][config]);
}
static int tcp_bpf_assert_proto_ops(struct proto *ops) static int tcp_bpf_assert_proto_ops(struct proto *ops)
{ {
/* In order to avoid retpoline, we make assumptions when we call /* In order to avoid retpoline, we make assumptions when we call
...@@ -648,34 +640,44 @@ static int tcp_bpf_assert_proto_ops(struct proto *ops) ...@@ -648,34 +640,44 @@ static int tcp_bpf_assert_proto_ops(struct proto *ops)
ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP; ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP;
} }
void tcp_bpf_reinit(struct sock *sk) static struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
{ {
struct sk_psock *psock; int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE;
sock_owned_by_me(sk); if (!psock->sk_proto) {
struct proto *ops = READ_ONCE(sk->sk_prot);
rcu_read_lock(); if (tcp_bpf_assert_proto_ops(ops))
psock = sk_psock(sk); return ERR_PTR(-EINVAL);
tcp_bpf_update_sk_prot(sk, psock);
rcu_read_unlock(); tcp_bpf_check_v6_needs_rebuild(sk, ops);
}
return &tcp_bpf_prots[family][config];
} }
int tcp_bpf_init(struct sock *sk) int tcp_bpf_init(struct sock *sk)
{ {
struct proto *ops = READ_ONCE(sk->sk_prot);
struct sk_psock *psock; struct sk_psock *psock;
struct proto *prot;
sock_owned_by_me(sk); sock_owned_by_me(sk);
rcu_read_lock(); rcu_read_lock();
psock = sk_psock(sk); psock = sk_psock(sk);
if (unlikely(!psock || psock->sk_proto || if (unlikely(!psock)) {
tcp_bpf_assert_proto_ops(ops))) {
rcu_read_unlock(); rcu_read_unlock();
return -EINVAL; return -EINVAL;
} }
tcp_bpf_check_v6_needs_rebuild(sk, ops);
tcp_bpf_update_sk_prot(sk, psock); prot = tcp_bpf_get_proto(sk, psock);
if (IS_ERR(prot)) {
rcu_read_unlock();
return PTR_ERR(prot);
}
sk_psock_update_proto(sk, psock, prot);
rcu_read_unlock(); rcu_read_unlock();
return 0; return 0;
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment