Commit 8fe08d70 authored by Eric Dumazet's avatar Eric Dumazet Committed by David S. Miller

netlink: convert nlk->flags to atomic flags

sk_diag_put_flags(), netlink_setsockopt(), netlink_getsockopt()
and others use nlk->flags without correct locking.

Use set_bit(), clear_bit(), test_bit(), assign_bit() to remove
data-races.
Reported-by: default avatarsyzbot <syzkaller@googlegroups.com>
Signed-off-by: default avatarEric Dumazet <edumazet@google.com>
Reviewed-by: default avatarSimon Horman <horms@kernel.org>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 86f03776
...@@ -84,7 +84,7 @@ struct listeners { ...@@ -84,7 +84,7 @@ struct listeners {
static inline int netlink_is_kernel(struct sock *sk) static inline int netlink_is_kernel(struct sock *sk)
{ {
return nlk_sk(sk)->flags & NETLINK_F_KERNEL_SOCKET; return nlk_test_bit(KERNEL_SOCKET, sk);
} }
struct netlink_table *nl_table __read_mostly; struct netlink_table *nl_table __read_mostly;
...@@ -349,9 +349,7 @@ static void netlink_deliver_tap_kernel(struct sock *dst, struct sock *src, ...@@ -349,9 +349,7 @@ static void netlink_deliver_tap_kernel(struct sock *dst, struct sock *src,
static void netlink_overrun(struct sock *sk) static void netlink_overrun(struct sock *sk)
{ {
struct netlink_sock *nlk = nlk_sk(sk); if (!nlk_test_bit(RECV_NO_ENOBUFS, sk)) {
if (!(nlk->flags & NETLINK_F_RECV_NO_ENOBUFS)) {
if (!test_and_set_bit(NETLINK_S_CONGESTED, if (!test_and_set_bit(NETLINK_S_CONGESTED,
&nlk_sk(sk)->state)) { &nlk_sk(sk)->state)) {
sk->sk_err = ENOBUFS; sk->sk_err = ENOBUFS;
...@@ -1407,9 +1405,7 @@ EXPORT_SYMBOL_GPL(netlink_has_listeners); ...@@ -1407,9 +1405,7 @@ EXPORT_SYMBOL_GPL(netlink_has_listeners);
bool netlink_strict_get_check(struct sk_buff *skb) bool netlink_strict_get_check(struct sk_buff *skb)
{ {
const struct netlink_sock *nlk = nlk_sk(NETLINK_CB(skb).sk); return nlk_test_bit(STRICT_CHK, NETLINK_CB(skb).sk);
return nlk->flags & NETLINK_F_STRICT_CHK;
} }
EXPORT_SYMBOL_GPL(netlink_strict_get_check); EXPORT_SYMBOL_GPL(netlink_strict_get_check);
...@@ -1455,7 +1451,7 @@ static void do_one_broadcast(struct sock *sk, ...@@ -1455,7 +1451,7 @@ static void do_one_broadcast(struct sock *sk,
return; return;
if (!net_eq(sock_net(sk), p->net)) { if (!net_eq(sock_net(sk), p->net)) {
if (!(nlk->flags & NETLINK_F_LISTEN_ALL_NSID)) if (!nlk_test_bit(LISTEN_ALL_NSID, sk))
return; return;
if (!peernet_has_id(sock_net(sk), p->net)) if (!peernet_has_id(sock_net(sk), p->net))
...@@ -1488,7 +1484,7 @@ static void do_one_broadcast(struct sock *sk, ...@@ -1488,7 +1484,7 @@ static void do_one_broadcast(struct sock *sk,
netlink_overrun(sk); netlink_overrun(sk);
/* Clone failed. Notify ALL listeners. */ /* Clone failed. Notify ALL listeners. */
p->failure = 1; p->failure = 1;
if (nlk->flags & NETLINK_F_BROADCAST_SEND_ERROR) if (nlk_test_bit(BROADCAST_SEND_ERROR, sk))
p->delivery_failure = 1; p->delivery_failure = 1;
goto out; goto out;
} }
...@@ -1510,7 +1506,7 @@ static void do_one_broadcast(struct sock *sk, ...@@ -1510,7 +1506,7 @@ static void do_one_broadcast(struct sock *sk,
val = netlink_broadcast_deliver(sk, p->skb2); val = netlink_broadcast_deliver(sk, p->skb2);
if (val < 0) { if (val < 0) {
netlink_overrun(sk); netlink_overrun(sk);
if (nlk->flags & NETLINK_F_BROADCAST_SEND_ERROR) if (nlk_test_bit(BROADCAST_SEND_ERROR, sk))
p->delivery_failure = 1; p->delivery_failure = 1;
} else { } else {
p->congested |= val; p->congested |= val;
...@@ -1604,7 +1600,7 @@ static int do_one_set_err(struct sock *sk, struct netlink_set_err_data *p) ...@@ -1604,7 +1600,7 @@ static int do_one_set_err(struct sock *sk, struct netlink_set_err_data *p)
!test_bit(p->group - 1, nlk->groups)) !test_bit(p->group - 1, nlk->groups))
goto out; goto out;
if (p->code == ENOBUFS && nlk->flags & NETLINK_F_RECV_NO_ENOBUFS) { if (p->code == ENOBUFS && nlk_test_bit(RECV_NO_ENOBUFS, sk)) {
ret = 1; ret = 1;
goto out; goto out;
} }
...@@ -1668,7 +1664,7 @@ static int netlink_setsockopt(struct socket *sock, int level, int optname, ...@@ -1668,7 +1664,7 @@ static int netlink_setsockopt(struct socket *sock, int level, int optname,
struct sock *sk = sock->sk; struct sock *sk = sock->sk;
struct netlink_sock *nlk = nlk_sk(sk); struct netlink_sock *nlk = nlk_sk(sk);
unsigned int val = 0; unsigned int val = 0;
int err; int nr = -1;
if (level != SOL_NETLINK) if (level != SOL_NETLINK)
return -ENOPROTOOPT; return -ENOPROTOOPT;
...@@ -1679,14 +1675,12 @@ static int netlink_setsockopt(struct socket *sock, int level, int optname, ...@@ -1679,14 +1675,12 @@ static int netlink_setsockopt(struct socket *sock, int level, int optname,
switch (optname) { switch (optname) {
case NETLINK_PKTINFO: case NETLINK_PKTINFO:
if (val) nr = NETLINK_F_RECV_PKTINFO;
nlk->flags |= NETLINK_F_RECV_PKTINFO;
else
nlk->flags &= ~NETLINK_F_RECV_PKTINFO;
err = 0;
break; break;
case NETLINK_ADD_MEMBERSHIP: case NETLINK_ADD_MEMBERSHIP:
case NETLINK_DROP_MEMBERSHIP: { case NETLINK_DROP_MEMBERSHIP: {
int err;
if (!netlink_allowed(sock, NL_CFG_F_NONROOT_RECV)) if (!netlink_allowed(sock, NL_CFG_F_NONROOT_RECV))
return -EPERM; return -EPERM;
err = netlink_realloc_groups(sk); err = netlink_realloc_groups(sk);
...@@ -1706,61 +1700,38 @@ static int netlink_setsockopt(struct socket *sock, int level, int optname, ...@@ -1706,61 +1700,38 @@ static int netlink_setsockopt(struct socket *sock, int level, int optname,
if (optname == NETLINK_DROP_MEMBERSHIP && nlk->netlink_unbind) if (optname == NETLINK_DROP_MEMBERSHIP && nlk->netlink_unbind)
nlk->netlink_unbind(sock_net(sk), val); nlk->netlink_unbind(sock_net(sk), val);
err = 0;
break; break;
} }
case NETLINK_BROADCAST_ERROR: case NETLINK_BROADCAST_ERROR:
if (val) nr = NETLINK_F_BROADCAST_SEND_ERROR;
nlk->flags |= NETLINK_F_BROADCAST_SEND_ERROR;
else
nlk->flags &= ~NETLINK_F_BROADCAST_SEND_ERROR;
err = 0;
break; break;
case NETLINK_NO_ENOBUFS: case NETLINK_NO_ENOBUFS:
assign_bit(NETLINK_F_RECV_NO_ENOBUFS, &nlk->flags, val);
if (val) { if (val) {
nlk->flags |= NETLINK_F_RECV_NO_ENOBUFS;
clear_bit(NETLINK_S_CONGESTED, &nlk->state); clear_bit(NETLINK_S_CONGESTED, &nlk->state);
wake_up_interruptible(&nlk->wait); wake_up_interruptible(&nlk->wait);
} else {
nlk->flags &= ~NETLINK_F_RECV_NO_ENOBUFS;
} }
err = 0;
break; break;
case NETLINK_LISTEN_ALL_NSID: case NETLINK_LISTEN_ALL_NSID:
if (!ns_capable(sock_net(sk)->user_ns, CAP_NET_BROADCAST)) if (!ns_capable(sock_net(sk)->user_ns, CAP_NET_BROADCAST))
return -EPERM; return -EPERM;
nr = NETLINK_F_LISTEN_ALL_NSID;
if (val)
nlk->flags |= NETLINK_F_LISTEN_ALL_NSID;
else
nlk->flags &= ~NETLINK_F_LISTEN_ALL_NSID;
err = 0;
break; break;
case NETLINK_CAP_ACK: case NETLINK_CAP_ACK:
if (val) nr = NETLINK_F_CAP_ACK;
nlk->flags |= NETLINK_F_CAP_ACK;
else
nlk->flags &= ~NETLINK_F_CAP_ACK;
err = 0;
break; break;
case NETLINK_EXT_ACK: case NETLINK_EXT_ACK:
if (val) nr = NETLINK_F_EXT_ACK;
nlk->flags |= NETLINK_F_EXT_ACK;
else
nlk->flags &= ~NETLINK_F_EXT_ACK;
err = 0;
break; break;
case NETLINK_GET_STRICT_CHK: case NETLINK_GET_STRICT_CHK:
if (val) nr = NETLINK_F_STRICT_CHK;
nlk->flags |= NETLINK_F_STRICT_CHK;
else
nlk->flags &= ~NETLINK_F_STRICT_CHK;
err = 0;
break; break;
default: default:
err = -ENOPROTOOPT; return -ENOPROTOOPT;
} }
return err; if (nr >= 0)
assign_bit(nr, &nlk->flags, val);
return 0;
} }
static int netlink_getsockopt(struct socket *sock, int level, int optname, static int netlink_getsockopt(struct socket *sock, int level, int optname,
...@@ -1827,7 +1798,7 @@ static int netlink_getsockopt(struct socket *sock, int level, int optname, ...@@ -1827,7 +1798,7 @@ static int netlink_getsockopt(struct socket *sock, int level, int optname,
return -EINVAL; return -EINVAL;
len = sizeof(int); len = sizeof(int);
val = nlk->flags & flag ? 1 : 0; val = test_bit(flag, &nlk->flags);
if (put_user(len, optlen) || if (put_user(len, optlen) ||
copy_to_user(optval, &val, len)) copy_to_user(optval, &val, len))
...@@ -2004,9 +1975,9 @@ static int netlink_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, ...@@ -2004,9 +1975,9 @@ static int netlink_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
msg->msg_namelen = sizeof(*addr); msg->msg_namelen = sizeof(*addr);
} }
if (nlk->flags & NETLINK_F_RECV_PKTINFO) if (nlk_test_bit(RECV_PKTINFO, sk))
netlink_cmsg_recv_pktinfo(msg, skb); netlink_cmsg_recv_pktinfo(msg, skb);
if (nlk->flags & NETLINK_F_LISTEN_ALL_NSID) if (nlk_test_bit(LISTEN_ALL_NSID, sk))
netlink_cmsg_listen_all_nsid(sk, msg, skb); netlink_cmsg_listen_all_nsid(sk, msg, skb);
memset(&scm, 0, sizeof(scm)); memset(&scm, 0, sizeof(scm));
...@@ -2083,7 +2054,7 @@ __netlink_kernel_create(struct net *net, int unit, struct module *module, ...@@ -2083,7 +2054,7 @@ __netlink_kernel_create(struct net *net, int unit, struct module *module,
goto out_sock_release; goto out_sock_release;
nlk = nlk_sk(sk); nlk = nlk_sk(sk);
nlk->flags |= NETLINK_F_KERNEL_SOCKET; set_bit(NETLINK_F_KERNEL_SOCKET, &nlk->flags);
netlink_table_grab(); netlink_table_grab();
if (!nl_table[unit].registered) { if (!nl_table[unit].registered) {
...@@ -2218,7 +2189,7 @@ static int netlink_dump_done(struct netlink_sock *nlk, struct sk_buff *skb, ...@@ -2218,7 +2189,7 @@ static int netlink_dump_done(struct netlink_sock *nlk, struct sk_buff *skb,
nl_dump_check_consistent(cb, nlh); nl_dump_check_consistent(cb, nlh);
memcpy(nlmsg_data(nlh), &nlk->dump_done_errno, sizeof(nlk->dump_done_errno)); memcpy(nlmsg_data(nlh), &nlk->dump_done_errno, sizeof(nlk->dump_done_errno));
if (extack->_msg && nlk->flags & NETLINK_F_EXT_ACK) { if (extack->_msg && test_bit(NETLINK_F_EXT_ACK, &nlk->flags)) {
nlh->nlmsg_flags |= NLM_F_ACK_TLVS; nlh->nlmsg_flags |= NLM_F_ACK_TLVS;
if (!nla_put_string(skb, NLMSGERR_ATTR_MSG, extack->_msg)) if (!nla_put_string(skb, NLMSGERR_ATTR_MSG, extack->_msg))
nlmsg_end(skb, nlh); nlmsg_end(skb, nlh);
...@@ -2347,8 +2318,8 @@ int __netlink_dump_start(struct sock *ssk, struct sk_buff *skb, ...@@ -2347,8 +2318,8 @@ int __netlink_dump_start(struct sock *ssk, struct sk_buff *skb,
const struct nlmsghdr *nlh, const struct nlmsghdr *nlh,
struct netlink_dump_control *control) struct netlink_dump_control *control)
{ {
struct netlink_sock *nlk, *nlk2;
struct netlink_callback *cb; struct netlink_callback *cb;
struct netlink_sock *nlk;
struct sock *sk; struct sock *sk;
int ret; int ret;
...@@ -2383,8 +2354,7 @@ int __netlink_dump_start(struct sock *ssk, struct sk_buff *skb, ...@@ -2383,8 +2354,7 @@ int __netlink_dump_start(struct sock *ssk, struct sk_buff *skb,
cb->min_dump_alloc = control->min_dump_alloc; cb->min_dump_alloc = control->min_dump_alloc;
cb->skb = skb; cb->skb = skb;
nlk2 = nlk_sk(NETLINK_CB(skb).sk); cb->strict_check = nlk_test_bit(STRICT_CHK, NETLINK_CB(skb).sk);
cb->strict_check = !!(nlk2->flags & NETLINK_F_STRICT_CHK);
if (control->start) { if (control->start) {
cb->extack = control->extack; cb->extack = control->extack;
...@@ -2428,7 +2398,7 @@ netlink_ack_tlv_len(struct netlink_sock *nlk, int err, ...@@ -2428,7 +2398,7 @@ netlink_ack_tlv_len(struct netlink_sock *nlk, int err,
{ {
size_t tlvlen; size_t tlvlen;
if (!extack || !(nlk->flags & NETLINK_F_EXT_ACK)) if (!extack || !test_bit(NETLINK_F_EXT_ACK, &nlk->flags))
return 0; return 0;
tlvlen = 0; tlvlen = 0;
...@@ -2500,7 +2470,7 @@ void netlink_ack(struct sk_buff *in_skb, struct nlmsghdr *nlh, int err, ...@@ -2500,7 +2470,7 @@ void netlink_ack(struct sk_buff *in_skb, struct nlmsghdr *nlh, int err,
* requests to cap the error message, and get extra error data if * requests to cap the error message, and get extra error data if
* requested. * requested.
*/ */
if (err && !(nlk->flags & NETLINK_F_CAP_ACK)) if (err && !test_bit(NETLINK_F_CAP_ACK, &nlk->flags))
payload += nlmsg_len(nlh); payload += nlmsg_len(nlh);
else else
flags |= NLM_F_CAPPED; flags |= NLM_F_CAPPED;
......
...@@ -8,14 +8,16 @@ ...@@ -8,14 +8,16 @@
#include <net/sock.h> #include <net/sock.h>
/* flags */ /* flags */
#define NETLINK_F_KERNEL_SOCKET 0x1 enum {
#define NETLINK_F_RECV_PKTINFO 0x2 NETLINK_F_KERNEL_SOCKET,
#define NETLINK_F_BROADCAST_SEND_ERROR 0x4 NETLINK_F_RECV_PKTINFO,
#define NETLINK_F_RECV_NO_ENOBUFS 0x8 NETLINK_F_BROADCAST_SEND_ERROR,
#define NETLINK_F_LISTEN_ALL_NSID 0x10 NETLINK_F_RECV_NO_ENOBUFS,
#define NETLINK_F_CAP_ACK 0x20 NETLINK_F_LISTEN_ALL_NSID,
#define NETLINK_F_EXT_ACK 0x40 NETLINK_F_CAP_ACK,
#define NETLINK_F_STRICT_CHK 0x80 NETLINK_F_EXT_ACK,
NETLINK_F_STRICT_CHK,
};
#define NLGRPSZ(x) (ALIGN(x, sizeof(unsigned long) * 8) / 8) #define NLGRPSZ(x) (ALIGN(x, sizeof(unsigned long) * 8) / 8)
#define NLGRPLONGS(x) (NLGRPSZ(x)/sizeof(unsigned long)) #define NLGRPLONGS(x) (NLGRPSZ(x)/sizeof(unsigned long))
...@@ -23,10 +25,10 @@ ...@@ -23,10 +25,10 @@
struct netlink_sock { struct netlink_sock {
/* struct sock has to be the first member of netlink_sock */ /* struct sock has to be the first member of netlink_sock */
struct sock sk; struct sock sk;
unsigned long flags;
u32 portid; u32 portid;
u32 dst_portid; u32 dst_portid;
u32 dst_group; u32 dst_group;
u32 flags;
u32 subscriptions; u32 subscriptions;
u32 ngroups; u32 ngroups;
unsigned long *groups; unsigned long *groups;
...@@ -56,6 +58,8 @@ static inline struct netlink_sock *nlk_sk(struct sock *sk) ...@@ -56,6 +58,8 @@ static inline struct netlink_sock *nlk_sk(struct sock *sk)
return container_of(sk, struct netlink_sock, sk); return container_of(sk, struct netlink_sock, sk);
} }
#define nlk_test_bit(nr, sk) test_bit(NETLINK_F_##nr, &nlk_sk(sk)->flags)
struct netlink_table { struct netlink_table {
struct rhashtable hash; struct rhashtable hash;
struct hlist_head mc_list; struct hlist_head mc_list;
......
...@@ -27,15 +27,15 @@ static int sk_diag_put_flags(struct sock *sk, struct sk_buff *skb) ...@@ -27,15 +27,15 @@ static int sk_diag_put_flags(struct sock *sk, struct sk_buff *skb)
if (nlk->cb_running) if (nlk->cb_running)
flags |= NDIAG_FLAG_CB_RUNNING; flags |= NDIAG_FLAG_CB_RUNNING;
if (nlk->flags & NETLINK_F_RECV_PKTINFO) if (nlk_test_bit(RECV_PKTINFO, sk))
flags |= NDIAG_FLAG_PKTINFO; flags |= NDIAG_FLAG_PKTINFO;
if (nlk->flags & NETLINK_F_BROADCAST_SEND_ERROR) if (nlk_test_bit(BROADCAST_SEND_ERROR, sk))
flags |= NDIAG_FLAG_BROADCAST_ERROR; flags |= NDIAG_FLAG_BROADCAST_ERROR;
if (nlk->flags & NETLINK_F_RECV_NO_ENOBUFS) if (nlk_test_bit(RECV_NO_ENOBUFS, sk))
flags |= NDIAG_FLAG_NO_ENOBUFS; flags |= NDIAG_FLAG_NO_ENOBUFS;
if (nlk->flags & NETLINK_F_LISTEN_ALL_NSID) if (nlk_test_bit(LISTEN_ALL_NSID, sk))
flags |= NDIAG_FLAG_LISTEN_ALL_NSID; flags |= NDIAG_FLAG_LISTEN_ALL_NSID;
if (nlk->flags & NETLINK_F_CAP_ACK) if (nlk_test_bit(CAP_ACK, sk))
flags |= NDIAG_FLAG_CAP_ACK; flags |= NDIAG_FLAG_CAP_ACK;
return nla_put_u32(skb, NETLINK_DIAG_FLAGS, flags); return nla_put_u32(skb, NETLINK_DIAG_FLAGS, flags);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment