Commit 75179e2b authored by Alexei Starovoitov's avatar Alexei Starovoitov

Merge branch 'bpf: net: Remove duplicated code from bpf_setsockopt()'

Martin KaFai Lau says:

====================

The code in bpf_setsockopt() is mostly a copy-and-paste from
the sock_setsockopt(), do_tcp_setsockopt(), do_ipv6_setsockopt(),
and do_ip_setsockopt().  As the allowed optnames in bpf_setsockopt()
grows, so are the duplicated code.  The code between the copies
also slowly drifted.

This set is an effort to clean this up and reuse the existing
{sock,do_tcp,do_ipv6,do_ip}_setsockopt() as much as possible.

After the clean up, this set also adds a few allowed optnames
that we need to the bpf_setsockopt().

The initial attempt was to clean up both bpf_setsockopt() and
bpf_getsockopt() together.  However, the patch set was getting
too long.  It is beneficial to leave the bpf_getsockopt()
out for another patch set.  Thus, this set is focusing
on the bpf_setsockopt().

v4:
- This set now depends on the commit f574f7f8 ("net: bpf: Use the protocol's set_rcvlowat behavior if there is one")
  in the net-next tree.  The commit calls a specific protocol's
  set_rcvlowat and it changed the bpf_setsockopt
  which this set has also changed.

  Because of this, patch 9 of this set has also adjusted
  and a 'sock' NULL check is added to the sk_setsockopt()
  because some of the bpf hooks have a NULL sk->sk_socket.
  This removes more dup code from the bpf_setsockopt() side.
- Avoid mentioning specific prog types in the comment of
  the has_current_bpf_ctx(). (Andrii)
- Replace signed with unsigned int bitfield in the
  patch 15 selftest. (Daniel)

v3:
- s/in_bpf/has_current_bpf_ctx/ (Andrii)
- Add comment to has_current_bpf_ctx() and sockopt_lock_sock()
  (Stanislav)
- Use vmlinux.h in selftest and add defines to bpf_tracing_net.h
  (Stanislav)
- Use bpf_getsockopt(SO_MARK) in selftest (Stanislav)
- Use BPF_CORE_READ_BITFIELD in selftest (Yonghong)

v2:
- A major change is to use in_bpf() to test if a setsockopt()
  is called by a bpf prog and use in_bpf() to skip capable
  check.  Suggested by Stanislav.
- Instead of passing is_locked through sockptr_t or through an extra
  argument to sk_setsockopt, v2 uses in_bpf() to skip the lock_sock()
  also because bpf prog has the lock acquired.
- No change to the current sockptr_t in this revision
- s/codes/code/
====================
Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
parents fb8d784b 31123c03
......@@ -1966,6 +1966,15 @@ static inline bool unprivileged_ebpf_enabled(void)
return !sysctl_unprivileged_bpf_disabled;
}
/* Not all bpf prog type has the bpf_ctx.
* For the bpf prog type that has initialized the bpf_ctx,
* this function can be used to decide if a kernel function
* is called by a bpf program.
*/
static inline bool has_current_bpf_ctx(void)
{
return !!current->bpf_ctx;
}
#else /* !CONFIG_BPF_SYSCALL */
static inline struct bpf_prog *bpf_prog_get(u32 ufd)
{
......@@ -2175,6 +2184,10 @@ static inline bool unprivileged_ebpf_enabled(void)
return false;
}
static inline bool has_current_bpf_ctx(void)
{
return false;
}
#endif /* CONFIG_BPF_SYSCALL */
void __bpf_free_used_btfs(struct bpf_prog_aux *aux,
......
......@@ -743,6 +743,8 @@ void ip_cmsg_recv_offset(struct msghdr *msg, struct sock *sk,
int ip_cmsg_send(struct sock *sk, struct msghdr *msg,
struct ipcm_cookie *ipc, bool allow_ipv6);
DECLARE_STATIC_KEY_FALSE(ip4_min_ttl);
int do_ip_setsockopt(struct sock *sk, int level, int optname, sockptr_t optval,
unsigned int optlen);
int ip_setsockopt(struct sock *sk, int level, int optname, sockptr_t optval,
unsigned int optlen);
int ip_getsockopt(struct sock *sk, int level, int optname, char __user *optval,
......
......@@ -1156,6 +1156,8 @@ struct in6_addr *fl6_update_dst(struct flowi6 *fl6,
*/
DECLARE_STATIC_KEY_FALSE(ip6_min_hopcount);
int do_ipv6_setsockopt(struct sock *sk, int level, int optname, sockptr_t optval,
unsigned int optlen);
int ipv6_setsockopt(struct sock *sk, int level, int optname, sockptr_t optval,
unsigned int optlen);
int ipv6_getsockopt(struct sock *sk, int level, int optname,
......
......@@ -81,6 +81,8 @@ struct ipv6_bpf_stub {
const struct in6_addr *daddr, __be16 dport,
int dif, int sdif, struct udp_table *tbl,
struct sk_buff *skb);
int (*ipv6_setsockopt)(struct sock *sk, int level, int optname,
sockptr_t optval, unsigned int optlen);
};
extern const struct ipv6_bpf_stub *ipv6_bpf_stub __read_mostly;
......
......@@ -1749,6 +1749,11 @@ static inline void unlock_sock_fast(struct sock *sk, bool slow)
}
}
void sockopt_lock_sock(struct sock *sk);
void sockopt_release_sock(struct sock *sk);
bool sockopt_ns_capable(struct user_namespace *ns, int cap);
bool sockopt_capable(int cap);
/* Used by processes to "lock" a socket state, so that
* interrupts and bottom half handlers won't change it
* from under us. It essentially blocks any incoming
......@@ -1823,6 +1828,8 @@ void sock_pfree(struct sk_buff *skb);
#define sock_edemux sock_efree
#endif
int sk_setsockopt(struct sock *sk, int level, int optname,
sockptr_t optval, unsigned int optlen);
int sock_setsockopt(struct socket *sock, int level, int op,
sockptr_t optval, unsigned int optlen);
......
......@@ -405,6 +405,8 @@ __poll_t tcp_poll(struct file *file, struct socket *sock,
int tcp_getsockopt(struct sock *sk, int level, int optname,
char __user *optval, int __user *optlen);
bool tcp_bpf_bypass_getsockopt(int level, int optname);
int do_tcp_setsockopt(struct sock *sk, int level, int optname,
sockptr_t optval, unsigned int optlen);
int tcp_setsockopt(struct sock *sk, int level, int optname, sockptr_t optval,
unsigned int optlen);
void tcp_set_keepalive(struct sock *sk, int val);
......
......@@ -694,19 +694,24 @@ struct bpf_prog *bpf_iter_get_info(struct bpf_iter_meta *meta, bool in_stop)
int bpf_iter_run_prog(struct bpf_prog *prog, void *ctx)
{
struct bpf_run_ctx run_ctx, *old_run_ctx;
int ret;
if (prog->aux->sleepable) {
rcu_read_lock_trace();
migrate_disable();
might_fault();
old_run_ctx = bpf_set_run_ctx(&run_ctx);
ret = bpf_prog_run(prog, ctx);
bpf_reset_run_ctx(old_run_ctx);
migrate_enable();
rcu_read_unlock_trace();
} else {
rcu_read_lock();
migrate_disable();
old_run_ctx = bpf_set_run_ctx(&run_ctx);
ret = bpf_prog_run(prog, ctx);
bpf_reset_run_ctx(old_run_ctx);
migrate_enable();
rcu_read_unlock();
}
......
......@@ -5013,251 +5013,166 @@ static const struct bpf_func_proto bpf_get_socket_uid_proto = {
.arg1_type = ARG_PTR_TO_CTX,
};
static int __bpf_setsockopt(struct sock *sk, int level, int optname,
char *optval, int optlen)
static int sol_socket_setsockopt(struct sock *sk, int optname,
char *optval, int optlen)
{
switch (optname) {
case SO_REUSEADDR:
case SO_SNDBUF:
case SO_RCVBUF:
case SO_KEEPALIVE:
case SO_PRIORITY:
case SO_REUSEPORT:
case SO_RCVLOWAT:
case SO_MARK:
case SO_MAX_PACING_RATE:
case SO_BINDTOIFINDEX:
case SO_TXREHASH:
if (optlen != sizeof(int))
return -EINVAL;
break;
case SO_BINDTODEVICE:
break;
default:
return -EINVAL;
}
return sk_setsockopt(sk, SOL_SOCKET, optname,
KERNEL_SOCKPTR(optval), optlen);
}
static int bpf_sol_tcp_setsockopt(struct sock *sk, int optname,
char *optval, int optlen)
{
char devname[IFNAMSIZ];
int val, valbool;
struct net *net;
int ifindex;
int ret = 0;
struct tcp_sock *tp = tcp_sk(sk);
unsigned long timeout;
int val;
if (!sk_fullsock(sk))
if (optlen != sizeof(int))
return -EINVAL;
if (level == SOL_SOCKET) {
if (optlen != sizeof(int) && optname != SO_BINDTODEVICE)
val = *(int *)optval;
/* Only some options are supported */
switch (optname) {
case TCP_BPF_IW:
if (val <= 0 || tp->data_segs_out > tp->syn_data)
return -EINVAL;
val = *((int *)optval);
valbool = val ? 1 : 0;
tcp_snd_cwnd_set(tp, val);
break;
case TCP_BPF_SNDCWND_CLAMP:
if (val <= 0)
return -EINVAL;
tp->snd_cwnd_clamp = val;
tp->snd_ssthresh = val;
break;
case TCP_BPF_DELACK_MAX:
timeout = usecs_to_jiffies(val);
if (timeout > TCP_DELACK_MAX ||
timeout < TCP_TIMEOUT_MIN)
return -EINVAL;
inet_csk(sk)->icsk_delack_max = timeout;
break;
case TCP_BPF_RTO_MIN:
timeout = usecs_to_jiffies(val);
if (timeout > TCP_RTO_MIN ||
timeout < TCP_TIMEOUT_MIN)
return -EINVAL;
inet_csk(sk)->icsk_rto_min = timeout;
break;
default:
return -EINVAL;
}
/* Only some socketops are supported */
switch (optname) {
case SO_RCVBUF:
val = min_t(u32, val, sysctl_rmem_max);
val = min_t(int, val, INT_MAX / 2);
sk->sk_userlocks |= SOCK_RCVBUF_LOCK;
WRITE_ONCE(sk->sk_rcvbuf,
max_t(int, val * 2, SOCK_MIN_RCVBUF));
break;
case SO_SNDBUF:
val = min_t(u32, val, sysctl_wmem_max);
val = min_t(int, val, INT_MAX / 2);
sk->sk_userlocks |= SOCK_SNDBUF_LOCK;
WRITE_ONCE(sk->sk_sndbuf,
max_t(int, val * 2, SOCK_MIN_SNDBUF));
break;
case SO_MAX_PACING_RATE: /* 32bit version */
if (val != ~0U)
cmpxchg(&sk->sk_pacing_status,
SK_PACING_NONE,
SK_PACING_NEEDED);
sk->sk_max_pacing_rate = (val == ~0U) ?
~0UL : (unsigned int)val;
sk->sk_pacing_rate = min(sk->sk_pacing_rate,
sk->sk_max_pacing_rate);
break;
case SO_PRIORITY:
sk->sk_priority = val;
break;
case SO_RCVLOWAT:
if (val < 0)
val = INT_MAX;
if (sk->sk_socket && sk->sk_socket->ops->set_rcvlowat)
ret = sk->sk_socket->ops->set_rcvlowat(sk, val);
else
WRITE_ONCE(sk->sk_rcvlowat, val ? : 1);
break;
case SO_MARK:
if (sk->sk_mark != val) {
sk->sk_mark = val;
sk_dst_reset(sk);
}
break;
case SO_BINDTODEVICE:
optlen = min_t(long, optlen, IFNAMSIZ - 1);
strncpy(devname, optval, optlen);
devname[optlen] = 0;
return 0;
}
ifindex = 0;
if (devname[0] != '\0') {
struct net_device *dev;
static int sol_tcp_setsockopt(struct sock *sk, int optname,
char *optval, int optlen)
{
if (sk->sk_prot->setsockopt != tcp_setsockopt)
return -EINVAL;
ret = -ENODEV;
switch (optname) {
case TCP_NODELAY:
case TCP_MAXSEG:
case TCP_KEEPIDLE:
case TCP_KEEPINTVL:
case TCP_KEEPCNT:
case TCP_SYNCNT:
case TCP_WINDOW_CLAMP:
case TCP_THIN_LINEAR_TIMEOUTS:
case TCP_USER_TIMEOUT:
case TCP_NOTSENT_LOWAT:
case TCP_SAVE_SYN:
if (optlen != sizeof(int))
return -EINVAL;
break;
case TCP_CONGESTION:
break;
default:
return bpf_sol_tcp_setsockopt(sk, optname, optval, optlen);
}
net = sock_net(sk);
dev = dev_get_by_name(net, devname);
if (!dev)
break;
ifindex = dev->ifindex;
dev_put(dev);
}
fallthrough;
case SO_BINDTOIFINDEX:
if (optname == SO_BINDTOIFINDEX)
ifindex = val;
ret = sock_bindtoindex(sk, ifindex, false);
break;
case SO_KEEPALIVE:
if (sk->sk_prot->keepalive)
sk->sk_prot->keepalive(sk, valbool);
sock_valbool_flag(sk, SOCK_KEEPOPEN, valbool);
break;
case SO_REUSEPORT:
sk->sk_reuseport = valbool;
break;
case SO_TXREHASH:
if (val < -1 || val > 1) {
ret = -EINVAL;
break;
}
sk->sk_txrehash = (u8)val;
break;
default:
ret = -EINVAL;
}
#ifdef CONFIG_INET
} else if (level == SOL_IP) {
if (optlen != sizeof(int) || sk->sk_family != AF_INET)
return do_tcp_setsockopt(sk, SOL_TCP, optname,
KERNEL_SOCKPTR(optval), optlen);
}
static int sol_ip_setsockopt(struct sock *sk, int optname,
char *optval, int optlen)
{
if (sk->sk_family != AF_INET)
return -EINVAL;
switch (optname) {
case IP_TOS:
if (optlen != sizeof(int))
return -EINVAL;
break;
default:
return -EINVAL;
}
val = *((int *)optval);
/* Only some options are supported */
switch (optname) {
case IP_TOS:
if (val < -1 || val > 0xff) {
ret = -EINVAL;
} else {
struct inet_sock *inet = inet_sk(sk);
return do_ip_setsockopt(sk, SOL_IP, optname,
KERNEL_SOCKPTR(optval), optlen);
}
if (val == -1)
val = 0;
inet->tos = val;
}
break;
default:
ret = -EINVAL;
}
#if IS_ENABLED(CONFIG_IPV6)
} else if (level == SOL_IPV6) {
if (optlen != sizeof(int) || sk->sk_family != AF_INET6)
static int sol_ipv6_setsockopt(struct sock *sk, int optname,
char *optval, int optlen)
{
if (sk->sk_family != AF_INET6)
return -EINVAL;
switch (optname) {
case IPV6_TCLASS:
case IPV6_AUTOFLOWLABEL:
if (optlen != sizeof(int))
return -EINVAL;
break;
default:
return -EINVAL;
}
val = *((int *)optval);
/* Only some options are supported */
switch (optname) {
case IPV6_TCLASS:
if (val < -1 || val > 0xff) {
ret = -EINVAL;
} else {
struct ipv6_pinfo *np = inet6_sk(sk);
return ipv6_bpf_stub->ipv6_setsockopt(sk, SOL_IPV6, optname,
KERNEL_SOCKPTR(optval), optlen);
}
if (val == -1)
val = 0;
np->tclass = val;
}
break;
default:
ret = -EINVAL;
}
#endif
} else if (level == SOL_TCP &&
sk->sk_prot->setsockopt == tcp_setsockopt) {
if (optname == TCP_CONGESTION) {
char name[TCP_CA_NAME_MAX];
strncpy(name, optval, min_t(long, optlen,
TCP_CA_NAME_MAX-1));
name[TCP_CA_NAME_MAX-1] = 0;
ret = tcp_set_congestion_control(sk, name, false, true);
} else {
struct inet_connection_sock *icsk = inet_csk(sk);
struct tcp_sock *tp = tcp_sk(sk);
unsigned long timeout;
static int __bpf_setsockopt(struct sock *sk, int level, int optname,
char *optval, int optlen)
{
if (!sk_fullsock(sk))
return -EINVAL;
if (optlen != sizeof(int))
return -EINVAL;
if (level == SOL_SOCKET)
return sol_socket_setsockopt(sk, optname, optval, optlen);
else if (IS_ENABLED(CONFIG_INET) && level == SOL_IP)
return sol_ip_setsockopt(sk, optname, optval, optlen);
else if (IS_ENABLED(CONFIG_IPV6) && level == SOL_IPV6)
return sol_ipv6_setsockopt(sk, optname, optval, optlen);
else if (IS_ENABLED(CONFIG_INET) && level == SOL_TCP)
return sol_tcp_setsockopt(sk, optname, optval, optlen);
val = *((int *)optval);
/* Only some options are supported */
switch (optname) {
case TCP_BPF_IW:
if (val <= 0 || tp->data_segs_out > tp->syn_data)
ret = -EINVAL;
else
tcp_snd_cwnd_set(tp, val);
break;
case TCP_BPF_SNDCWND_CLAMP:
if (val <= 0) {
ret = -EINVAL;
} else {
tp->snd_cwnd_clamp = val;
tp->snd_ssthresh = val;
}
break;
case TCP_BPF_DELACK_MAX:
timeout = usecs_to_jiffies(val);
if (timeout > TCP_DELACK_MAX ||
timeout < TCP_TIMEOUT_MIN)
return -EINVAL;
inet_csk(sk)->icsk_delack_max = timeout;
break;
case TCP_BPF_RTO_MIN:
timeout = usecs_to_jiffies(val);
if (timeout > TCP_RTO_MIN ||
timeout < TCP_TIMEOUT_MIN)
return -EINVAL;
inet_csk(sk)->icsk_rto_min = timeout;
break;
case TCP_SAVE_SYN:
if (val < 0 || val > 1)
ret = -EINVAL;
else
tp->save_syn = val;
break;
case TCP_KEEPIDLE:
ret = tcp_sock_set_keepidle_locked(sk, val);
break;
case TCP_KEEPINTVL:
if (val < 1 || val > MAX_TCP_KEEPINTVL)
ret = -EINVAL;
else
tp->keepalive_intvl = val * HZ;
break;
case TCP_KEEPCNT:
if (val < 1 || val > MAX_TCP_KEEPCNT)
ret = -EINVAL;
else
tp->keepalive_probes = val;
break;
case TCP_SYNCNT:
if (val < 1 || val > MAX_TCP_SYNCNT)
ret = -EINVAL;
else
icsk->icsk_syn_retries = val;
break;
case TCP_USER_TIMEOUT:
if (val < 0)
ret = -EINVAL;
else
icsk->icsk_user_timeout = val;
break;
case TCP_NOTSENT_LOWAT:
tp->notsent_lowat = val;
sk->sk_write_space(sk);
break;
case TCP_WINDOW_CLAMP:
ret = tcp_set_window_clamp(sk, val);
break;
default:
ret = -EINVAL;
}
}
#endif
} else {
ret = -EINVAL;
}
return ret;
return -EINVAL;
}
static int _bpf_setsockopt(struct sock *sk, int level, int optname,
......
......@@ -703,7 +703,9 @@ static int sock_setbindtodevice(struct sock *sk, sockptr_t optval, int optlen)
goto out;
}
return sock_bindtoindex(sk, index, true);
sockopt_lock_sock(sk);
ret = sock_bindtoindex_locked(sk, index);
sockopt_release_sock(sk);
out:
#endif
......@@ -1036,17 +1038,51 @@ static int sock_reserve_memory(struct sock *sk, int bytes)
return 0;
}
void sockopt_lock_sock(struct sock *sk)
{
/* When current->bpf_ctx is set, the setsockopt is called from
* a bpf prog. bpf has ensured the sk lock has been
* acquired before calling setsockopt().
*/
if (has_current_bpf_ctx())
return;
lock_sock(sk);
}
EXPORT_SYMBOL(sockopt_lock_sock);
void sockopt_release_sock(struct sock *sk)
{
if (has_current_bpf_ctx())
return;
release_sock(sk);
}
EXPORT_SYMBOL(sockopt_release_sock);
bool sockopt_ns_capable(struct user_namespace *ns, int cap)
{
return has_current_bpf_ctx() || ns_capable(ns, cap);
}
EXPORT_SYMBOL(sockopt_ns_capable);
bool sockopt_capable(int cap)
{
return has_current_bpf_ctx() || capable(cap);
}
EXPORT_SYMBOL(sockopt_capable);
/*
* This is meant for all protocols to use and covers goings on
* at the socket level. Everything here is generic.
*/
int sock_setsockopt(struct socket *sock, int level, int optname,
sockptr_t optval, unsigned int optlen)
int sk_setsockopt(struct sock *sk, int level, int optname,
sockptr_t optval, unsigned int optlen)
{
struct so_timestamping timestamping;
struct socket *sock = sk->sk_socket;
struct sock_txtime sk_txtime;
struct sock *sk = sock->sk;
int val;
int valbool;
struct linger ling;
......@@ -1067,11 +1103,11 @@ int sock_setsockopt(struct socket *sock, int level, int optname,
valbool = val ? 1 : 0;
lock_sock(sk);
sockopt_lock_sock(sk);
switch (optname) {
case SO_DEBUG:
if (val && !capable(CAP_NET_ADMIN))
if (val && !sockopt_capable(CAP_NET_ADMIN))
ret = -EACCES;
else
sock_valbool_flag(sk, SOCK_DBG, valbool);
......@@ -1115,7 +1151,7 @@ int sock_setsockopt(struct socket *sock, int level, int optname,
break;
case SO_SNDBUFFORCE:
if (!capable(CAP_NET_ADMIN)) {
if (!sockopt_capable(CAP_NET_ADMIN)) {
ret = -EPERM;
break;
}
......@@ -1137,7 +1173,7 @@ int sock_setsockopt(struct socket *sock, int level, int optname,
break;
case SO_RCVBUFFORCE:
if (!capable(CAP_NET_ADMIN)) {
if (!sockopt_capable(CAP_NET_ADMIN)) {
ret = -EPERM;
break;
}
......@@ -1164,8 +1200,8 @@ int sock_setsockopt(struct socket *sock, int level, int optname,
case SO_PRIORITY:
if ((val >= 0 && val <= 6) ||
ns_capable(sock_net(sk)->user_ns, CAP_NET_RAW) ||
ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN))
sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_RAW) ||
sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN))
sk->sk_priority = val;
else
ret = -EPERM;
......@@ -1228,7 +1264,7 @@ int sock_setsockopt(struct socket *sock, int level, int optname,
case SO_RCVLOWAT:
if (val < 0)
val = INT_MAX;
if (sock->ops->set_rcvlowat)
if (sock && sock->ops->set_rcvlowat)
ret = sock->ops->set_rcvlowat(sk, val);
else
WRITE_ONCE(sk->sk_rcvlowat, val ? : 1);
......@@ -1310,8 +1346,8 @@ int sock_setsockopt(struct socket *sock, int level, int optname,
clear_bit(SOCK_PASSSEC, &sock->flags);
break;
case SO_MARK:
if (!ns_capable(sock_net(sk)->user_ns, CAP_NET_RAW) &&
!ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) {
if (!sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_RAW) &&
!sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) {
ret = -EPERM;
break;
}
......@@ -1319,8 +1355,8 @@ int sock_setsockopt(struct socket *sock, int level, int optname,
__sock_set_mark(sk, val);
break;
case SO_RCVMARK:
if (!ns_capable(sock_net(sk)->user_ns, CAP_NET_RAW) &&
!ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) {
if (!sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_RAW) &&
!sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) {
ret = -EPERM;
break;
}
......@@ -1354,7 +1390,7 @@ int sock_setsockopt(struct socket *sock, int level, int optname,
#ifdef CONFIG_NET_RX_BUSY_POLL
case SO_BUSY_POLL:
/* allow unprivileged users to decrease the value */
if ((val > sk->sk_ll_usec) && !capable(CAP_NET_ADMIN))
if ((val > sk->sk_ll_usec) && !sockopt_capable(CAP_NET_ADMIN))
ret = -EPERM;
else {
if (val < 0)
......@@ -1364,13 +1400,13 @@ int sock_setsockopt(struct socket *sock, int level, int optname,
}
break;
case SO_PREFER_BUSY_POLL:
if (valbool && !capable(CAP_NET_ADMIN))
if (valbool && !sockopt_capable(CAP_NET_ADMIN))
ret = -EPERM;
else
WRITE_ONCE(sk->sk_prefer_busy_poll, valbool);
break;
case SO_BUSY_POLL_BUDGET:
if (val > READ_ONCE(sk->sk_busy_poll_budget) && !capable(CAP_NET_ADMIN)) {
if (val > READ_ONCE(sk->sk_busy_poll_budget) && !sockopt_capable(CAP_NET_ADMIN)) {
ret = -EPERM;
} else {
if (val < 0 || val > U16_MAX)
......@@ -1441,7 +1477,7 @@ int sock_setsockopt(struct socket *sock, int level, int optname,
* scheduler has enough safe guards.
*/
if (sk_txtime.clockid != CLOCK_MONOTONIC &&
!ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) {
!sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) {
ret = -EPERM;
break;
}
......@@ -1496,9 +1532,16 @@ int sock_setsockopt(struct socket *sock, int level, int optname,
ret = -ENOPROTOOPT;
break;
}
release_sock(sk);
sockopt_release_sock(sk);
return ret;
}
int sock_setsockopt(struct socket *sock, int level, int optname,
sockptr_t optval, unsigned int optlen)
{
return sk_setsockopt(sock->sk, level, optname,
optval, optlen);
}
EXPORT_SYMBOL(sock_setsockopt);
static const struct cred *sk_get_peer_cred(struct sock *sk)
......
......@@ -888,8 +888,8 @@ static int compat_ip_mcast_join_leave(struct sock *sk, int optname,
DEFINE_STATIC_KEY_FALSE(ip4_min_ttl);
static int do_ip_setsockopt(struct sock *sk, int level, int optname,
sockptr_t optval, unsigned int optlen)
int do_ip_setsockopt(struct sock *sk, int level, int optname,
sockptr_t optval, unsigned int optlen)
{
struct inet_sock *inet = inet_sk(sk);
struct net *net = sock_net(sk);
......@@ -944,7 +944,7 @@ static int do_ip_setsockopt(struct sock *sk, int level, int optname,
err = 0;
if (needs_rtnl)
rtnl_lock();
lock_sock(sk);
sockopt_lock_sock(sk);
switch (optname) {
case IP_OPTIONS:
......@@ -1333,14 +1333,14 @@ static int do_ip_setsockopt(struct sock *sk, int level, int optname,
case IP_IPSEC_POLICY:
case IP_XFRM_POLICY:
err = -EPERM;
if (!ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN))
if (!sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN))
break;
err = xfrm_user_policy(sk, optname, optval, optlen);
break;
case IP_TRANSPARENT:
if (!!val && !ns_capable(sock_net(sk)->user_ns, CAP_NET_RAW) &&
!ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) {
if (!!val && !sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_RAW) &&
!sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) {
err = -EPERM;
break;
}
......@@ -1368,13 +1368,13 @@ static int do_ip_setsockopt(struct sock *sk, int level, int optname,
err = -ENOPROTOOPT;
break;
}
release_sock(sk);
sockopt_release_sock(sk);
if (needs_rtnl)
rtnl_unlock();
return err;
e_inval:
release_sock(sk);
sockopt_release_sock(sk);
if (needs_rtnl)
rtnl_unlock();
return -EINVAL;
......
......@@ -3202,7 +3202,7 @@ EXPORT_SYMBOL(tcp_disconnect);
static inline bool tcp_can_repair_sock(const struct sock *sk)
{
return ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN) &&
return sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN) &&
(sk->sk_state != TCP_LISTEN);
}
......@@ -3479,8 +3479,8 @@ int tcp_set_window_clamp(struct sock *sk, int val)
/*
* Socket option code for TCP.
*/
static int do_tcp_setsockopt(struct sock *sk, int level, int optname,
sockptr_t optval, unsigned int optlen)
int do_tcp_setsockopt(struct sock *sk, int level, int optname,
sockptr_t optval, unsigned int optlen)
{
struct tcp_sock *tp = tcp_sk(sk);
struct inet_connection_sock *icsk = inet_csk(sk);
......@@ -3502,11 +3502,11 @@ static int do_tcp_setsockopt(struct sock *sk, int level, int optname,
return -EFAULT;
name[val] = 0;
lock_sock(sk);
sockopt_lock_sock(sk);
err = tcp_set_congestion_control(sk, name, true,
ns_capable(sock_net(sk)->user_ns,
CAP_NET_ADMIN));
release_sock(sk);
sockopt_ns_capable(sock_net(sk)->user_ns,
CAP_NET_ADMIN));
sockopt_release_sock(sk);
return err;
}
case TCP_ULP: {
......@@ -3522,9 +3522,9 @@ static int do_tcp_setsockopt(struct sock *sk, int level, int optname,
return -EFAULT;
name[val] = 0;
lock_sock(sk);
sockopt_lock_sock(sk);
err = tcp_set_ulp(sk, name);
release_sock(sk);
sockopt_release_sock(sk);
return err;
}
case TCP_FASTOPEN_KEY: {
......@@ -3557,7 +3557,7 @@ static int do_tcp_setsockopt(struct sock *sk, int level, int optname,
if (copy_from_sockptr(&val, optval, sizeof(val)))
return -EFAULT;
lock_sock(sk);
sockopt_lock_sock(sk);
switch (optname) {
case TCP_MAXSEG:
......@@ -3779,7 +3779,7 @@ static int do_tcp_setsockopt(struct sock *sk, int level, int optname,
break;
}
release_sock(sk);
sockopt_release_sock(sk);
return err;
}
......
......@@ -1057,6 +1057,7 @@ static const struct ipv6_stub ipv6_stub_impl = {
static const struct ipv6_bpf_stub ipv6_bpf_stub_impl = {
.inet6_bind = __inet6_bind,
.udp6_lib_lookup = __udp6_lib_lookup,
.ipv6_setsockopt = do_ipv6_setsockopt,
};
static int __init inet6_init(void)
......
......@@ -327,7 +327,7 @@ static int ipv6_set_opt_hdr(struct sock *sk, int optname, sockptr_t optval,
int err;
/* hop-by-hop / destination options are privileged option */
if (optname != IPV6_RTHDR && !ns_capable(net->user_ns, CAP_NET_RAW))
if (optname != IPV6_RTHDR && !sockopt_ns_capable(net->user_ns, CAP_NET_RAW))
return -EPERM;
/* remove any sticky options header with a zero option
......@@ -391,8 +391,8 @@ static int ipv6_set_opt_hdr(struct sock *sk, int optname, sockptr_t optval,
return err;
}
static int do_ipv6_setsockopt(struct sock *sk, int level, int optname,
sockptr_t optval, unsigned int optlen)
int do_ipv6_setsockopt(struct sock *sk, int level, int optname,
sockptr_t optval, unsigned int optlen)
{
struct ipv6_pinfo *np = inet6_sk(sk);
struct net *net = sock_net(sk);
......@@ -417,7 +417,7 @@ static int do_ipv6_setsockopt(struct sock *sk, int level, int optname,
if (needs_rtnl)
rtnl_lock();
lock_sock(sk);
sockopt_lock_sock(sk);
switch (optname) {
......@@ -634,8 +634,8 @@ static int do_ipv6_setsockopt(struct sock *sk, int level, int optname,
break;
case IPV6_TRANSPARENT:
if (valbool && !ns_capable(net->user_ns, CAP_NET_RAW) &&
!ns_capable(net->user_ns, CAP_NET_ADMIN)) {
if (valbool && !sockopt_ns_capable(net->user_ns, CAP_NET_RAW) &&
!sockopt_ns_capable(net->user_ns, CAP_NET_ADMIN)) {
retv = -EPERM;
break;
}
......@@ -946,7 +946,7 @@ static int do_ipv6_setsockopt(struct sock *sk, int level, int optname,
case IPV6_IPSEC_POLICY:
case IPV6_XFRM_POLICY:
retv = -EPERM;
if (!ns_capable(net->user_ns, CAP_NET_ADMIN))
if (!sockopt_ns_capable(net->user_ns, CAP_NET_ADMIN))
break;
retv = xfrm_user_policy(sk, optname, optval, optlen);
break;
......@@ -994,14 +994,14 @@ static int do_ipv6_setsockopt(struct sock *sk, int level, int optname,
break;
}
release_sock(sk);
sockopt_release_sock(sk);
if (needs_rtnl)
rtnl_unlock();
return retv;
e_inval:
release_sock(sk);
sockopt_release_sock(sk);
if (needs_rtnl)
rtnl_unlock();
return -EINVAL;
......
// SPDX-License-Identifier: GPL-2.0
/* Copyright (c) Meta Platforms, Inc. and affiliates. */
#define _GNU_SOURCE
#include <sched.h>
#include <linux/socket.h>
#include <net/if.h>
#include "test_progs.h"
#include "cgroup_helpers.h"
#include "network_helpers.h"
#include "setget_sockopt.skel.h"
#define CG_NAME "/setget-sockopt-test"
static const char addr4_str[] = "127.0.0.1";
static const char addr6_str[] = "::1";
static struct setget_sockopt *skel;
static int cg_fd;
static int create_netns(void)
{
if (!ASSERT_OK(unshare(CLONE_NEWNET), "create netns"))
return -1;
if (!ASSERT_OK(system("ip link set dev lo up"), "set lo up"))
return -1;
if (!ASSERT_OK(system("ip link add dev binddevtest1 type veth peer name binddevtest2"),
"add veth"))
return -1;
if (!ASSERT_OK(system("ip link set dev binddevtest1 up"),
"bring veth up"))
return -1;
return 0;
}
static void test_tcp(int family)
{
struct setget_sockopt__bss *bss = skel->bss;
int sfd, cfd;
memset(bss, 0, sizeof(*bss));
sfd = start_server(family, SOCK_STREAM,
family == AF_INET6 ? addr6_str : addr4_str, 0, 0);
if (!ASSERT_GE(sfd, 0, "start_server"))
return;
cfd = connect_to_fd(sfd, 0);
if (!ASSERT_GE(cfd, 0, "connect_to_fd_server")) {
close(sfd);
return;
}
close(sfd);
close(cfd);
ASSERT_EQ(bss->nr_listen, 1, "nr_listen");
ASSERT_EQ(bss->nr_connect, 1, "nr_connect");
ASSERT_EQ(bss->nr_active, 1, "nr_active");
ASSERT_EQ(bss->nr_passive, 1, "nr_passive");
ASSERT_EQ(bss->nr_socket_post_create, 2, "nr_socket_post_create");
ASSERT_EQ(bss->nr_binddev, 2, "nr_bind");
}
static void test_udp(int family)
{
struct setget_sockopt__bss *bss = skel->bss;
int sfd;
memset(bss, 0, sizeof(*bss));
sfd = start_server(family, SOCK_DGRAM,
family == AF_INET6 ? addr6_str : addr4_str, 0, 0);
if (!ASSERT_GE(sfd, 0, "start_server"))
return;
close(sfd);
ASSERT_GE(bss->nr_socket_post_create, 1, "nr_socket_post_create");
ASSERT_EQ(bss->nr_binddev, 1, "nr_bind");
}
void test_setget_sockopt(void)
{
cg_fd = test__join_cgroup(CG_NAME);
if (cg_fd < 0)
return;
if (create_netns())
goto done;
skel = setget_sockopt__open();
if (!ASSERT_OK_PTR(skel, "open skel"))
goto done;
strcpy(skel->rodata->veth, "binddevtest1");
skel->rodata->veth_ifindex = if_nametoindex("binddevtest1");
if (!ASSERT_GT(skel->rodata->veth_ifindex, 0, "if_nametoindex"))
goto done;
if (!ASSERT_OK(setget_sockopt__load(skel), "load skel"))
goto done;
skel->links.skops_sockopt =
bpf_program__attach_cgroup(skel->progs.skops_sockopt, cg_fd);
if (!ASSERT_OK_PTR(skel->links.skops_sockopt, "attach cgroup"))
goto done;
skel->links.socket_post_create =
bpf_program__attach_cgroup(skel->progs.socket_post_create, cg_fd);
if (!ASSERT_OK_PTR(skel->links.socket_post_create, "attach_cgroup"))
goto done;
test_tcp(AF_INET6);
test_tcp(AF_INET);
test_udp(AF_INET6);
test_udp(AF_INET);
done:
setget_sockopt__destroy(skel);
close(cg_fd);
}
......@@ -6,13 +6,40 @@
#define AF_INET6 10
#define SOL_SOCKET 1
#define SO_REUSEADDR 2
#define SO_SNDBUF 7
#define __SO_ACCEPTCON (1 << 16)
#define SO_RCVBUF 8
#define SO_KEEPALIVE 9
#define SO_PRIORITY 12
#define SO_REUSEPORT 15
#define SO_RCVLOWAT 18
#define SO_BINDTODEVICE 25
#define SO_MARK 36
#define SO_MAX_PACING_RATE 47
#define SO_BINDTOIFINDEX 62
#define SO_TXREHASH 74
#define __SO_ACCEPTCON (1 << 16)
#define IP_TOS 1
#define IPV6_TCLASS 67
#define IPV6_AUTOFLOWLABEL 70
#define SOL_TCP 6
#define TCP_NODELAY 1
#define TCP_MAXSEG 2
#define TCP_KEEPIDLE 4
#define TCP_KEEPINTVL 5
#define TCP_KEEPCNT 6
#define TCP_SYNCNT 7
#define TCP_WINDOW_CLAMP 10
#define TCP_CONGESTION 13
#define TCP_THIN_LINEAR_TIMEOUTS 16
#define TCP_USER_TIMEOUT 18
#define TCP_NOTSENT_LOWAT 25
#define TCP_SAVE_SYN 27
#define TCP_CA_NAME_MAX 16
#define TCP_NAGLE_OFF 1
#define ICSK_TIME_RETRANS 1
#define ICSK_TIME_PROBE0 3
......@@ -49,6 +76,8 @@
#define sk_state __sk_common.skc_state
#define sk_v6_daddr __sk_common.skc_v6_daddr
#define sk_v6_rcv_saddr __sk_common.skc_v6_rcv_saddr
#define sk_flags __sk_common.skc_flags
#define sk_reuse __sk_common.skc_reuse
#define s6_addr32 in6_u.u6_addr32
......
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment