Commit e4202511 authored by Florian Westphal's avatar Florian Westphal Committed by David S. Miller

rtnetlink: get reference on module before invoking handlers

Add yet another rtnl_register function.  It will be used by modules
that can be removed.

The passed module struct is used to prevent module unload while
a netlink dump is in progress or when a DOIT_UNLOCKED doit callback
is called.

Cc: Peter Zijlstra <peterz@infradead.org>
Signed-off-by: default avatarFlorian Westphal <fw@strlen.de>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent addf9b90
...@@ -17,6 +17,8 @@ int __rtnl_register(int protocol, int msgtype, ...@@ -17,6 +17,8 @@ int __rtnl_register(int protocol, int msgtype,
rtnl_doit_func, rtnl_dumpit_func, unsigned int flags); rtnl_doit_func, rtnl_dumpit_func, unsigned int flags);
void rtnl_register(int protocol, int msgtype, void rtnl_register(int protocol, int msgtype,
rtnl_doit_func, rtnl_dumpit_func, unsigned int flags); rtnl_doit_func, rtnl_dumpit_func, unsigned int flags);
int rtnl_register_module(struct module *owner, int protocol, int msgtype,
rtnl_doit_func, rtnl_dumpit_func, unsigned int flags);
int rtnl_unregister(int protocol, int msgtype); int rtnl_unregister(int protocol, int msgtype);
void rtnl_unregister_all(int protocol); void rtnl_unregister_all(int protocol);
......
...@@ -62,6 +62,7 @@ ...@@ -62,6 +62,7 @@
struct rtnl_link { struct rtnl_link {
rtnl_doit_func doit; rtnl_doit_func doit;
rtnl_dumpit_func dumpit; rtnl_dumpit_func dumpit;
struct module *owner;
unsigned int flags; unsigned int flags;
struct rcu_head rcu; struct rcu_head rcu;
}; };
...@@ -129,7 +130,6 @@ EXPORT_SYMBOL(lockdep_rtnl_is_held); ...@@ -129,7 +130,6 @@ EXPORT_SYMBOL(lockdep_rtnl_is_held);
#endif /* #ifdef CONFIG_PROVE_LOCKING */ #endif /* #ifdef CONFIG_PROVE_LOCKING */
static struct rtnl_link __rcu **rtnl_msg_handlers[RTNL_FAMILY_MAX + 1]; static struct rtnl_link __rcu **rtnl_msg_handlers[RTNL_FAMILY_MAX + 1];
static refcount_t rtnl_msg_handlers_ref[RTNL_FAMILY_MAX + 1];
static inline int rtm_msgindex(int msgtype) static inline int rtm_msgindex(int msgtype)
{ {
...@@ -159,27 +159,10 @@ static struct rtnl_link *rtnl_get_link(int protocol, int msgtype) ...@@ -159,27 +159,10 @@ static struct rtnl_link *rtnl_get_link(int protocol, int msgtype)
return tab[msgtype]; return tab[msgtype];
} }
/** static int rtnl_register_internal(struct module *owner,
* __rtnl_register - Register a rtnetlink message type int protocol, int msgtype,
* @protocol: Protocol family or PF_UNSPEC rtnl_doit_func doit, rtnl_dumpit_func dumpit,
* @msgtype: rtnetlink message type unsigned int flags)
* @doit: Function pointer called for each request message
* @dumpit: Function pointer called for each dump request (NLM_F_DUMP) message
* @flags: rtnl_link_flags to modifiy behaviour of doit/dumpit functions
*
* Registers the specified function pointers (at least one of them has
* to be non-NULL) to be called whenever a request message for the
* specified protocol family and message type is received.
*
* The special protocol family PF_UNSPEC may be used to define fallback
* function pointers for the case when no entry for the specific protocol
* family exists.
*
* Returns 0 on success or a negative error code.
*/
int __rtnl_register(int protocol, int msgtype,
rtnl_doit_func doit, rtnl_dumpit_func dumpit,
unsigned int flags)
{ {
struct rtnl_link **tab, *link, *old; struct rtnl_link **tab, *link, *old;
int msgindex; int msgindex;
...@@ -210,6 +193,9 @@ int __rtnl_register(int protocol, int msgtype, ...@@ -210,6 +193,9 @@ int __rtnl_register(int protocol, int msgtype,
goto unlock; goto unlock;
} }
WARN_ON(link->owner && link->owner != owner);
link->owner = owner;
WARN_ON(doit && link->doit && link->doit != doit); WARN_ON(doit && link->doit && link->doit != doit);
if (doit) if (doit)
link->doit = doit; link->doit = doit;
...@@ -228,6 +214,54 @@ int __rtnl_register(int protocol, int msgtype, ...@@ -228,6 +214,54 @@ int __rtnl_register(int protocol, int msgtype,
rtnl_unlock(); rtnl_unlock();
return ret; return ret;
} }
/**
* rtnl_register_module - Register a rtnetlink message type
*
* @owner: module registering the hook (THIS_MODULE)
* @protocol: Protocol family or PF_UNSPEC
* @msgtype: rtnetlink message type
* @doit: Function pointer called for each request message
* @dumpit: Function pointer called for each dump request (NLM_F_DUMP) message
* @flags: rtnl_link_flags to modifiy behaviour of doit/dumpit functions
*
* Like rtnl_register, but for use by removable modules.
*/
int rtnl_register_module(struct module *owner,
int protocol, int msgtype,
rtnl_doit_func doit, rtnl_dumpit_func dumpit,
unsigned int flags)
{
return rtnl_register_internal(owner, protocol, msgtype,
doit, dumpit, flags);
}
EXPORT_SYMBOL_GPL(rtnl_register_module);
/**
* __rtnl_register - Register a rtnetlink message type
* @protocol: Protocol family or PF_UNSPEC
* @msgtype: rtnetlink message type
* @doit: Function pointer called for each request message
* @dumpit: Function pointer called for each dump request (NLM_F_DUMP) message
* @flags: rtnl_link_flags to modifiy behaviour of doit/dumpit functions
*
* Registers the specified function pointers (at least one of them has
* to be non-NULL) to be called whenever a request message for the
* specified protocol family and message type is received.
*
* The special protocol family PF_UNSPEC may be used to define fallback
* function pointers for the case when no entry for the specific protocol
* family exists.
*
* Returns 0 on success or a negative error code.
*/
int __rtnl_register(int protocol, int msgtype,
rtnl_doit_func doit, rtnl_dumpit_func dumpit,
unsigned int flags)
{
return rtnl_register_internal(NULL, protocol, msgtype,
doit, dumpit, flags);
}
EXPORT_SYMBOL_GPL(__rtnl_register); EXPORT_SYMBOL_GPL(__rtnl_register);
/** /**
...@@ -311,8 +345,6 @@ void rtnl_unregister_all(int protocol) ...@@ -311,8 +345,6 @@ void rtnl_unregister_all(int protocol)
synchronize_net(); synchronize_net();
while (refcount_read(&rtnl_msg_handlers_ref[protocol]) > 1)
schedule();
kfree(tab); kfree(tab);
} }
EXPORT_SYMBOL_GPL(rtnl_unregister_all); EXPORT_SYMBOL_GPL(rtnl_unregister_all);
...@@ -4372,6 +4404,7 @@ static int rtnetlink_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh, ...@@ -4372,6 +4404,7 @@ static int rtnetlink_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh,
{ {
struct net *net = sock_net(skb->sk); struct net *net = sock_net(skb->sk);
struct rtnl_link *link; struct rtnl_link *link;
struct module *owner;
int err = -EOPNOTSUPP; int err = -EOPNOTSUPP;
rtnl_doit_func doit; rtnl_doit_func doit;
unsigned int flags; unsigned int flags;
...@@ -4408,24 +4441,32 @@ static int rtnetlink_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh, ...@@ -4408,24 +4441,32 @@ static int rtnetlink_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh,
if (!link || !link->dumpit) if (!link || !link->dumpit)
goto err_unlock; goto err_unlock;
} }
owner = link->owner;
dumpit = link->dumpit; dumpit = link->dumpit;
refcount_inc(&rtnl_msg_handlers_ref[family]);
if (type == RTM_GETLINK - RTM_BASE) if (type == RTM_GETLINK - RTM_BASE)
min_dump_alloc = rtnl_calcit(skb, nlh); min_dump_alloc = rtnl_calcit(skb, nlh);
err = 0;
/* need to do this before rcu_read_unlock() */
if (!try_module_get(owner))
err = -EPROTONOSUPPORT;
rcu_read_unlock(); rcu_read_unlock();
rtnl = net->rtnl; rtnl = net->rtnl;
{ if (err == 0) {
struct netlink_dump_control c = { struct netlink_dump_control c = {
.dump = dumpit, .dump = dumpit,
.min_dump_alloc = min_dump_alloc, .min_dump_alloc = min_dump_alloc,
.module = owner,
}; };
err = netlink_dump_start(rtnl, skb, nlh, &c); err = netlink_dump_start(rtnl, skb, nlh, &c);
/* netlink_dump_start() will keep a reference on
* module if dump is still in progress.
*/
module_put(owner);
} }
refcount_dec(&rtnl_msg_handlers_ref[family]);
return err; return err;
} }
...@@ -4437,14 +4478,19 @@ static int rtnetlink_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh, ...@@ -4437,14 +4478,19 @@ static int rtnetlink_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh,
goto out_unlock; goto out_unlock;
} }
owner = link->owner;
if (!try_module_get(owner)) {
err = -EPROTONOSUPPORT;
goto out_unlock;
}
flags = link->flags; flags = link->flags;
if (flags & RTNL_FLAG_DOIT_UNLOCKED) { if (flags & RTNL_FLAG_DOIT_UNLOCKED) {
refcount_inc(&rtnl_msg_handlers_ref[family]);
doit = link->doit; doit = link->doit;
rcu_read_unlock(); rcu_read_unlock();
if (doit) if (doit)
err = doit(skb, nlh, extack); err = doit(skb, nlh, extack);
refcount_dec(&rtnl_msg_handlers_ref[family]); module_put(owner);
return err; return err;
} }
rcu_read_unlock(); rcu_read_unlock();
...@@ -4455,6 +4501,8 @@ static int rtnetlink_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh, ...@@ -4455,6 +4501,8 @@ static int rtnetlink_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh,
err = link->doit(skb, nlh, extack); err = link->doit(skb, nlh, extack);
rtnl_unlock(); rtnl_unlock();
module_put(owner);
return err; return err;
out_unlock: out_unlock:
...@@ -4546,11 +4594,6 @@ static struct pernet_operations rtnetlink_net_ops = { ...@@ -4546,11 +4594,6 @@ static struct pernet_operations rtnetlink_net_ops = {
void __init rtnetlink_init(void) void __init rtnetlink_init(void)
{ {
int i;
for (i = 0; i < ARRAY_SIZE(rtnl_msg_handlers_ref); i++)
refcount_set(&rtnl_msg_handlers_ref[i], 1);
if (register_pernet_subsys(&rtnetlink_net_ops)) if (register_pernet_subsys(&rtnetlink_net_ops))
panic("rtnetlink_init: cannot initialize rtnetlink\n"); panic("rtnetlink_init: cannot initialize rtnetlink\n");
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment