Commit f9242b6b authored by David S. Miller's avatar David S. Miller

inet: Sanitize inet{,6} protocol demux.

Don't pretend that inet_protos[] and inet6_protos[] are hashes, thay
are just a straight arrays.  Remove all unnecessary hash masking.

Document MAX_INET_PROTOS.

Use RAW_HTABLE_SIZE when appropriate.
Reported-by: default avatarBen Hutchings <bhutchings@solarflare.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 677a3d60
...@@ -29,8 +29,11 @@ ...@@ -29,8 +29,11 @@
#include <linux/ipv6.h> #include <linux/ipv6.h>
#endif #endif
#define MAX_INET_PROTOS 256 /* Must be a power of 2 */ /* This is one larger than the largest protocol value that can be
* found in an ipv4 or ipv6 header. Since in both cases the protocol
* value is presented in a __u8, this is defined to be 256.
*/
#define MAX_INET_PROTOS 256
/* This is used to register protocols. */ /* This is used to register protocols. */
struct net_protocol { struct net_protocol {
......
...@@ -242,20 +242,18 @@ void build_ehash_secret(void) ...@@ -242,20 +242,18 @@ void build_ehash_secret(void)
} }
EXPORT_SYMBOL(build_ehash_secret); EXPORT_SYMBOL(build_ehash_secret);
static inline int inet_netns_ok(struct net *net, int protocol) static inline int inet_netns_ok(struct net *net, __u8 protocol)
{ {
int hash;
const struct net_protocol *ipprot; const struct net_protocol *ipprot;
if (net_eq(net, &init_net)) if (net_eq(net, &init_net))
return 1; return 1;
hash = protocol & (MAX_INET_PROTOS - 1); ipprot = rcu_dereference(inet_protos[protocol]);
ipprot = rcu_dereference(inet_protos[hash]); if (ipprot == NULL) {
if (ipprot == NULL)
/* raw IP is OK */ /* raw IP is OK */
return 1; return 1;
}
return ipprot->netns_ok; return ipprot->netns_ok;
} }
...@@ -1216,8 +1214,8 @@ EXPORT_SYMBOL(inet_sk_rebuild_header); ...@@ -1216,8 +1214,8 @@ EXPORT_SYMBOL(inet_sk_rebuild_header);
static int inet_gso_send_check(struct sk_buff *skb) static int inet_gso_send_check(struct sk_buff *skb)
{ {
const struct iphdr *iph;
const struct net_protocol *ops; const struct net_protocol *ops;
const struct iphdr *iph;
int proto; int proto;
int ihl; int ihl;
int err = -EINVAL; int err = -EINVAL;
...@@ -1236,7 +1234,7 @@ static int inet_gso_send_check(struct sk_buff *skb) ...@@ -1236,7 +1234,7 @@ static int inet_gso_send_check(struct sk_buff *skb)
__skb_pull(skb, ihl); __skb_pull(skb, ihl);
skb_reset_transport_header(skb); skb_reset_transport_header(skb);
iph = ip_hdr(skb); iph = ip_hdr(skb);
proto = iph->protocol & (MAX_INET_PROTOS - 1); proto = iph->protocol;
err = -EPROTONOSUPPORT; err = -EPROTONOSUPPORT;
rcu_read_lock(); rcu_read_lock();
...@@ -1253,8 +1251,8 @@ static struct sk_buff *inet_gso_segment(struct sk_buff *skb, ...@@ -1253,8 +1251,8 @@ static struct sk_buff *inet_gso_segment(struct sk_buff *skb,
netdev_features_t features) netdev_features_t features)
{ {
struct sk_buff *segs = ERR_PTR(-EINVAL); struct sk_buff *segs = ERR_PTR(-EINVAL);
struct iphdr *iph;
const struct net_protocol *ops; const struct net_protocol *ops;
struct iphdr *iph;
int proto; int proto;
int ihl; int ihl;
int id; int id;
...@@ -1286,7 +1284,7 @@ static struct sk_buff *inet_gso_segment(struct sk_buff *skb, ...@@ -1286,7 +1284,7 @@ static struct sk_buff *inet_gso_segment(struct sk_buff *skb,
skb_reset_transport_header(skb); skb_reset_transport_header(skb);
iph = ip_hdr(skb); iph = ip_hdr(skb);
id = ntohs(iph->id); id = ntohs(iph->id);
proto = iph->protocol & (MAX_INET_PROTOS - 1); proto = iph->protocol;
segs = ERR_PTR(-EPROTONOSUPPORT); segs = ERR_PTR(-EPROTONOSUPPORT);
rcu_read_lock(); rcu_read_lock();
...@@ -1340,7 +1338,7 @@ static struct sk_buff **inet_gro_receive(struct sk_buff **head, ...@@ -1340,7 +1338,7 @@ static struct sk_buff **inet_gro_receive(struct sk_buff **head,
goto out; goto out;
} }
proto = iph->protocol & (MAX_INET_PROTOS - 1); proto = iph->protocol;
rcu_read_lock(); rcu_read_lock();
ops = rcu_dereference(inet_protos[proto]); ops = rcu_dereference(inet_protos[proto]);
...@@ -1398,11 +1396,11 @@ static struct sk_buff **inet_gro_receive(struct sk_buff **head, ...@@ -1398,11 +1396,11 @@ static struct sk_buff **inet_gro_receive(struct sk_buff **head,
static int inet_gro_complete(struct sk_buff *skb) static int inet_gro_complete(struct sk_buff *skb)
{ {
const struct net_protocol *ops; __be16 newlen = htons(skb->len - skb_network_offset(skb));
struct iphdr *iph = ip_hdr(skb); struct iphdr *iph = ip_hdr(skb);
int proto = iph->protocol & (MAX_INET_PROTOS - 1); const struct net_protocol *ops;
int proto = iph->protocol;
int err = -ENOSYS; int err = -ENOSYS;
__be16 newlen = htons(skb->len - skb_network_offset(skb));
csum_replace2(&iph->check, iph->tot_len, newlen); csum_replace2(&iph->check, iph->tot_len, newlen);
iph->tot_len = newlen; iph->tot_len = newlen;
......
...@@ -637,12 +637,12 @@ EXPORT_SYMBOL(icmp_send); ...@@ -637,12 +637,12 @@ EXPORT_SYMBOL(icmp_send);
static void icmp_unreach(struct sk_buff *skb) static void icmp_unreach(struct sk_buff *skb)
{ {
const struct net_protocol *ipprot;
const struct iphdr *iph; const struct iphdr *iph;
struct icmphdr *icmph; struct icmphdr *icmph;
int hash, protocol;
const struct net_protocol *ipprot;
u32 info = 0;
struct net *net; struct net *net;
u32 info = 0;
int protocol;
net = dev_net(skb_dst(skb)->dev); net = dev_net(skb_dst(skb)->dev);
...@@ -731,9 +731,8 @@ static void icmp_unreach(struct sk_buff *skb) ...@@ -731,9 +731,8 @@ static void icmp_unreach(struct sk_buff *skb)
*/ */
raw_icmp_error(skb, protocol, info); raw_icmp_error(skb, protocol, info);
hash = protocol & (MAX_INET_PROTOS - 1);
rcu_read_lock(); rcu_read_lock();
ipprot = rcu_dereference(inet_protos[hash]); ipprot = rcu_dereference(inet_protos[protocol]);
if (ipprot && ipprot->err_handler) if (ipprot && ipprot->err_handler)
ipprot->err_handler(skb, info); ipprot->err_handler(skb, info);
rcu_read_unlock(); rcu_read_unlock();
......
...@@ -198,14 +198,13 @@ static int ip_local_deliver_finish(struct sk_buff *skb) ...@@ -198,14 +198,13 @@ static int ip_local_deliver_finish(struct sk_buff *skb)
rcu_read_lock(); rcu_read_lock();
{ {
int protocol = ip_hdr(skb)->protocol; int protocol = ip_hdr(skb)->protocol;
int hash, raw;
const struct net_protocol *ipprot; const struct net_protocol *ipprot;
int raw;
resubmit: resubmit:
raw = raw_local_deliver(skb, protocol); raw = raw_local_deliver(skb, protocol);
hash = protocol & (MAX_INET_PROTOS - 1); ipprot = rcu_dereference(inet_protos[protocol]);
ipprot = rcu_dereference(inet_protos[hash]);
if (ipprot != NULL) { if (ipprot != NULL) {
int ret; int ret;
......
...@@ -36,9 +36,7 @@ const struct net_protocol __rcu *inet_protos[MAX_INET_PROTOS] __read_mostly; ...@@ -36,9 +36,7 @@ const struct net_protocol __rcu *inet_protos[MAX_INET_PROTOS] __read_mostly;
int inet_add_protocol(const struct net_protocol *prot, unsigned char protocol) int inet_add_protocol(const struct net_protocol *prot, unsigned char protocol)
{ {
int hash = protocol & (MAX_INET_PROTOS - 1); return !cmpxchg((const struct net_protocol **)&inet_protos[protocol],
return !cmpxchg((const struct net_protocol **)&inet_protos[hash],
NULL, prot) ? 0 : -1; NULL, prot) ? 0 : -1;
} }
EXPORT_SYMBOL(inet_add_protocol); EXPORT_SYMBOL(inet_add_protocol);
...@@ -49,9 +47,9 @@ EXPORT_SYMBOL(inet_add_protocol); ...@@ -49,9 +47,9 @@ EXPORT_SYMBOL(inet_add_protocol);
int inet_del_protocol(const struct net_protocol *prot, unsigned char protocol) int inet_del_protocol(const struct net_protocol *prot, unsigned char protocol)
{ {
int ret, hash = protocol & (MAX_INET_PROTOS - 1); int ret;
ret = (cmpxchg((const struct net_protocol **)&inet_protos[hash], ret = (cmpxchg((const struct net_protocol **)&inet_protos[protocol],
prot, NULL) == prot) ? 0 : -1; prot, NULL) == prot) ? 0 : -1;
synchronize_net(); synchronize_net();
......
...@@ -600,9 +600,8 @@ static void icmpv6_notify(struct sk_buff *skb, u8 type, u8 code, __be32 info) ...@@ -600,9 +600,8 @@ static void icmpv6_notify(struct sk_buff *skb, u8 type, u8 code, __be32 info)
{ {
const struct inet6_protocol *ipprot; const struct inet6_protocol *ipprot;
int inner_offset; int inner_offset;
int hash;
u8 nexthdr;
__be16 frag_off; __be16 frag_off;
u8 nexthdr;
if (!pskb_may_pull(skb, sizeof(struct ipv6hdr))) if (!pskb_may_pull(skb, sizeof(struct ipv6hdr)))
return; return;
...@@ -629,10 +628,8 @@ static void icmpv6_notify(struct sk_buff *skb, u8 type, u8 code, __be32 info) ...@@ -629,10 +628,8 @@ static void icmpv6_notify(struct sk_buff *skb, u8 type, u8 code, __be32 info)
--ANK (980726) --ANK (980726)
*/ */
hash = nexthdr & (MAX_INET_PROTOS - 1);
rcu_read_lock(); rcu_read_lock();
ipprot = rcu_dereference(inet6_protos[hash]); ipprot = rcu_dereference(inet6_protos[nexthdr]);
if (ipprot && ipprot->err_handler) if (ipprot && ipprot->err_handler)
ipprot->err_handler(skb, NULL, type, code, inner_offset, info); ipprot->err_handler(skb, NULL, type, code, inner_offset, info);
rcu_read_unlock(); rcu_read_unlock();
......
...@@ -168,13 +168,12 @@ int ipv6_rcv(struct sk_buff *skb, struct net_device *dev, struct packet_type *pt ...@@ -168,13 +168,12 @@ int ipv6_rcv(struct sk_buff *skb, struct net_device *dev, struct packet_type *pt
static int ip6_input_finish(struct sk_buff *skb) static int ip6_input_finish(struct sk_buff *skb)
{ {
struct net *net = dev_net(skb_dst(skb)->dev);
const struct inet6_protocol *ipprot; const struct inet6_protocol *ipprot;
struct inet6_dev *idev;
unsigned int nhoff; unsigned int nhoff;
int nexthdr; int nexthdr;
bool raw; bool raw;
u8 hash;
struct inet6_dev *idev;
struct net *net = dev_net(skb_dst(skb)->dev);
/* /*
* Parse extension headers * Parse extension headers
...@@ -189,9 +188,7 @@ static int ip6_input_finish(struct sk_buff *skb) ...@@ -189,9 +188,7 @@ static int ip6_input_finish(struct sk_buff *skb)
nexthdr = skb_network_header(skb)[nhoff]; nexthdr = skb_network_header(skb)[nhoff];
raw = raw6_local_deliver(skb, nexthdr); raw = raw6_local_deliver(skb, nexthdr);
if ((ipprot = rcu_dereference(inet6_protos[nexthdr])) != NULL) {
hash = nexthdr & (MAX_INET_PROTOS - 1);
if ((ipprot = rcu_dereference(inet6_protos[hash])) != NULL) {
int ret; int ret;
if (ipprot->flags & INET6_PROTO_FINAL) { if (ipprot->flags & INET6_PROTO_FINAL) {
......
...@@ -29,9 +29,7 @@ const struct inet6_protocol __rcu *inet6_protos[MAX_INET_PROTOS] __read_mostly; ...@@ -29,9 +29,7 @@ const struct inet6_protocol __rcu *inet6_protos[MAX_INET_PROTOS] __read_mostly;
int inet6_add_protocol(const struct inet6_protocol *prot, unsigned char protocol) int inet6_add_protocol(const struct inet6_protocol *prot, unsigned char protocol)
{ {
int hash = protocol & (MAX_INET_PROTOS - 1); return !cmpxchg((const struct inet6_protocol **)&inet6_protos[protocol],
return !cmpxchg((const struct inet6_protocol **)&inet6_protos[hash],
NULL, prot) ? 0 : -1; NULL, prot) ? 0 : -1;
} }
EXPORT_SYMBOL(inet6_add_protocol); EXPORT_SYMBOL(inet6_add_protocol);
...@@ -42,9 +40,9 @@ EXPORT_SYMBOL(inet6_add_protocol); ...@@ -42,9 +40,9 @@ EXPORT_SYMBOL(inet6_add_protocol);
int inet6_del_protocol(const struct inet6_protocol *prot, unsigned char protocol) int inet6_del_protocol(const struct inet6_protocol *prot, unsigned char protocol)
{ {
int ret, hash = protocol & (MAX_INET_PROTOS - 1); int ret;
ret = (cmpxchg((const struct inet6_protocol **)&inet6_protos[hash], ret = (cmpxchg((const struct inet6_protocol **)&inet6_protos[protocol],
prot, NULL) == prot) ? 0 : -1; prot, NULL) == prot) ? 0 : -1;
synchronize_net(); synchronize_net();
......
...@@ -165,7 +165,7 @@ static bool ipv6_raw_deliver(struct sk_buff *skb, int nexthdr) ...@@ -165,7 +165,7 @@ static bool ipv6_raw_deliver(struct sk_buff *skb, int nexthdr)
saddr = &ipv6_hdr(skb)->saddr; saddr = &ipv6_hdr(skb)->saddr;
daddr = saddr + 1; daddr = saddr + 1;
hash = nexthdr & (MAX_INET_PROTOS - 1); hash = nexthdr & (RAW_HTABLE_SIZE - 1);
read_lock(&raw_v6_hashinfo.lock); read_lock(&raw_v6_hashinfo.lock);
sk = sk_head(&raw_v6_hashinfo.ht[hash]); sk = sk_head(&raw_v6_hashinfo.ht[hash]);
...@@ -229,7 +229,7 @@ bool raw6_local_deliver(struct sk_buff *skb, int nexthdr) ...@@ -229,7 +229,7 @@ bool raw6_local_deliver(struct sk_buff *skb, int nexthdr)
{ {
struct sock *raw_sk; struct sock *raw_sk;
raw_sk = sk_head(&raw_v6_hashinfo.ht[nexthdr & (MAX_INET_PROTOS - 1)]); raw_sk = sk_head(&raw_v6_hashinfo.ht[nexthdr & (RAW_HTABLE_SIZE - 1)]);
if (raw_sk && !ipv6_raw_deliver(skb, nexthdr)) if (raw_sk && !ipv6_raw_deliver(skb, nexthdr))
raw_sk = NULL; raw_sk = NULL;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment