Commit e97e68b5 authored by David S. Miller's avatar David S. Miller

Merge branch 'sk_bound_dev_if-annotations'

Eric Dumazet says:

====================
net: add annotations for sk->sk_bound_dev_if

While writes on sk->sk_bound_dev_if are protected by socket lock,
we have many lockless reads all over the places.

This is based on syzbot report found in the first patch changelog.

v2: inline ipv6 function only defined if IS_ENABLED(CONFIG_IPV6) (kernel bots)
    Change the INET6_MATCH() to inet6_match(), this is no longer a macro.
    Change INET_MATCH() to inet_match() (Olivier Hartkopp & Jakub Kicinski)

====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 7fa2e481 eda090c3
......@@ -103,15 +103,25 @@ struct sock *inet6_lookup(struct net *net, struct inet_hashinfo *hashinfo,
const int dif);
int inet6_hash(struct sock *sk);
#endif /* IS_ENABLED(CONFIG_IPV6) */
#define INET6_MATCH(__sk, __net, __saddr, __daddr, __ports, __dif, __sdif) \
(((__sk)->sk_portpair == (__ports)) && \
((__sk)->sk_family == AF_INET6) && \
ipv6_addr_equal(&(__sk)->sk_v6_daddr, (__saddr)) && \
ipv6_addr_equal(&(__sk)->sk_v6_rcv_saddr, (__daddr)) && \
(((__sk)->sk_bound_dev_if == (__dif)) || \
((__sk)->sk_bound_dev_if == (__sdif))) && \
net_eq(sock_net(__sk), (__net)))
static inline bool inet6_match(struct net *net, const struct sock *sk,
const struct in6_addr *saddr,
const struct in6_addr *daddr,
const __portpair ports,
const int dif, const int sdif)
{
int bound_dev_if;
if (!net_eq(sock_net(sk), net) ||
sk->sk_family != AF_INET6 ||
sk->sk_portpair != ports ||
!ipv6_addr_equal(&sk->sk_v6_daddr, saddr) ||
!ipv6_addr_equal(&sk->sk_v6_rcv_saddr, daddr))
return false;
bound_dev_if = READ_ONCE(sk->sk_bound_dev_if);
return bound_dev_if == dif || bound_dev_if == sdif;
}
#endif /* IS_ENABLED(CONFIG_IPV6) */
#endif /* _INET6_HASHTABLES_H */
......@@ -267,7 +267,7 @@ static inline struct sock *inet_lookup_listener(struct net *net,
((__force __u64)(__be32)(__saddr)))
#endif /* __BIG_ENDIAN */
static inline bool INET_MATCH(struct net *net, const struct sock *sk,
static inline bool inet_match(struct net *net, const struct sock *sk,
const __addrpair cookie, const __portpair ports,
int dif, int sdif)
{
......
......@@ -116,14 +116,15 @@ static inline u32 inet_request_mark(const struct sock *sk, struct sk_buff *skb)
static inline int inet_request_bound_dev_if(const struct sock *sk,
struct sk_buff *skb)
{
int bound_dev_if = READ_ONCE(sk->sk_bound_dev_if);
#ifdef CONFIG_NET_L3_MASTER_DEV
struct net *net = sock_net(sk);
if (!sk->sk_bound_dev_if && net->ipv4.sysctl_tcp_l3mdev_accept)
if (!bound_dev_if && net->ipv4.sysctl_tcp_l3mdev_accept)
return l3mdev_master_ifindex_by_index(net, skb->skb_iif);
#endif
return sk->sk_bound_dev_if;
return bound_dev_if;
}
static inline int inet_sk_bound_l3mdev(const struct sock *sk)
......
......@@ -93,7 +93,7 @@ static inline void ipcm_init_sk(struct ipcm_cookie *ipcm,
ipcm->sockc.mark = inet->sk.sk_mark;
ipcm->sockc.tsflags = inet->sk.sk_tsflags;
ipcm->oif = inet->sk.sk_bound_dev_if;
ipcm->oif = READ_ONCE(inet->sk.sk_bound_dev_if);
ipcm->addr = inet->inet_saddr;
}
......
......@@ -2875,13 +2875,14 @@ static inline void sk_pacing_shift_update(struct sock *sk, int val)
*/
static inline bool sk_dev_equal_l3scope(struct sock *sk, int dif)
{
int bound_dev_if = READ_ONCE(sk->sk_bound_dev_if);
int mdif;
if (!sk->sk_bound_dev_if || sk->sk_bound_dev_if == dif)
if (!bound_dev_if || bound_dev_if == dif)
return true;
mdif = l3mdev_master_ifindex_by_index(sock_net(sk), dif);
if (mdif && mdif == sk->sk_bound_dev_if)
if (mdif && mdif == bound_dev_if)
return true;
return false;
......
......@@ -635,7 +635,9 @@ static int sock_bindtoindex_locked(struct sock *sk, int ifindex)
if (ifindex < 0)
goto out;
sk->sk_bound_dev_if = ifindex;
/* Paired with all READ_ONCE() done locklessly. */
WRITE_ONCE(sk->sk_bound_dev_if, ifindex);
if (sk->sk_prot->rehash)
sk->sk_prot->rehash(sk);
sk_dst_reset(sk);
......@@ -713,10 +715,11 @@ static int sock_getbindtodevice(struct sock *sk, char __user *optval,
{
int ret = -ENOPROTOOPT;
#ifdef CONFIG_NETDEVICES
int bound_dev_if = READ_ONCE(sk->sk_bound_dev_if);
struct net *net = sock_net(sk);
char devname[IFNAMSIZ];
if (sk->sk_bound_dev_if == 0) {
if (bound_dev_if == 0) {
len = 0;
goto zero;
}
......@@ -725,7 +728,7 @@ static int sock_getbindtodevice(struct sock *sk, char __user *optval,
if (len < IFNAMSIZ)
goto out;
ret = netdev_get_name(net, devname, sk->sk_bound_dev_if);
ret = netdev_get_name(net, devname, bound_dev_if);
if (ret)
goto out;
......@@ -1861,7 +1864,7 @@ int sock_getsockopt(struct socket *sock, int level, int optname,
break;
case SO_BINDTOIFINDEX:
v.val = sk->sk_bound_dev_if;
v.val = READ_ONCE(sk->sk_bound_dev_if);
break;
case SO_NETNS_COOKIE:
......
......@@ -628,7 +628,7 @@ int dccp_v4_conn_request(struct sock *sk, struct sk_buff *skb)
sk_daddr_set(req_to_sk(req), ip_hdr(skb)->saddr);
ireq->ir_mark = inet_request_mark(sk, skb);
ireq->ireq_family = AF_INET;
ireq->ir_iif = sk->sk_bound_dev_if;
ireq->ir_iif = READ_ONCE(sk->sk_bound_dev_if);
/*
* Step 3: Process LISTEN state
......
......@@ -374,10 +374,10 @@ static int dccp_v6_conn_request(struct sock *sk, struct sk_buff *skb)
refcount_inc(&skb->users);
ireq->pktopts = skb;
}
ireq->ir_iif = sk->sk_bound_dev_if;
ireq->ir_iif = READ_ONCE(sk->sk_bound_dev_if);
/* So that link locals have meaning */
if (!sk->sk_bound_dev_if &&
if (!ireq->ir_iif &&
ipv6_addr_type(&ireq->ir_v6_rmt_addr) & IPV6_ADDR_LINKLOCAL)
ireq->ir_iif = inet6_iif(skb);
......
......@@ -155,10 +155,14 @@ static int inet_csk_bind_conflict(const struct sock *sk,
*/
sk_for_each_bound(sk2, &tb->owners) {
if (sk != sk2 &&
(!sk->sk_bound_dev_if ||
!sk2->sk_bound_dev_if ||
sk->sk_bound_dev_if == sk2->sk_bound_dev_if)) {
int bound_dev_if2;
if (sk == sk2)
continue;
bound_dev_if2 = READ_ONCE(sk2->sk_bound_dev_if);
if ((!sk->sk_bound_dev_if ||
!bound_dev_if2 ||
sk->sk_bound_dev_if == bound_dev_if2)) {
if (reuse && sk2->sk_reuse &&
sk2->sk_state != TCP_LISTEN) {
if ((!relax ||
......
......@@ -373,10 +373,10 @@ struct sock *__inet_lookup_established(struct net *net,
sk_nulls_for_each_rcu(sk, node, &head->chain) {
if (sk->sk_hash != hash)
continue;
if (likely(INET_MATCH(net, sk, acookie, ports, dif, sdif))) {
if (likely(inet_match(net, sk, acookie, ports, dif, sdif))) {
if (unlikely(!refcount_inc_not_zero(&sk->sk_refcnt)))
goto out;
if (unlikely(!INET_MATCH(net, sk, acookie,
if (unlikely(!inet_match(net, sk, acookie,
ports, dif, sdif))) {
sock_gen_put(sk);
goto begin;
......@@ -426,7 +426,7 @@ static int __inet_check_established(struct inet_timewait_death_row *death_row,
if (sk2->sk_hash != hash)
continue;
if (likely(INET_MATCH(net, sk2, acookie, ports, dif, sdif))) {
if (likely(inet_match(net, sk2, acookie, ports, dif, sdif))) {
if (sk2->sk_state == TCP_TIME_WAIT) {
tw = inet_twsk(sk2);
if (twsk_unique(sk, sk2, twp))
......@@ -492,14 +492,14 @@ static bool inet_ehash_lookup_by_sk(struct sock *sk,
if (esk->sk_hash != sk->sk_hash)
continue;
if (sk->sk_family == AF_INET) {
if (unlikely(INET_MATCH(net, esk, acookie,
if (unlikely(inet_match(net, esk, acookie,
ports, dif, sdif))) {
return true;
}
}
#if IS_ENABLED(CONFIG_IPV6)
else if (sk->sk_family == AF_INET6) {
if (unlikely(INET6_MATCH(esk, net,
if (unlikely(inet6_match(net, esk,
&sk->sk_v6_daddr,
&sk->sk_v6_rcv_saddr,
ports, dif, sdif))) {
......
......@@ -2563,7 +2563,7 @@ static struct sock *__udp4_lib_demux_lookup(struct net *net,
struct sock *sk;
udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) {
if (INET_MATCH(net, sk, acookie, ports, dif, sdif))
if (inet_match(net, sk, acookie, ports, dif, sdif))
return sk;
/* Only check first socket in chain */
break;
......
......@@ -218,11 +218,11 @@ int __ip6_datagram_connect(struct sock *sk, struct sockaddr *uaddr,
err = -EINVAL;
goto out;
}
sk->sk_bound_dev_if = usin->sin6_scope_id;
WRITE_ONCE(sk->sk_bound_dev_if, usin->sin6_scope_id);
}
if (!sk->sk_bound_dev_if && (addr_type & IPV6_ADDR_MULTICAST))
sk->sk_bound_dev_if = np->mcast_oif;
WRITE_ONCE(sk->sk_bound_dev_if, np->mcast_oif);
/* Connect to link-local address requires an interface */
if (!sk->sk_bound_dev_if) {
......@@ -798,7 +798,7 @@ int ip6_datagram_send_ctl(struct net *net, struct sock *sk,
if (src_idx) {
if (fl6->flowi6_oif &&
src_idx != fl6->flowi6_oif &&
(sk->sk_bound_dev_if != fl6->flowi6_oif ||
(READ_ONCE(sk->sk_bound_dev_if) != fl6->flowi6_oif ||
!sk_dev_equal_l3scope(sk, src_idx)))
return -EINVAL;
fl6->flowi6_oif = src_idx;
......
......@@ -71,12 +71,12 @@ struct sock *__inet6_lookup_established(struct net *net,
sk_nulls_for_each_rcu(sk, node, &head->chain) {
if (sk->sk_hash != hash)
continue;
if (!INET6_MATCH(sk, net, saddr, daddr, ports, dif, sdif))
if (!inet6_match(net, sk, saddr, daddr, ports, dif, sdif))
continue;
if (unlikely(!refcount_inc_not_zero(&sk->sk_refcnt)))
goto out;
if (unlikely(!INET6_MATCH(sk, net, saddr, daddr, ports, dif, sdif))) {
if (unlikely(!inet6_match(net, sk, saddr, daddr, ports, dif, sdif))) {
sock_gen_put(sk);
goto begin;
}
......@@ -268,7 +268,7 @@ static int __inet6_check_established(struct inet_timewait_death_row *death_row,
if (sk2->sk_hash != hash)
continue;
if (likely(INET6_MATCH(sk2, net, saddr, daddr, ports,
if (likely(inet6_match(net, sk2, saddr, daddr, ports,
dif, sdif))) {
if (sk2->sk_state == TCP_TIME_WAIT) {
tw = inet_twsk(sk2);
......
......@@ -105,7 +105,7 @@ static int compute_score(struct sock *sk, struct net *net,
const struct in6_addr *daddr, unsigned short hnum,
int dif, int sdif)
{
int score;
int bound_dev_if, score;
struct inet_sock *inet;
bool dev_match;
......@@ -132,10 +132,11 @@ static int compute_score(struct sock *sk, struct net *net,
score++;
}
dev_match = udp_sk_bound_dev_eq(net, sk->sk_bound_dev_if, dif, sdif);
bound_dev_if = READ_ONCE(sk->sk_bound_dev_if);
dev_match = udp_sk_bound_dev_eq(net, bound_dev_if, dif, sdif);
if (!dev_match)
return -1;
if (sk->sk_bound_dev_if)
if (bound_dev_if)
score++;
if (READ_ONCE(sk->sk_incoming_cpu) == raw_smp_processor_id())
......@@ -789,7 +790,7 @@ static bool __udp_v6_is_mcast_sock(struct net *net, struct sock *sk,
(inet->inet_dport && inet->inet_dport != rmt_port) ||
(!ipv6_addr_any(&sk->sk_v6_daddr) &&
!ipv6_addr_equal(&sk->sk_v6_daddr, rmt_addr)) ||
!udp_sk_bound_dev_eq(net, sk->sk_bound_dev_if, dif, sdif) ||
!udp_sk_bound_dev_eq(net, READ_ONCE(sk->sk_bound_dev_if), dif, sdif) ||
(!ipv6_addr_any(&sk->sk_v6_rcv_saddr) &&
!ipv6_addr_equal(&sk->sk_v6_rcv_saddr, loc_addr)))
return false;
......@@ -1043,7 +1044,7 @@ static struct sock *__udp6_lib_demux_lookup(struct net *net,
udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) {
if (sk->sk_state == TCP_ESTABLISHED &&
INET6_MATCH(sk, net, rmt_addr, loc_addr, ports, dif, sdif))
inet6_match(net, sk, rmt_addr, loc_addr, ports, dif, sdif))
return sk;
/* Only check first socket in chain */
break;
......@@ -1433,7 +1434,7 @@ int udpv6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
}
if (!fl6->flowi6_oif)
fl6->flowi6_oif = sk->sk_bound_dev_if;
fl6->flowi6_oif = READ_ONCE(sk->sk_bound_dev_if);
if (!fl6->flowi6_oif)
fl6->flowi6_oif = np->sticky_pktinfo.ipi6_ifindex;
......
......@@ -50,11 +50,13 @@ static struct sock *__l2tp_ip_bind_lookup(const struct net *net, __be32 laddr,
sk_for_each_bound(sk, &l2tp_ip_bind_table) {
const struct l2tp_ip_sock *l2tp = l2tp_ip_sk(sk);
const struct inet_sock *inet = inet_sk(sk);
int bound_dev_if;
if (!net_eq(sock_net(sk), net))
continue;
if (sk->sk_bound_dev_if && dif && sk->sk_bound_dev_if != dif)
bound_dev_if = READ_ONCE(sk->sk_bound_dev_if);
if (bound_dev_if && dif && bound_dev_if != dif)
continue;
if (inet->inet_rcv_saddr && laddr &&
......
......@@ -62,11 +62,13 @@ static struct sock *__l2tp_ip6_bind_lookup(const struct net *net,
const struct in6_addr *sk_laddr = inet6_rcv_saddr(sk);
const struct in6_addr *sk_raddr = &sk->sk_v6_daddr;
const struct l2tp_ip6_sock *l2tp = l2tp_ip6_sk(sk);
int bound_dev_if;
if (!net_eq(sock_net(sk), net))
continue;
if (sk->sk_bound_dev_if && dif && sk->sk_bound_dev_if != dif)
bound_dev_if = READ_ONCE(sk->sk_bound_dev_if);
if (bound_dev_if && dif && bound_dev_if != dif)
continue;
if (sk_laddr && !ipv6_addr_any(sk_laddr) &&
......@@ -445,7 +447,7 @@ static int l2tp_ip6_getname(struct socket *sock, struct sockaddr *uaddr,
lsa->l2tp_conn_id = lsk->conn_id;
}
if (ipv6_addr_type(&lsa->l2tp_addr) & IPV6_ADDR_LINKLOCAL)
lsa->l2tp_scope_id = sk->sk_bound_dev_if;
lsa->l2tp_scope_id = READ_ONCE(sk->sk_bound_dev_if);
return sizeof(*lsa);
}
......@@ -560,7 +562,7 @@ static int l2tp_ip6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
}
if (fl6.flowi6_oif == 0)
fl6.flowi6_oif = sk->sk_bound_dev_if;
fl6.flowi6_oif = READ_ONCE(sk->sk_bound_dev_if);
if (msg->msg_controllen) {
opt = &opt_space;
......
......@@ -311,12 +311,15 @@ META_COLLECTOR(int_sk_bound_if)
META_COLLECTOR(var_sk_bound_if)
{
int bound_dev_if;
if (skip_nonlocal(skb)) {
*err = -1;
return;
}
if (skb->sk->sk_bound_dev_if == 0) {
bound_dev_if = READ_ONCE(skb->sk->sk_bound_dev_if);
if (bound_dev_if == 0) {
dst->value = (unsigned long) "any";
dst->len = 3;
} else {
......@@ -324,7 +327,7 @@ META_COLLECTOR(var_sk_bound_if)
rcu_read_lock();
dev = dev_get_by_index_rcu(sock_net(skb->sk),
skb->sk->sk_bound_dev_if);
bound_dev_if);
*err = var_dev(dev, dst);
rcu_read_unlock();
}
......
......@@ -92,6 +92,7 @@ int sctp_rcv(struct sk_buff *skb)
struct sctp_chunk *chunk;
union sctp_addr src;
union sctp_addr dest;
int bound_dev_if;
int family;
struct sctp_af *af;
struct net *net = dev_net(skb->dev);
......@@ -169,7 +170,8 @@ int sctp_rcv(struct sk_buff *skb)
* If a frame arrives on an interface and the receiving socket is
* bound to another interface, via SO_BINDTODEVICE, treat it as OOTB
*/
if (sk->sk_bound_dev_if && (sk->sk_bound_dev_if != af->skb_iif(skb))) {
bound_dev_if = READ_ONCE(sk->sk_bound_dev_if);
if (bound_dev_if && (bound_dev_if != af->skb_iif(skb))) {
if (transport) {
sctp_transport_put(transport);
asoc = NULL;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment