Commit af65bdfc authored by Patrick McHardy's avatar Patrick McHardy Committed by David S. Miller

[NETLINK]: Switch cb_lock spinlock to mutex and allow to override it

Switch cb_lock to mutex and allow netlink kernel users to override it
with a subsystem specific mutex for consistent locking in dump callbacks.
All netlink_dump_start users have been audited not to rely on any
side-effects of the previously used spinlock.
Signed-off-by: default avatarPatrick McHardy <kaber@trash.net>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent b076deb8
...@@ -448,7 +448,7 @@ static int __devinit cn_init(void) ...@@ -448,7 +448,7 @@ static int __devinit cn_init(void)
dev->nls = netlink_kernel_create(NETLINK_CONNECTOR, dev->nls = netlink_kernel_create(NETLINK_CONNECTOR,
CN_NETLINK_USERS + 0xf, CN_NETLINK_USERS + 0xf,
dev->input, THIS_MODULE); dev->input, NULL, THIS_MODULE);
if (!dev->nls) if (!dev->nls)
return -EIO; return -EIO;
......
...@@ -168,7 +168,8 @@ scsi_netlink_init(void) ...@@ -168,7 +168,8 @@ scsi_netlink_init(void)
} }
scsi_nl_sock = netlink_kernel_create(NETLINK_SCSITRANSPORT, scsi_nl_sock = netlink_kernel_create(NETLINK_SCSITRANSPORT,
SCSI_NL_GRP_CNT, scsi_nl_rcv, THIS_MODULE); SCSI_NL_GRP_CNT, scsi_nl_rcv, NULL,
THIS_MODULE);
if (!scsi_nl_sock) { if (!scsi_nl_sock) {
printk(KERN_ERR "%s: register of recieve handler failed\n", printk(KERN_ERR "%s: register of recieve handler failed\n",
__FUNCTION__); __FUNCTION__);
......
...@@ -1435,7 +1435,7 @@ static __init int iscsi_transport_init(void) ...@@ -1435,7 +1435,7 @@ static __init int iscsi_transport_init(void)
if (err) if (err)
goto unregister_conn_class; goto unregister_conn_class;
nls = netlink_kernel_create(NETLINK_ISCSI, 1, iscsi_if_rx, nls = netlink_kernel_create(NETLINK_ISCSI, 1, iscsi_if_rx, NULL,
THIS_MODULE); THIS_MODULE);
if (!nls) { if (!nls) {
err = -ENOBUFS; err = -ENOBUFS;
......
...@@ -229,7 +229,7 @@ int ecryptfs_init_netlink(void) ...@@ -229,7 +229,7 @@ int ecryptfs_init_netlink(void)
ecryptfs_nl_sock = netlink_kernel_create(NETLINK_ECRYPTFS, 0, ecryptfs_nl_sock = netlink_kernel_create(NETLINK_ECRYPTFS, 0,
ecryptfs_receive_nl_message, ecryptfs_receive_nl_message,
THIS_MODULE); NULL, THIS_MODULE);
if (!ecryptfs_nl_sock) { if (!ecryptfs_nl_sock) {
rc = -EIO; rc = -EIO;
ecryptfs_printk(KERN_ERR, "Failed to create netlink socket\n"); ecryptfs_printk(KERN_ERR, "Failed to create netlink socket\n");
......
...@@ -157,7 +157,10 @@ struct netlink_skb_parms ...@@ -157,7 +157,10 @@ struct netlink_skb_parms
#define NETLINK_CREDS(skb) (&NETLINK_CB((skb)).creds) #define NETLINK_CREDS(skb) (&NETLINK_CB((skb)).creds)
extern struct sock *netlink_kernel_create(int unit, unsigned int groups, void (*input)(struct sock *sk, int len), struct module *module); extern struct sock *netlink_kernel_create(int unit, unsigned int groups,
void (*input)(struct sock *sk, int len),
struct mutex *cb_mutex,
struct module *module);
extern void netlink_ack(struct sk_buff *in_skb, struct nlmsghdr *nlh, int err); extern void netlink_ack(struct sk_buff *in_skb, struct nlmsghdr *nlh, int err);
extern int netlink_has_listeners(struct sock *sk, unsigned int group); extern int netlink_has_listeners(struct sock *sk, unsigned int group);
extern int netlink_unicast(struct sock *ssk, struct sk_buff *skb, __u32 pid, int nonblock); extern int netlink_unicast(struct sock *ssk, struct sk_buff *skb, __u32 pid, int nonblock);
......
...@@ -795,7 +795,7 @@ static int __init audit_init(void) ...@@ -795,7 +795,7 @@ static int __init audit_init(void)
printk(KERN_INFO "audit: initializing netlink socket (%s)\n", printk(KERN_INFO "audit: initializing netlink socket (%s)\n",
audit_default ? "enabled" : "disabled"); audit_default ? "enabled" : "disabled");
audit_sock = netlink_kernel_create(NETLINK_AUDIT, 0, audit_receive, audit_sock = netlink_kernel_create(NETLINK_AUDIT, 0, audit_receive,
THIS_MODULE); NULL, THIS_MODULE);
if (!audit_sock) if (!audit_sock)
audit_panic("cannot initialize netlink socket"); audit_panic("cannot initialize netlink socket");
else else
......
...@@ -293,7 +293,7 @@ EXPORT_SYMBOL_GPL(add_uevent_var); ...@@ -293,7 +293,7 @@ EXPORT_SYMBOL_GPL(add_uevent_var);
static int __init kobject_uevent_init(void) static int __init kobject_uevent_init(void)
{ {
uevent_sock = netlink_kernel_create(NETLINK_KOBJECT_UEVENT, 1, NULL, uevent_sock = netlink_kernel_create(NETLINK_KOBJECT_UEVENT, 1, NULL,
THIS_MODULE); NULL, THIS_MODULE);
if (!uevent_sock) { if (!uevent_sock) {
printk(KERN_ERR printk(KERN_ERR
......
...@@ -302,7 +302,7 @@ static int __init ebt_ulog_init(void) ...@@ -302,7 +302,7 @@ static int __init ebt_ulog_init(void)
} }
ebtulognl = netlink_kernel_create(NETLINK_NFLOG, EBT_ULOG_MAXNLGROUPS, ebtulognl = netlink_kernel_create(NETLINK_NFLOG, EBT_ULOG_MAXNLGROUPS,
NULL, THIS_MODULE); NULL, NULL, THIS_MODULE);
if (!ebtulognl) if (!ebtulognl)
ret = -ENOMEM; ret = -ENOMEM;
else if ((ret = ebt_register_watcher(&ulog))) else if ((ret = ebt_register_watcher(&ulog)))
......
...@@ -972,7 +972,7 @@ void __init rtnetlink_init(void) ...@@ -972,7 +972,7 @@ void __init rtnetlink_init(void)
panic("rtnetlink_init: cannot allocate rta_buf\n"); panic("rtnetlink_init: cannot allocate rta_buf\n");
rtnl = netlink_kernel_create(NETLINK_ROUTE, RTNLGRP_MAX, rtnetlink_rcv, rtnl = netlink_kernel_create(NETLINK_ROUTE, RTNLGRP_MAX, rtnetlink_rcv,
THIS_MODULE); NULL, THIS_MODULE);
if (rtnl == NULL) if (rtnl == NULL)
panic("rtnetlink_init: cannot initialize rtnetlink\n"); panic("rtnetlink_init: cannot initialize rtnetlink\n");
netlink_set_nonroot(NETLINK_ROUTE, NL_NONROOT_RECV); netlink_set_nonroot(NETLINK_ROUTE, NL_NONROOT_RECV);
......
...@@ -138,7 +138,7 @@ static int __init dn_rtmsg_init(void) ...@@ -138,7 +138,7 @@ static int __init dn_rtmsg_init(void)
int rv = 0; int rv = 0;
dnrmg = netlink_kernel_create(NETLINK_DNRTMSG, DNRNG_NLGRP_MAX, dnrmg = netlink_kernel_create(NETLINK_DNRTMSG, DNRNG_NLGRP_MAX,
dnrmg_receive_user_sk, THIS_MODULE); dnrmg_receive_user_sk, NULL, THIS_MODULE);
if (dnrmg == NULL) { if (dnrmg == NULL) {
printk(KERN_ERR "dn_rtmsg: Cannot create netlink socket"); printk(KERN_ERR "dn_rtmsg: Cannot create netlink socket");
return -ENOMEM; return -ENOMEM;
......
...@@ -827,7 +827,8 @@ static void nl_fib_input(struct sock *sk, int len) ...@@ -827,7 +827,8 @@ static void nl_fib_input(struct sock *sk, int len)
static void nl_fib_lookup_init(void) static void nl_fib_lookup_init(void)
{ {
netlink_kernel_create(NETLINK_FIB_LOOKUP, 0, nl_fib_input, THIS_MODULE); netlink_kernel_create(NETLINK_FIB_LOOKUP, 0, nl_fib_input, NULL,
THIS_MODULE);
} }
static void fib_disable_ip(struct net_device *dev, int force) static void fib_disable_ip(struct net_device *dev, int force)
......
...@@ -893,7 +893,7 @@ static int __init inet_diag_init(void) ...@@ -893,7 +893,7 @@ static int __init inet_diag_init(void)
goto out; goto out;
idiagnl = netlink_kernel_create(NETLINK_INET_DIAG, 0, inet_diag_rcv, idiagnl = netlink_kernel_create(NETLINK_INET_DIAG, 0, inet_diag_rcv,
THIS_MODULE); NULL, THIS_MODULE);
if (idiagnl == NULL) if (idiagnl == NULL)
goto out_free_table; goto out_free_table;
err = 0; err = 0;
......
...@@ -668,7 +668,7 @@ static int __init ip_queue_init(void) ...@@ -668,7 +668,7 @@ static int __init ip_queue_init(void)
netlink_register_notifier(&ipq_nl_notifier); netlink_register_notifier(&ipq_nl_notifier);
ipqnl = netlink_kernel_create(NETLINK_FIREWALL, 0, ipq_rcv_sk, ipqnl = netlink_kernel_create(NETLINK_FIREWALL, 0, ipq_rcv_sk,
THIS_MODULE); NULL, THIS_MODULE);
if (ipqnl == NULL) { if (ipqnl == NULL) {
printk(KERN_ERR "ip_queue: failed to create netlink socket\n"); printk(KERN_ERR "ip_queue: failed to create netlink socket\n");
goto cleanup_netlink_notifier; goto cleanup_netlink_notifier;
......
...@@ -420,7 +420,7 @@ static int __init ipt_ulog_init(void) ...@@ -420,7 +420,7 @@ static int __init ipt_ulog_init(void)
setup_timer(&ulog_buffers[i].timer, ulog_timer, i); setup_timer(&ulog_buffers[i].timer, ulog_timer, i);
nflognl = netlink_kernel_create(NETLINK_NFLOG, ULOG_MAXNLGROUPS, NULL, nflognl = netlink_kernel_create(NETLINK_NFLOG, ULOG_MAXNLGROUPS, NULL,
THIS_MODULE); NULL, THIS_MODULE);
if (!nflognl) if (!nflognl)
return -ENOMEM; return -ENOMEM;
......
...@@ -657,7 +657,7 @@ static int __init ip6_queue_init(void) ...@@ -657,7 +657,7 @@ static int __init ip6_queue_init(void)
struct proc_dir_entry *proc; struct proc_dir_entry *proc;
netlink_register_notifier(&ipq_nl_notifier); netlink_register_notifier(&ipq_nl_notifier);
ipqnl = netlink_kernel_create(NETLINK_IP6_FW, 0, ipq_rcv_sk, ipqnl = netlink_kernel_create(NETLINK_IP6_FW, 0, ipq_rcv_sk, NULL,
THIS_MODULE); THIS_MODULE);
if (ipqnl == NULL) { if (ipqnl == NULL) {
printk(KERN_ERR "ip6_queue: failed to create netlink socket\n"); printk(KERN_ERR "ip6_queue: failed to create netlink socket\n");
......
...@@ -265,7 +265,7 @@ static int __init nfnetlink_init(void) ...@@ -265,7 +265,7 @@ static int __init nfnetlink_init(void)
printk("Netfilter messages via NETLINK v%s.\n", nfversion); printk("Netfilter messages via NETLINK v%s.\n", nfversion);
nfnl = netlink_kernel_create(NETLINK_NETFILTER, NFNLGRP_MAX, nfnl = netlink_kernel_create(NETLINK_NETFILTER, NFNLGRP_MAX,
nfnetlink_rcv, THIS_MODULE); nfnetlink_rcv, NULL, THIS_MODULE);
if (!nfnl) { if (!nfnl) {
printk(KERN_ERR "cannot initialize nfnetlink!\n"); printk(KERN_ERR "cannot initialize nfnetlink!\n");
return -1; return -1;
......
...@@ -56,6 +56,7 @@ ...@@ -56,6 +56,7 @@
#include <linux/types.h> #include <linux/types.h>
#include <linux/audit.h> #include <linux/audit.h>
#include <linux/selinux.h> #include <linux/selinux.h>
#include <linux/mutex.h>
#include <net/sock.h> #include <net/sock.h>
#include <net/scm.h> #include <net/scm.h>
...@@ -76,7 +77,8 @@ struct netlink_sock { ...@@ -76,7 +77,8 @@ struct netlink_sock {
unsigned long state; unsigned long state;
wait_queue_head_t wait; wait_queue_head_t wait;
struct netlink_callback *cb; struct netlink_callback *cb;
spinlock_t cb_lock; struct mutex *cb_mutex;
struct mutex cb_def_mutex;
void (*data_ready)(struct sock *sk, int bytes); void (*data_ready)(struct sock *sk, int bytes);
struct module *module; struct module *module;
}; };
...@@ -108,6 +110,7 @@ struct netlink_table { ...@@ -108,6 +110,7 @@ struct netlink_table {
unsigned long *listeners; unsigned long *listeners;
unsigned int nl_nonroot; unsigned int nl_nonroot;
unsigned int groups; unsigned int groups;
struct mutex *cb_mutex;
struct module *module; struct module *module;
int registered; int registered;
}; };
...@@ -370,7 +373,8 @@ static struct proto netlink_proto = { ...@@ -370,7 +373,8 @@ static struct proto netlink_proto = {
.obj_size = sizeof(struct netlink_sock), .obj_size = sizeof(struct netlink_sock),
}; };
static int __netlink_create(struct socket *sock, int protocol) static int __netlink_create(struct socket *sock, struct mutex *cb_mutex,
int protocol)
{ {
struct sock *sk; struct sock *sk;
struct netlink_sock *nlk; struct netlink_sock *nlk;
...@@ -384,7 +388,8 @@ static int __netlink_create(struct socket *sock, int protocol) ...@@ -384,7 +388,8 @@ static int __netlink_create(struct socket *sock, int protocol)
sock_init_data(sock, sk); sock_init_data(sock, sk);
nlk = nlk_sk(sk); nlk = nlk_sk(sk);
spin_lock_init(&nlk->cb_lock); nlk->cb_mutex = cb_mutex ? : &nlk->cb_def_mutex;
mutex_init(nlk->cb_mutex);
init_waitqueue_head(&nlk->wait); init_waitqueue_head(&nlk->wait);
sk->sk_destruct = netlink_sock_destruct; sk->sk_destruct = netlink_sock_destruct;
...@@ -395,6 +400,7 @@ static int __netlink_create(struct socket *sock, int protocol) ...@@ -395,6 +400,7 @@ static int __netlink_create(struct socket *sock, int protocol)
static int netlink_create(struct socket *sock, int protocol) static int netlink_create(struct socket *sock, int protocol)
{ {
struct module *module = NULL; struct module *module = NULL;
struct mutex *cb_mutex;
struct netlink_sock *nlk; struct netlink_sock *nlk;
int err = 0; int err = 0;
...@@ -417,9 +423,10 @@ static int netlink_create(struct socket *sock, int protocol) ...@@ -417,9 +423,10 @@ static int netlink_create(struct socket *sock, int protocol)
if (nl_table[protocol].registered && if (nl_table[protocol].registered &&
try_module_get(nl_table[protocol].module)) try_module_get(nl_table[protocol].module))
module = nl_table[protocol].module; module = nl_table[protocol].module;
cb_mutex = nl_table[protocol].cb_mutex;
netlink_unlock_table(); netlink_unlock_table();
if ((err = __netlink_create(sock, protocol)) < 0) if ((err = __netlink_create(sock, cb_mutex, protocol)) < 0)
goto out_module; goto out_module;
nlk = nlk_sk(sock->sk); nlk = nlk_sk(sock->sk);
...@@ -444,14 +451,14 @@ static int netlink_release(struct socket *sock) ...@@ -444,14 +451,14 @@ static int netlink_release(struct socket *sock)
sock_orphan(sk); sock_orphan(sk);
nlk = nlk_sk(sk); nlk = nlk_sk(sk);
spin_lock(&nlk->cb_lock); mutex_lock(nlk->cb_mutex);
if (nlk->cb) { if (nlk->cb) {
if (nlk->cb->done) if (nlk->cb->done)
nlk->cb->done(nlk->cb); nlk->cb->done(nlk->cb);
netlink_destroy_callback(nlk->cb); netlink_destroy_callback(nlk->cb);
nlk->cb = NULL; nlk->cb = NULL;
} }
spin_unlock(&nlk->cb_lock); mutex_unlock(nlk->cb_mutex);
/* OK. Socket is unlinked, and, therefore, /* OK. Socket is unlinked, and, therefore,
no new packets will arrive */ no new packets will arrive */
...@@ -1266,7 +1273,7 @@ static void netlink_data_ready(struct sock *sk, int len) ...@@ -1266,7 +1273,7 @@ static void netlink_data_ready(struct sock *sk, int len)
struct sock * struct sock *
netlink_kernel_create(int unit, unsigned int groups, netlink_kernel_create(int unit, unsigned int groups,
void (*input)(struct sock *sk, int len), void (*input)(struct sock *sk, int len),
struct module *module) struct mutex *cb_mutex, struct module *module)
{ {
struct socket *sock; struct socket *sock;
struct sock *sk; struct sock *sk;
...@@ -1281,7 +1288,7 @@ netlink_kernel_create(int unit, unsigned int groups, ...@@ -1281,7 +1288,7 @@ netlink_kernel_create(int unit, unsigned int groups,
if (sock_create_lite(PF_NETLINK, SOCK_DGRAM, unit, &sock)) if (sock_create_lite(PF_NETLINK, SOCK_DGRAM, unit, &sock))
return NULL; return NULL;
if (__netlink_create(sock, unit) < 0) if (__netlink_create(sock, cb_mutex, unit) < 0)
goto out_sock_release; goto out_sock_release;
if (groups < 32) if (groups < 32)
...@@ -1305,6 +1312,7 @@ netlink_kernel_create(int unit, unsigned int groups, ...@@ -1305,6 +1312,7 @@ netlink_kernel_create(int unit, unsigned int groups,
netlink_table_grab(); netlink_table_grab();
nl_table[unit].groups = groups; nl_table[unit].groups = groups;
nl_table[unit].listeners = listeners; nl_table[unit].listeners = listeners;
nl_table[unit].cb_mutex = cb_mutex;
nl_table[unit].module = module; nl_table[unit].module = module;
nl_table[unit].registered = 1; nl_table[unit].registered = 1;
netlink_table_ungrab(); netlink_table_ungrab();
...@@ -1347,7 +1355,7 @@ static int netlink_dump(struct sock *sk) ...@@ -1347,7 +1355,7 @@ static int netlink_dump(struct sock *sk)
if (!skb) if (!skb)
goto errout; goto errout;
spin_lock(&nlk->cb_lock); mutex_lock(nlk->cb_mutex);
cb = nlk->cb; cb = nlk->cb;
if (cb == NULL) { if (cb == NULL) {
...@@ -1358,7 +1366,7 @@ static int netlink_dump(struct sock *sk) ...@@ -1358,7 +1366,7 @@ static int netlink_dump(struct sock *sk)
len = cb->dump(skb, cb); len = cb->dump(skb, cb);
if (len > 0) { if (len > 0) {
spin_unlock(&nlk->cb_lock); mutex_unlock(nlk->cb_mutex);
skb_queue_tail(&sk->sk_receive_queue, skb); skb_queue_tail(&sk->sk_receive_queue, skb);
sk->sk_data_ready(sk, len); sk->sk_data_ready(sk, len);
return 0; return 0;
...@@ -1376,13 +1384,13 @@ static int netlink_dump(struct sock *sk) ...@@ -1376,13 +1384,13 @@ static int netlink_dump(struct sock *sk)
if (cb->done) if (cb->done)
cb->done(cb); cb->done(cb);
nlk->cb = NULL; nlk->cb = NULL;
spin_unlock(&nlk->cb_lock); mutex_unlock(nlk->cb_mutex);
netlink_destroy_callback(cb); netlink_destroy_callback(cb);
return 0; return 0;
errout_skb: errout_skb:
spin_unlock(&nlk->cb_lock); mutex_unlock(nlk->cb_mutex);
kfree_skb(skb); kfree_skb(skb);
errout: errout:
return err; return err;
...@@ -1414,15 +1422,15 @@ int netlink_dump_start(struct sock *ssk, struct sk_buff *skb, ...@@ -1414,15 +1422,15 @@ int netlink_dump_start(struct sock *ssk, struct sk_buff *skb,
} }
nlk = nlk_sk(sk); nlk = nlk_sk(sk);
/* A dump or destruction is in progress... */ /* A dump or destruction is in progress... */
spin_lock(&nlk->cb_lock); mutex_lock(nlk->cb_mutex);
if (nlk->cb || sock_flag(sk, SOCK_DEAD)) { if (nlk->cb || sock_flag(sk, SOCK_DEAD)) {
spin_unlock(&nlk->cb_lock); mutex_unlock(nlk->cb_mutex);
netlink_destroy_callback(cb); netlink_destroy_callback(cb);
sock_put(sk); sock_put(sk);
return -EBUSY; return -EBUSY;
} }
nlk->cb = cb; nlk->cb = cb;
spin_unlock(&nlk->cb_lock); mutex_unlock(nlk->cb_mutex);
netlink_dump(sk); netlink_dump(sk);
sock_put(sk); sock_put(sk);
......
...@@ -558,7 +558,7 @@ static int __init genl_init(void) ...@@ -558,7 +558,7 @@ static int __init genl_init(void)
netlink_set_nonroot(NETLINK_GENERIC, NL_NONROOT_RECV); netlink_set_nonroot(NETLINK_GENERIC, NL_NONROOT_RECV);
genl_sock = netlink_kernel_create(NETLINK_GENERIC, GENL_MAX_ID, genl_sock = netlink_kernel_create(NETLINK_GENERIC, GENL_MAX_ID,
genl_rcv, THIS_MODULE); genl_rcv, NULL, THIS_MODULE);
if (genl_sock == NULL) if (genl_sock == NULL)
panic("GENL: Cannot initialize generic netlink\n"); panic("GENL: Cannot initialize generic netlink\n");
......
...@@ -2444,7 +2444,7 @@ static int __init xfrm_user_init(void) ...@@ -2444,7 +2444,7 @@ static int __init xfrm_user_init(void)
printk(KERN_INFO "Initializing XFRM netlink socket\n"); printk(KERN_INFO "Initializing XFRM netlink socket\n");
nlsk = netlink_kernel_create(NETLINK_XFRM, XFRMNLGRP_MAX, nlsk = netlink_kernel_create(NETLINK_XFRM, XFRMNLGRP_MAX,
xfrm_netlink_rcv, THIS_MODULE); xfrm_netlink_rcv, NULL, THIS_MODULE);
if (nlsk == NULL) if (nlsk == NULL)
return -ENOMEM; return -ENOMEM;
rcu_assign_pointer(xfrm_nl, nlsk); rcu_assign_pointer(xfrm_nl, nlsk);
......
...@@ -104,7 +104,7 @@ void selnl_notify_policyload(u32 seqno) ...@@ -104,7 +104,7 @@ void selnl_notify_policyload(u32 seqno)
static int __init selnl_init(void) static int __init selnl_init(void)
{ {
selnl = netlink_kernel_create(NETLINK_SELINUX, SELNLGRP_MAX, NULL, selnl = netlink_kernel_create(NETLINK_SELINUX, SELNLGRP_MAX, NULL, NULL,
THIS_MODULE); THIS_MODULE);
if (selnl == NULL) if (selnl == NULL)
panic("SELinux: Cannot create netlink socket."); panic("SELinux: Cannot create netlink socket.");
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment