Commit 19d0c341 authored by Eric W. Biederman's avatar Eric W. Biederman Committed by David S. Miller

mpls: Cleanup the rcu usage in the code.

Sparse was generating a lot of warnings mostly from missing annotations
in the code.  Add missing annotations and in a few cases tweak the code
for performance by moving work before loops.

This also fixes a problematic ommision of rcu_assign_pointer and
rcu_dereference.

Hopefully with complete rcu annotations any new rcu errors will stick
out like a sore thumb.
Signed-off-by: default avatar"Eric W. Biederman" <ebiederm@xmission.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent d865616e
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#define MAX_VIA_ALEN (ALIGN(MAX_ADDR_LEN, sizeof(unsigned long))) #define MAX_VIA_ALEN (ALIGN(MAX_ADDR_LEN, sizeof(unsigned long)))
struct mpls_route { /* next hop label forwarding entry */ struct mpls_route { /* next hop label forwarding entry */
struct net_device *rt_dev; struct net_device __rcu *rt_dev;
struct rcu_head rt_rcu; struct rcu_head rt_rcu;
u32 rt_label[MAX_NEW_LABELS]; u32 rt_label[MAX_NEW_LABELS];
u8 rt_protocol; /* routing protocol that set this entry */ u8 rt_protocol; /* routing protocol that set this entry */
...@@ -152,7 +152,7 @@ static int mpls_forward(struct sk_buff *skb, struct net_device *dev, ...@@ -152,7 +152,7 @@ static int mpls_forward(struct sk_buff *skb, struct net_device *dev,
goto drop; goto drop;
/* Find the output device */ /* Find the output device */
out_dev = rt->rt_dev; out_dev = rcu_dereference(rt->rt_dev);
if (!mpls_output_possible(out_dev)) if (!mpls_output_possible(out_dev))
goto drop; goto drop;
...@@ -269,13 +269,15 @@ static void mpls_route_update(struct net *net, unsigned index, ...@@ -269,13 +269,15 @@ static void mpls_route_update(struct net *net, unsigned index,
struct net_device *dev, struct mpls_route *new, struct net_device *dev, struct mpls_route *new,
const struct nl_info *info) const struct nl_info *info)
{ {
struct mpls_route __rcu **platform_label;
struct mpls_route *rt, *old = NULL; struct mpls_route *rt, *old = NULL;
ASSERT_RTNL(); ASSERT_RTNL();
rt = net->mpls.platform_label[index]; platform_label = rtnl_dereference(net->mpls.platform_label);
if (!dev || (rt && (rt->rt_dev == dev))) { rt = rtnl_dereference(platform_label[index]);
rcu_assign_pointer(net->mpls.platform_label[index], new); if (!dev || (rt && (rtnl_dereference(rt->rt_dev) == dev))) {
rcu_assign_pointer(platform_label[index], new);
old = rt; old = rt;
} }
...@@ -287,9 +289,14 @@ static void mpls_route_update(struct net *net, unsigned index, ...@@ -287,9 +289,14 @@ static void mpls_route_update(struct net *net, unsigned index,
static unsigned find_free_label(struct net *net) static unsigned find_free_label(struct net *net)
{ {
struct mpls_route __rcu **platform_label;
size_t platform_labels;
unsigned index; unsigned index;
for (index = 16; index < net->mpls.platform_labels; index++) {
if (!net->mpls.platform_label[index]) platform_label = rtnl_dereference(net->mpls.platform_label);
platform_labels = net->mpls.platform_labels;
for (index = 16; index < platform_labels; index++) {
if (!rtnl_dereference(platform_label[index]))
return index; return index;
} }
return LABEL_NOT_SPECIFIED; return LABEL_NOT_SPECIFIED;
...@@ -297,6 +304,7 @@ static unsigned find_free_label(struct net *net) ...@@ -297,6 +304,7 @@ static unsigned find_free_label(struct net *net)
static int mpls_route_add(struct mpls_route_config *cfg) static int mpls_route_add(struct mpls_route_config *cfg)
{ {
struct mpls_route __rcu **platform_label;
struct net *net = cfg->rc_nlinfo.nl_net; struct net *net = cfg->rc_nlinfo.nl_net;
struct net_device *dev = NULL; struct net_device *dev = NULL;
struct mpls_route *rt, *old; struct mpls_route *rt, *old;
...@@ -345,7 +353,8 @@ static int mpls_route_add(struct mpls_route_config *cfg) ...@@ -345,7 +353,8 @@ static int mpls_route_add(struct mpls_route_config *cfg)
goto errout; goto errout;
err = -EEXIST; err = -EEXIST;
old = net->mpls.platform_label[index]; platform_label = rtnl_dereference(net->mpls.platform_label);
old = rtnl_dereference(platform_label[index]);
if ((cfg->rc_nlflags & NLM_F_EXCL) && old) if ((cfg->rc_nlflags & NLM_F_EXCL) && old)
goto errout; goto errout;
...@@ -366,7 +375,7 @@ static int mpls_route_add(struct mpls_route_config *cfg) ...@@ -366,7 +375,7 @@ static int mpls_route_add(struct mpls_route_config *cfg)
for (i = 0; i < rt->rt_labels; i++) for (i = 0; i < rt->rt_labels; i++)
rt->rt_label[i] = cfg->rc_output_label[i]; rt->rt_label[i] = cfg->rc_output_label[i];
rt->rt_protocol = cfg->rc_protocol; rt->rt_protocol = cfg->rc_protocol;
rt->rt_dev = dev; RCU_INIT_POINTER(rt->rt_dev, dev);
rt->rt_via_family = cfg->rc_via_family; rt->rt_via_family = cfg->rc_via_family;
memcpy(rt->rt_via, cfg->rc_via, cfg->rc_via_alen); memcpy(rt->rt_via, cfg->rc_via, cfg->rc_via_alen);
...@@ -406,14 +415,16 @@ static int mpls_route_del(struct mpls_route_config *cfg) ...@@ -406,14 +415,16 @@ static int mpls_route_del(struct mpls_route_config *cfg)
static void mpls_ifdown(struct net_device *dev) static void mpls_ifdown(struct net_device *dev)
{ {
struct mpls_route __rcu **platform_label;
struct net *net = dev_net(dev); struct net *net = dev_net(dev);
unsigned index; unsigned index;
platform_label = rtnl_dereference(net->mpls.platform_label);
for (index = 0; index < net->mpls.platform_labels; index++) { for (index = 0; index < net->mpls.platform_labels; index++) {
struct mpls_route *rt = net->mpls.platform_label[index]; struct mpls_route *rt = rtnl_dereference(platform_label[index]);
if (!rt) if (!rt)
continue; continue;
if (rt->rt_dev != dev) if (rtnl_dereference(rt->rt_dev) != dev)
continue; continue;
rt->rt_dev = NULL; rt->rt_dev = NULL;
} }
...@@ -653,6 +664,7 @@ static int mpls_rtm_newroute(struct sk_buff *skb, struct nlmsghdr *nlh) ...@@ -653,6 +664,7 @@ static int mpls_rtm_newroute(struct sk_buff *skb, struct nlmsghdr *nlh)
static int mpls_dump_route(struct sk_buff *skb, u32 portid, u32 seq, int event, static int mpls_dump_route(struct sk_buff *skb, u32 portid, u32 seq, int event,
u32 label, struct mpls_route *rt, int flags) u32 label, struct mpls_route *rt, int flags)
{ {
struct net_device *dev;
struct nlmsghdr *nlh; struct nlmsghdr *nlh;
struct rtmsg *rtm; struct rtmsg *rtm;
...@@ -676,7 +688,8 @@ static int mpls_dump_route(struct sk_buff *skb, u32 portid, u32 seq, int event, ...@@ -676,7 +688,8 @@ static int mpls_dump_route(struct sk_buff *skb, u32 portid, u32 seq, int event,
goto nla_put_failure; goto nla_put_failure;
if (nla_put_via(skb, rt->rt_via_family, rt->rt_via, rt->rt_via_alen)) if (nla_put_via(skb, rt->rt_via_family, rt->rt_via, rt->rt_via_alen))
goto nla_put_failure; goto nla_put_failure;
if (rt->rt_dev && nla_put_u32(skb, RTA_OIF, rt->rt_dev->ifindex)) dev = rtnl_dereference(rt->rt_dev);
if (dev && nla_put_u32(skb, RTA_OIF, dev->ifindex))
goto nla_put_failure; goto nla_put_failure;
if (nla_put_labels(skb, RTA_DST, 1, &label)) if (nla_put_labels(skb, RTA_DST, 1, &label))
goto nla_put_failure; goto nla_put_failure;
...@@ -692,6 +705,8 @@ static int mpls_dump_route(struct sk_buff *skb, u32 portid, u32 seq, int event, ...@@ -692,6 +705,8 @@ static int mpls_dump_route(struct sk_buff *skb, u32 portid, u32 seq, int event,
static int mpls_dump_routes(struct sk_buff *skb, struct netlink_callback *cb) static int mpls_dump_routes(struct sk_buff *skb, struct netlink_callback *cb)
{ {
struct net *net = sock_net(skb->sk); struct net *net = sock_net(skb->sk);
struct mpls_route __rcu **platform_label;
size_t platform_labels;
unsigned int index; unsigned int index;
ASSERT_RTNL(); ASSERT_RTNL();
...@@ -700,9 +715,11 @@ static int mpls_dump_routes(struct sk_buff *skb, struct netlink_callback *cb) ...@@ -700,9 +715,11 @@ static int mpls_dump_routes(struct sk_buff *skb, struct netlink_callback *cb)
if (index < 16) if (index < 16)
index = 16; index = 16;
for (; index < net->mpls.platform_labels; index++) { platform_label = rtnl_dereference(net->mpls.platform_label);
platform_labels = net->mpls.platform_labels;
for (; index < platform_labels; index++) {
struct mpls_route *rt; struct mpls_route *rt;
rt = net->mpls.platform_label[index]; rt = rtnl_dereference(platform_label[index]);
if (!rt) if (!rt)
continue; continue;
...@@ -780,7 +797,7 @@ static int resize_platform_label_table(struct net *net, size_t limit) ...@@ -780,7 +797,7 @@ static int resize_platform_label_table(struct net *net, size_t limit)
rt0 = mpls_rt_alloc(lo->addr_len); rt0 = mpls_rt_alloc(lo->addr_len);
if (!rt0) if (!rt0)
goto nort0; goto nort0;
rt0->rt_dev = lo; RCU_INIT_POINTER(rt0->rt_dev, lo);
rt0->rt_protocol = RTPROT_KERNEL; rt0->rt_protocol = RTPROT_KERNEL;
rt0->rt_via_family = AF_PACKET; rt0->rt_via_family = AF_PACKET;
memcpy(rt0->rt_via, lo->dev_addr, lo->addr_len); memcpy(rt0->rt_via, lo->dev_addr, lo->addr_len);
...@@ -790,7 +807,7 @@ static int resize_platform_label_table(struct net *net, size_t limit) ...@@ -790,7 +807,7 @@ static int resize_platform_label_table(struct net *net, size_t limit)
rt2 = mpls_rt_alloc(lo->addr_len); rt2 = mpls_rt_alloc(lo->addr_len);
if (!rt2) if (!rt2)
goto nort2; goto nort2;
rt2->rt_dev = lo; RCU_INIT_POINTER(rt2->rt_dev, lo);
rt2->rt_protocol = RTPROT_KERNEL; rt2->rt_protocol = RTPROT_KERNEL;
rt2->rt_via_family = AF_PACKET; rt2->rt_via_family = AF_PACKET;
memcpy(rt2->rt_via, lo->dev_addr, lo->addr_len); memcpy(rt2->rt_via, lo->dev_addr, lo->addr_len);
...@@ -798,7 +815,7 @@ static int resize_platform_label_table(struct net *net, size_t limit) ...@@ -798,7 +815,7 @@ static int resize_platform_label_table(struct net *net, size_t limit)
rtnl_lock(); rtnl_lock();
/* Remember the original table */ /* Remember the original table */
old = net->mpls.platform_label; old = rtnl_dereference(net->mpls.platform_label);
old_limit = net->mpls.platform_labels; old_limit = net->mpls.platform_labels;
/* Free any labels beyond the new table */ /* Free any labels beyond the new table */
...@@ -815,19 +832,19 @@ static int resize_platform_label_table(struct net *net, size_t limit) ...@@ -815,19 +832,19 @@ static int resize_platform_label_table(struct net *net, size_t limit)
/* If needed set the predefined labels */ /* If needed set the predefined labels */
if ((old_limit <= LABEL_IPV6_EXPLICIT_NULL) && if ((old_limit <= LABEL_IPV6_EXPLICIT_NULL) &&
(limit > LABEL_IPV6_EXPLICIT_NULL)) { (limit > LABEL_IPV6_EXPLICIT_NULL)) {
labels[LABEL_IPV6_EXPLICIT_NULL] = rt2; RCU_INIT_POINTER(labels[LABEL_IPV6_EXPLICIT_NULL], rt2);
rt2 = NULL; rt2 = NULL;
} }
if ((old_limit <= LABEL_IPV4_EXPLICIT_NULL) && if ((old_limit <= LABEL_IPV4_EXPLICIT_NULL) &&
(limit > LABEL_IPV4_EXPLICIT_NULL)) { (limit > LABEL_IPV4_EXPLICIT_NULL)) {
labels[LABEL_IPV4_EXPLICIT_NULL] = rt0; RCU_INIT_POINTER(labels[LABEL_IPV4_EXPLICIT_NULL], rt0);
rt0 = NULL; rt0 = NULL;
} }
/* Update the global pointers */ /* Update the global pointers */
net->mpls.platform_labels = limit; net->mpls.platform_labels = limit;
net->mpls.platform_label = labels; rcu_assign_pointer(net->mpls.platform_label, labels);
rtnl_unlock(); rtnl_unlock();
...@@ -903,6 +920,8 @@ static int mpls_net_init(struct net *net) ...@@ -903,6 +920,8 @@ static int mpls_net_init(struct net *net)
static void mpls_net_exit(struct net *net) static void mpls_net_exit(struct net *net)
{ {
struct mpls_route __rcu **platform_label;
size_t platform_labels;
struct ctl_table *table; struct ctl_table *table;
unsigned int index; unsigned int index;
...@@ -910,8 +929,8 @@ static void mpls_net_exit(struct net *net) ...@@ -910,8 +929,8 @@ static void mpls_net_exit(struct net *net)
unregister_net_sysctl_table(net->mpls.ctl); unregister_net_sysctl_table(net->mpls.ctl);
kfree(table); kfree(table);
/* An rcu grace period haselapsed since there was a device in /* An rcu grace period has passed since there was a device in
* the network namespace (and thus the last in fqlight packet) * the network namespace (and thus the last in flight packet)
* left this network namespace. This is because * left this network namespace. This is because
* unregister_netdevice_many and netdev_run_todo has completed * unregister_netdevice_many and netdev_run_todo has completed
* for each network device that was in this network namespace. * for each network device that was in this network namespace.
...@@ -920,14 +939,16 @@ static void mpls_net_exit(struct net *net) ...@@ -920,14 +939,16 @@ static void mpls_net_exit(struct net *net)
* freeing the platform_label table. * freeing the platform_label table.
*/ */
rtnl_lock(); rtnl_lock();
for (index = 0; index < net->mpls.platform_labels; index++) { platform_label = rtnl_dereference(net->mpls.platform_label);
struct mpls_route *rt = net->mpls.platform_label[index]; platform_labels = net->mpls.platform_labels;
rcu_assign_pointer(net->mpls.platform_label[index], NULL); for (index = 0; index < platform_labels; index++) {
struct mpls_route *rt = rtnl_dereference(platform_label[index]);
RCU_INIT_POINTER(platform_label[index], NULL);
mpls_rt_free(rt); mpls_rt_free(rt);
} }
rtnl_unlock(); rtnl_unlock();
kvfree(net->mpls.platform_label); kvfree(platform_label);
} }
static struct pernet_operations mpls_net_ops = { static struct pernet_operations mpls_net_ops = {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment