Commit 6853dd48 authored by Florian Westphal's avatar Florian Westphal Committed by David S. Miller

rtnetlink: protect handler table with rcu

Note that netlink dumps still acquire rtnl mutex via the netlink
dump infrastructure.
Signed-off-by: default avatarFlorian Westphal <fw@strlen.de>
Reviewed-by: default avatarHannes Frederic Sowa <hannes@stressinduktion.org>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 0cc09020
...@@ -126,7 +126,7 @@ bool lockdep_rtnl_is_held(void) ...@@ -126,7 +126,7 @@ bool lockdep_rtnl_is_held(void)
EXPORT_SYMBOL(lockdep_rtnl_is_held); EXPORT_SYMBOL(lockdep_rtnl_is_held);
#endif /* #ifdef CONFIG_PROVE_LOCKING */ #endif /* #ifdef CONFIG_PROVE_LOCKING */
static struct rtnl_link *rtnl_msg_handlers[RTNL_FAMILY_MAX + 1]; static struct rtnl_link __rcu *rtnl_msg_handlers[RTNL_FAMILY_MAX + 1];
static refcount_t rtnl_msg_handlers_ref[RTNL_FAMILY_MAX + 1]; static refcount_t rtnl_msg_handlers_ref[RTNL_FAMILY_MAX + 1];
static inline int rtm_msgindex(int msgtype) static inline int rtm_msgindex(int msgtype)
...@@ -143,36 +143,6 @@ static inline int rtm_msgindex(int msgtype) ...@@ -143,36 +143,6 @@ static inline int rtm_msgindex(int msgtype)
return msgindex; return msgindex;
} }
static rtnl_doit_func rtnl_get_doit(int protocol, int msgindex)
{
struct rtnl_link *tab;
if (protocol <= RTNL_FAMILY_MAX)
tab = rtnl_msg_handlers[protocol];
else
tab = NULL;
if (tab == NULL || tab[msgindex].doit == NULL)
tab = rtnl_msg_handlers[PF_UNSPEC];
return tab[msgindex].doit;
}
static rtnl_dumpit_func rtnl_get_dumpit(int protocol, int msgindex)
{
struct rtnl_link *tab;
if (protocol <= RTNL_FAMILY_MAX)
tab = rtnl_msg_handlers[protocol];
else
tab = NULL;
if (tab == NULL || tab[msgindex].dumpit == NULL)
tab = rtnl_msg_handlers[PF_UNSPEC];
return tab[msgindex].dumpit;
}
/** /**
* __rtnl_register - Register a rtnetlink message type * __rtnl_register - Register a rtnetlink message type
* @protocol: Protocol family or PF_UNSPEC * @protocol: Protocol family or PF_UNSPEC
...@@ -201,18 +171,17 @@ int __rtnl_register(int protocol, int msgtype, ...@@ -201,18 +171,17 @@ int __rtnl_register(int protocol, int msgtype,
BUG_ON(protocol < 0 || protocol > RTNL_FAMILY_MAX); BUG_ON(protocol < 0 || protocol > RTNL_FAMILY_MAX);
msgindex = rtm_msgindex(msgtype); msgindex = rtm_msgindex(msgtype);
tab = rtnl_msg_handlers[protocol]; tab = rcu_dereference(rtnl_msg_handlers[protocol]);
if (tab == NULL) { if (tab == NULL) {
tab = kcalloc(RTM_NR_MSGTYPES, sizeof(*tab), GFP_KERNEL); tab = kcalloc(RTM_NR_MSGTYPES, sizeof(*tab), GFP_KERNEL);
if (tab == NULL) if (tab == NULL)
return -ENOBUFS; return -ENOBUFS;
rtnl_msg_handlers[protocol] = tab; rcu_assign_pointer(rtnl_msg_handlers[protocol], tab);
} }
if (doit) if (doit)
tab[msgindex].doit = doit; tab[msgindex].doit = doit;
if (dumpit) if (dumpit)
tab[msgindex].dumpit = dumpit; tab[msgindex].dumpit = dumpit;
...@@ -249,16 +218,22 @@ EXPORT_SYMBOL_GPL(rtnl_register); ...@@ -249,16 +218,22 @@ EXPORT_SYMBOL_GPL(rtnl_register);
*/ */
int rtnl_unregister(int protocol, int msgtype) int rtnl_unregister(int protocol, int msgtype)
{ {
struct rtnl_link *handlers;
int msgindex; int msgindex;
BUG_ON(protocol < 0 || protocol > RTNL_FAMILY_MAX); BUG_ON(protocol < 0 || protocol > RTNL_FAMILY_MAX);
msgindex = rtm_msgindex(msgtype); msgindex = rtm_msgindex(msgtype);
if (rtnl_msg_handlers[protocol] == NULL) rtnl_lock();
handlers = rtnl_dereference(rtnl_msg_handlers[protocol]);
if (!handlers) {
rtnl_unlock();
return -ENOENT; return -ENOENT;
}
rtnl_msg_handlers[protocol][msgindex].doit = NULL; handlers[msgindex].doit = NULL;
rtnl_msg_handlers[protocol][msgindex].dumpit = NULL; handlers[msgindex].dumpit = NULL;
rtnl_unlock();
return 0; return 0;
} }
...@@ -278,10 +253,12 @@ void rtnl_unregister_all(int protocol) ...@@ -278,10 +253,12 @@ void rtnl_unregister_all(int protocol)
BUG_ON(protocol < 0 || protocol > RTNL_FAMILY_MAX); BUG_ON(protocol < 0 || protocol > RTNL_FAMILY_MAX);
rtnl_lock(); rtnl_lock();
handlers = rtnl_msg_handlers[protocol]; handlers = rtnl_dereference(rtnl_msg_handlers[protocol]);
rtnl_msg_handlers[protocol] = NULL; RCU_INIT_POINTER(rtnl_msg_handlers[protocol], NULL);
rtnl_unlock(); rtnl_unlock();
synchronize_net();
while (refcount_read(&rtnl_msg_handlers_ref[protocol]) > 0) while (refcount_read(&rtnl_msg_handlers_ref[protocol]) > 0)
schedule(); schedule();
kfree(handlers); kfree(handlers);
...@@ -2820,11 +2797,13 @@ static u16 rtnl_calcit(struct sk_buff *skb, struct nlmsghdr *nlh) ...@@ -2820,11 +2797,13 @@ static u16 rtnl_calcit(struct sk_buff *skb, struct nlmsghdr *nlh)
* traverse the list of net devices and compute the minimum * traverse the list of net devices and compute the minimum
* buffer size based upon the filter mask. * buffer size based upon the filter mask.
*/ */
list_for_each_entry(dev, &net->dev_base_head, dev_list) { rcu_read_lock();
for_each_netdev_rcu(net, dev) {
min_ifinfo_dump_size = max_t(u16, min_ifinfo_dump_size, min_ifinfo_dump_size = max_t(u16, min_ifinfo_dump_size,
if_nlmsg_size(dev, if_nlmsg_size(dev,
ext_filter_mask)); ext_filter_mask));
} }
rcu_read_unlock();
return nlmsg_total_size(min_ifinfo_dump_size); return nlmsg_total_size(min_ifinfo_dump_size);
} }
...@@ -2836,19 +2815,29 @@ static int rtnl_dump_all(struct sk_buff *skb, struct netlink_callback *cb) ...@@ -2836,19 +2815,29 @@ static int rtnl_dump_all(struct sk_buff *skb, struct netlink_callback *cb)
if (s_idx == 0) if (s_idx == 0)
s_idx = 1; s_idx = 1;
for (idx = 1; idx <= RTNL_FAMILY_MAX; idx++) { for (idx = 1; idx <= RTNL_FAMILY_MAX; idx++) {
int type = cb->nlh->nlmsg_type-RTM_BASE; int type = cb->nlh->nlmsg_type-RTM_BASE;
struct rtnl_link *handlers;
rtnl_dumpit_func dumpit;
if (idx < s_idx || idx == PF_PACKET) if (idx < s_idx || idx == PF_PACKET)
continue; continue;
if (rtnl_msg_handlers[idx] == NULL ||
rtnl_msg_handlers[idx][type].dumpit == NULL) handlers = rtnl_dereference(rtnl_msg_handlers[idx]);
if (!handlers)
continue; continue;
dumpit = READ_ONCE(handlers[type].dumpit);
if (!dumpit)
continue;
if (idx > s_idx) { if (idx > s_idx) {
memset(&cb->args[0], 0, sizeof(cb->args)); memset(&cb->args[0], 0, sizeof(cb->args));
cb->prev_seq = 0; cb->prev_seq = 0;
cb->seq = 0; cb->seq = 0;
} }
if (rtnl_msg_handlers[idx][type].dumpit(skb, cb)) if (dumpit(skb, cb))
break; break;
} }
cb->family = idx; cb->family = idx;
...@@ -4151,11 +4140,12 @@ static int rtnetlink_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh, ...@@ -4151,11 +4140,12 @@ static int rtnetlink_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh,
struct netlink_ext_ack *extack) struct netlink_ext_ack *extack)
{ {
struct net *net = sock_net(skb->sk); struct net *net = sock_net(skb->sk);
struct rtnl_link *handlers;
int err = -EOPNOTSUPP;
rtnl_doit_func doit; rtnl_doit_func doit;
int kind; int kind;
int family; int family;
int type; int type;
int err;
type = nlh->nlmsg_type; type = nlh->nlmsg_type;
if (type > RTM_MAX) if (type > RTM_MAX)
...@@ -4173,23 +4163,40 @@ static int rtnetlink_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh, ...@@ -4173,23 +4163,40 @@ static int rtnetlink_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh,
if (kind != 2 && !netlink_net_capable(skb, CAP_NET_ADMIN)) if (kind != 2 && !netlink_net_capable(skb, CAP_NET_ADMIN))
return -EPERM; return -EPERM;
if (family > ARRAY_SIZE(rtnl_msg_handlers))
family = PF_UNSPEC;
rcu_read_lock();
handlers = rcu_dereference(rtnl_msg_handlers[family]);
if (!handlers) {
family = PF_UNSPEC;
handlers = rcu_dereference(rtnl_msg_handlers[family]);
}
if (kind == 2 && nlh->nlmsg_flags&NLM_F_DUMP) { if (kind == 2 && nlh->nlmsg_flags&NLM_F_DUMP) {
struct sock *rtnl; struct sock *rtnl;
rtnl_dumpit_func dumpit; rtnl_dumpit_func dumpit;
u16 min_dump_alloc = 0; u16 min_dump_alloc = 0;
rtnl_lock(); dumpit = READ_ONCE(handlers[type].dumpit);
if (!dumpit) {
family = PF_UNSPEC;
handlers = rcu_dereference(rtnl_msg_handlers[PF_UNSPEC]);
if (!handlers)
goto err_unlock;
dumpit = rtnl_get_dumpit(family, type); dumpit = READ_ONCE(handlers[type].dumpit);
if (dumpit == NULL) if (!dumpit)
goto err_unlock; goto err_unlock;
}
refcount_inc(&rtnl_msg_handlers_ref[family]); refcount_inc(&rtnl_msg_handlers_ref[family]);
if (type == RTM_GETLINK) if (type == RTM_GETLINK)
min_dump_alloc = rtnl_calcit(skb, nlh); min_dump_alloc = rtnl_calcit(skb, nlh);
__rtnl_unlock(); rcu_read_unlock();
rtnl = net->rtnl; rtnl = net->rtnl;
{ {
struct netlink_dump_control c = { struct netlink_dump_control c = {
...@@ -4202,18 +4209,20 @@ static int rtnetlink_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh, ...@@ -4202,18 +4209,20 @@ static int rtnetlink_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh,
return err; return err;
} }
rtnl_lock(); rcu_read_unlock();
doit = rtnl_get_doit(family, type);
if (doit == NULL)
goto err_unlock;
rtnl_lock();
handlers = rtnl_dereference(rtnl_msg_handlers[family]);
if (handlers) {
doit = READ_ONCE(handlers[type].doit);
if (doit)
err = doit(skb, nlh, extack); err = doit(skb, nlh, extack);
}
rtnl_unlock(); rtnl_unlock();
return err; return err;
err_unlock: err_unlock:
rtnl_unlock(); rcu_read_unlock();
return -EOPNOTSUPP; return -EOPNOTSUPP;
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment