Commit c50524ec authored by Alexei Starovoitov's avatar Alexei Starovoitov

Merge branch 'sockmap: add sockmap support for unix datagram socket'

Cong Wang says:

====================

From: Cong Wang <cong.wang@bytedance.com>

This is the last patchset of the original large patchset. In the
previous patchset, a new BPF sockmap program BPF_SK_SKB_VERDICT
was introduced and UDP began to support it too. In this patchset,
we add BPF_SK_SKB_VERDICT support to Unix datagram socket, so that
we can finally splice Unix datagram socket and UDP socket. Please
check each patch description for more details.

To see the big picture, the previous patchsets are available here:
https://git.kernel.org/pub/scm/linux/kernel/git/bpf/bpf-next.git/commit/?id=1e0ab70778bd86a90de438cc5e1535c115a7c396
https://git.kernel.org/pub/scm/linux/kernel/git/bpf/bpf-next.git/commit/?id=89d69c5d0fbcabd8656459bc8b1a476d6f1efee4

and this patchset is available here:
https://github.com/congwang/linux/tree/sockmap3Acked-by: default avatarJohn Fastabend <john.fastabend@gmail.com>
---
v5: lift socket state check for dgram
    remove ->unhash() case
    add retries for EAGAIN in all test cases
    remove an unused parameter of __unix_dgram_recvmsg()
    rebase on the latest bpf-next

v4: fix af_unix disconnect case
    add unix_unhash()
    split out two small patches
    reduce u->iolock critical section
    remove an unused parameter of __unix_dgram_recvmsg()

v3: fix Kconfig dependency
    make unix_read_sock() static
    fix a UAF in unix_release()
    add a missing header unix_bpf.c

v2: separate out from the original large patchset
    rebase to the latest bpf-next
    clean up unix_read_sock()
    export sock_map_close()
    factor out some helpers in selftests for code reuse
====================
Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
parents 1554a080 a2ffda38
...@@ -10277,6 +10277,7 @@ F: net/core/skmsg.c ...@@ -10277,6 +10277,7 @@ F: net/core/skmsg.c
F: net/core/sock_map.c F: net/core/sock_map.c
F: net/ipv4/tcp_bpf.c F: net/ipv4/tcp_bpf.c
F: net/ipv4/udp_bpf.c F: net/ipv4/udp_bpf.c
F: net/unix/unix_bpf.c
LANDLOCK SECURITY MODULE LANDLOCK SECURITY MODULE
M: Mickaël Salaün <mic@digikod.net> M: Mickaël Salaün <mic@digikod.net>
......
...@@ -1887,6 +1887,12 @@ void bpf_map_offload_map_free(struct bpf_map *map); ...@@ -1887,6 +1887,12 @@ void bpf_map_offload_map_free(struct bpf_map *map);
int bpf_prog_test_run_syscall(struct bpf_prog *prog, int bpf_prog_test_run_syscall(struct bpf_prog *prog,
const union bpf_attr *kattr, const union bpf_attr *kattr,
union bpf_attr __user *uattr); union bpf_attr __user *uattr);
int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog);
int sock_map_prog_detach(const union bpf_attr *attr, enum bpf_prog_type ptype);
int sock_map_update_elem_sys(struct bpf_map *map, void *key, void *value, u64 flags);
void sock_map_unhash(struct sock *sk);
void sock_map_close(struct sock *sk, long timeout);
#else #else
static inline int bpf_prog_offload_init(struct bpf_prog *prog, static inline int bpf_prog_offload_init(struct bpf_prog *prog,
union bpf_attr *attr) union bpf_attr *attr)
...@@ -1919,24 +1925,6 @@ static inline int bpf_prog_test_run_syscall(struct bpf_prog *prog, ...@@ -1919,24 +1925,6 @@ static inline int bpf_prog_test_run_syscall(struct bpf_prog *prog,
{ {
return -ENOTSUPP; return -ENOTSUPP;
} }
#endif /* CONFIG_NET && CONFIG_BPF_SYSCALL */
#if defined(CONFIG_INET) && defined(CONFIG_BPF_SYSCALL)
int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog);
int sock_map_prog_detach(const union bpf_attr *attr, enum bpf_prog_type ptype);
int sock_map_update_elem_sys(struct bpf_map *map, void *key, void *value, u64 flags);
void sock_map_unhash(struct sock *sk);
void sock_map_close(struct sock *sk, long timeout);
void bpf_sk_reuseport_detach(struct sock *sk);
int bpf_fd_reuseport_array_lookup_elem(struct bpf_map *map, void *key,
void *value);
int bpf_fd_reuseport_array_update_elem(struct bpf_map *map, void *key,
void *value, u64 map_flags);
#else
static inline void bpf_sk_reuseport_detach(struct sock *sk)
{
}
#ifdef CONFIG_BPF_SYSCALL #ifdef CONFIG_BPF_SYSCALL
static inline int sock_map_get_from_fd(const union bpf_attr *attr, static inline int sock_map_get_from_fd(const union bpf_attr *attr,
...@@ -1956,7 +1944,21 @@ static inline int sock_map_update_elem_sys(struct bpf_map *map, void *key, void ...@@ -1956,7 +1944,21 @@ static inline int sock_map_update_elem_sys(struct bpf_map *map, void *key, void
{ {
return -EOPNOTSUPP; return -EOPNOTSUPP;
} }
#endif /* CONFIG_BPF_SYSCALL */
#endif /* CONFIG_NET && CONFIG_BPF_SYSCALL */
#if defined(CONFIG_INET) && defined(CONFIG_BPF_SYSCALL)
void bpf_sk_reuseport_detach(struct sock *sk);
int bpf_fd_reuseport_array_lookup_elem(struct bpf_map *map, void *key,
void *value);
int bpf_fd_reuseport_array_update_elem(struct bpf_map *map, void *key,
void *value, u64 map_flags);
#else
static inline void bpf_sk_reuseport_detach(struct sock *sk)
{
}
#ifdef CONFIG_BPF_SYSCALL
static inline int bpf_fd_reuseport_array_lookup_elem(struct bpf_map *map, static inline int bpf_fd_reuseport_array_lookup_elem(struct bpf_map *map,
void *key, void *value) void *key, void *value)
{ {
......
...@@ -82,6 +82,8 @@ static inline struct unix_sock *unix_sk(const struct sock *sk) ...@@ -82,6 +82,8 @@ static inline struct unix_sock *unix_sk(const struct sock *sk)
long unix_inq_len(struct sock *sk); long unix_inq_len(struct sock *sk);
long unix_outq_len(struct sock *sk); long unix_outq_len(struct sock *sk);
int __unix_dgram_recvmsg(struct sock *sk, struct msghdr *msg, size_t size,
int flags);
#ifdef CONFIG_SYSCTL #ifdef CONFIG_SYSCTL
int unix_sysctl_register(struct net *net); int unix_sysctl_register(struct net *net);
void unix_sysctl_unregister(struct net *net); void unix_sysctl_unregister(struct net *net);
...@@ -89,4 +91,14 @@ void unix_sysctl_unregister(struct net *net); ...@@ -89,4 +91,14 @@ void unix_sysctl_unregister(struct net *net);
static inline int unix_sysctl_register(struct net *net) { return 0; } static inline int unix_sysctl_register(struct net *net) { return 0; }
static inline void unix_sysctl_unregister(struct net *net) {} static inline void unix_sysctl_unregister(struct net *net) {}
#endif #endif
#ifdef CONFIG_BPF_SYSCALL
extern struct proto unix_proto;
int unix_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore);
void __init unix_bpf_build_proto(void);
#else
static inline void __init unix_bpf_build_proto(void)
{}
#endif
#endif #endif
...@@ -29,7 +29,7 @@ config BPF_SYSCALL ...@@ -29,7 +29,7 @@ config BPF_SYSCALL
select IRQ_WORK select IRQ_WORK
select TASKS_TRACE_RCU select TASKS_TRACE_RCU
select BINARY_PRINTF select BINARY_PRINTF
select NET_SOCK_MSG if INET select NET_SOCK_MSG if NET
default n default n
help help
Enable the bpf() system call that allows to manipulate BPF programs Enable the bpf() system call that allows to manipulate BPF programs
......
...@@ -33,8 +33,6 @@ obj-$(CONFIG_HWBM) += hwbm.o ...@@ -33,8 +33,6 @@ obj-$(CONFIG_HWBM) += hwbm.o
obj-$(CONFIG_NET_DEVLINK) += devlink.o obj-$(CONFIG_NET_DEVLINK) += devlink.o
obj-$(CONFIG_GRO_CELLS) += gro_cells.o obj-$(CONFIG_GRO_CELLS) += gro_cells.o
obj-$(CONFIG_FAILOVER) += failover.o obj-$(CONFIG_FAILOVER) += failover.o
ifeq ($(CONFIG_INET),y)
obj-$(CONFIG_NET_SOCK_MSG) += skmsg.o obj-$(CONFIG_NET_SOCK_MSG) += skmsg.o
obj-$(CONFIG_BPF_SYSCALL) += sock_map.o obj-$(CONFIG_BPF_SYSCALL) += sock_map.o
endif
obj-$(CONFIG_BPF_SYSCALL) += bpf_sk_storage.o obj-$(CONFIG_BPF_SYSCALL) += bpf_sk_storage.o
...@@ -211,8 +211,6 @@ static struct sk_psock *sock_map_psock_get_checked(struct sock *sk) ...@@ -211,8 +211,6 @@ static struct sk_psock *sock_map_psock_get_checked(struct sock *sk)
return psock; return psock;
} }
static bool sock_map_redirect_allowed(const struct sock *sk);
static int sock_map_link(struct bpf_map *map, struct sock *sk) static int sock_map_link(struct bpf_map *map, struct sock *sk)
{ {
struct sk_psock_progs *progs = sock_map_progs(map); struct sk_psock_progs *progs = sock_map_progs(map);
...@@ -223,13 +221,6 @@ static int sock_map_link(struct bpf_map *map, struct sock *sk) ...@@ -223,13 +221,6 @@ static int sock_map_link(struct bpf_map *map, struct sock *sk)
struct sk_psock *psock; struct sk_psock *psock;
int ret; int ret;
/* Only sockets we can redirect into/from in BPF need to hold
* refs to parser/verdict progs and have their sk_data_ready
* and sk_write_space callbacks overridden.
*/
if (!sock_map_redirect_allowed(sk))
goto no_progs;
stream_verdict = READ_ONCE(progs->stream_verdict); stream_verdict = READ_ONCE(progs->stream_verdict);
if (stream_verdict) { if (stream_verdict) {
stream_verdict = bpf_prog_inc_not_zero(stream_verdict); stream_verdict = bpf_prog_inc_not_zero(stream_verdict);
...@@ -264,7 +255,6 @@ static int sock_map_link(struct bpf_map *map, struct sock *sk) ...@@ -264,7 +255,6 @@ static int sock_map_link(struct bpf_map *map, struct sock *sk)
} }
} }
no_progs:
psock = sock_map_psock_get_checked(sk); psock = sock_map_psock_get_checked(sk);
if (IS_ERR(psock)) { if (IS_ERR(psock)) {
ret = PTR_ERR(psock); ret = PTR_ERR(psock);
...@@ -527,12 +517,6 @@ static bool sk_is_tcp(const struct sock *sk) ...@@ -527,12 +517,6 @@ static bool sk_is_tcp(const struct sock *sk)
sk->sk_protocol == IPPROTO_TCP; sk->sk_protocol == IPPROTO_TCP;
} }
static bool sk_is_udp(const struct sock *sk)
{
return sk->sk_type == SOCK_DGRAM &&
sk->sk_protocol == IPPROTO_UDP;
}
static bool sock_map_redirect_allowed(const struct sock *sk) static bool sock_map_redirect_allowed(const struct sock *sk)
{ {
if (sk_is_tcp(sk)) if (sk_is_tcp(sk))
...@@ -550,10 +534,7 @@ static bool sock_map_sk_state_allowed(const struct sock *sk) ...@@ -550,10 +534,7 @@ static bool sock_map_sk_state_allowed(const struct sock *sk)
{ {
if (sk_is_tcp(sk)) if (sk_is_tcp(sk))
return (1 << sk->sk_state) & (TCPF_ESTABLISHED | TCPF_LISTEN); return (1 << sk->sk_state) & (TCPF_ESTABLISHED | TCPF_LISTEN);
else if (sk_is_udp(sk)) return true;
return sk_hashed(sk);
return false;
} }
static int sock_hash_update_common(struct bpf_map *map, void *key, static int sock_hash_update_common(struct bpf_map *map, void *key,
...@@ -1536,6 +1517,7 @@ void sock_map_close(struct sock *sk, long timeout) ...@@ -1536,6 +1517,7 @@ void sock_map_close(struct sock *sk, long timeout)
release_sock(sk); release_sock(sk);
saved_close(sk, timeout); saved_close(sk, timeout);
} }
EXPORT_SYMBOL_GPL(sock_map_close);
static int sock_map_iter_attach_target(struct bpf_prog *prog, static int sock_map_iter_attach_target(struct bpf_prog *prog,
union bpf_iter_link_info *linfo, union bpf_iter_link_info *linfo,
......
...@@ -112,7 +112,6 @@ static struct proto udp_bpf_prots[UDP_BPF_NUM_PROTS]; ...@@ -112,7 +112,6 @@ static struct proto udp_bpf_prots[UDP_BPF_NUM_PROTS];
static void udp_bpf_rebuild_protos(struct proto *prot, const struct proto *base) static void udp_bpf_rebuild_protos(struct proto *prot, const struct proto *base)
{ {
*prot = *base; *prot = *base;
prot->unhash = sock_map_unhash;
prot->close = sock_map_close; prot->close = sock_map_close;
prot->recvmsg = udp_bpf_recvmsg; prot->recvmsg = udp_bpf_recvmsg;
} }
......
...@@ -7,6 +7,7 @@ obj-$(CONFIG_UNIX) += unix.o ...@@ -7,6 +7,7 @@ obj-$(CONFIG_UNIX) += unix.o
unix-y := af_unix.o garbage.o unix-y := af_unix.o garbage.o
unix-$(CONFIG_SYSCTL) += sysctl_net_unix.o unix-$(CONFIG_SYSCTL) += sysctl_net_unix.o
unix-$(CONFIG_BPF_SYSCALL) += unix_bpf.o
obj-$(CONFIG_UNIX_DIAG) += unix_diag.o obj-$(CONFIG_UNIX_DIAG) += unix_diag.o
unix_diag-y := diag.o unix_diag-y := diag.o
......
...@@ -494,6 +494,7 @@ static void unix_dgram_disconnected(struct sock *sk, struct sock *other) ...@@ -494,6 +494,7 @@ static void unix_dgram_disconnected(struct sock *sk, struct sock *other)
sk_error_report(other); sk_error_report(other);
} }
} }
sk->sk_state = other->sk_state = TCP_CLOSE;
} }
static void unix_sock_destructor(struct sock *sk) static void unix_sock_destructor(struct sock *sk)
...@@ -669,6 +670,8 @@ static ssize_t unix_stream_splice_read(struct socket *, loff_t *ppos, ...@@ -669,6 +670,8 @@ static ssize_t unix_stream_splice_read(struct socket *, loff_t *ppos,
unsigned int flags); unsigned int flags);
static int unix_dgram_sendmsg(struct socket *, struct msghdr *, size_t); static int unix_dgram_sendmsg(struct socket *, struct msghdr *, size_t);
static int unix_dgram_recvmsg(struct socket *, struct msghdr *, size_t, int); static int unix_dgram_recvmsg(struct socket *, struct msghdr *, size_t, int);
static int unix_read_sock(struct sock *sk, read_descriptor_t *desc,
sk_read_actor_t recv_actor);
static int unix_dgram_connect(struct socket *, struct sockaddr *, static int unix_dgram_connect(struct socket *, struct sockaddr *,
int, int); int, int);
static int unix_seqpacket_sendmsg(struct socket *, struct msghdr *, size_t); static int unix_seqpacket_sendmsg(struct socket *, struct msghdr *, size_t);
...@@ -746,6 +749,7 @@ static const struct proto_ops unix_dgram_ops = { ...@@ -746,6 +749,7 @@ static const struct proto_ops unix_dgram_ops = {
.listen = sock_no_listen, .listen = sock_no_listen,
.shutdown = unix_shutdown, .shutdown = unix_shutdown,
.sendmsg = unix_dgram_sendmsg, .sendmsg = unix_dgram_sendmsg,
.read_sock = unix_read_sock,
.recvmsg = unix_dgram_recvmsg, .recvmsg = unix_dgram_recvmsg,
.mmap = sock_no_mmap, .mmap = sock_no_mmap,
.sendpage = sock_no_sendpage, .sendpage = sock_no_sendpage,
...@@ -777,10 +781,21 @@ static const struct proto_ops unix_seqpacket_ops = { ...@@ -777,10 +781,21 @@ static const struct proto_ops unix_seqpacket_ops = {
.show_fdinfo = unix_show_fdinfo, .show_fdinfo = unix_show_fdinfo,
}; };
static struct proto unix_proto = { static void unix_close(struct sock *sk, long timeout)
{
/* Nothing to do here, unix socket does not need a ->close().
* This is merely for sockmap.
*/
}
struct proto unix_proto = {
.name = "UNIX", .name = "UNIX",
.owner = THIS_MODULE, .owner = THIS_MODULE,
.obj_size = sizeof(struct unix_sock), .obj_size = sizeof(struct unix_sock),
.close = unix_close,
#ifdef CONFIG_BPF_SYSCALL
.psock_update_sk_prot = unix_bpf_update_proto,
#endif
}; };
static struct sock *unix_create1(struct net *net, struct socket *sock, int kern) static struct sock *unix_create1(struct net *net, struct socket *sock, int kern)
...@@ -864,6 +879,7 @@ static int unix_release(struct socket *sock) ...@@ -864,6 +879,7 @@ static int unix_release(struct socket *sock)
if (!sk) if (!sk)
return 0; return 0;
sk->sk_prot->close(sk, 0);
unix_release_sock(sk, 0); unix_release_sock(sk, 0);
sock->sk = NULL; sock->sk = NULL;
...@@ -1199,6 +1215,9 @@ static int unix_dgram_connect(struct socket *sock, struct sockaddr *addr, ...@@ -1199,6 +1215,9 @@ static int unix_dgram_connect(struct socket *sock, struct sockaddr *addr,
unix_peer(sk) = other; unix_peer(sk) = other;
unix_state_double_unlock(sk, other); unix_state_double_unlock(sk, other);
} }
if (unix_peer(sk))
sk->sk_state = other->sk_state = TCP_ESTABLISHED;
return 0; return 0;
out_unlock: out_unlock:
...@@ -1431,12 +1450,10 @@ static int unix_socketpair(struct socket *socka, struct socket *sockb) ...@@ -1431,12 +1450,10 @@ static int unix_socketpair(struct socket *socka, struct socket *sockb)
init_peercred(ska); init_peercred(ska);
init_peercred(skb); init_peercred(skb);
if (ska->sk_type != SOCK_DGRAM) {
ska->sk_state = TCP_ESTABLISHED; ska->sk_state = TCP_ESTABLISHED;
skb->sk_state = TCP_ESTABLISHED; skb->sk_state = TCP_ESTABLISHED;
socka->state = SS_CONNECTED; socka->state = SS_CONNECTED;
sockb->state = SS_CONNECTED; sockb->state = SS_CONNECTED;
}
return 0; return 0;
} }
...@@ -2081,11 +2098,11 @@ static void unix_copy_addr(struct msghdr *msg, struct sock *sk) ...@@ -2081,11 +2098,11 @@ static void unix_copy_addr(struct msghdr *msg, struct sock *sk)
} }
} }
static int unix_dgram_recvmsg(struct socket *sock, struct msghdr *msg, int __unix_dgram_recvmsg(struct sock *sk, struct msghdr *msg, size_t size,
size_t size, int flags) int flags)
{ {
struct scm_cookie scm; struct scm_cookie scm;
struct sock *sk = sock->sk; struct socket *sock = sk->sk_socket;
struct unix_sock *u = unix_sk(sk); struct unix_sock *u = unix_sk(sk);
struct sk_buff *skb, *last; struct sk_buff *skb, *last;
long timeo; long timeo;
...@@ -2188,6 +2205,53 @@ static int unix_dgram_recvmsg(struct socket *sock, struct msghdr *msg, ...@@ -2188,6 +2205,53 @@ static int unix_dgram_recvmsg(struct socket *sock, struct msghdr *msg,
return err; return err;
} }
static int unix_dgram_recvmsg(struct socket *sock, struct msghdr *msg, size_t size,
int flags)
{
struct sock *sk = sock->sk;
#ifdef CONFIG_BPF_SYSCALL
if (sk->sk_prot != &unix_proto)
return sk->sk_prot->recvmsg(sk, msg, size, flags & MSG_DONTWAIT,
flags & ~MSG_DONTWAIT, NULL);
#endif
return __unix_dgram_recvmsg(sk, msg, size, flags);
}
static int unix_read_sock(struct sock *sk, read_descriptor_t *desc,
sk_read_actor_t recv_actor)
{
int copied = 0;
while (1) {
struct unix_sock *u = unix_sk(sk);
struct sk_buff *skb;
int used, err;
mutex_lock(&u->iolock);
skb = skb_recv_datagram(sk, 0, 1, &err);
mutex_unlock(&u->iolock);
if (!skb)
return err;
used = recv_actor(desc, skb, 0, skb->len);
if (used <= 0) {
if (!copied)
copied = used;
kfree_skb(skb);
break;
} else if (used <= skb->len) {
copied += used;
}
kfree_skb(skb);
if (!desc->count)
break;
}
return copied;
}
/* /*
* Sleep until more data has arrived. But check for races.. * Sleep until more data has arrived. But check for races..
*/ */
...@@ -2925,6 +2989,7 @@ static int __init af_unix_init(void) ...@@ -2925,6 +2989,7 @@ static int __init af_unix_init(void)
sock_register(&unix_family_ops); sock_register(&unix_family_ops);
register_pernet_subsys(&unix_net_ops); register_pernet_subsys(&unix_net_ops);
unix_bpf_build_proto();
out: out:
return rc; return rc;
} }
......
// SPDX-License-Identifier: GPL-2.0
/* Copyright (c) 2021 Cong Wang <cong.wang@bytedance.com> */
#include <linux/skmsg.h>
#include <linux/bpf.h>
#include <net/sock.h>
#include <net/af_unix.h>
#define unix_sk_has_data(__sk, __psock) \
({ !skb_queue_empty(&__sk->sk_receive_queue) || \
!skb_queue_empty(&__psock->ingress_skb) || \
!list_empty(&__psock->ingress_msg); \
})
static int unix_msg_wait_data(struct sock *sk, struct sk_psock *psock,
long timeo)
{
DEFINE_WAIT_FUNC(wait, woken_wake_function);
struct unix_sock *u = unix_sk(sk);
int ret = 0;
if (sk->sk_shutdown & RCV_SHUTDOWN)
return 1;
if (!timeo)
return ret;
add_wait_queue(sk_sleep(sk), &wait);
sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
if (!unix_sk_has_data(sk, psock)) {
mutex_unlock(&u->iolock);
wait_woken(&wait, TASK_INTERRUPTIBLE, timeo);
mutex_lock(&u->iolock);
ret = unix_sk_has_data(sk, psock);
}
sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
remove_wait_queue(sk_sleep(sk), &wait);
return ret;
}
static int unix_dgram_bpf_recvmsg(struct sock *sk, struct msghdr *msg,
size_t len, int nonblock, int flags,
int *addr_len)
{
struct unix_sock *u = unix_sk(sk);
struct sk_psock *psock;
int copied, ret;
psock = sk_psock_get(sk);
if (unlikely(!psock))
return __unix_dgram_recvmsg(sk, msg, len, flags);
mutex_lock(&u->iolock);
if (!skb_queue_empty(&sk->sk_receive_queue) &&
sk_psock_queue_empty(psock)) {
ret = __unix_dgram_recvmsg(sk, msg, len, flags);
goto out;
}
msg_bytes_ready:
copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
if (!copied) {
long timeo;
int data;
timeo = sock_rcvtimeo(sk, nonblock);
data = unix_msg_wait_data(sk, psock, timeo);
if (data) {
if (!sk_psock_queue_empty(psock))
goto msg_bytes_ready;
ret = __unix_dgram_recvmsg(sk, msg, len, flags);
goto out;
}
copied = -EAGAIN;
}
ret = copied;
out:
mutex_unlock(&u->iolock);
sk_psock_put(sk, psock);
return ret;
}
static struct proto *unix_prot_saved __read_mostly;
static DEFINE_SPINLOCK(unix_prot_lock);
static struct proto unix_bpf_prot;
static void unix_bpf_rebuild_protos(struct proto *prot, const struct proto *base)
{
*prot = *base;
prot->close = sock_map_close;
prot->recvmsg = unix_dgram_bpf_recvmsg;
}
static void unix_bpf_check_needs_rebuild(struct proto *ops)
{
if (unlikely(ops != smp_load_acquire(&unix_prot_saved))) {
spin_lock_bh(&unix_prot_lock);
if (likely(ops != unix_prot_saved)) {
unix_bpf_rebuild_protos(&unix_bpf_prot, ops);
smp_store_release(&unix_prot_saved, ops);
}
spin_unlock_bh(&unix_prot_lock);
}
}
int unix_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore)
{
if (restore) {
sk->sk_write_space = psock->saved_write_space;
WRITE_ONCE(sk->sk_prot, psock->sk_proto);
return 0;
}
unix_bpf_check_needs_rebuild(psock->sk_proto);
WRITE_ONCE(sk->sk_prot, &unix_bpf_prot);
return 0;
}
void __init unix_bpf_build_proto(void)
{
unix_bpf_rebuild_protos(&unix_bpf_prot, &unix_proto);
}
...@@ -351,9 +351,11 @@ static void test_insert_opened(int family, int sotype, int mapfd) ...@@ -351,9 +351,11 @@ static void test_insert_opened(int family, int sotype, int mapfd)
errno = 0; errno = 0;
value = s; value = s;
err = bpf_map_update_elem(mapfd, &key, &value, BPF_NOEXIST); err = bpf_map_update_elem(mapfd, &key, &value, BPF_NOEXIST);
if (sotype == SOCK_STREAM) {
if (!err || errno != EOPNOTSUPP) if (!err || errno != EOPNOTSUPP)
FAIL_ERRNO("map_update: expected EOPNOTSUPP"); FAIL_ERRNO("map_update: expected EOPNOTSUPP");
} else if (err)
FAIL_ERRNO("map_update: expected success");
xclose(s); xclose(s);
} }
...@@ -919,6 +921,23 @@ static const char *redir_mode_str(enum redir_mode mode) ...@@ -919,6 +921,23 @@ static const char *redir_mode_str(enum redir_mode mode)
} }
} }
static int add_to_sockmap(int sock_mapfd, int fd1, int fd2)
{
u64 value;
u32 key;
int err;
key = 0;
value = fd1;
err = xbpf_map_update_elem(sock_mapfd, &key, &value, BPF_NOEXIST);
if (err)
return err;
key = 1;
value = fd2;
return xbpf_map_update_elem(sock_mapfd, &key, &value, BPF_NOEXIST);
}
static void redir_to_connected(int family, int sotype, int sock_mapfd, static void redir_to_connected(int family, int sotype, int sock_mapfd,
int verd_mapfd, enum redir_mode mode) int verd_mapfd, enum redir_mode mode)
{ {
...@@ -928,7 +947,6 @@ static void redir_to_connected(int family, int sotype, int sock_mapfd, ...@@ -928,7 +947,6 @@ static void redir_to_connected(int family, int sotype, int sock_mapfd,
unsigned int pass; unsigned int pass;
socklen_t len; socklen_t len;
int err, n; int err, n;
u64 value;
u32 key; u32 key;
char b; char b;
...@@ -965,15 +983,7 @@ static void redir_to_connected(int family, int sotype, int sock_mapfd, ...@@ -965,15 +983,7 @@ static void redir_to_connected(int family, int sotype, int sock_mapfd,
if (p1 < 0) if (p1 < 0)
goto close_cli1; goto close_cli1;
key = 0; err = add_to_sockmap(sock_mapfd, p0, p1);
value = p0;
err = xbpf_map_update_elem(sock_mapfd, &key, &value, BPF_NOEXIST);
if (err)
goto close_peer1;
key = 1;
value = p1;
err = xbpf_map_update_elem(sock_mapfd, &key, &value, BPF_NOEXIST);
if (err) if (err)
goto close_peer1; goto close_peer1;
...@@ -1061,7 +1071,6 @@ static void redir_to_listening(int family, int sotype, int sock_mapfd, ...@@ -1061,7 +1071,6 @@ static void redir_to_listening(int family, int sotype, int sock_mapfd,
int s, c, p, err, n; int s, c, p, err, n;
unsigned int drop; unsigned int drop;
socklen_t len; socklen_t len;
u64 value;
u32 key; u32 key;
zero_verdict_count(verd_mapfd); zero_verdict_count(verd_mapfd);
...@@ -1086,15 +1095,7 @@ static void redir_to_listening(int family, int sotype, int sock_mapfd, ...@@ -1086,15 +1095,7 @@ static void redir_to_listening(int family, int sotype, int sock_mapfd,
if (p < 0) if (p < 0)
goto close_cli; goto close_cli;
key = 0; err = add_to_sockmap(sock_mapfd, s, p);
value = s;
err = xbpf_map_update_elem(sock_mapfd, &key, &value, BPF_NOEXIST);
if (err)
goto close_peer;
key = 1;
value = p;
err = xbpf_map_update_elem(sock_mapfd, &key, &value, BPF_NOEXIST);
if (err) if (err)
goto close_peer; goto close_peer;
...@@ -1346,7 +1347,6 @@ static void test_reuseport_mixed_groups(int family, int sotype, int sock_map, ...@@ -1346,7 +1347,6 @@ static void test_reuseport_mixed_groups(int family, int sotype, int sock_map,
int s1, s2, c, err; int s1, s2, c, err;
unsigned int drop; unsigned int drop;
socklen_t len; socklen_t len;
u64 value;
u32 key; u32 key;
zero_verdict_count(verd_map); zero_verdict_count(verd_map);
...@@ -1360,16 +1360,10 @@ static void test_reuseport_mixed_groups(int family, int sotype, int sock_map, ...@@ -1360,16 +1360,10 @@ static void test_reuseport_mixed_groups(int family, int sotype, int sock_map,
if (s2 < 0) if (s2 < 0)
goto close_srv1; goto close_srv1;
key = 0; err = add_to_sockmap(sock_map, s1, s2);
value = s1;
err = xbpf_map_update_elem(sock_map, &key, &value, BPF_NOEXIST);
if (err) if (err)
goto close_srv2; goto close_srv2;
key = 1;
value = s2;
err = xbpf_map_update_elem(sock_map, &key, &value, BPF_NOEXIST);
/* Connect to s2, reuseport BPF selects s1 via sock_map[0] */ /* Connect to s2, reuseport BPF selects s1 via sock_map[0] */
len = sizeof(addr); len = sizeof(addr);
err = xgetsockname(s2, sockaddr(&addr), &len); err = xgetsockname(s2, sockaddr(&addr), &len);
...@@ -1441,6 +1435,8 @@ static const char *family_str(sa_family_t family) ...@@ -1441,6 +1435,8 @@ static const char *family_str(sa_family_t family)
return "IPv4"; return "IPv4";
case AF_INET6: case AF_INET6:
return "IPv6"; return "IPv6";
case AF_UNIX:
return "Unix";
default: default:
return "unknown"; return "unknown";
} }
...@@ -1563,6 +1559,99 @@ static void test_redir(struct test_sockmap_listen *skel, struct bpf_map *map, ...@@ -1563,6 +1559,99 @@ static void test_redir(struct test_sockmap_listen *skel, struct bpf_map *map,
} }
} }
static void unix_redir_to_connected(int sotype, int sock_mapfd,
int verd_mapfd, enum redir_mode mode)
{
const char *log_prefix = redir_mode_str(mode);
int c0, c1, p0, p1;
unsigned int pass;
int retries = 100;
int err, n;
int sfd[2];
u32 key;
char b;
zero_verdict_count(verd_mapfd);
if (socketpair(AF_UNIX, sotype | SOCK_NONBLOCK, 0, sfd))
return;
c0 = sfd[0], p0 = sfd[1];
if (socketpair(AF_UNIX, sotype | SOCK_NONBLOCK, 0, sfd))
goto close0;
c1 = sfd[0], p1 = sfd[1];
err = add_to_sockmap(sock_mapfd, p0, p1);
if (err)
goto close;
n = write(c1, "a", 1);
if (n < 0)
FAIL_ERRNO("%s: write", log_prefix);
if (n == 0)
FAIL("%s: incomplete write", log_prefix);
if (n < 1)
goto close;
key = SK_PASS;
err = xbpf_map_lookup_elem(verd_mapfd, &key, &pass);
if (err)
goto close;
if (pass != 1)
FAIL("%s: want pass count 1, have %d", log_prefix, pass);
again:
n = read(mode == REDIR_INGRESS ? p0 : c0, &b, 1);
if (n < 0) {
if (errno == EAGAIN && retries--)
goto again;
FAIL_ERRNO("%s: read", log_prefix);
}
if (n == 0)
FAIL("%s: incomplete read", log_prefix);
close:
xclose(c1);
xclose(p1);
close0:
xclose(c0);
xclose(p0);
}
static void unix_skb_redir_to_connected(struct test_sockmap_listen *skel,
struct bpf_map *inner_map, int sotype)
{
int verdict = bpf_program__fd(skel->progs.prog_skb_verdict);
int verdict_map = bpf_map__fd(skel->maps.verdict_map);
int sock_map = bpf_map__fd(inner_map);
int err;
err = xbpf_prog_attach(verdict, sock_map, BPF_SK_SKB_VERDICT, 0);
if (err)
return;
skel->bss->test_ingress = false;
unix_redir_to_connected(sotype, sock_map, verdict_map, REDIR_EGRESS);
skel->bss->test_ingress = true;
unix_redir_to_connected(sotype, sock_map, verdict_map, REDIR_INGRESS);
xbpf_prog_detach2(verdict, sock_map, BPF_SK_SKB_VERDICT);
}
static void test_unix_redir(struct test_sockmap_listen *skel, struct bpf_map *map,
int sotype)
{
const char *family_name, *map_name;
char s[MAX_TEST_NAME];
family_name = family_str(AF_UNIX);
map_name = map_type_str(map);
snprintf(s, sizeof(s), "%s %s %s", map_name, family_name, __func__);
if (!test__start_subtest(s))
return;
unix_skb_redir_to_connected(skel, map, sotype);
}
static void test_reuseport(struct test_sockmap_listen *skel, static void test_reuseport(struct test_sockmap_listen *skel,
struct bpf_map *map, int family, int sotype) struct bpf_map *map, int family, int sotype)
{ {
...@@ -1603,33 +1692,27 @@ static void test_reuseport(struct test_sockmap_listen *skel, ...@@ -1603,33 +1692,27 @@ static void test_reuseport(struct test_sockmap_listen *skel,
} }
} }
static void udp_redir_to_connected(int family, int sotype, int sock_mapfd, static int udp_socketpair(int family, int *s, int *c)
int verd_mapfd, enum redir_mode mode)
{ {
const char *log_prefix = redir_mode_str(mode);
struct sockaddr_storage addr; struct sockaddr_storage addr;
int c0, c1, p0, p1;
unsigned int pass;
int retries = 100;
socklen_t len; socklen_t len;
int err, n; int p0, c0;
u64 value; int err;
u32 key;
char b;
zero_verdict_count(verd_mapfd);
p0 = socket_loopback(family, sotype | SOCK_NONBLOCK); p0 = socket_loopback(family, SOCK_DGRAM | SOCK_NONBLOCK);
if (p0 < 0) if (p0 < 0)
return; return p0;
len = sizeof(addr); len = sizeof(addr);
err = xgetsockname(p0, sockaddr(&addr), &len); err = xgetsockname(p0, sockaddr(&addr), &len);
if (err) if (err)
goto close_peer0; goto close_peer0;
c0 = xsocket(family, sotype | SOCK_NONBLOCK, 0); c0 = xsocket(family, SOCK_DGRAM | SOCK_NONBLOCK, 0);
if (c0 < 0) if (c0 < 0) {
err = c0;
goto close_peer0; goto close_peer0;
}
err = xconnect(c0, sockaddr(&addr), len); err = xconnect(c0, sockaddr(&addr), len);
if (err) if (err)
goto close_cli0; goto close_cli0;
...@@ -1640,35 +1723,131 @@ static void udp_redir_to_connected(int family, int sotype, int sock_mapfd, ...@@ -1640,35 +1723,131 @@ static void udp_redir_to_connected(int family, int sotype, int sock_mapfd,
if (err) if (err)
goto close_cli0; goto close_cli0;
p1 = socket_loopback(family, sotype | SOCK_NONBLOCK); *s = p0;
if (p1 < 0) *c = c0;
goto close_cli0; return 0;
err = xgetsockname(p1, sockaddr(&addr), &len);
close_cli0:
xclose(c0);
close_peer0:
xclose(p0);
return err;
}
static void udp_redir_to_connected(int family, int sock_mapfd, int verd_mapfd,
enum redir_mode mode)
{
const char *log_prefix = redir_mode_str(mode);
int c0, c1, p0, p1;
unsigned int pass;
int retries = 100;
int err, n;
u32 key;
char b;
zero_verdict_count(verd_mapfd);
err = udp_socketpair(family, &p0, &c0);
if (err)
return;
err = udp_socketpair(family, &p1, &c1);
if (err) if (err)
goto close_cli0; goto close_cli0;
c1 = xsocket(family, sotype | SOCK_NONBLOCK, 0); err = add_to_sockmap(sock_mapfd, p0, p1);
if (c1 < 0)
goto close_peer1;
err = xconnect(c1, sockaddr(&addr), len);
if (err) if (err)
goto close_cli1; goto close_cli1;
err = xgetsockname(c1, sockaddr(&addr), &len);
if (err) n = write(c1, "a", 1);
if (n < 0)
FAIL_ERRNO("%s: write", log_prefix);
if (n == 0)
FAIL("%s: incomplete write", log_prefix);
if (n < 1)
goto close_cli1; goto close_cli1;
err = xconnect(p1, sockaddr(&addr), len);
key = SK_PASS;
err = xbpf_map_lookup_elem(verd_mapfd, &key, &pass);
if (err) if (err)
goto close_cli1; goto close_cli1;
if (pass != 1)
FAIL("%s: want pass count 1, have %d", log_prefix, pass);
key = 0; again:
value = p0; n = read(mode == REDIR_INGRESS ? p0 : c0, &b, 1);
err = xbpf_map_update_elem(sock_mapfd, &key, &value, BPF_NOEXIST); if (n < 0) {
if (errno == EAGAIN && retries--)
goto again;
FAIL_ERRNO("%s: read", log_prefix);
}
if (n == 0)
FAIL("%s: incomplete read", log_prefix);
close_cli1:
xclose(c1);
xclose(p1);
close_cli0:
xclose(c0);
xclose(p0);
}
static void udp_skb_redir_to_connected(struct test_sockmap_listen *skel,
struct bpf_map *inner_map, int family)
{
int verdict = bpf_program__fd(skel->progs.prog_skb_verdict);
int verdict_map = bpf_map__fd(skel->maps.verdict_map);
int sock_map = bpf_map__fd(inner_map);
int err;
err = xbpf_prog_attach(verdict, sock_map, BPF_SK_SKB_VERDICT, 0);
if (err) if (err)
goto close_cli1; return;
key = 1; skel->bss->test_ingress = false;
value = p1; udp_redir_to_connected(family, sock_map, verdict_map, REDIR_EGRESS);
err = xbpf_map_update_elem(sock_mapfd, &key, &value, BPF_NOEXIST); skel->bss->test_ingress = true;
udp_redir_to_connected(family, sock_map, verdict_map, REDIR_INGRESS);
xbpf_prog_detach2(verdict, sock_map, BPF_SK_SKB_VERDICT);
}
static void test_udp_redir(struct test_sockmap_listen *skel, struct bpf_map *map,
int family)
{
const char *family_name, *map_name;
char s[MAX_TEST_NAME];
family_name = family_str(family);
map_name = map_type_str(map);
snprintf(s, sizeof(s), "%s %s %s", map_name, family_name, __func__);
if (!test__start_subtest(s))
return;
udp_skb_redir_to_connected(skel, map, family);
}
static void udp_unix_redir_to_connected(int family, int sock_mapfd,
int verd_mapfd, enum redir_mode mode)
{
const char *log_prefix = redir_mode_str(mode);
int c0, c1, p0, p1;
unsigned int pass;
int retries = 100;
int err, n;
int sfd[2];
u32 key;
char b;
zero_verdict_count(verd_mapfd);
if (socketpair(AF_UNIX, SOCK_DGRAM | SOCK_NONBLOCK, 0, sfd))
return;
c0 = sfd[0], p0 = sfd[1];
err = udp_socketpair(family, &p1, &c1);
if (err)
goto close;
err = add_to_sockmap(sock_mapfd, p0, p1);
if (err) if (err)
goto close_cli1; goto close_cli1;
...@@ -1699,15 +1878,88 @@ static void udp_redir_to_connected(int family, int sotype, int sock_mapfd, ...@@ -1699,15 +1878,88 @@ static void udp_redir_to_connected(int family, int sotype, int sock_mapfd,
close_cli1: close_cli1:
xclose(c1); xclose(c1);
close_peer1: xclose(p1);
close:
xclose(c0);
xclose(p0);
}
static void udp_unix_skb_redir_to_connected(struct test_sockmap_listen *skel,
struct bpf_map *inner_map, int family)
{
int verdict = bpf_program__fd(skel->progs.prog_skb_verdict);
int verdict_map = bpf_map__fd(skel->maps.verdict_map);
int sock_map = bpf_map__fd(inner_map);
int err;
err = xbpf_prog_attach(verdict, sock_map, BPF_SK_SKB_VERDICT, 0);
if (err)
return;
skel->bss->test_ingress = false;
udp_unix_redir_to_connected(family, sock_map, verdict_map, REDIR_EGRESS);
skel->bss->test_ingress = true;
udp_unix_redir_to_connected(family, sock_map, verdict_map, REDIR_INGRESS);
xbpf_prog_detach2(verdict, sock_map, BPF_SK_SKB_VERDICT);
}
static void unix_udp_redir_to_connected(int family, int sock_mapfd,
int verd_mapfd, enum redir_mode mode)
{
const char *log_prefix = redir_mode_str(mode);
int c0, c1, p0, p1;
unsigned int pass;
int err, n;
int sfd[2];
u32 key;
char b;
zero_verdict_count(verd_mapfd);
err = udp_socketpair(family, &p0, &c0);
if (err)
return;
if (socketpair(AF_UNIX, SOCK_DGRAM | SOCK_NONBLOCK, 0, sfd))
goto close_cli0;
c1 = sfd[0], p1 = sfd[1];
err = add_to_sockmap(sock_mapfd, p0, p1);
if (err)
goto close;
n = write(c1, "a", 1);
if (n < 0)
FAIL_ERRNO("%s: write", log_prefix);
if (n == 0)
FAIL("%s: incomplete write", log_prefix);
if (n < 1)
goto close;
key = SK_PASS;
err = xbpf_map_lookup_elem(verd_mapfd, &key, &pass);
if (err)
goto close;
if (pass != 1)
FAIL("%s: want pass count 1, have %d", log_prefix, pass);
n = read(mode == REDIR_INGRESS ? p0 : c0, &b, 1);
if (n < 0)
FAIL_ERRNO("%s: read", log_prefix);
if (n == 0)
FAIL("%s: incomplete read", log_prefix);
close:
xclose(c1);
xclose(p1); xclose(p1);
close_cli0: close_cli0:
xclose(c0); xclose(c0);
close_peer0:
xclose(p0); xclose(p0);
} }
static void udp_skb_redir_to_connected(struct test_sockmap_listen *skel, static void unix_udp_skb_redir_to_connected(struct test_sockmap_listen *skel,
struct bpf_map *inner_map, int family) struct bpf_map *inner_map, int family)
{ {
int verdict = bpf_program__fd(skel->progs.prog_skb_verdict); int verdict = bpf_program__fd(skel->progs.prog_skb_verdict);
...@@ -1720,16 +1972,14 @@ static void udp_skb_redir_to_connected(struct test_sockmap_listen *skel, ...@@ -1720,16 +1972,14 @@ static void udp_skb_redir_to_connected(struct test_sockmap_listen *skel,
return; return;
skel->bss->test_ingress = false; skel->bss->test_ingress = false;
udp_redir_to_connected(family, SOCK_DGRAM, sock_map, verdict_map, unix_udp_redir_to_connected(family, sock_map, verdict_map, REDIR_EGRESS);
REDIR_EGRESS);
skel->bss->test_ingress = true; skel->bss->test_ingress = true;
udp_redir_to_connected(family, SOCK_DGRAM, sock_map, verdict_map, unix_udp_redir_to_connected(family, sock_map, verdict_map, REDIR_INGRESS);
REDIR_INGRESS);
xbpf_prog_detach2(verdict, sock_map, BPF_SK_SKB_VERDICT); xbpf_prog_detach2(verdict, sock_map, BPF_SK_SKB_VERDICT);
} }
static void test_udp_redir(struct test_sockmap_listen *skel, struct bpf_map *map, static void test_udp_unix_redir(struct test_sockmap_listen *skel, struct bpf_map *map,
int family) int family)
{ {
const char *family_name, *map_name; const char *family_name, *map_name;
...@@ -1740,7 +1990,8 @@ static void test_udp_redir(struct test_sockmap_listen *skel, struct bpf_map *map ...@@ -1740,7 +1990,8 @@ static void test_udp_redir(struct test_sockmap_listen *skel, struct bpf_map *map
snprintf(s, sizeof(s), "%s %s %s", map_name, family_name, __func__); snprintf(s, sizeof(s), "%s %s %s", map_name, family_name, __func__);
if (!test__start_subtest(s)) if (!test__start_subtest(s))
return; return;
udp_skb_redir_to_connected(skel, map, family); udp_unix_skb_redir_to_connected(skel, map, family);
unix_udp_skb_redir_to_connected(skel, map, family);
} }
static void run_tests(struct test_sockmap_listen *skel, struct bpf_map *map, static void run_tests(struct test_sockmap_listen *skel, struct bpf_map *map,
...@@ -1752,6 +2003,7 @@ static void run_tests(struct test_sockmap_listen *skel, struct bpf_map *map, ...@@ -1752,6 +2003,7 @@ static void run_tests(struct test_sockmap_listen *skel, struct bpf_map *map,
test_reuseport(skel, map, family, SOCK_STREAM); test_reuseport(skel, map, family, SOCK_STREAM);
test_reuseport(skel, map, family, SOCK_DGRAM); test_reuseport(skel, map, family, SOCK_DGRAM);
test_udp_redir(skel, map, family); test_udp_redir(skel, map, family);
test_udp_unix_redir(skel, map, family);
} }
void test_sockmap_listen(void) void test_sockmap_listen(void)
...@@ -1767,10 +2019,12 @@ void test_sockmap_listen(void) ...@@ -1767,10 +2019,12 @@ void test_sockmap_listen(void)
skel->bss->test_sockmap = true; skel->bss->test_sockmap = true;
run_tests(skel, skel->maps.sock_map, AF_INET); run_tests(skel, skel->maps.sock_map, AF_INET);
run_tests(skel, skel->maps.sock_map, AF_INET6); run_tests(skel, skel->maps.sock_map, AF_INET6);
test_unix_redir(skel, skel->maps.sock_map, SOCK_DGRAM);
skel->bss->test_sockmap = false; skel->bss->test_sockmap = false;
run_tests(skel, skel->maps.sock_hash, AF_INET); run_tests(skel, skel->maps.sock_hash, AF_INET);
run_tests(skel, skel->maps.sock_hash, AF_INET6); run_tests(skel, skel->maps.sock_hash, AF_INET6);
test_unix_redir(skel, skel->maps.sock_hash, SOCK_DGRAM);
test_sockmap_listen__destroy(skel); test_sockmap_listen__destroy(skel);
} }
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment