Commit f747632b authored by Lorenz Bauer's avatar Lorenz Bauer Committed by Daniel Borkmann

bpf: sockmap: Move generic sockmap hooks from BPF TCP

The init, close and unhash handlers from TCP sockmap are generic,
and can be reused by UDP sockmap. Move the helpers into the sockmap code
base and expose them. This requires tcp_bpf_get_proto and tcp_bpf_clone to
be conditional on BPF_STREAM_PARSER.

The moved functions are unmodified, except that sk_psock_unlink is
renamed to sock_map_unlink to better match its behaviour.
Signed-off-by: default avatarLorenz Bauer <lmb@cloudflare.com>
Signed-off-by: default avatarDaniel Borkmann <daniel@iogearbox.net>
Reviewed-by: default avatarJakub Sitnicki <jakub@cloudflare.com>
Acked-by: default avatarJohn Fastabend <john.fastabend@gmail.com>
Link: https://lore.kernel.org/bpf/20200309111243.6982-6-lmb@cloudflare.com
parent 5da00404
...@@ -1419,6 +1419,8 @@ static inline void bpf_map_offload_map_free(struct bpf_map *map) ...@@ -1419,6 +1419,8 @@ static inline void bpf_map_offload_map_free(struct bpf_map *map)
#if defined(CONFIG_BPF_STREAM_PARSER) #if defined(CONFIG_BPF_STREAM_PARSER)
int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog, u32 which); int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog, u32 which);
int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog); int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog);
void sock_map_unhash(struct sock *sk);
void sock_map_close(struct sock *sk, long timeout);
#else #else
static inline int sock_map_prog_update(struct bpf_map *map, static inline int sock_map_prog_update(struct bpf_map *map,
struct bpf_prog *prog, u32 which) struct bpf_prog *prog, u32 which)
...@@ -1431,7 +1433,7 @@ static inline int sock_map_get_from_fd(const union bpf_attr *attr, ...@@ -1431,7 +1433,7 @@ static inline int sock_map_get_from_fd(const union bpf_attr *attr,
{ {
return -EINVAL; return -EINVAL;
} }
#endif #endif /* CONFIG_BPF_STREAM_PARSER */
#if defined(CONFIG_INET) && defined(CONFIG_BPF_SYSCALL) #if defined(CONFIG_INET) && defined(CONFIG_BPF_SYSCALL)
void bpf_sk_reuseport_detach(struct sock *sk); void bpf_sk_reuseport_detach(struct sock *sk);
......
...@@ -323,14 +323,6 @@ static inline void sk_psock_free_link(struct sk_psock_link *link) ...@@ -323,14 +323,6 @@ static inline void sk_psock_free_link(struct sk_psock_link *link)
} }
struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock); struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock);
#if defined(CONFIG_BPF_STREAM_PARSER)
void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link);
#else
static inline void sk_psock_unlink(struct sock *sk,
struct sk_psock_link *link)
{
}
#endif
void __sk_psock_purge_ingress_msg(struct sk_psock *psock); void __sk_psock_purge_ingress_msg(struct sk_psock *psock);
...@@ -399,26 +391,6 @@ static inline bool sk_psock_test_state(const struct sk_psock *psock, ...@@ -399,26 +391,6 @@ static inline bool sk_psock_test_state(const struct sk_psock *psock,
return test_bit(bit, &psock->state); return test_bit(bit, &psock->state);
} }
static inline struct sk_psock *sk_psock_get_checked(struct sock *sk)
{
struct sk_psock *psock;
rcu_read_lock();
psock = sk_psock(sk);
if (psock) {
if (sk->sk_prot->recvmsg != tcp_bpf_recvmsg) {
psock = ERR_PTR(-EBUSY);
goto out;
}
if (!refcount_inc_not_zero(&psock->refcnt))
psock = ERR_PTR(-EBUSY);
}
out:
rcu_read_unlock();
return psock;
}
static inline struct sk_psock *sk_psock_get(struct sock *sk) static inline struct sk_psock *sk_psock_get(struct sock *sk)
{ {
struct sk_psock *psock; struct sk_psock *psock;
......
...@@ -2195,19 +2195,22 @@ void tcp_update_ulp(struct sock *sk, struct proto *p, ...@@ -2195,19 +2195,22 @@ void tcp_update_ulp(struct sock *sk, struct proto *p,
struct sk_msg; struct sk_msg;
struct sk_psock; struct sk_psock;
#ifdef CONFIG_BPF_STREAM_PARSER
struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock);
void tcp_bpf_clone(const struct sock *sk, struct sock *newsk);
#else
static inline void tcp_bpf_clone(const struct sock *sk, struct sock *newsk)
{
}
#endif /* CONFIG_BPF_STREAM_PARSER */
#ifdef CONFIG_NET_SOCK_MSG #ifdef CONFIG_NET_SOCK_MSG
int tcp_bpf_init(struct sock *sk);
int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, u32 bytes, int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, u32 bytes,
int flags); int flags);
int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
int nonblock, int flags, int *addr_len); int nonblock, int flags, int *addr_len);
int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock, int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
struct msghdr *msg, int len, int flags); struct msghdr *msg, int len, int flags);
void tcp_bpf_clone(const struct sock *sk, struct sock *newsk);
#else
static inline void tcp_bpf_clone(const struct sock *sk, struct sock *newsk)
{
}
#endif /* CONFIG_NET_SOCK_MSG */ #endif /* CONFIG_NET_SOCK_MSG */
/* Call BPF_SOCK_OPS program that returns an int. If the return value /* Call BPF_SOCK_OPS program that returns an int. If the return value
......
...@@ -141,6 +141,51 @@ static void sock_map_unref(struct sock *sk, void *link_raw) ...@@ -141,6 +141,51 @@ static void sock_map_unref(struct sock *sk, void *link_raw)
} }
} }
static int sock_map_init_proto(struct sock *sk)
{
struct sk_psock *psock;
struct proto *prot;
sock_owned_by_me(sk);
rcu_read_lock();
psock = sk_psock(sk);
if (unlikely(!psock)) {
rcu_read_unlock();
return -EINVAL;
}
prot = tcp_bpf_get_proto(sk, psock);
if (IS_ERR(prot)) {
rcu_read_unlock();
return PTR_ERR(prot);
}
sk_psock_update_proto(sk, psock, prot);
rcu_read_unlock();
return 0;
}
static struct sk_psock *sock_map_psock_get_checked(struct sock *sk)
{
struct sk_psock *psock;
rcu_read_lock();
psock = sk_psock(sk);
if (psock) {
if (sk->sk_prot->recvmsg != tcp_bpf_recvmsg) {
psock = ERR_PTR(-EBUSY);
goto out;
}
if (!refcount_inc_not_zero(&psock->refcnt))
psock = ERR_PTR(-EBUSY);
}
out:
rcu_read_unlock();
return psock;
}
static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs, static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
struct sock *sk) struct sock *sk)
{ {
...@@ -172,7 +217,7 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs, ...@@ -172,7 +217,7 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
} }
} }
psock = sk_psock_get_checked(sk); psock = sock_map_psock_get_checked(sk);
if (IS_ERR(psock)) { if (IS_ERR(psock)) {
ret = PTR_ERR(psock); ret = PTR_ERR(psock);
goto out_progs; goto out_progs;
...@@ -196,7 +241,7 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs, ...@@ -196,7 +241,7 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
if (msg_parser) if (msg_parser)
psock_set_prog(&psock->progs.msg_parser, msg_parser); psock_set_prog(&psock->progs.msg_parser, msg_parser);
ret = tcp_bpf_init(sk); ret = sock_map_init_proto(sk);
if (ret < 0) if (ret < 0)
goto out_drop; goto out_drop;
...@@ -231,7 +276,7 @@ static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk) ...@@ -231,7 +276,7 @@ static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk)
struct sk_psock *psock; struct sk_psock *psock;
int ret; int ret;
psock = sk_psock_get_checked(sk); psock = sock_map_psock_get_checked(sk);
if (IS_ERR(psock)) if (IS_ERR(psock))
return PTR_ERR(psock); return PTR_ERR(psock);
...@@ -241,7 +286,7 @@ static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk) ...@@ -241,7 +286,7 @@ static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk)
return -ENOMEM; return -ENOMEM;
} }
ret = tcp_bpf_init(sk); ret = sock_map_init_proto(sk);
if (ret < 0) if (ret < 0)
sk_psock_put(sk, psock); sk_psock_put(sk, psock);
return ret; return ret;
...@@ -1120,7 +1165,7 @@ int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog, ...@@ -1120,7 +1165,7 @@ int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog,
return 0; return 0;
} }
void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link) static void sock_map_unlink(struct sock *sk, struct sk_psock_link *link)
{ {
switch (link->map->map_type) { switch (link->map->map_type) {
case BPF_MAP_TYPE_SOCKMAP: case BPF_MAP_TYPE_SOCKMAP:
...@@ -1133,3 +1178,54 @@ void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link) ...@@ -1133,3 +1178,54 @@ void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link)
break; break;
} }
} }
static void sock_map_remove_links(struct sock *sk, struct sk_psock *psock)
{
struct sk_psock_link *link;
while ((link = sk_psock_link_pop(psock))) {
sock_map_unlink(sk, link);
sk_psock_free_link(link);
}
}
void sock_map_unhash(struct sock *sk)
{
void (*saved_unhash)(struct sock *sk);
struct sk_psock *psock;
rcu_read_lock();
psock = sk_psock(sk);
if (unlikely(!psock)) {
rcu_read_unlock();
if (sk->sk_prot->unhash)
sk->sk_prot->unhash(sk);
return;
}
saved_unhash = psock->saved_unhash;
sock_map_remove_links(sk, psock);
rcu_read_unlock();
saved_unhash(sk);
}
void sock_map_close(struct sock *sk, long timeout)
{
void (*saved_close)(struct sock *sk, long timeout);
struct sk_psock *psock;
lock_sock(sk);
rcu_read_lock();
psock = sk_psock(sk);
if (unlikely(!psock)) {
rcu_read_unlock();
release_sock(sk);
return sk->sk_prot->close(sk, timeout);
}
saved_close = psock->saved_close;
sock_map_remove_links(sk, psock);
rcu_read_unlock();
release_sock(sk);
saved_close(sk, timeout);
}
...@@ -528,57 +528,7 @@ static int tcp_bpf_sendpage(struct sock *sk, struct page *page, int offset, ...@@ -528,57 +528,7 @@ static int tcp_bpf_sendpage(struct sock *sk, struct page *page, int offset,
return copied ? copied : err; return copied ? copied : err;
} }
static void tcp_bpf_remove(struct sock *sk, struct sk_psock *psock) #ifdef CONFIG_BPF_STREAM_PARSER
{
struct sk_psock_link *link;
while ((link = sk_psock_link_pop(psock))) {
sk_psock_unlink(sk, link);
sk_psock_free_link(link);
}
}
static void tcp_bpf_unhash(struct sock *sk)
{
void (*saved_unhash)(struct sock *sk);
struct sk_psock *psock;
rcu_read_lock();
psock = sk_psock(sk);
if (unlikely(!psock)) {
rcu_read_unlock();
if (sk->sk_prot->unhash)
sk->sk_prot->unhash(sk);
return;
}
saved_unhash = psock->saved_unhash;
tcp_bpf_remove(sk, psock);
rcu_read_unlock();
saved_unhash(sk);
}
static void tcp_bpf_close(struct sock *sk, long timeout)
{
void (*saved_close)(struct sock *sk, long timeout);
struct sk_psock *psock;
lock_sock(sk);
rcu_read_lock();
psock = sk_psock(sk);
if (unlikely(!psock)) {
rcu_read_unlock();
release_sock(sk);
return sk->sk_prot->close(sk, timeout);
}
saved_close = psock->saved_close;
tcp_bpf_remove(sk, psock);
rcu_read_unlock();
release_sock(sk);
saved_close(sk, timeout);
}
enum { enum {
TCP_BPF_IPV4, TCP_BPF_IPV4,
TCP_BPF_IPV6, TCP_BPF_IPV6,
...@@ -599,8 +549,8 @@ static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS], ...@@ -599,8 +549,8 @@ static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS],
struct proto *base) struct proto *base)
{ {
prot[TCP_BPF_BASE] = *base; prot[TCP_BPF_BASE] = *base;
prot[TCP_BPF_BASE].unhash = tcp_bpf_unhash; prot[TCP_BPF_BASE].unhash = sock_map_unhash;
prot[TCP_BPF_BASE].close = tcp_bpf_close; prot[TCP_BPF_BASE].close = sock_map_close;
prot[TCP_BPF_BASE].recvmsg = tcp_bpf_recvmsg; prot[TCP_BPF_BASE].recvmsg = tcp_bpf_recvmsg;
prot[TCP_BPF_BASE].stream_memory_read = tcp_bpf_stream_read; prot[TCP_BPF_BASE].stream_memory_read = tcp_bpf_stream_read;
...@@ -640,7 +590,7 @@ static int tcp_bpf_assert_proto_ops(struct proto *ops) ...@@ -640,7 +590,7 @@ static int tcp_bpf_assert_proto_ops(struct proto *ops)
ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP; ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP;
} }
static struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock) struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
{ {
int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE; int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE;
...@@ -657,31 +607,6 @@ static struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock) ...@@ -657,31 +607,6 @@ static struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
return &tcp_bpf_prots[family][config]; return &tcp_bpf_prots[family][config];
} }
int tcp_bpf_init(struct sock *sk)
{
struct sk_psock *psock;
struct proto *prot;
sock_owned_by_me(sk);
rcu_read_lock();
psock = sk_psock(sk);
if (unlikely(!psock)) {
rcu_read_unlock();
return -EINVAL;
}
prot = tcp_bpf_get_proto(sk, psock);
if (IS_ERR(prot)) {
rcu_read_unlock();
return PTR_ERR(prot);
}
sk_psock_update_proto(sk, psock, prot);
rcu_read_unlock();
return 0;
}
/* If a child got cloned from a listening socket that had tcp_bpf /* If a child got cloned from a listening socket that had tcp_bpf
* protocol callbacks installed, we need to restore the callbacks to * protocol callbacks installed, we need to restore the callbacks to
* the default ones because the child does not inherit the psock state * the default ones because the child does not inherit the psock state
...@@ -695,3 +620,4 @@ void tcp_bpf_clone(const struct sock *sk, struct sock *newsk) ...@@ -695,3 +620,4 @@ void tcp_bpf_clone(const struct sock *sk, struct sock *newsk)
if (prot == &tcp_bpf_prots[family][TCP_BPF_BASE]) if (prot == &tcp_bpf_prots[family][TCP_BPF_BASE])
newsk->sk_prot = sk->sk_prot_creator; newsk->sk_prot = sk->sk_prot_creator;
} }
#endif /* CONFIG_BPF_STREAM_PARSER */
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment