Commit 19a9fbc0 authored by David S. Miller's avatar David S. Miller

Merge branch 'net-packet-KCSAN'

Eric Dumazet says:

====================
net/packet: KCSAN awareness

This series is based on one syzbot report [1]

Seven 'flags/booleans' are converted to atomic bit variant.

po->xmit and po->tp_tstamp accesses get annotations.

[1]
BUG: KCSAN: data-race in packet_rcv / packet_setsockopt

read-write to 0xffff88813dbe84e4 of 1 bytes by task 12312 on cpu 0:
packet_setsockopt+0xb77/0xe60 net/packet/af_packet.c:3900
__sys_setsockopt+0x212/0x2b0 net/socket.c:2252
__do_sys_setsockopt net/socket.c:2263 [inline]
__se_sys_setsockopt net/socket.c:2260 [inline]
__x64_sys_setsockopt+0x62/0x70 net/socket.c:2260
do_syscall_x64 arch/x86/entry/common.c:50 [inline]
do_syscall_64+0x2b/0x70 arch/x86/entry/common.c:80
entry_SYSCALL_64_after_hwframe+0x63/0xcd

read to 0xffff88813dbe84e4 of 1 bytes by task 1911 on cpu 1:
packet_rcv+0x4b1/0xa40 net/packet/af_packet.c:2187
deliver_skb net/core/dev.c:2189 [inline]
dev_queue_xmit_nit+0x3a9/0x620 net/core/dev.c:2259
xmit_one+0x71/0x2a0 net/core/dev.c:3586
dev_hard_start_xmit+0x72/0x120 net/core/dev.c:3606
__dev_queue_xmit+0x91c/0x11c0 net/core/dev.c:4256
dev_queue_xmit include/linux/netdevice.h:3008 [inline]
neigh_hh_output include/net/neighbour.h:530 [inline]
neigh_output include/net/neighbour.h:544 [inline]
ip6_finish_output2+0x9e9/0xc30 net/ipv6/ip6_output.c:134
__ip6_finish_output net/ipv6/ip6_output.c:195 [inline]
ip6_finish_output+0x395/0x4f0 net/ipv6/ip6_output.c:206
NF_HOOK_COND include/linux/netfilter.h:291 [inline]
ip6_output+0x10e/0x210 net/ipv6/ip6_output.c:227
dst_output include/net/dst.h:445 [inline]
ip6_local_out+0x60/0x80 net/ipv6/output_core.c:161
ip6tunnel_xmit include/net/ip6_tunnel.h:161 [inline]
udp_tunnel6_xmit_skb+0x321/0x4a0 net/ipv6/ip6_udp_tunnel.c:109
send6+0x2ed/0x3b0 drivers/net/wireguard/socket.c:152
wg_socket_send_skb_to_peer+0xbb/0x120 drivers/net/wireguard/socket.c:178
wg_packet_create_data_done drivers/net/wireguard/send.c:251 [inline]
wg_packet_tx_worker+0x142/0x360 drivers/net/wireguard/send.c:276
process_one_work+0x3d3/0x720 kernel/workqueue.c:2289
worker_thread+0x618/0xa70 kernel/workqueue.c:2436
kthread+0x1a9/0x1e0 kernel/kthread.c:376
ret_from_fork+0x1f/0x30 arch/x86/entry/entry_64.S:306
====================
Reviewed-by: default avatarWillem de Bruijn <willemb@google.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents dc021e6c 791a3e9f
...@@ -307,7 +307,8 @@ static void packet_cached_dev_reset(struct packet_sock *po) ...@@ -307,7 +307,8 @@ static void packet_cached_dev_reset(struct packet_sock *po)
static bool packet_use_direct_xmit(const struct packet_sock *po) static bool packet_use_direct_xmit(const struct packet_sock *po)
{ {
return po->xmit == packet_direct_xmit; /* Paired with WRITE_ONCE() in packet_setsockopt() */
return READ_ONCE(po->xmit) == packet_direct_xmit;
} }
static u16 packet_pick_tx_queue(struct sk_buff *skb) static u16 packet_pick_tx_queue(struct sk_buff *skb)
...@@ -339,14 +340,14 @@ static void __register_prot_hook(struct sock *sk) ...@@ -339,14 +340,14 @@ static void __register_prot_hook(struct sock *sk)
{ {
struct packet_sock *po = pkt_sk(sk); struct packet_sock *po = pkt_sk(sk);
if (!po->running) { if (!packet_sock_flag(po, PACKET_SOCK_RUNNING)) {
if (po->fanout) if (po->fanout)
__fanout_link(sk, po); __fanout_link(sk, po);
else else
dev_add_pack(&po->prot_hook); dev_add_pack(&po->prot_hook);
sock_hold(sk); sock_hold(sk);
po->running = 1; packet_sock_flag_set(po, PACKET_SOCK_RUNNING, 1);
} }
} }
...@@ -368,7 +369,7 @@ static void __unregister_prot_hook(struct sock *sk, bool sync) ...@@ -368,7 +369,7 @@ static void __unregister_prot_hook(struct sock *sk, bool sync)
lockdep_assert_held_once(&po->bind_lock); lockdep_assert_held_once(&po->bind_lock);
po->running = 0; packet_sock_flag_set(po, PACKET_SOCK_RUNNING, 0);
if (po->fanout) if (po->fanout)
__fanout_unlink(sk, po); __fanout_unlink(sk, po);
...@@ -388,7 +389,7 @@ static void unregister_prot_hook(struct sock *sk, bool sync) ...@@ -388,7 +389,7 @@ static void unregister_prot_hook(struct sock *sk, bool sync)
{ {
struct packet_sock *po = pkt_sk(sk); struct packet_sock *po = pkt_sk(sk);
if (po->running) if (packet_sock_flag(po, PACKET_SOCK_RUNNING))
__unregister_prot_hook(sk, sync); __unregister_prot_hook(sk, sync);
} }
...@@ -473,7 +474,7 @@ static __u32 __packet_set_timestamp(struct packet_sock *po, void *frame, ...@@ -473,7 +474,7 @@ static __u32 __packet_set_timestamp(struct packet_sock *po, void *frame,
struct timespec64 ts; struct timespec64 ts;
__u32 ts_status; __u32 ts_status;
if (!(ts_status = tpacket_get_timestamp(skb, &ts, po->tp_tstamp))) if (!(ts_status = tpacket_get_timestamp(skb, &ts, READ_ONCE(po->tp_tstamp))))
return 0; return 0;
h.raw = frame; h.raw = frame;
...@@ -1306,22 +1307,23 @@ static int __packet_rcv_has_room(const struct packet_sock *po, ...@@ -1306,22 +1307,23 @@ static int __packet_rcv_has_room(const struct packet_sock *po,
static int packet_rcv_has_room(struct packet_sock *po, struct sk_buff *skb) static int packet_rcv_has_room(struct packet_sock *po, struct sk_buff *skb)
{ {
int pressure, ret; bool pressure;
int ret;
ret = __packet_rcv_has_room(po, skb); ret = __packet_rcv_has_room(po, skb);
pressure = ret != ROOM_NORMAL; pressure = ret != ROOM_NORMAL;
if (READ_ONCE(po->pressure) != pressure) if (packet_sock_flag(po, PACKET_SOCK_PRESSURE) != pressure)
WRITE_ONCE(po->pressure, pressure); packet_sock_flag_set(po, PACKET_SOCK_PRESSURE, pressure);
return ret; return ret;
} }
static void packet_rcv_try_clear_pressure(struct packet_sock *po) static void packet_rcv_try_clear_pressure(struct packet_sock *po)
{ {
if (READ_ONCE(po->pressure) && if (packet_sock_flag(po, PACKET_SOCK_PRESSURE) &&
__packet_rcv_has_room(po, NULL) == ROOM_NORMAL) __packet_rcv_has_room(po, NULL) == ROOM_NORMAL)
WRITE_ONCE(po->pressure, 0); packet_sock_flag_set(po, PACKET_SOCK_PRESSURE, false);
} }
static void packet_sock_destruct(struct sock *sk) static void packet_sock_destruct(struct sock *sk)
...@@ -1408,7 +1410,8 @@ static unsigned int fanout_demux_rollover(struct packet_fanout *f, ...@@ -1408,7 +1410,8 @@ static unsigned int fanout_demux_rollover(struct packet_fanout *f,
i = j = min_t(int, po->rollover->sock, num - 1); i = j = min_t(int, po->rollover->sock, num - 1);
do { do {
po_next = pkt_sk(rcu_dereference(f->arr[i])); po_next = pkt_sk(rcu_dereference(f->arr[i]));
if (po_next != po_skip && !READ_ONCE(po_next->pressure) && if (po_next != po_skip &&
!packet_sock_flag(po_next, PACKET_SOCK_PRESSURE) &&
packet_rcv_has_room(po_next, skb) == ROOM_NORMAL) { packet_rcv_has_room(po_next, skb) == ROOM_NORMAL) {
if (i != j) if (i != j)
po->rollover->sock = i; po->rollover->sock = i;
...@@ -1781,7 +1784,7 @@ static int fanout_add(struct sock *sk, struct fanout_args *args) ...@@ -1781,7 +1784,7 @@ static int fanout_add(struct sock *sk, struct fanout_args *args)
err = -EINVAL; err = -EINVAL;
spin_lock(&po->bind_lock); spin_lock(&po->bind_lock);
if (po->running && if (packet_sock_flag(po, PACKET_SOCK_RUNNING) &&
match->type == type && match->type == type &&
match->prot_hook.type == po->prot_hook.type && match->prot_hook.type == po->prot_hook.type &&
match->prot_hook.dev == po->prot_hook.dev) { match->prot_hook.dev == po->prot_hook.dev) {
...@@ -2183,7 +2186,7 @@ static int packet_rcv(struct sk_buff *skb, struct net_device *dev, ...@@ -2183,7 +2186,7 @@ static int packet_rcv(struct sk_buff *skb, struct net_device *dev,
sll = &PACKET_SKB_CB(skb)->sa.ll; sll = &PACKET_SKB_CB(skb)->sa.ll;
sll->sll_hatype = dev->type; sll->sll_hatype = dev->type;
sll->sll_pkttype = skb->pkt_type; sll->sll_pkttype = skb->pkt_type;
if (unlikely(po->origdev)) if (unlikely(packet_sock_flag(po, PACKET_SOCK_ORIGDEV)))
sll->sll_ifindex = orig_dev->ifindex; sll->sll_ifindex = orig_dev->ifindex;
else else
sll->sll_ifindex = dev->ifindex; sll->sll_ifindex = dev->ifindex;
...@@ -2308,7 +2311,7 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev, ...@@ -2308,7 +2311,7 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
netoff = TPACKET_ALIGN(po->tp_hdrlen + netoff = TPACKET_ALIGN(po->tp_hdrlen +
(maclen < 16 ? 16 : maclen)) + (maclen < 16 ? 16 : maclen)) +
po->tp_reserve; po->tp_reserve;
if (po->has_vnet_hdr) { if (packet_sock_flag(po, PACKET_SOCK_HAS_VNET_HDR)) {
netoff += sizeof(struct virtio_net_hdr); netoff += sizeof(struct virtio_net_hdr);
do_vnet = true; do_vnet = true;
} }
...@@ -2402,7 +2405,8 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev, ...@@ -2402,7 +2405,8 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
* closer to the time of capture. * closer to the time of capture.
*/ */
ts_status = tpacket_get_timestamp(skb, &ts, ts_status = tpacket_get_timestamp(skb, &ts,
po->tp_tstamp | SOF_TIMESTAMPING_SOFTWARE); READ_ONCE(po->tp_tstamp) |
SOF_TIMESTAMPING_SOFTWARE);
if (!ts_status) if (!ts_status)
ktime_get_real_ts64(&ts); ktime_get_real_ts64(&ts);
...@@ -2460,7 +2464,7 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev, ...@@ -2460,7 +2464,7 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
sll->sll_hatype = dev->type; sll->sll_hatype = dev->type;
sll->sll_protocol = skb->protocol; sll->sll_protocol = skb->protocol;
sll->sll_pkttype = skb->pkt_type; sll->sll_pkttype = skb->pkt_type;
if (unlikely(po->origdev)) if (unlikely(packet_sock_flag(po, PACKET_SOCK_ORIGDEV)))
sll->sll_ifindex = orig_dev->ifindex; sll->sll_ifindex = orig_dev->ifindex;
else else
sll->sll_ifindex = dev->ifindex; sll->sll_ifindex = dev->ifindex;
...@@ -2670,7 +2674,7 @@ static int tpacket_parse_header(struct packet_sock *po, void *frame, ...@@ -2670,7 +2674,7 @@ static int tpacket_parse_header(struct packet_sock *po, void *frame,
return -EMSGSIZE; return -EMSGSIZE;
} }
if (unlikely(po->tp_tx_has_off)) { if (unlikely(packet_sock_flag(po, PACKET_SOCK_TX_HAS_OFF))) {
int off_min, off_max; int off_min, off_max;
off_min = po->tp_hdrlen - sizeof(struct sockaddr_ll); off_min = po->tp_hdrlen - sizeof(struct sockaddr_ll);
...@@ -2778,7 +2782,8 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg) ...@@ -2778,7 +2782,8 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
size_max = po->tx_ring.frame_size size_max = po->tx_ring.frame_size
- (po->tp_hdrlen - sizeof(struct sockaddr_ll)); - (po->tp_hdrlen - sizeof(struct sockaddr_ll));
if ((size_max > dev->mtu + reserve + VLAN_HLEN) && !po->has_vnet_hdr) if ((size_max > dev->mtu + reserve + VLAN_HLEN) &&
!packet_sock_flag(po, PACKET_SOCK_HAS_VNET_HDR))
size_max = dev->mtu + reserve + VLAN_HLEN; size_max = dev->mtu + reserve + VLAN_HLEN;
reinit_completion(&po->skb_completion); reinit_completion(&po->skb_completion);
...@@ -2807,7 +2812,7 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg) ...@@ -2807,7 +2812,7 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
status = TP_STATUS_SEND_REQUEST; status = TP_STATUS_SEND_REQUEST;
hlen = LL_RESERVED_SPACE(dev); hlen = LL_RESERVED_SPACE(dev);
tlen = dev->needed_tailroom; tlen = dev->needed_tailroom;
if (po->has_vnet_hdr) { if (packet_sock_flag(po, PACKET_SOCK_HAS_VNET_HDR)) {
vnet_hdr = data; vnet_hdr = data;
data += sizeof(*vnet_hdr); data += sizeof(*vnet_hdr);
tp_len -= sizeof(*vnet_hdr); tp_len -= sizeof(*vnet_hdr);
...@@ -2835,13 +2840,13 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg) ...@@ -2835,13 +2840,13 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
addr, hlen, copylen, &sockc); addr, hlen, copylen, &sockc);
if (likely(tp_len >= 0) && if (likely(tp_len >= 0) &&
tp_len > dev->mtu + reserve && tp_len > dev->mtu + reserve &&
!po->has_vnet_hdr && !packet_sock_flag(po, PACKET_SOCK_HAS_VNET_HDR) &&
!packet_extra_vlan_len_allowed(dev, skb)) !packet_extra_vlan_len_allowed(dev, skb))
tp_len = -EMSGSIZE; tp_len = -EMSGSIZE;
if (unlikely(tp_len < 0)) { if (unlikely(tp_len < 0)) {
tpacket_error: tpacket_error:
if (po->tp_loss) { if (packet_sock_flag(po, PACKET_SOCK_TP_LOSS)) {
__packet_set_status(po, ph, __packet_set_status(po, ph,
TP_STATUS_AVAILABLE); TP_STATUS_AVAILABLE);
packet_increment_head(&po->tx_ring); packet_increment_head(&po->tx_ring);
...@@ -2854,7 +2859,7 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg) ...@@ -2854,7 +2859,7 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
} }
} }
if (po->has_vnet_hdr) { if (packet_sock_flag(po, PACKET_SOCK_HAS_VNET_HDR)) {
if (virtio_net_hdr_to_skb(skb, vnet_hdr, vio_le())) { if (virtio_net_hdr_to_skb(skb, vnet_hdr, vio_le())) {
tp_len = -EINVAL; tp_len = -EINVAL;
goto tpacket_error; goto tpacket_error;
...@@ -2867,7 +2872,8 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg) ...@@ -2867,7 +2872,8 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
packet_inc_pending(&po->tx_ring); packet_inc_pending(&po->tx_ring);
status = TP_STATUS_SEND_REQUEST; status = TP_STATUS_SEND_REQUEST;
err = po->xmit(skb); /* Paired with WRITE_ONCE() in packet_setsockopt() */
err = READ_ONCE(po->xmit)(skb);
if (unlikely(err != 0)) { if (unlikely(err != 0)) {
if (err > 0) if (err > 0)
err = net_xmit_errno(err); err = net_xmit_errno(err);
...@@ -2988,7 +2994,7 @@ static int packet_snd(struct socket *sock, struct msghdr *msg, size_t len) ...@@ -2988,7 +2994,7 @@ static int packet_snd(struct socket *sock, struct msghdr *msg, size_t len)
if (sock->type == SOCK_RAW) if (sock->type == SOCK_RAW)
reserve = dev->hard_header_len; reserve = dev->hard_header_len;
if (po->has_vnet_hdr) { if (packet_sock_flag(po, PACKET_SOCK_HAS_VNET_HDR)) {
err = packet_snd_vnet_parse(msg, &len, &vnet_hdr); err = packet_snd_vnet_parse(msg, &len, &vnet_hdr);
if (err) if (err)
goto out_unlock; goto out_unlock;
...@@ -3070,7 +3076,8 @@ static int packet_snd(struct socket *sock, struct msghdr *msg, size_t len) ...@@ -3070,7 +3076,8 @@ static int packet_snd(struct socket *sock, struct msghdr *msg, size_t len)
virtio_net_hdr_set_proto(skb, &vnet_hdr); virtio_net_hdr_set_proto(skb, &vnet_hdr);
} }
err = po->xmit(skb); /* Paired with WRITE_ONCE() in packet_setsockopt() */
err = READ_ONCE(po->xmit)(skb);
if (unlikely(err != 0)) { if (unlikely(err != 0)) {
if (err > 0) if (err > 0)
err = net_xmit_errno(err); err = net_xmit_errno(err);
...@@ -3217,7 +3224,7 @@ static int packet_do_bind(struct sock *sk, const char *name, int ifindex, ...@@ -3217,7 +3224,7 @@ static int packet_do_bind(struct sock *sk, const char *name, int ifindex,
if (need_rehook) { if (need_rehook) {
dev_hold(dev); dev_hold(dev);
if (po->running) { if (packet_sock_flag(po, PACKET_SOCK_RUNNING)) {
rcu_read_unlock(); rcu_read_unlock();
/* prevents packet_notifier() from calling /* prevents packet_notifier() from calling
* register_prot_hook() * register_prot_hook()
...@@ -3230,7 +3237,7 @@ static int packet_do_bind(struct sock *sk, const char *name, int ifindex, ...@@ -3230,7 +3237,7 @@ static int packet_do_bind(struct sock *sk, const char *name, int ifindex,
dev->ifindex); dev->ifindex);
} }
BUG_ON(po->running); BUG_ON(packet_sock_flag(po, PACKET_SOCK_RUNNING));
WRITE_ONCE(po->num, proto); WRITE_ONCE(po->num, proto);
po->prot_hook.type = proto; po->prot_hook.type = proto;
...@@ -3447,7 +3454,7 @@ static int packet_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, ...@@ -3447,7 +3454,7 @@ static int packet_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
packet_rcv_try_clear_pressure(pkt_sk(sk)); packet_rcv_try_clear_pressure(pkt_sk(sk));
if (pkt_sk(sk)->has_vnet_hdr) { if (packet_sock_flag(pkt_sk(sk), PACKET_SOCK_HAS_VNET_HDR)) {
err = packet_rcv_vnet(msg, skb, &len); err = packet_rcv_vnet(msg, skb, &len);
if (err) if (err)
goto out_free; goto out_free;
...@@ -3511,7 +3518,7 @@ static int packet_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, ...@@ -3511,7 +3518,7 @@ static int packet_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
memcpy(msg->msg_name, &PACKET_SKB_CB(skb)->sa, copy_len); memcpy(msg->msg_name, &PACKET_SKB_CB(skb)->sa, copy_len);
} }
if (pkt_sk(sk)->auxdata) { if (packet_sock_flag(pkt_sk(sk), PACKET_SOCK_AUXDATA)) {
struct tpacket_auxdata aux; struct tpacket_auxdata aux;
aux.tp_status = TP_STATUS_USER; aux.tp_status = TP_STATUS_USER;
...@@ -3882,7 +3889,7 @@ packet_setsockopt(struct socket *sock, int level, int optname, sockptr_t optval, ...@@ -3882,7 +3889,7 @@ packet_setsockopt(struct socket *sock, int level, int optname, sockptr_t optval,
if (po->rx_ring.pg_vec || po->tx_ring.pg_vec) { if (po->rx_ring.pg_vec || po->tx_ring.pg_vec) {
ret = -EBUSY; ret = -EBUSY;
} else { } else {
po->tp_loss = !!val; packet_sock_flag_set(po, PACKET_SOCK_TP_LOSS, val);
ret = 0; ret = 0;
} }
release_sock(sk); release_sock(sk);
...@@ -3897,9 +3904,7 @@ packet_setsockopt(struct socket *sock, int level, int optname, sockptr_t optval, ...@@ -3897,9 +3904,7 @@ packet_setsockopt(struct socket *sock, int level, int optname, sockptr_t optval,
if (copy_from_sockptr(&val, optval, sizeof(val))) if (copy_from_sockptr(&val, optval, sizeof(val)))
return -EFAULT; return -EFAULT;
lock_sock(sk); packet_sock_flag_set(po, PACKET_SOCK_AUXDATA, val);
po->auxdata = !!val;
release_sock(sk);
return 0; return 0;
} }
case PACKET_ORIGDEV: case PACKET_ORIGDEV:
...@@ -3911,9 +3916,7 @@ packet_setsockopt(struct socket *sock, int level, int optname, sockptr_t optval, ...@@ -3911,9 +3916,7 @@ packet_setsockopt(struct socket *sock, int level, int optname, sockptr_t optval,
if (copy_from_sockptr(&val, optval, sizeof(val))) if (copy_from_sockptr(&val, optval, sizeof(val)))
return -EFAULT; return -EFAULT;
lock_sock(sk); packet_sock_flag_set(po, PACKET_SOCK_ORIGDEV, val);
po->origdev = !!val;
release_sock(sk);
return 0; return 0;
} }
case PACKET_VNET_HDR: case PACKET_VNET_HDR:
...@@ -3931,7 +3934,7 @@ packet_setsockopt(struct socket *sock, int level, int optname, sockptr_t optval, ...@@ -3931,7 +3934,7 @@ packet_setsockopt(struct socket *sock, int level, int optname, sockptr_t optval,
if (po->rx_ring.pg_vec || po->tx_ring.pg_vec) { if (po->rx_ring.pg_vec || po->tx_ring.pg_vec) {
ret = -EBUSY; ret = -EBUSY;
} else { } else {
po->has_vnet_hdr = !!val; packet_sock_flag_set(po, PACKET_SOCK_HAS_VNET_HDR, val);
ret = 0; ret = 0;
} }
release_sock(sk); release_sock(sk);
...@@ -3946,7 +3949,7 @@ packet_setsockopt(struct socket *sock, int level, int optname, sockptr_t optval, ...@@ -3946,7 +3949,7 @@ packet_setsockopt(struct socket *sock, int level, int optname, sockptr_t optval,
if (copy_from_sockptr(&val, optval, sizeof(val))) if (copy_from_sockptr(&val, optval, sizeof(val)))
return -EFAULT; return -EFAULT;
po->tp_tstamp = val; WRITE_ONCE(po->tp_tstamp, val);
return 0; return 0;
} }
case PACKET_FANOUT: case PACKET_FANOUT:
...@@ -3993,7 +3996,7 @@ packet_setsockopt(struct socket *sock, int level, int optname, sockptr_t optval, ...@@ -3993,7 +3996,7 @@ packet_setsockopt(struct socket *sock, int level, int optname, sockptr_t optval,
lock_sock(sk); lock_sock(sk);
if (!po->rx_ring.pg_vec && !po->tx_ring.pg_vec) if (!po->rx_ring.pg_vec && !po->tx_ring.pg_vec)
po->tp_tx_has_off = !!val; packet_sock_flag_set(po, PACKET_SOCK_TX_HAS_OFF, val);
release_sock(sk); release_sock(sk);
return 0; return 0;
...@@ -4007,7 +4010,8 @@ packet_setsockopt(struct socket *sock, int level, int optname, sockptr_t optval, ...@@ -4007,7 +4010,8 @@ packet_setsockopt(struct socket *sock, int level, int optname, sockptr_t optval,
if (copy_from_sockptr(&val, optval, sizeof(val))) if (copy_from_sockptr(&val, optval, sizeof(val)))
return -EFAULT; return -EFAULT;
po->xmit = val ? packet_direct_xmit : dev_queue_xmit; /* Paired with all lockless reads of po->xmit */
WRITE_ONCE(po->xmit, val ? packet_direct_xmit : dev_queue_xmit);
return 0; return 0;
} }
default: default:
...@@ -4058,13 +4062,13 @@ static int packet_getsockopt(struct socket *sock, int level, int optname, ...@@ -4058,13 +4062,13 @@ static int packet_getsockopt(struct socket *sock, int level, int optname,
break; break;
case PACKET_AUXDATA: case PACKET_AUXDATA:
val = po->auxdata; val = packet_sock_flag(po, PACKET_SOCK_AUXDATA);
break; break;
case PACKET_ORIGDEV: case PACKET_ORIGDEV:
val = po->origdev; val = packet_sock_flag(po, PACKET_SOCK_ORIGDEV);
break; break;
case PACKET_VNET_HDR: case PACKET_VNET_HDR:
val = po->has_vnet_hdr; val = packet_sock_flag(po, PACKET_SOCK_HAS_VNET_HDR);
break; break;
case PACKET_VERSION: case PACKET_VERSION:
val = po->tp_version; val = po->tp_version;
...@@ -4094,10 +4098,10 @@ static int packet_getsockopt(struct socket *sock, int level, int optname, ...@@ -4094,10 +4098,10 @@ static int packet_getsockopt(struct socket *sock, int level, int optname,
val = po->tp_reserve; val = po->tp_reserve;
break; break;
case PACKET_LOSS: case PACKET_LOSS:
val = po->tp_loss; val = packet_sock_flag(po, PACKET_SOCK_TP_LOSS);
break; break;
case PACKET_TIMESTAMP: case PACKET_TIMESTAMP:
val = po->tp_tstamp; val = READ_ONCE(po->tp_tstamp);
break; break;
case PACKET_FANOUT: case PACKET_FANOUT:
val = (po->fanout ? val = (po->fanout ?
...@@ -4119,7 +4123,7 @@ static int packet_getsockopt(struct socket *sock, int level, int optname, ...@@ -4119,7 +4123,7 @@ static int packet_getsockopt(struct socket *sock, int level, int optname,
lv = sizeof(rstats); lv = sizeof(rstats);
break; break;
case PACKET_TX_HAS_OFF: case PACKET_TX_HAS_OFF:
val = po->tp_tx_has_off; val = packet_sock_flag(po, PACKET_SOCK_TX_HAS_OFF);
break; break;
case PACKET_QDISC_BYPASS: case PACKET_QDISC_BYPASS:
val = packet_use_direct_xmit(po); val = packet_use_direct_xmit(po);
...@@ -4157,7 +4161,7 @@ static int packet_notifier(struct notifier_block *this, ...@@ -4157,7 +4161,7 @@ static int packet_notifier(struct notifier_block *this,
case NETDEV_DOWN: case NETDEV_DOWN:
if (dev->ifindex == po->ifindex) { if (dev->ifindex == po->ifindex) {
spin_lock(&po->bind_lock); spin_lock(&po->bind_lock);
if (po->running) { if (packet_sock_flag(po, PACKET_SOCK_RUNNING)) {
__unregister_prot_hook(sk, false); __unregister_prot_hook(sk, false);
sk->sk_err = ENETDOWN; sk->sk_err = ENETDOWN;
if (!sock_flag(sk, SOCK_DEAD)) if (!sock_flag(sk, SOCK_DEAD))
...@@ -4468,7 +4472,7 @@ static int packet_set_ring(struct sock *sk, union tpacket_req_u *req_u, ...@@ -4468,7 +4472,7 @@ static int packet_set_ring(struct sock *sk, union tpacket_req_u *req_u,
/* Detach socket from network */ /* Detach socket from network */
spin_lock(&po->bind_lock); spin_lock(&po->bind_lock);
was_running = po->running; was_running = packet_sock_flag(po, PACKET_SOCK_RUNNING);
num = po->num; num = po->num;
if (was_running) { if (was_running) {
WRITE_ONCE(po->num, 0); WRITE_ONCE(po->num, 0);
...@@ -4679,7 +4683,7 @@ static int packet_seq_show(struct seq_file *seq, void *v) ...@@ -4679,7 +4683,7 @@ static int packet_seq_show(struct seq_file *seq, void *v)
s->sk_type, s->sk_type,
ntohs(READ_ONCE(po->num)), ntohs(READ_ONCE(po->num)),
READ_ONCE(po->ifindex), READ_ONCE(po->ifindex),
po->running, packet_sock_flag(po, PACKET_SOCK_RUNNING),
atomic_read(&s->sk_rmem_alloc), atomic_read(&s->sk_rmem_alloc),
from_kuid_munged(seq_user_ns(seq), sock_i_uid(s)), from_kuid_munged(seq_user_ns(seq), sock_i_uid(s)),
sock_i_ino(s)); sock_i_ino(s));
......
...@@ -18,18 +18,18 @@ static int pdiag_put_info(const struct packet_sock *po, struct sk_buff *nlskb) ...@@ -18,18 +18,18 @@ static int pdiag_put_info(const struct packet_sock *po, struct sk_buff *nlskb)
pinfo.pdi_version = po->tp_version; pinfo.pdi_version = po->tp_version;
pinfo.pdi_reserve = po->tp_reserve; pinfo.pdi_reserve = po->tp_reserve;
pinfo.pdi_copy_thresh = po->copy_thresh; pinfo.pdi_copy_thresh = po->copy_thresh;
pinfo.pdi_tstamp = po->tp_tstamp; pinfo.pdi_tstamp = READ_ONCE(po->tp_tstamp);
pinfo.pdi_flags = 0; pinfo.pdi_flags = 0;
if (po->running) if (packet_sock_flag(po, PACKET_SOCK_RUNNING))
pinfo.pdi_flags |= PDI_RUNNING; pinfo.pdi_flags |= PDI_RUNNING;
if (po->auxdata) if (packet_sock_flag(po, PACKET_SOCK_AUXDATA))
pinfo.pdi_flags |= PDI_AUXDATA; pinfo.pdi_flags |= PDI_AUXDATA;
if (po->origdev) if (packet_sock_flag(po, PACKET_SOCK_ORIGDEV))
pinfo.pdi_flags |= PDI_ORIGDEV; pinfo.pdi_flags |= PDI_ORIGDEV;
if (po->has_vnet_hdr) if (packet_sock_flag(po, PACKET_SOCK_HAS_VNET_HDR))
pinfo.pdi_flags |= PDI_VNETHDR; pinfo.pdi_flags |= PDI_VNETHDR;
if (po->tp_loss) if (packet_sock_flag(po, PACKET_SOCK_TP_LOSS))
pinfo.pdi_flags |= PDI_LOSS; pinfo.pdi_flags |= PDI_LOSS;
return nla_put(nlskb, PACKET_DIAG_INFO, sizeof(pinfo), &pinfo); return nla_put(nlskb, PACKET_DIAG_INFO, sizeof(pinfo), &pinfo);
......
...@@ -116,13 +116,7 @@ struct packet_sock { ...@@ -116,13 +116,7 @@ struct packet_sock {
int copy_thresh; int copy_thresh;
spinlock_t bind_lock; spinlock_t bind_lock;
struct mutex pg_vec_lock; struct mutex pg_vec_lock;
unsigned int running; /* bind_lock must be held */ unsigned long flags;
unsigned int auxdata:1, /* writer must hold sock lock */
origdev:1,
has_vnet_hdr:1,
tp_loss:1,
tp_tx_has_off:1;
int pressure;
int ifindex; /* bound device */ int ifindex; /* bound device */
__be16 num; __be16 num;
struct packet_rollover *rollover; struct packet_rollover *rollover;
...@@ -144,4 +138,30 @@ static inline struct packet_sock *pkt_sk(struct sock *sk) ...@@ -144,4 +138,30 @@ static inline struct packet_sock *pkt_sk(struct sock *sk)
return (struct packet_sock *)sk; return (struct packet_sock *)sk;
} }
enum packet_sock_flags {
PACKET_SOCK_ORIGDEV,
PACKET_SOCK_AUXDATA,
PACKET_SOCK_TX_HAS_OFF,
PACKET_SOCK_TP_LOSS,
PACKET_SOCK_HAS_VNET_HDR,
PACKET_SOCK_RUNNING,
PACKET_SOCK_PRESSURE,
};
static inline void packet_sock_flag_set(struct packet_sock *po,
enum packet_sock_flags flag,
bool val)
{
if (val)
set_bit(flag, &po->flags);
else
clear_bit(flag, &po->flags);
}
static inline bool packet_sock_flag(const struct packet_sock *po,
enum packet_sock_flags flag)
{
return test_bit(flag, &po->flags);
}
#endif #endif
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment