Commit cafbe182 authored by Eric Dumazet's avatar Eric Dumazet Committed by David S. Miller

inet: move inet->hdrincl to inet->inet_flags

IP_HDRINCL socket option can now be set/read
without locking the socket.
Signed-off-by: default avatarEric Dumazet <edumazet@google.com>
Acked-by: default avatarSoheil Hassas Yeganeh <soheil@google.com>
Reviewed-by: default avatarSimon Horman <horms@kernel.org>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 3f7e7532
...@@ -231,7 +231,6 @@ struct inet_sock { ...@@ -231,7 +231,6 @@ struct inet_sock {
__u8 mc_ttl; __u8 mc_ttl;
__u8 pmtudisc; __u8 pmtudisc;
__u8 is_icsk:1, __u8 is_icsk:1,
hdrincl:1,
mc_loop:1, mc_loop:1,
transparent:1, transparent:1,
mc_all:1, mc_all:1,
...@@ -271,6 +270,7 @@ enum { ...@@ -271,6 +270,7 @@ enum {
INET_FLAGS_RECVERR = 9, INET_FLAGS_RECVERR = 9,
INET_FLAGS_RECVERR_RFC4884 = 10, INET_FLAGS_RECVERR_RFC4884 = 10,
INET_FLAGS_FREEBIND = 11, INET_FLAGS_FREEBIND = 11,
INET_FLAGS_HDRINCL = 12,
}; };
/* cmsg flags for inet */ /* cmsg flags for inet */
...@@ -397,7 +397,7 @@ static inline __u8 inet_sk_flowi_flags(const struct sock *sk) ...@@ -397,7 +397,7 @@ static inline __u8 inet_sk_flowi_flags(const struct sock *sk)
{ {
__u8 flags = 0; __u8 flags = 0;
if (inet_sk(sk)->transparent || inet_sk(sk)->hdrincl) if (inet_sk(sk)->transparent || inet_test_bit(HDRINCL, sk))
flags |= FLOWI_FLAG_ANYSRC; flags |= FLOWI_FLAG_ANYSRC;
return flags; return flags;
} }
......
...@@ -338,7 +338,7 @@ static int inet_create(struct net *net, struct socket *sock, int protocol, ...@@ -338,7 +338,7 @@ static int inet_create(struct net *net, struct socket *sock, int protocol,
if (SOCK_RAW == sock->type) { if (SOCK_RAW == sock->type) {
inet->inet_num = protocol; inet->inet_num = protocol;
if (IPPROTO_RAW == protocol) if (IPPROTO_RAW == protocol)
inet->hdrincl = 1; inet_set_bit(HDRINCL, sk);
} }
if (READ_ONCE(net->ipv4.sysctl_ip_no_pmtu_disc)) if (READ_ONCE(net->ipv4.sysctl_ip_no_pmtu_disc))
......
...@@ -185,7 +185,7 @@ int inet_diag_msg_attrs_fill(struct sock *sk, struct sk_buff *skb, ...@@ -185,7 +185,7 @@ int inet_diag_msg_attrs_fill(struct sock *sk, struct sk_buff *skb,
inet_sockopt.recverr = inet_test_bit(RECVERR, sk); inet_sockopt.recverr = inet_test_bit(RECVERR, sk);
inet_sockopt.is_icsk = inet->is_icsk; inet_sockopt.is_icsk = inet->is_icsk;
inet_sockopt.freebind = inet_test_bit(FREEBIND, sk); inet_sockopt.freebind = inet_test_bit(FREEBIND, sk);
inet_sockopt.hdrincl = inet->hdrincl; inet_sockopt.hdrincl = inet_test_bit(HDRINCL, sk);
inet_sockopt.mc_loop = inet->mc_loop; inet_sockopt.mc_loop = inet->mc_loop;
inet_sockopt.transparent = inet->transparent; inet_sockopt.transparent = inet->transparent;
inet_sockopt.mc_all = inet->mc_all; inet_sockopt.mc_all = inet->mc_all;
......
...@@ -1039,7 +1039,7 @@ static int __ip_append_data(struct sock *sk, ...@@ -1039,7 +1039,7 @@ static int __ip_append_data(struct sock *sk,
} }
} }
} else if ((flags & MSG_SPLICE_PAGES) && length) { } else if ((flags & MSG_SPLICE_PAGES) && length) {
if (inet->hdrincl) if (inet_test_bit(HDRINCL, sk))
return -EPERM; return -EPERM;
if (rt->dst.dev->features & NETIF_F_SG && if (rt->dst.dev->features & NETIF_F_SG &&
getfrag == ip_generic_getfrag) getfrag == ip_generic_getfrag)
...@@ -1467,7 +1467,8 @@ struct sk_buff *__ip_make_skb(struct sock *sk, ...@@ -1467,7 +1467,8 @@ struct sk_buff *__ip_make_skb(struct sock *sk,
* so icmphdr does not in skb linear region and can not get icmp_type * so icmphdr does not in skb linear region and can not get icmp_type
* by icmp_hdr(skb)->type. * by icmp_hdr(skb)->type.
*/ */
if (sk->sk_type == SOCK_RAW && !inet_sk(sk)->hdrincl) if (sk->sk_type == SOCK_RAW &&
!inet_test_bit(HDRINCL, sk))
icmp_type = fl4->fl4_icmp_type; icmp_type = fl4->fl4_icmp_type;
else else
icmp_type = icmp_hdr(skb)->type; icmp_type = icmp_hdr(skb)->type;
......
...@@ -988,6 +988,11 @@ int do_ip_setsockopt(struct sock *sk, int level, int optname, ...@@ -988,6 +988,11 @@ int do_ip_setsockopt(struct sock *sk, int level, int optname,
return -EINVAL; return -EINVAL;
inet_assign_bit(FREEBIND, sk, val); inet_assign_bit(FREEBIND, sk, val);
return 0; return 0;
case IP_HDRINCL:
if (sk->sk_type != SOCK_RAW)
return -ENOPROTOOPT;
inet_assign_bit(HDRINCL, sk, val);
return 0;
} }
err = 0; err = 0;
...@@ -1052,13 +1057,6 @@ int do_ip_setsockopt(struct sock *sk, int level, int optname, ...@@ -1052,13 +1057,6 @@ int do_ip_setsockopt(struct sock *sk, int level, int optname,
goto e_inval; goto e_inval;
inet->uc_ttl = val; inet->uc_ttl = val;
break; break;
case IP_HDRINCL:
if (sk->sk_type != SOCK_RAW) {
err = -ENOPROTOOPT;
break;
}
inet->hdrincl = val ? 1 : 0;
break;
case IP_NODEFRAG: case IP_NODEFRAG:
if (sk->sk_type != SOCK_RAW) { if (sk->sk_type != SOCK_RAW) {
err = -ENOPROTOOPT; err = -ENOPROTOOPT;
...@@ -1578,6 +1576,9 @@ int do_ip_getsockopt(struct sock *sk, int level, int optname, ...@@ -1578,6 +1576,9 @@ int do_ip_getsockopt(struct sock *sk, int level, int optname,
case IP_FREEBIND: case IP_FREEBIND:
val = inet_test_bit(FREEBIND, sk); val = inet_test_bit(FREEBIND, sk);
goto copyval; goto copyval;
case IP_HDRINCL:
val = inet_test_bit(HDRINCL, sk);
goto copyval;
} }
if (needs_rtnl) if (needs_rtnl)
...@@ -1625,9 +1626,6 @@ int do_ip_getsockopt(struct sock *sk, int level, int optname, ...@@ -1625,9 +1626,6 @@ int do_ip_getsockopt(struct sock *sk, int level, int optname,
inet->uc_ttl); inet->uc_ttl);
break; break;
} }
case IP_HDRINCL:
val = inet->hdrincl;
break;
case IP_NODEFRAG: case IP_NODEFRAG:
val = inet->nodefrag; val = inet->nodefrag;
break; break;
......
...@@ -251,7 +251,7 @@ static void raw_err(struct sock *sk, struct sk_buff *skb, u32 info) ...@@ -251,7 +251,7 @@ static void raw_err(struct sock *sk, struct sk_buff *skb, u32 info)
const struct iphdr *iph = (const struct iphdr *)skb->data; const struct iphdr *iph = (const struct iphdr *)skb->data;
u8 *payload = skb->data + (iph->ihl << 2); u8 *payload = skb->data + (iph->ihl << 2);
if (inet->hdrincl) if (inet_test_bit(HDRINCL, sk))
payload = skb->data; payload = skb->data;
ip_icmp_error(sk, skb, err, 0, info, payload); ip_icmp_error(sk, skb, err, 0, info, payload);
} }
...@@ -491,12 +491,8 @@ static int raw_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) ...@@ -491,12 +491,8 @@ static int raw_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
if (len > 0xFFFF) if (len > 0xFFFF)
goto out; goto out;
/* hdrincl should be READ_ONCE(inet->hdrincl) hdrincl = inet_test_bit(HDRINCL, sk);
* but READ_ONCE() doesn't work with bit fields.
* Doing this indirectly yields the same result.
*/
hdrincl = inet->hdrincl;
hdrincl = READ_ONCE(hdrincl);
/* /*
* Check the flags. * Check the flags.
*/ */
......
...@@ -515,13 +515,12 @@ static void __build_flow_key(const struct net *net, struct flowi4 *fl4, ...@@ -515,13 +515,12 @@ static void __build_flow_key(const struct net *net, struct flowi4 *fl4,
__u8 scope = RT_SCOPE_UNIVERSE; __u8 scope = RT_SCOPE_UNIVERSE;
if (sk) { if (sk) {
const struct inet_sock *inet = inet_sk(sk);
oif = sk->sk_bound_dev_if; oif = sk->sk_bound_dev_if;
mark = READ_ONCE(sk->sk_mark); mark = READ_ONCE(sk->sk_mark);
tos = ip_sock_rt_tos(sk); tos = ip_sock_rt_tos(sk);
scope = ip_sock_rt_scope(sk); scope = ip_sock_rt_scope(sk);
prot = inet->hdrincl ? IPPROTO_RAW : sk->sk_protocol; prot = inet_test_bit(HDRINCL, sk) ? IPPROTO_RAW :
sk->sk_protocol;
} }
flowi4_init_output(fl4, oif, mark, tos & IPTOS_RT_MASK, scope, flowi4_init_output(fl4, oif, mark, tos & IPTOS_RT_MASK, scope,
...@@ -555,7 +554,8 @@ static void build_sk_flow_key(struct flowi4 *fl4, const struct sock *sk) ...@@ -555,7 +554,8 @@ static void build_sk_flow_key(struct flowi4 *fl4, const struct sock *sk)
flowi4_init_output(fl4, sk->sk_bound_dev_if, READ_ONCE(sk->sk_mark), flowi4_init_output(fl4, sk->sk_bound_dev_if, READ_ONCE(sk->sk_mark),
ip_sock_rt_tos(sk) & IPTOS_RT_MASK, ip_sock_rt_tos(sk) & IPTOS_RT_MASK,
ip_sock_rt_scope(sk), ip_sock_rt_scope(sk),
inet->hdrincl ? IPPROTO_RAW : sk->sk_protocol, inet_test_bit(HDRINCL, sk) ?
IPPROTO_RAW : sk->sk_protocol,
inet_sk_flowi_flags(sk), inet_sk_flowi_flags(sk),
daddr, inet->inet_saddr, 0, 0, sk->sk_uid); daddr, inet->inet_saddr, 0, 0, sk->sk_uid);
rcu_read_unlock(); rcu_read_unlock();
......
...@@ -205,7 +205,7 @@ static int inet6_create(struct net *net, struct socket *sock, int protocol, ...@@ -205,7 +205,7 @@ static int inet6_create(struct net *net, struct socket *sock, int protocol,
if (SOCK_RAW == sock->type) { if (SOCK_RAW == sock->type) {
inet->inet_num = protocol; inet->inet_num = protocol;
if (IPPROTO_RAW == protocol) if (IPPROTO_RAW == protocol)
inet->hdrincl = 1; inet_set_bit(HDRINCL, sk);
} }
sk->sk_destruct = inet6_sock_destruct; sk->sk_destruct = inet6_sock_destruct;
......
...@@ -1591,7 +1591,7 @@ static int __ip6_append_data(struct sock *sk, ...@@ -1591,7 +1591,7 @@ static int __ip6_append_data(struct sock *sk,
} }
} }
} else if ((flags & MSG_SPLICE_PAGES) && length) { } else if ((flags & MSG_SPLICE_PAGES) && length) {
if (inet_sk(sk)->hdrincl) if (inet_test_bit(HDRINCL, sk))
return -EPERM; return -EPERM;
if (rt->dst.dev->features & NETIF_F_SG && if (rt->dst.dev->features & NETIF_F_SG &&
getfrag == ip_generic_getfrag) getfrag == ip_generic_getfrag)
...@@ -1995,7 +1995,8 @@ struct sk_buff *__ip6_make_skb(struct sock *sk, ...@@ -1995,7 +1995,8 @@ struct sk_buff *__ip6_make_skb(struct sock *sk,
struct inet6_dev *idev = ip6_dst_idev(skb_dst(skb)); struct inet6_dev *idev = ip6_dst_idev(skb_dst(skb));
u8 icmp6_type; u8 icmp6_type;
if (sk->sk_socket->type == SOCK_RAW && !inet_sk(sk)->hdrincl) if (sk->sk_socket->type == SOCK_RAW &&
!inet_test_bit(HDRINCL, sk))
icmp6_type = fl6->fl6_icmp_type; icmp6_type = fl6->fl6_icmp_type;
else else
icmp6_type = icmp6_hdr(skb)->icmp6_type; icmp6_type = icmp6_hdr(skb)->icmp6_type;
......
...@@ -291,7 +291,6 @@ static void rawv6_err(struct sock *sk, struct sk_buff *skb, ...@@ -291,7 +291,6 @@ static void rawv6_err(struct sock *sk, struct sk_buff *skb,
struct inet6_skb_parm *opt, struct inet6_skb_parm *opt,
u8 type, u8 code, int offset, __be32 info) u8 type, u8 code, int offset, __be32 info)
{ {
struct inet_sock *inet = inet_sk(sk);
struct ipv6_pinfo *np = inet6_sk(sk); struct ipv6_pinfo *np = inet6_sk(sk);
int err; int err;
int harderr; int harderr;
...@@ -315,7 +314,7 @@ static void rawv6_err(struct sock *sk, struct sk_buff *skb, ...@@ -315,7 +314,7 @@ static void rawv6_err(struct sock *sk, struct sk_buff *skb,
} }
if (np->recverr) { if (np->recverr) {
u8 *payload = skb->data; u8 *payload = skb->data;
if (!inet->hdrincl) if (!inet_test_bit(HDRINCL, sk))
payload += offset; payload += offset;
ipv6_icmp_error(sk, skb, err, 0, ntohl(info), payload); ipv6_icmp_error(sk, skb, err, 0, ntohl(info), payload);
} }
...@@ -406,7 +405,7 @@ int rawv6_rcv(struct sock *sk, struct sk_buff *skb) ...@@ -406,7 +405,7 @@ int rawv6_rcv(struct sock *sk, struct sk_buff *skb)
skb->len, skb->len,
inet->inet_num, 0)); inet->inet_num, 0));
if (inet->hdrincl) { if (inet_test_bit(HDRINCL, sk)) {
if (skb_checksum_complete(skb)) { if (skb_checksum_complete(skb)) {
atomic_inc(&sk->sk_drops); atomic_inc(&sk->sk_drops);
kfree_skb_reason(skb, SKB_DROP_REASON_SKB_CSUM); kfree_skb_reason(skb, SKB_DROP_REASON_SKB_CSUM);
...@@ -762,12 +761,7 @@ static int rawv6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) ...@@ -762,12 +761,7 @@ static int rawv6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
if (msg->msg_flags & MSG_OOB) if (msg->msg_flags & MSG_OOB)
return -EOPNOTSUPP; return -EOPNOTSUPP;
/* hdrincl should be READ_ONCE(inet->hdrincl) hdrincl = inet_test_bit(HDRINCL, sk);
* but READ_ONCE() doesn't work with bit fields.
* Doing this indirectly yields the same result.
*/
hdrincl = inet->hdrincl;
hdrincl = READ_ONCE(hdrincl);
/* /*
* Get and verify the address. * Get and verify the address.
...@@ -1000,7 +994,7 @@ static int do_rawv6_setsockopt(struct sock *sk, int level, int optname, ...@@ -1000,7 +994,7 @@ static int do_rawv6_setsockopt(struct sock *sk, int level, int optname,
case IPV6_HDRINCL: case IPV6_HDRINCL:
if (sk->sk_type != SOCK_RAW) if (sk->sk_type != SOCK_RAW)
return -EINVAL; return -EINVAL;
inet_sk(sk)->hdrincl = !!val; inet_assign_bit(HDRINCL, sk, val);
return 0; return 0;
case IPV6_CHECKSUM: case IPV6_CHECKSUM:
if (inet_sk(sk)->inet_num == IPPROTO_ICMPV6 && if (inet_sk(sk)->inet_num == IPPROTO_ICMPV6 &&
...@@ -1068,7 +1062,7 @@ static int do_rawv6_getsockopt(struct sock *sk, int level, int optname, ...@@ -1068,7 +1062,7 @@ static int do_rawv6_getsockopt(struct sock *sk, int level, int optname,
switch (optname) { switch (optname) {
case IPV6_HDRINCL: case IPV6_HDRINCL:
val = inet_sk(sk)->hdrincl; val = inet_test_bit(HDRINCL, sk);
break; break;
case IPV6_CHECKSUM: case IPV6_CHECKSUM:
/* /*
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment