Commit f747632b authored by Lorenz Bauer's avatar Lorenz Bauer Committed by Daniel Borkmann

bpf: sockmap: Move generic sockmap hooks from BPF TCP

The init, close and unhash handlers from TCP sockmap are generic,
and can be reused by UDP sockmap. Move the helpers into the sockmap code
base and expose them. This requires tcp_bpf_get_proto and tcp_bpf_clone to
be conditional on BPF_STREAM_PARSER.

The moved functions are unmodified, except that sk_psock_unlink is
renamed to sock_map_unlink to better match its behaviour.
Signed-off-by: default avatarLorenz Bauer <lmb@cloudflare.com>
Signed-off-by: default avatarDaniel Borkmann <daniel@iogearbox.net>
Reviewed-by: default avatarJakub Sitnicki <jakub@cloudflare.com>
Acked-by: default avatarJohn Fastabend <john.fastabend@gmail.com>
Link: https://lore.kernel.org/bpf/20200309111243.6982-6-lmb@cloudflare.com
parent 5da00404
......@@ -1419,6 +1419,8 @@ static inline void bpf_map_offload_map_free(struct bpf_map *map)
#if defined(CONFIG_BPF_STREAM_PARSER)
int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog, u32 which);
int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog);
void sock_map_unhash(struct sock *sk);
void sock_map_close(struct sock *sk, long timeout);
#else
static inline int sock_map_prog_update(struct bpf_map *map,
struct bpf_prog *prog, u32 which)
......@@ -1431,7 +1433,7 @@ static inline int sock_map_get_from_fd(const union bpf_attr *attr,
{
return -EINVAL;
}
#endif
#endif /* CONFIG_BPF_STREAM_PARSER */
#if defined(CONFIG_INET) && defined(CONFIG_BPF_SYSCALL)
void bpf_sk_reuseport_detach(struct sock *sk);
......
......@@ -323,14 +323,6 @@ static inline void sk_psock_free_link(struct sk_psock_link *link)
}
struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock);
#if defined(CONFIG_BPF_STREAM_PARSER)
void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link);
#else
static inline void sk_psock_unlink(struct sock *sk,
struct sk_psock_link *link)
{
}
#endif
void __sk_psock_purge_ingress_msg(struct sk_psock *psock);
......@@ -399,26 +391,6 @@ static inline bool sk_psock_test_state(const struct sk_psock *psock,
return test_bit(bit, &psock->state);
}
static inline struct sk_psock *sk_psock_get_checked(struct sock *sk)
{
struct sk_psock *psock;
rcu_read_lock();
psock = sk_psock(sk);
if (psock) {
if (sk->sk_prot->recvmsg != tcp_bpf_recvmsg) {
psock = ERR_PTR(-EBUSY);
goto out;
}
if (!refcount_inc_not_zero(&psock->refcnt))
psock = ERR_PTR(-EBUSY);
}
out:
rcu_read_unlock();
return psock;
}
static inline struct sk_psock *sk_psock_get(struct sock *sk)
{
struct sk_psock *psock;
......
......@@ -2195,19 +2195,22 @@ void tcp_update_ulp(struct sock *sk, struct proto *p,
struct sk_msg;
struct sk_psock;
#ifdef CONFIG_BPF_STREAM_PARSER
struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock);
void tcp_bpf_clone(const struct sock *sk, struct sock *newsk);
#else
static inline void tcp_bpf_clone(const struct sock *sk, struct sock *newsk)
{
}
#endif /* CONFIG_BPF_STREAM_PARSER */
#ifdef CONFIG_NET_SOCK_MSG
int tcp_bpf_init(struct sock *sk);
int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, u32 bytes,
int flags);
int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
int nonblock, int flags, int *addr_len);
int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
struct msghdr *msg, int len, int flags);
void tcp_bpf_clone(const struct sock *sk, struct sock *newsk);
#else
static inline void tcp_bpf_clone(const struct sock *sk, struct sock *newsk)
{
}
#endif /* CONFIG_NET_SOCK_MSG */
/* Call BPF_SOCK_OPS program that returns an int. If the return value
......
......@@ -141,6 +141,51 @@ static void sock_map_unref(struct sock *sk, void *link_raw)
}
}
static int sock_map_init_proto(struct sock *sk)
{
struct sk_psock *psock;
struct proto *prot;
sock_owned_by_me(sk);
rcu_read_lock();
psock = sk_psock(sk);
if (unlikely(!psock)) {
rcu_read_unlock();
return -EINVAL;
}
prot = tcp_bpf_get_proto(sk, psock);
if (IS_ERR(prot)) {
rcu_read_unlock();
return PTR_ERR(prot);
}
sk_psock_update_proto(sk, psock, prot);
rcu_read_unlock();
return 0;
}
static struct sk_psock *sock_map_psock_get_checked(struct sock *sk)
{
struct sk_psock *psock;
rcu_read_lock();
psock = sk_psock(sk);
if (psock) {
if (sk->sk_prot->recvmsg != tcp_bpf_recvmsg) {
psock = ERR_PTR(-EBUSY);
goto out;
}
if (!refcount_inc_not_zero(&psock->refcnt))
psock = ERR_PTR(-EBUSY);
}
out:
rcu_read_unlock();
return psock;
}
static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
struct sock *sk)
{
......@@ -172,7 +217,7 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
}
}
psock = sk_psock_get_checked(sk);
psock = sock_map_psock_get_checked(sk);
if (IS_ERR(psock)) {
ret = PTR_ERR(psock);
goto out_progs;
......@@ -196,7 +241,7 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
if (msg_parser)
psock_set_prog(&psock->progs.msg_parser, msg_parser);
ret = tcp_bpf_init(sk);
ret = sock_map_init_proto(sk);
if (ret < 0)
goto out_drop;
......@@ -231,7 +276,7 @@ static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk)
struct sk_psock *psock;
int ret;
psock = sk_psock_get_checked(sk);
psock = sock_map_psock_get_checked(sk);
if (IS_ERR(psock))
return PTR_ERR(psock);
......@@ -241,7 +286,7 @@ static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk)
return -ENOMEM;
}
ret = tcp_bpf_init(sk);
ret = sock_map_init_proto(sk);
if (ret < 0)
sk_psock_put(sk, psock);
return ret;
......@@ -1120,7 +1165,7 @@ int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog,
return 0;
}
void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link)
static void sock_map_unlink(struct sock *sk, struct sk_psock_link *link)
{
switch (link->map->map_type) {
case BPF_MAP_TYPE_SOCKMAP:
......@@ -1133,3 +1178,54 @@ void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link)
break;
}
}
static void sock_map_remove_links(struct sock *sk, struct sk_psock *psock)
{
struct sk_psock_link *link;
while ((link = sk_psock_link_pop(psock))) {
sock_map_unlink(sk, link);
sk_psock_free_link(link);
}
}
void sock_map_unhash(struct sock *sk)
{
void (*saved_unhash)(struct sock *sk);
struct sk_psock *psock;
rcu_read_lock();
psock = sk_psock(sk);
if (unlikely(!psock)) {
rcu_read_unlock();
if (sk->sk_prot->unhash)
sk->sk_prot->unhash(sk);
return;
}
saved_unhash = psock->saved_unhash;
sock_map_remove_links(sk, psock);
rcu_read_unlock();
saved_unhash(sk);
}
void sock_map_close(struct sock *sk, long timeout)
{
void (*saved_close)(struct sock *sk, long timeout);
struct sk_psock *psock;
lock_sock(sk);
rcu_read_lock();
psock = sk_psock(sk);
if (unlikely(!psock)) {
rcu_read_unlock();
release_sock(sk);
return sk->sk_prot->close(sk, timeout);
}
saved_close = psock->saved_close;
sock_map_remove_links(sk, psock);
rcu_read_unlock();
release_sock(sk);
saved_close(sk, timeout);
}
......@@ -528,57 +528,7 @@ static int tcp_bpf_sendpage(struct sock *sk, struct page *page, int offset,
return copied ? copied : err;
}
static void tcp_bpf_remove(struct sock *sk, struct sk_psock *psock)
{
struct sk_psock_link *link;
while ((link = sk_psock_link_pop(psock))) {
sk_psock_unlink(sk, link);
sk_psock_free_link(link);
}
}
static void tcp_bpf_unhash(struct sock *sk)
{
void (*saved_unhash)(struct sock *sk);
struct sk_psock *psock;
rcu_read_lock();
psock = sk_psock(sk);
if (unlikely(!psock)) {
rcu_read_unlock();
if (sk->sk_prot->unhash)
sk->sk_prot->unhash(sk);
return;
}
saved_unhash = psock->saved_unhash;
tcp_bpf_remove(sk, psock);
rcu_read_unlock();
saved_unhash(sk);
}
static void tcp_bpf_close(struct sock *sk, long timeout)
{
void (*saved_close)(struct sock *sk, long timeout);
struct sk_psock *psock;
lock_sock(sk);
rcu_read_lock();
psock = sk_psock(sk);
if (unlikely(!psock)) {
rcu_read_unlock();
release_sock(sk);
return sk->sk_prot->close(sk, timeout);
}
saved_close = psock->saved_close;
tcp_bpf_remove(sk, psock);
rcu_read_unlock();
release_sock(sk);
saved_close(sk, timeout);
}
#ifdef CONFIG_BPF_STREAM_PARSER
enum {
TCP_BPF_IPV4,
TCP_BPF_IPV6,
......@@ -599,8 +549,8 @@ static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS],
struct proto *base)
{
prot[TCP_BPF_BASE] = *base;
prot[TCP_BPF_BASE].unhash = tcp_bpf_unhash;
prot[TCP_BPF_BASE].close = tcp_bpf_close;
prot[TCP_BPF_BASE].unhash = sock_map_unhash;
prot[TCP_BPF_BASE].close = sock_map_close;
prot[TCP_BPF_BASE].recvmsg = tcp_bpf_recvmsg;
prot[TCP_BPF_BASE].stream_memory_read = tcp_bpf_stream_read;
......@@ -640,7 +590,7 @@ static int tcp_bpf_assert_proto_ops(struct proto *ops)
ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP;
}
static struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
{
int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE;
......@@ -657,31 +607,6 @@ static struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
return &tcp_bpf_prots[family][config];
}
int tcp_bpf_init(struct sock *sk)
{
struct sk_psock *psock;
struct proto *prot;
sock_owned_by_me(sk);
rcu_read_lock();
psock = sk_psock(sk);
if (unlikely(!psock)) {
rcu_read_unlock();
return -EINVAL;
}
prot = tcp_bpf_get_proto(sk, psock);
if (IS_ERR(prot)) {
rcu_read_unlock();
return PTR_ERR(prot);
}
sk_psock_update_proto(sk, psock, prot);
rcu_read_unlock();
return 0;
}
/* If a child got cloned from a listening socket that had tcp_bpf
* protocol callbacks installed, we need to restore the callbacks to
* the default ones because the child does not inherit the psock state
......@@ -695,3 +620,4 @@ void tcp_bpf_clone(const struct sock *sk, struct sock *newsk)
if (prot == &tcp_bpf_prots[family][TCP_BPF_BASE])
newsk->sk_prot = sk->sk_prot_creator;
}
#endif /* CONFIG_BPF_STREAM_PARSER */
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment