Commit 10bc9563 authored by Herbert Xu's avatar Herbert Xu Committed by Patrick McHardy

[NET]: Add reference counting to neigh_parms.

I've added a refcnt on neigh_parms as well as a dead flag.  The latter
is checked under the tbl_lock before adding a neigh entry to the hash
table.

The non-trivial bit of the patch is the first chunk of net/core/neighbour.c.
I removed that line because not doing so would mean that I have to drop
the reference to the parms right there.  That would've lead to race
conditions since many places dereference neigh->parms without holding
locks.  It's also unnecessary to reset n->parms since we're no longer
in a hurry to see it go due to the new ref counting.

You'll also notice that I've put all dereferences of dev->*_ptr under
the rcu_read_lock().  Without this we may get a neigh_parms that's
already been released.

Incidentally a lot of these places were racy even before the RCU change.
For example, in the IPv6 case neigh->parms may be set to a value that's
just been released.

Finally in order to make sure that all stale entries are purged as
quickly as possible I've added neigh_ifdown/arp_ifdown calls after
every neigh_parms_release call.  In many cases we now have multiple
calls to neigh_ifdown in the shutdown path.  I didn't remove the
earlier calls because there may be hidden dependencies for them to
be there.  Once the respective maintainers have looked at them we
can probably remove most of them.
Signed-off-by: default avatarHerbert Xu <herbert@gondor.apana.org.au>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 954bebf6
...@@ -6710,19 +6710,28 @@ static int ...@@ -6710,19 +6710,28 @@ static int
qeth_arp_constructor(struct neighbour *neigh) qeth_arp_constructor(struct neighbour *neigh)
{ {
struct net_device *dev = neigh->dev; struct net_device *dev = neigh->dev;
struct in_device *in_dev = in_dev_get(dev); struct in_device *in_dev;
struct neigh_parms *parms;
if (in_dev == NULL)
return -EINVAL;
if (!qeth_verify_dev(dev)) { if (!qeth_verify_dev(dev)) {
in_dev_put(in_dev);
return qeth_old_arp_constructor(neigh); return qeth_old_arp_constructor(neigh);
} }
rcu_read_lock();
in_dev = __in_dev_get(dev);
if (in_dev == NULL) {
rcu_read_unlock();
return -EINVAL;
}
parms = in_dev->arp_parms;
if (parms) {
__neigh_parms_put(neigh->parms);
neigh->parms = neigh_parms_clone(parms);
}
rcu_read_unlock();
neigh->type = inet_addr_type(*(u32 *) neigh->primary_key); neigh->type = inet_addr_type(*(u32 *) neigh->primary_key);
if (in_dev->arp_parms)
neigh->parms = in_dev->arp_parms;
in_dev_put(in_dev);
neigh->nud_state = NUD_NOARP; neigh->nud_state = NUD_NOARP;
neigh->ops = arp_direct_ops; neigh->ops = arp_direct_ops;
neigh->output = neigh->ops->queue_xmit; neigh->output = neigh->ops->queue_xmit;
......
...@@ -67,6 +67,8 @@ struct neigh_parms ...@@ -67,6 +67,8 @@ struct neigh_parms
void *sysctl_table; void *sysctl_table;
int dead;
atomic_t refcnt;
struct rcu_head rcu_head; struct rcu_head rcu_head;
int base_reachable_time; int base_reachable_time;
...@@ -199,6 +201,7 @@ extern struct neighbour *neigh_event_ns(struct neigh_table *tbl, ...@@ -199,6 +201,7 @@ extern struct neighbour *neigh_event_ns(struct neigh_table *tbl,
extern struct neigh_parms *neigh_parms_alloc(struct net_device *dev, struct neigh_table *tbl); extern struct neigh_parms *neigh_parms_alloc(struct net_device *dev, struct neigh_table *tbl);
extern void neigh_parms_release(struct neigh_table *tbl, struct neigh_parms *parms); extern void neigh_parms_release(struct neigh_table *tbl, struct neigh_parms *parms);
extern void neigh_parms_destroy(struct neigh_parms *parms);
extern unsigned long neigh_rand_reach_time(unsigned long base); extern unsigned long neigh_rand_reach_time(unsigned long base);
extern void pneigh_enqueue(struct neigh_table *tbl, struct neigh_parms *p, extern void pneigh_enqueue(struct neigh_table *tbl, struct neigh_parms *p,
...@@ -220,6 +223,23 @@ extern int neigh_sysctl_register(struct net_device *dev, ...@@ -220,6 +223,23 @@ extern int neigh_sysctl_register(struct net_device *dev,
proc_handler *proc_handler); proc_handler *proc_handler);
extern void neigh_sysctl_unregister(struct neigh_parms *p); extern void neigh_sysctl_unregister(struct neigh_parms *p);
static inline void __neigh_parms_put(struct neigh_parms *parms)
{
atomic_dec(&parms->refcnt);
}
static inline void neigh_parms_put(struct neigh_parms *parms)
{
if (atomic_dec_and_test(&parms->refcnt))
neigh_parms_destroy(parms);
}
static inline struct neigh_parms *neigh_parms_clone(struct neigh_parms *parms)
{
atomic_inc(&parms->refcnt);
return parms;
}
/* /*
* Neighbour references * Neighbour references
*/ */
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <linux/bitops.h> #include <linux/bitops.h>
#include <linux/proc_fs.h> #include <linux/proc_fs.h>
#include <linux/seq_file.h> #include <linux/seq_file.h>
#include <linux/rcupdate.h>
#include <net/route.h> /* for struct rtable and routing */ #include <net/route.h> /* for struct rtable and routing */
#include <net/icmp.h> /* icmp_send */ #include <net/icmp.h> /* icmp_send */
#include <asm/param.h> /* for HZ */ #include <asm/param.h> /* for HZ */
...@@ -311,13 +312,27 @@ static int clip_constructor(struct neighbour *neigh) ...@@ -311,13 +312,27 @@ static int clip_constructor(struct neighbour *neigh)
{ {
struct atmarp_entry *entry = NEIGH2ENTRY(neigh); struct atmarp_entry *entry = NEIGH2ENTRY(neigh);
struct net_device *dev = neigh->dev; struct net_device *dev = neigh->dev;
struct in_device *in_dev = dev->ip_ptr; struct in_device *in_dev;
struct neigh_parms *parms;
DPRINTK("clip_constructor (neigh %p, entry %p)\n",neigh,entry); DPRINTK("clip_constructor (neigh %p, entry %p)\n",neigh,entry);
if (!in_dev) return -EINVAL;
neigh->type = inet_addr_type(entry->ip); neigh->type = inet_addr_type(entry->ip);
if (neigh->type != RTN_UNICAST) return -EINVAL; if (neigh->type != RTN_UNICAST) return -EINVAL;
if (in_dev->arp_parms) neigh->parms = in_dev->arp_parms;
rcu_read_lock();
in_dev = __in_dev_get(dev);
if (!in_dev) {
rcu_read_unlock();
return -EINVAL;
}
parms = in_dev->arp_parms;
if (parms) {
__neigh_parms_put(neigh->parms);
neigh->parms = neigh_parms_clone(parms);
}
rcu_read_unlock();
neigh->ops = &clip_neigh_ops; neigh->ops = &clip_neigh_ops;
neigh->output = neigh->nud_state & NUD_VALID ? neigh->output = neigh->nud_state & NUD_VALID ?
neigh->ops->connected_output : neigh->ops->output; neigh->ops->connected_output : neigh->ops->output;
......
...@@ -227,7 +227,6 @@ int neigh_ifdown(struct neigh_table *tbl, struct net_device *dev) ...@@ -227,7 +227,6 @@ int neigh_ifdown(struct neigh_table *tbl, struct net_device *dev)
we must kill timers etc. and move we must kill timers etc. and move
it to safe state. it to safe state.
*/ */
n->parms = &tbl->parms;
skb_queue_purge(&n->arp_queue); skb_queue_purge(&n->arp_queue);
n->output = neigh_blackhole; n->output = neigh_blackhole;
if (n->nud_state & NUD_VALID) if (n->nud_state & NUD_VALID)
...@@ -273,7 +272,7 @@ static struct neighbour *neigh_alloc(struct neigh_table *tbl) ...@@ -273,7 +272,7 @@ static struct neighbour *neigh_alloc(struct neigh_table *tbl)
n->updated = n->used = now; n->updated = n->used = now;
n->nud_state = NUD_NONE; n->nud_state = NUD_NONE;
n->output = neigh_blackhole; n->output = neigh_blackhole;
n->parms = &tbl->parms; n->parms = neigh_parms_clone(&tbl->parms);
init_timer(&n->timer); init_timer(&n->timer);
n->timer.function = neigh_timer_handler; n->timer.function = neigh_timer_handler;
n->timer.data = (unsigned long)n; n->timer.data = (unsigned long)n;
...@@ -340,12 +339,16 @@ struct neighbour *neigh_create(struct neigh_table *tbl, const void *pkey, ...@@ -340,12 +339,16 @@ struct neighbour *neigh_create(struct neigh_table *tbl, const void *pkey,
hash_val = tbl->hash(pkey, dev); hash_val = tbl->hash(pkey, dev);
write_lock_bh(&tbl->lock); write_lock_bh(&tbl->lock);
if (n->parms->dead) {
rc = ERR_PTR(-EINVAL);
goto out_tbl_unlock;
}
for (n1 = tbl->hash_buckets[hash_val]; n1; n1 = n1->next) { for (n1 = tbl->hash_buckets[hash_val]; n1; n1 = n1->next) {
if (dev == n1->dev && !memcmp(n1->primary_key, pkey, key_len)) { if (dev == n1->dev && !memcmp(n1->primary_key, pkey, key_len)) {
neigh_hold(n1); neigh_hold(n1);
write_unlock_bh(&tbl->lock);
rc = n1; rc = n1;
goto out_neigh_release; goto out_tbl_unlock;
} }
} }
...@@ -358,6 +361,8 @@ struct neighbour *neigh_create(struct neigh_table *tbl, const void *pkey, ...@@ -358,6 +361,8 @@ struct neighbour *neigh_create(struct neigh_table *tbl, const void *pkey,
rc = n; rc = n;
out: out:
return rc; return rc;
out_tbl_unlock:
write_unlock_bh(&tbl->lock);
out_neigh_release: out_neigh_release:
neigh_release(n); neigh_release(n);
goto out; goto out;
...@@ -494,6 +499,7 @@ void neigh_destroy(struct neighbour *neigh) ...@@ -494,6 +499,7 @@ void neigh_destroy(struct neighbour *neigh)
skb_queue_purge(&neigh->arp_queue); skb_queue_purge(&neigh->arp_queue);
dev_put(neigh->dev); dev_put(neigh->dev);
neigh_parms_put(neigh->parms);
NEIGH_PRINTK2("neigh %p is destroyed.\n", neigh); NEIGH_PRINTK2("neigh %p is destroyed.\n", neigh);
...@@ -1120,6 +1126,7 @@ struct neigh_parms *neigh_parms_alloc(struct net_device *dev, ...@@ -1120,6 +1126,7 @@ struct neigh_parms *neigh_parms_alloc(struct net_device *dev,
if (p) { if (p) {
memcpy(p, &tbl->parms, sizeof(*p)); memcpy(p, &tbl->parms, sizeof(*p));
p->tbl = tbl; p->tbl = tbl;
atomic_set(&p->refcnt, 1);
INIT_RCU_HEAD(&p->rcu_head); INIT_RCU_HEAD(&p->rcu_head);
p->reachable_time = p->reachable_time =
neigh_rand_reach_time(p->base_reachable_time); neigh_rand_reach_time(p->base_reachable_time);
...@@ -1141,7 +1148,7 @@ static void neigh_rcu_free_parms(struct rcu_head *head) ...@@ -1141,7 +1148,7 @@ static void neigh_rcu_free_parms(struct rcu_head *head)
struct neigh_parms *parms = struct neigh_parms *parms =
container_of(head, struct neigh_parms, rcu_head); container_of(head, struct neigh_parms, rcu_head);
kfree(parms); neigh_parms_put(parms);
} }
void neigh_parms_release(struct neigh_table *tbl, struct neigh_parms *parms) void neigh_parms_release(struct neigh_table *tbl, struct neigh_parms *parms)
...@@ -1154,6 +1161,7 @@ void neigh_parms_release(struct neigh_table *tbl, struct neigh_parms *parms) ...@@ -1154,6 +1161,7 @@ void neigh_parms_release(struct neigh_table *tbl, struct neigh_parms *parms)
for (p = &tbl->parms.next; *p; p = &(*p)->next) { for (p = &tbl->parms.next; *p; p = &(*p)->next) {
if (*p == parms) { if (*p == parms) {
*p = parms->next; *p = parms->next;
parms->dead = 1;
write_unlock_bh(&tbl->lock); write_unlock_bh(&tbl->lock);
call_rcu(&parms->rcu_head, neigh_rcu_free_parms); call_rcu(&parms->rcu_head, neigh_rcu_free_parms);
return; return;
...@@ -1163,11 +1171,17 @@ void neigh_parms_release(struct neigh_table *tbl, struct neigh_parms *parms) ...@@ -1163,11 +1171,17 @@ void neigh_parms_release(struct neigh_table *tbl, struct neigh_parms *parms)
NEIGH_PRINTK1("neigh_parms_release: not found\n"); NEIGH_PRINTK1("neigh_parms_release: not found\n");
} }
void neigh_parms_destroy(struct neigh_parms *parms)
{
kfree(parms);
}
void neigh_table_init(struct neigh_table *tbl) void neigh_table_init(struct neigh_table *tbl)
{ {
unsigned long now = jiffies; unsigned long now = jiffies;
atomic_set(&tbl->parms.refcnt, 1);
INIT_RCU_HEAD(&tbl->parms.rcu_head); INIT_RCU_HEAD(&tbl->parms.rcu_head);
tbl->parms.reachable_time = tbl->parms.reachable_time =
neigh_rand_reach_time(tbl->parms.base_reachable_time); neigh_rand_reach_time(tbl->parms.base_reachable_time);
......
...@@ -1215,6 +1215,7 @@ static void dn_dev_delete(struct net_device *dev) ...@@ -1215,6 +1215,7 @@ static void dn_dev_delete(struct net_device *dev)
dev->dn_ptr = NULL; dev->dn_ptr = NULL;
neigh_parms_release(&dn_neigh_table, dn_db->neigh_parms); neigh_parms_release(&dn_neigh_table, dn_db->neigh_parms);
neigh_ifdown(&dn_neigh_table, dev);
if (dn_db->router) if (dn_db->router)
neigh_release(dn_db->router); neigh_release(dn_db->router);
......
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include <linux/netfilter_decnet.h> #include <linux/netfilter_decnet.h>
#include <linux/spinlock.h> #include <linux/spinlock.h>
#include <linux/seq_file.h> #include <linux/seq_file.h>
#include <linux/rcupdate.h>
#include <asm/atomic.h> #include <asm/atomic.h>
#include <net/neighbour.h> #include <net/neighbour.h>
#include <net/dst.h> #include <net/dst.h>
...@@ -134,13 +135,22 @@ static int dn_neigh_construct(struct neighbour *neigh) ...@@ -134,13 +135,22 @@ static int dn_neigh_construct(struct neighbour *neigh)
{ {
struct net_device *dev = neigh->dev; struct net_device *dev = neigh->dev;
struct dn_neigh *dn = (struct dn_neigh *)neigh; struct dn_neigh *dn = (struct dn_neigh *)neigh;
struct dn_dev *dn_db = (struct dn_dev *)dev->dn_ptr; struct dn_dev *dn_db;
struct neigh_parms *parms;
if (dn_db == NULL) rcu_read_lock();
dn_db = dev->dn_ptr;
if (dn_db == NULL) {
rcu_read_unlock();
return -EINVAL; return -EINVAL;
}
if (dn_db->neigh_parms) parms = dn_db->neigh_parms;
neigh->parms = dn_db->neigh_parms; if (parms) {
__neigh_parms_put(neigh->parms);
neigh->parms = neigh_parms_clone(parms);
}
rcu_read_unlock();
if (dn_db->use_long) if (dn_db->use_long)
neigh->ops = &dn_long_ops; neigh->ops = &dn_long_ops;
......
...@@ -96,6 +96,7 @@ ...@@ -96,6 +96,7 @@
#include <linux/stat.h> #include <linux/stat.h>
#include <linux/init.h> #include <linux/init.h>
#include <linux/net.h> #include <linux/net.h>
#include <linux/rcupdate.h>
#ifdef CONFIG_SYSCTL #ifdef CONFIG_SYSCTL
#include <linux/sysctl.h> #include <linux/sysctl.h>
#endif #endif
...@@ -237,16 +238,24 @@ static int arp_constructor(struct neighbour *neigh) ...@@ -237,16 +238,24 @@ static int arp_constructor(struct neighbour *neigh)
{ {
u32 addr = *(u32*)neigh->primary_key; u32 addr = *(u32*)neigh->primary_key;
struct net_device *dev = neigh->dev; struct net_device *dev = neigh->dev;
struct in_device *in_dev = in_dev_get(dev); struct in_device *in_dev;
struct neigh_parms *parms;
if (in_dev == NULL)
return -EINVAL;
neigh->type = inet_addr_type(addr); neigh->type = inet_addr_type(addr);
if (in_dev->arp_parms)
neigh->parms = in_dev->arp_parms;
in_dev_put(in_dev); rcu_read_lock();
in_dev = __in_dev_get(dev);
if (in_dev == NULL) {
rcu_read_unlock();
return -EINVAL;
}
parms = in_dev->arp_parms;
if (parms) {
__neigh_parms_put(neigh->parms);
neigh->parms = neigh_parms_clone(parms);
}
rcu_read_unlock();
if (dev->hard_header == NULL) { if (dev->hard_header == NULL) {
neigh->nud_state = NUD_NOARP; neigh->nud_state = NUD_NOARP;
......
...@@ -184,6 +184,7 @@ static void in_dev_rcu_put(struct rcu_head *head) ...@@ -184,6 +184,7 @@ static void in_dev_rcu_put(struct rcu_head *head)
static void inetdev_destroy(struct in_device *in_dev) static void inetdev_destroy(struct in_device *in_dev)
{ {
struct in_ifaddr *ifa; struct in_ifaddr *ifa;
struct net_device *dev;
ASSERT_RTNL(); ASSERT_RTNL();
...@@ -200,12 +201,15 @@ static void inetdev_destroy(struct in_device *in_dev) ...@@ -200,12 +201,15 @@ static void inetdev_destroy(struct in_device *in_dev)
devinet_sysctl_unregister(&in_dev->cnf); devinet_sysctl_unregister(&in_dev->cnf);
#endif #endif
in_dev->dev->ip_ptr = NULL; dev = in_dev->dev;
dev->ip_ptr = NULL;
#ifdef CONFIG_SYSCTL #ifdef CONFIG_SYSCTL
neigh_sysctl_unregister(in_dev->arp_parms); neigh_sysctl_unregister(in_dev->arp_parms);
#endif #endif
neigh_parms_release(&arp_tbl, in_dev->arp_parms); neigh_parms_release(&arp_tbl, in_dev->arp_parms);
arp_ifdown(dev);
call_rcu(&in_dev->rcu_head, in_dev_rcu_put); call_rcu(&in_dev->rcu_head, in_dev_rcu_put);
} }
......
...@@ -2072,6 +2072,7 @@ static int addrconf_ifdown(struct net_device *dev, int how) ...@@ -2072,6 +2072,7 @@ static int addrconf_ifdown(struct net_device *dev, int how)
neigh_sysctl_unregister(idev->nd_parms); neigh_sysctl_unregister(idev->nd_parms);
#endif #endif
neigh_parms_release(&nd_tbl, idev->nd_parms); neigh_parms_release(&nd_tbl, idev->nd_parms);
neigh_ifdown(&nd_tbl, dev);
in6_dev_put(idev); in6_dev_put(idev);
} }
return 0; return 0;
......
...@@ -58,6 +58,7 @@ ...@@ -58,6 +58,7 @@
#include <linux/in6.h> #include <linux/in6.h>
#include <linux/route.h> #include <linux/route.h>
#include <linux/init.h> #include <linux/init.h>
#include <linux/rcupdate.h>
#ifdef CONFIG_SYSCTL #ifdef CONFIG_SYSCTL
#include <linux/sysctl.h> #include <linux/sysctl.h>
#endif #endif
...@@ -284,14 +285,23 @@ static int ndisc_constructor(struct neighbour *neigh) ...@@ -284,14 +285,23 @@ static int ndisc_constructor(struct neighbour *neigh)
{ {
struct in6_addr *addr = (struct in6_addr*)&neigh->primary_key; struct in6_addr *addr = (struct in6_addr*)&neigh->primary_key;
struct net_device *dev = neigh->dev; struct net_device *dev = neigh->dev;
struct inet6_dev *in6_dev = in6_dev_get(dev); struct inet6_dev *in6_dev;
struct neigh_parms *parms;
int is_multicast = ipv6_addr_is_multicast(addr); int is_multicast = ipv6_addr_is_multicast(addr);
if (in6_dev == NULL) rcu_read_lock();
in6_dev = in6_dev_get(dev);
if (in6_dev == NULL) {
rcu_read_unlock();
return -EINVAL; return -EINVAL;
}
if (in6_dev->nd_parms) parms = in6_dev->nd_parms;
neigh->parms = in6_dev->nd_parms; if (parms) {
__neigh_parms_put(neigh->parms);
neigh->parms = neigh_parms_clone(parms);
}
rcu_read_unlock();
neigh->type = is_multicast ? RTN_MULTICAST : RTN_UNICAST; neigh->type = is_multicast ? RTN_MULTICAST : RTN_UNICAST;
if (dev->hard_header == NULL) { if (dev->hard_header == NULL) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment