Commit a2e2ca3b authored by Linus Lüssing's avatar Linus Lüssing Committed by David S. Miller

bridge: simplify ip_mc_check_igmp() and ipv6_mc_check_mld() internals

With this patch the internal use of the skb_trimmed is reduced to
the ICMPv6/IGMP checksum verification. And for the length checks
the newly introduced helper functions are used instead of calculating
and checking with skb->len directly.

These changes should hopefully make it easier to verify that length
checks are performed properly.
Signed-off-by: default avatarLinus Lüssing <linus.luessing@c0d3.blue>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent ba5ea614
...@@ -1493,22 +1493,22 @@ static int ip_mc_check_igmp_reportv3(struct sk_buff *skb) ...@@ -1493,22 +1493,22 @@ static int ip_mc_check_igmp_reportv3(struct sk_buff *skb)
len += sizeof(struct igmpv3_report); len += sizeof(struct igmpv3_report);
return pskb_may_pull(skb, len) ? 0 : -EINVAL; return ip_mc_may_pull(skb, len) ? 0 : -EINVAL;
} }
static int ip_mc_check_igmp_query(struct sk_buff *skb) static int ip_mc_check_igmp_query(struct sk_buff *skb)
{ {
unsigned int len = skb_transport_offset(skb); unsigned int transport_len = ip_transport_len(skb);
unsigned int len;
len += sizeof(struct igmphdr);
if (skb->len < len)
return -EINVAL;
/* IGMPv{1,2}? */ /* IGMPv{1,2}? */
if (skb->len != len) { if (transport_len != sizeof(struct igmphdr)) {
/* or IGMPv3? */ /* or IGMPv3? */
len += sizeof(struct igmpv3_query) - sizeof(struct igmphdr); if (transport_len < sizeof(struct igmpv3_query))
if (skb->len < len || !pskb_may_pull(skb, len)) return -EINVAL;
len = skb_transport_offset(skb) + sizeof(struct igmpv3_query);
if (!ip_mc_may_pull(skb, len))
return -EINVAL; return -EINVAL;
} }
...@@ -1544,35 +1544,24 @@ static inline __sum16 ip_mc_validate_checksum(struct sk_buff *skb) ...@@ -1544,35 +1544,24 @@ static inline __sum16 ip_mc_validate_checksum(struct sk_buff *skb)
return skb_checksum_simple_validate(skb); return skb_checksum_simple_validate(skb);
} }
static int __ip_mc_check_igmp(struct sk_buff *skb) static int ip_mc_check_igmp_csum(struct sk_buff *skb)
{ {
struct sk_buff *skb_chk;
unsigned int transport_len;
unsigned int len = skb_transport_offset(skb) + sizeof(struct igmphdr); unsigned int len = skb_transport_offset(skb) + sizeof(struct igmphdr);
int ret = -EINVAL; unsigned int transport_len = ip_transport_len(skb);
struct sk_buff *skb_chk;
transport_len = ntohs(ip_hdr(skb)->tot_len) - ip_hdrlen(skb); if (!ip_mc_may_pull(skb, len))
return -EINVAL;
skb_chk = skb_checksum_trimmed(skb, transport_len, skb_chk = skb_checksum_trimmed(skb, transport_len,
ip_mc_validate_checksum); ip_mc_validate_checksum);
if (!skb_chk) if (!skb_chk)
goto err; return -EINVAL;
if (!pskb_may_pull(skb_chk, len))
goto err;
ret = ip_mc_check_igmp_msg(skb_chk);
if (ret)
goto err;
ret = 0;
err: if (skb_chk != skb)
if (skb_chk && skb_chk != skb)
kfree_skb(skb_chk); kfree_skb(skb_chk);
return ret; return 0;
} }
/** /**
...@@ -1600,7 +1589,11 @@ int ip_mc_check_igmp(struct sk_buff *skb) ...@@ -1600,7 +1589,11 @@ int ip_mc_check_igmp(struct sk_buff *skb)
if (ip_hdr(skb)->protocol != IPPROTO_IGMP) if (ip_hdr(skb)->protocol != IPPROTO_IGMP)
return -ENOMSG; return -ENOMSG;
return __ip_mc_check_igmp(skb); ret = ip_mc_check_igmp_csum(skb);
if (ret < 0)
return ret;
return ip_mc_check_igmp_msg(skb);
} }
EXPORT_SYMBOL(ip_mc_check_igmp); EXPORT_SYMBOL(ip_mc_check_igmp);
......
...@@ -77,27 +77,27 @@ static int ipv6_mc_check_mld_reportv2(struct sk_buff *skb) ...@@ -77,27 +77,27 @@ static int ipv6_mc_check_mld_reportv2(struct sk_buff *skb)
len += sizeof(struct mld2_report); len += sizeof(struct mld2_report);
return pskb_may_pull(skb, len) ? 0 : -EINVAL; return ipv6_mc_may_pull(skb, len) ? 0 : -EINVAL;
} }
static int ipv6_mc_check_mld_query(struct sk_buff *skb) static int ipv6_mc_check_mld_query(struct sk_buff *skb)
{ {
unsigned int transport_len = ipv6_transport_len(skb);
struct mld_msg *mld; struct mld_msg *mld;
unsigned int len = skb_transport_offset(skb); unsigned int len;
/* RFC2710+RFC3810 (MLDv1+MLDv2) require link-local source addresses */ /* RFC2710+RFC3810 (MLDv1+MLDv2) require link-local source addresses */
if (!(ipv6_addr_type(&ipv6_hdr(skb)->saddr) & IPV6_ADDR_LINKLOCAL)) if (!(ipv6_addr_type(&ipv6_hdr(skb)->saddr) & IPV6_ADDR_LINKLOCAL))
return -EINVAL; return -EINVAL;
len += sizeof(struct mld_msg);
if (skb->len < len)
return -EINVAL;
/* MLDv1? */ /* MLDv1? */
if (skb->len != len) { if (transport_len != sizeof(struct mld_msg)) {
/* or MLDv2? */ /* or MLDv2? */
len += sizeof(struct mld2_query) - sizeof(struct mld_msg); if (transport_len < sizeof(struct mld2_query))
if (skb->len < len || !pskb_may_pull(skb, len)) return -EINVAL;
len = skb_transport_offset(skb) + sizeof(struct mld2_query);
if (!ipv6_mc_may_pull(skb, len))
return -EINVAL; return -EINVAL;
} }
...@@ -115,7 +115,13 @@ static int ipv6_mc_check_mld_query(struct sk_buff *skb) ...@@ -115,7 +115,13 @@ static int ipv6_mc_check_mld_query(struct sk_buff *skb)
static int ipv6_mc_check_mld_msg(struct sk_buff *skb) static int ipv6_mc_check_mld_msg(struct sk_buff *skb)
{ {
struct mld_msg *mld = (struct mld_msg *)skb_transport_header(skb); unsigned int len = skb_transport_offset(skb) + sizeof(struct mld_msg);
struct mld_msg *mld;
if (!ipv6_mc_may_pull(skb, len))
return -EINVAL;
mld = (struct mld_msg *)skb_transport_header(skb);
switch (mld->mld_type) { switch (mld->mld_type) {
case ICMPV6_MGM_REDUCTION: case ICMPV6_MGM_REDUCTION:
...@@ -136,36 +142,24 @@ static inline __sum16 ipv6_mc_validate_checksum(struct sk_buff *skb) ...@@ -136,36 +142,24 @@ static inline __sum16 ipv6_mc_validate_checksum(struct sk_buff *skb)
return skb_checksum_validate(skb, IPPROTO_ICMPV6, ip6_compute_pseudo); return skb_checksum_validate(skb, IPPROTO_ICMPV6, ip6_compute_pseudo);
} }
static int __ipv6_mc_check_mld(struct sk_buff *skb) static int ipv6_mc_check_icmpv6(struct sk_buff *skb)
{ {
struct sk_buff *skb_chk = NULL; unsigned int len = skb_transport_offset(skb) + sizeof(struct icmp6hdr);
unsigned int transport_len; unsigned int transport_len = ipv6_transport_len(skb);
unsigned int len = skb_transport_offset(skb) + sizeof(struct mld_msg); struct sk_buff *skb_chk;
int ret = -EINVAL;
transport_len = ntohs(ipv6_hdr(skb)->payload_len); if (!ipv6_mc_may_pull(skb, len))
transport_len -= skb_transport_offset(skb) - sizeof(struct ipv6hdr); return -EINVAL;
skb_chk = skb_checksum_trimmed(skb, transport_len, skb_chk = skb_checksum_trimmed(skb, transport_len,
ipv6_mc_validate_checksum); ipv6_mc_validate_checksum);
if (!skb_chk) if (!skb_chk)
goto err; return -EINVAL;
if (!pskb_may_pull(skb_chk, len))
goto err;
ret = ipv6_mc_check_mld_msg(skb_chk);
if (ret)
goto err;
ret = 0;
err: if (skb_chk != skb)
if (skb_chk && skb_chk != skb)
kfree_skb(skb_chk); kfree_skb(skb_chk);
return ret; return 0;
} }
/** /**
...@@ -195,6 +189,10 @@ int ipv6_mc_check_mld(struct sk_buff *skb) ...@@ -195,6 +189,10 @@ int ipv6_mc_check_mld(struct sk_buff *skb)
if (ret < 0) if (ret < 0)
return ret; return ret;
return __ipv6_mc_check_mld(skb); ret = ipv6_mc_check_icmpv6(skb);
if (ret < 0)
return ret;
return ipv6_mc_check_mld_msg(skb);
} }
EXPORT_SYMBOL(ipv6_mc_check_mld); EXPORT_SYMBOL(ipv6_mc_check_mld);
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment