Commit c274af22 authored by Eric Dumazet's avatar Eric Dumazet Committed by David S. Miller

inet: introduce inet->inet_flags

Various inet fields are currently racy.

do_ip_setsockopt() and do_ip_getsockopt() are mostly holding
the socket lock, but some (fast) paths do not.

Use a new inet->inet_flags to hold atomic bits in the series.

Remove inet->cmsg_flags, and use instead 9 bits from inet_flags.
Signed-off-by: default avatarEric Dumazet <edumazet@google.com>
Acked-by: default avatarSoheil Hassas Yeganeh <soheil@google.com>
Reviewed-by: default avatarSimon Horman <horms@kernel.org>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 936db833
...@@ -194,6 +194,7 @@ struct rtable; ...@@ -194,6 +194,7 @@ struct rtable;
* @inet_rcv_saddr - Bound local IPv4 addr * @inet_rcv_saddr - Bound local IPv4 addr
* @inet_dport - Destination port * @inet_dport - Destination port
* @inet_num - Local port * @inet_num - Local port
* @inet_flags - various atomic flags
* @inet_saddr - Sending source * @inet_saddr - Sending source
* @uc_ttl - Unicast TTL * @uc_ttl - Unicast TTL
* @inet_sport - Source port * @inet_sport - Source port
...@@ -218,11 +219,11 @@ struct inet_sock { ...@@ -218,11 +219,11 @@ struct inet_sock {
#define inet_dport sk.__sk_common.skc_dport #define inet_dport sk.__sk_common.skc_dport
#define inet_num sk.__sk_common.skc_num #define inet_num sk.__sk_common.skc_num
unsigned long inet_flags;
__be32 inet_saddr; __be32 inet_saddr;
__s16 uc_ttl; __s16 uc_ttl;
__u16 cmsg_flags;
struct ip_options_rcu __rcu *inet_opt;
__be16 inet_sport; __be16 inet_sport;
struct ip_options_rcu __rcu *inet_opt;
__u16 inet_id; __u16 inet_id;
__u8 tos; __u8 tos;
...@@ -259,16 +260,48 @@ struct inet_sock { ...@@ -259,16 +260,48 @@ struct inet_sock {
#define IPCORK_OPT 1 /* ip-options has been held in ipcork.opt */ #define IPCORK_OPT 1 /* ip-options has been held in ipcork.opt */
#define IPCORK_ALLFRAG 2 /* always fragment (for ipv6 for now) */ #define IPCORK_ALLFRAG 2 /* always fragment (for ipv6 for now) */
enum {
INET_FLAGS_PKTINFO = 0,
INET_FLAGS_TTL = 1,
INET_FLAGS_TOS = 2,
INET_FLAGS_RECVOPTS = 3,
INET_FLAGS_RETOPTS = 4,
INET_FLAGS_PASSSEC = 5,
INET_FLAGS_ORIGDSTADDR = 6,
INET_FLAGS_CHECKSUM = 7,
INET_FLAGS_RECVFRAGSIZE = 8,
};
/* cmsg flags for inet */ /* cmsg flags for inet */
#define IP_CMSG_PKTINFO BIT(0) #define IP_CMSG_PKTINFO BIT(INET_FLAGS_PKTINFO)
#define IP_CMSG_TTL BIT(1) #define IP_CMSG_TTL BIT(INET_FLAGS_TTL)
#define IP_CMSG_TOS BIT(2) #define IP_CMSG_TOS BIT(INET_FLAGS_TOS)
#define IP_CMSG_RECVOPTS BIT(3) #define IP_CMSG_RECVOPTS BIT(INET_FLAGS_RECVOPTS)
#define IP_CMSG_RETOPTS BIT(4) #define IP_CMSG_RETOPTS BIT(INET_FLAGS_RETOPTS)
#define IP_CMSG_PASSSEC BIT(5) #define IP_CMSG_PASSSEC BIT(INET_FLAGS_PASSSEC)
#define IP_CMSG_ORIGDSTADDR BIT(6) #define IP_CMSG_ORIGDSTADDR BIT(INET_FLAGS_ORIGDSTADDR)
#define IP_CMSG_CHECKSUM BIT(7) #define IP_CMSG_CHECKSUM BIT(INET_FLAGS_CHECKSUM)
#define IP_CMSG_RECVFRAGSIZE BIT(8) #define IP_CMSG_RECVFRAGSIZE BIT(INET_FLAGS_RECVFRAGSIZE)
#define IP_CMSG_ALL (IP_CMSG_PKTINFO | IP_CMSG_TTL | \
IP_CMSG_TOS | IP_CMSG_RECVOPTS | \
IP_CMSG_RETOPTS | IP_CMSG_PASSSEC | \
IP_CMSG_ORIGDSTADDR | IP_CMSG_CHECKSUM | \
IP_CMSG_RECVFRAGSIZE)
static inline unsigned long inet_cmsg_flags(const struct inet_sock *inet)
{
return READ_ONCE(inet->inet_flags) & IP_CMSG_ALL;
}
#define inet_test_bit(nr, sk) \
test_bit(INET_FLAGS_##nr, &inet_sk(sk)->inet_flags)
#define inet_set_bit(nr, sk) \
set_bit(INET_FLAGS_##nr, &inet_sk(sk)->inet_flags)
#define inet_clear_bit(nr, sk) \
clear_bit(INET_FLAGS_##nr, &inet_sk(sk)->inet_flags)
#define inet_assign_bit(nr, sk, val) \
assign_bit(INET_FLAGS_##nr, &inet_sk(sk)->inet_flags, val)
static inline bool sk_is_inet(struct sock *sk) static inline bool sk_is_inet(struct sock *sk)
{ {
......
...@@ -171,8 +171,10 @@ static void ip_cmsg_recv_dstaddr(struct msghdr *msg, struct sk_buff *skb) ...@@ -171,8 +171,10 @@ static void ip_cmsg_recv_dstaddr(struct msghdr *msg, struct sk_buff *skb)
void ip_cmsg_recv_offset(struct msghdr *msg, struct sock *sk, void ip_cmsg_recv_offset(struct msghdr *msg, struct sock *sk,
struct sk_buff *skb, int tlen, int offset) struct sk_buff *skb, int tlen, int offset)
{ {
struct inet_sock *inet = inet_sk(sk); unsigned long flags = inet_cmsg_flags(inet_sk(sk));
unsigned int flags = inet->cmsg_flags;
if (!flags)
return;
/* Ordered by supposed usage frequency */ /* Ordered by supposed usage frequency */
if (flags & IP_CMSG_PKTINFO) { if (flags & IP_CMSG_PKTINFO) {
...@@ -568,7 +570,7 @@ int ip_recv_error(struct sock *sk, struct msghdr *msg, int len, int *addr_len) ...@@ -568,7 +570,7 @@ int ip_recv_error(struct sock *sk, struct msghdr *msg, int len, int *addr_len)
if (ipv4_datagram_support_cmsg(sk, skb, serr->ee.ee_origin)) { if (ipv4_datagram_support_cmsg(sk, skb, serr->ee.ee_origin)) {
sin->sin_family = AF_INET; sin->sin_family = AF_INET;
sin->sin_addr.s_addr = ip_hdr(skb)->saddr; sin->sin_addr.s_addr = ip_hdr(skb)->saddr;
if (inet_sk(sk)->cmsg_flags) if (inet_cmsg_flags(inet_sk(sk)))
ip_cmsg_recv(msg, skb); ip_cmsg_recv(msg, skb);
} }
...@@ -635,7 +637,7 @@ EXPORT_SYMBOL(ip_sock_set_mtu_discover); ...@@ -635,7 +637,7 @@ EXPORT_SYMBOL(ip_sock_set_mtu_discover);
void ip_sock_set_pktinfo(struct sock *sk) void ip_sock_set_pktinfo(struct sock *sk)
{ {
lock_sock(sk); lock_sock(sk);
inet_sk(sk)->cmsg_flags |= IP_CMSG_PKTINFO; inet_set_bit(PKTINFO, sk);
release_sock(sk); release_sock(sk);
} }
EXPORT_SYMBOL(ip_sock_set_pktinfo); EXPORT_SYMBOL(ip_sock_set_pktinfo);
...@@ -990,67 +992,43 @@ int do_ip_setsockopt(struct sock *sk, int level, int optname, ...@@ -990,67 +992,43 @@ int do_ip_setsockopt(struct sock *sk, int level, int optname,
break; break;
} }
case IP_PKTINFO: case IP_PKTINFO:
if (val) inet_assign_bit(PKTINFO, sk, val);
inet->cmsg_flags |= IP_CMSG_PKTINFO;
else
inet->cmsg_flags &= ~IP_CMSG_PKTINFO;
break; break;
case IP_RECVTTL: case IP_RECVTTL:
if (val) inet_assign_bit(TTL, sk, val);
inet->cmsg_flags |= IP_CMSG_TTL;
else
inet->cmsg_flags &= ~IP_CMSG_TTL;
break; break;
case IP_RECVTOS: case IP_RECVTOS:
if (val) inet_assign_bit(TOS, sk, val);
inet->cmsg_flags |= IP_CMSG_TOS;
else
inet->cmsg_flags &= ~IP_CMSG_TOS;
break; break;
case IP_RECVOPTS: case IP_RECVOPTS:
if (val) inet_assign_bit(RECVOPTS, sk, val);
inet->cmsg_flags |= IP_CMSG_RECVOPTS;
else
inet->cmsg_flags &= ~IP_CMSG_RECVOPTS;
break; break;
case IP_RETOPTS: case IP_RETOPTS:
if (val) inet_assign_bit(RETOPTS, sk, val);
inet->cmsg_flags |= IP_CMSG_RETOPTS;
else
inet->cmsg_flags &= ~IP_CMSG_RETOPTS;
break; break;
case IP_PASSSEC: case IP_PASSSEC:
if (val) inet_assign_bit(PASSSEC, sk, val);
inet->cmsg_flags |= IP_CMSG_PASSSEC;
else
inet->cmsg_flags &= ~IP_CMSG_PASSSEC;
break; break;
case IP_RECVORIGDSTADDR: case IP_RECVORIGDSTADDR:
if (val) inet_assign_bit(ORIGDSTADDR, sk, val);
inet->cmsg_flags |= IP_CMSG_ORIGDSTADDR;
else
inet->cmsg_flags &= ~IP_CMSG_ORIGDSTADDR;
break; break;
case IP_CHECKSUM: case IP_CHECKSUM:
if (val) { if (val) {
if (!(inet->cmsg_flags & IP_CMSG_CHECKSUM)) { if (!(inet_test_bit(CHECKSUM, sk))) {
inet_inc_convert_csum(sk); inet_inc_convert_csum(sk);
inet->cmsg_flags |= IP_CMSG_CHECKSUM; inet_set_bit(CHECKSUM, sk);
} }
} else { } else {
if (inet->cmsg_flags & IP_CMSG_CHECKSUM) { if (inet_test_bit(CHECKSUM, sk)) {
inet_dec_convert_csum(sk); inet_dec_convert_csum(sk);
inet->cmsg_flags &= ~IP_CMSG_CHECKSUM; inet_clear_bit(CHECKSUM, sk);
} }
} }
break; break;
case IP_RECVFRAGSIZE: case IP_RECVFRAGSIZE:
if (sk->sk_type != SOCK_RAW && sk->sk_type != SOCK_DGRAM) if (sk->sk_type != SOCK_RAW && sk->sk_type != SOCK_DGRAM)
goto e_inval; goto e_inval;
if (val) inet_assign_bit(RECVFRAGSIZE, sk, val);
inet->cmsg_flags |= IP_CMSG_RECVFRAGSIZE;
else
inet->cmsg_flags &= ~IP_CMSG_RECVFRAGSIZE;
break; break;
case IP_TOS: /* This sets both TOS and Precedence */ case IP_TOS: /* This sets both TOS and Precedence */
__ip_sock_set_tos(sk, val); __ip_sock_set_tos(sk, val);
...@@ -1415,7 +1393,7 @@ int do_ip_setsockopt(struct sock *sk, int level, int optname, ...@@ -1415,7 +1393,7 @@ int do_ip_setsockopt(struct sock *sk, int level, int optname,
void ipv4_pktinfo_prepare(const struct sock *sk, struct sk_buff *skb) void ipv4_pktinfo_prepare(const struct sock *sk, struct sk_buff *skb)
{ {
struct in_pktinfo *pktinfo = PKTINFO_SKB_CB(skb); struct in_pktinfo *pktinfo = PKTINFO_SKB_CB(skb);
bool prepare = (inet_sk(sk)->cmsg_flags & IP_CMSG_PKTINFO) || bool prepare = inet_test_bit(PKTINFO, sk) ||
ipv6_sk_rxinfo(sk); ipv6_sk_rxinfo(sk);
if (prepare && skb_rtable(skb)) { if (prepare && skb_rtable(skb)) {
...@@ -1601,31 +1579,31 @@ int do_ip_getsockopt(struct sock *sk, int level, int optname, ...@@ -1601,31 +1579,31 @@ int do_ip_getsockopt(struct sock *sk, int level, int optname,
return 0; return 0;
} }
case IP_PKTINFO: case IP_PKTINFO:
val = (inet->cmsg_flags & IP_CMSG_PKTINFO) != 0; val = inet_test_bit(PKTINFO, sk);
break; break;
case IP_RECVTTL: case IP_RECVTTL:
val = (inet->cmsg_flags & IP_CMSG_TTL) != 0; val = inet_test_bit(TTL, sk);
break; break;
case IP_RECVTOS: case IP_RECVTOS:
val = (inet->cmsg_flags & IP_CMSG_TOS) != 0; val = inet_test_bit(TOS, sk);
break; break;
case IP_RECVOPTS: case IP_RECVOPTS:
val = (inet->cmsg_flags & IP_CMSG_RECVOPTS) != 0; val = inet_test_bit(RECVOPTS, sk);
break; break;
case IP_RETOPTS: case IP_RETOPTS:
val = (inet->cmsg_flags & IP_CMSG_RETOPTS) != 0; val = inet_test_bit(RETOPTS, sk);
break; break;
case IP_PASSSEC: case IP_PASSSEC:
val = (inet->cmsg_flags & IP_CMSG_PASSSEC) != 0; val = inet_test_bit(PASSSEC, sk);
break; break;
case IP_RECVORIGDSTADDR: case IP_RECVORIGDSTADDR:
val = (inet->cmsg_flags & IP_CMSG_ORIGDSTADDR) != 0; val = inet_test_bit(ORIGDSTADDR, sk);
break; break;
case IP_CHECKSUM: case IP_CHECKSUM:
val = (inet->cmsg_flags & IP_CMSG_CHECKSUM) != 0; val = inet_test_bit(CHECKSUM, sk);
break; break;
case IP_RECVFRAGSIZE: case IP_RECVFRAGSIZE:
val = (inet->cmsg_flags & IP_CMSG_RECVFRAGSIZE) != 0; val = inet_test_bit(RECVFRAGSIZE, sk);
break; break;
case IP_TOS: case IP_TOS:
val = inet->tos; val = inet->tos;
...@@ -1737,7 +1715,7 @@ int do_ip_getsockopt(struct sock *sk, int level, int optname, ...@@ -1737,7 +1715,7 @@ int do_ip_getsockopt(struct sock *sk, int level, int optname,
msg.msg_controllen = len; msg.msg_controllen = len;
msg.msg_flags = in_compat_syscall() ? MSG_CMSG_COMPAT : 0; msg.msg_flags = in_compat_syscall() ? MSG_CMSG_COMPAT : 0;
if (inet->cmsg_flags & IP_CMSG_PKTINFO) { if (inet_test_bit(PKTINFO, sk)) {
struct in_pktinfo info; struct in_pktinfo info;
info.ipi_addr.s_addr = inet->inet_rcv_saddr; info.ipi_addr.s_addr = inet->inet_rcv_saddr;
...@@ -1745,11 +1723,11 @@ int do_ip_getsockopt(struct sock *sk, int level, int optname, ...@@ -1745,11 +1723,11 @@ int do_ip_getsockopt(struct sock *sk, int level, int optname,
info.ipi_ifindex = inet->mc_index; info.ipi_ifindex = inet->mc_index;
put_cmsg(&msg, SOL_IP, IP_PKTINFO, sizeof(info), &info); put_cmsg(&msg, SOL_IP, IP_PKTINFO, sizeof(info), &info);
} }
if (inet->cmsg_flags & IP_CMSG_TTL) { if (inet_test_bit(TTL, sk)) {
int hlim = inet->mc_ttl; int hlim = inet->mc_ttl;
put_cmsg(&msg, SOL_IP, IP_TTL, sizeof(hlim), &hlim); put_cmsg(&msg, SOL_IP, IP_TTL, sizeof(hlim), &hlim);
} }
if (inet->cmsg_flags & IP_CMSG_TOS) { if (inet_test_bit(TOS, sk)) {
int tos = inet->rcv_tos; int tos = inet->rcv_tos;
put_cmsg(&msg, SOL_IP, IP_TOS, sizeof(tos), &tos); put_cmsg(&msg, SOL_IP, IP_TOS, sizeof(tos), &tos);
} }
......
...@@ -894,7 +894,7 @@ int ping_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags, ...@@ -894,7 +894,7 @@ int ping_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags,
*addr_len = sizeof(*sin); *addr_len = sizeof(*sin);
} }
if (isk->cmsg_flags) if (inet_cmsg_flags(isk))
ip_cmsg_recv(msg, skb); ip_cmsg_recv(msg, skb);
#if IS_ENABLED(CONFIG_IPV6) #if IS_ENABLED(CONFIG_IPV6)
...@@ -921,7 +921,8 @@ int ping_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags, ...@@ -921,7 +921,8 @@ int ping_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags,
if (skb->protocol == htons(ETH_P_IPV6) && if (skb->protocol == htons(ETH_P_IPV6) &&
inet6_sk(sk)->rxopt.all) inet6_sk(sk)->rxopt.all)
pingv6_ops.ip6_datagram_recv_specific_ctl(sk, msg, skb); pingv6_ops.ip6_datagram_recv_specific_ctl(sk, msg, skb);
else if (skb->protocol == htons(ETH_P_IP) && isk->cmsg_flags) else if (skb->protocol == htons(ETH_P_IP) &&
inet_cmsg_flags(isk))
ip_cmsg_recv(msg, skb); ip_cmsg_recv(msg, skb);
#endif #endif
} else { } else {
......
...@@ -767,7 +767,7 @@ static int raw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, ...@@ -767,7 +767,7 @@ static int raw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
memset(&sin->sin_zero, 0, sizeof(sin->sin_zero)); memset(&sin->sin_zero, 0, sizeof(sin->sin_zero));
*addr_len = sizeof(*sin); *addr_len = sizeof(*sin);
} }
if (inet->cmsg_flags) if (inet_cmsg_flags(inet))
ip_cmsg_recv(msg, skb); ip_cmsg_recv(msg, skb);
if (flags & MSG_TRUNC) if (flags & MSG_TRUNC)
copied = skb->len; copied = skb->len;
......
...@@ -1870,7 +1870,7 @@ int udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags, ...@@ -1870,7 +1870,7 @@ int udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags,
if (udp_sk(sk)->gro_enabled) if (udp_sk(sk)->gro_enabled)
udp_cmsg_recv(msg, sk, skb); udp_cmsg_recv(msg, sk, skb);
if (inet->cmsg_flags) if (inet_cmsg_flags(inet))
ip_cmsg_recv_offset(msg, sk, skb, sizeof(struct udphdr), off); ip_cmsg_recv_offset(msg, sk, skb, sizeof(struct udphdr), off);
err = copied; err = copied;
......
...@@ -524,7 +524,7 @@ int ipv6_recv_error(struct sock *sk, struct msghdr *msg, int len, int *addr_len) ...@@ -524,7 +524,7 @@ int ipv6_recv_error(struct sock *sk, struct msghdr *msg, int len, int *addr_len)
} else { } else {
ipv6_addr_set_v4mapped(ip_hdr(skb)->saddr, ipv6_addr_set_v4mapped(ip_hdr(skb)->saddr,
&sin->sin6_addr); &sin->sin6_addr);
if (inet_sk(sk)->cmsg_flags) if (inet_cmsg_flags(inet_sk(sk)))
ip_cmsg_recv(msg, skb); ip_cmsg_recv(msg, skb);
} }
} }
......
...@@ -420,7 +420,7 @@ int udpv6_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, ...@@ -420,7 +420,7 @@ int udpv6_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
ip6_datagram_recv_common_ctl(sk, msg, skb); ip6_datagram_recv_common_ctl(sk, msg, skb);
if (is_udp4) { if (is_udp4) {
if (inet->cmsg_flags) if (inet_cmsg_flags(inet))
ip_cmsg_recv_offset(msg, sk, skb, ip_cmsg_recv_offset(msg, sk, skb,
sizeof(struct udphdr), off); sizeof(struct udphdr), off);
} else { } else {
......
...@@ -552,7 +552,7 @@ static int l2tp_ip_recvmsg(struct sock *sk, struct msghdr *msg, ...@@ -552,7 +552,7 @@ static int l2tp_ip_recvmsg(struct sock *sk, struct msghdr *msg,
memset(&sin->sin_zero, 0, sizeof(sin->sin_zero)); memset(&sin->sin_zero, 0, sizeof(sin->sin_zero));
*addr_len = sizeof(*sin); *addr_len = sizeof(*sin);
} }
if (inet->cmsg_flags) if (inet_cmsg_flags(inet))
ip_cmsg_recv(msg, skb); ip_cmsg_recv(msg, skb);
if (flags & MSG_TRUNC) if (flags & MSG_TRUNC)
copied = skb->len; copied = skb->len;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment