Commit 77247bbb authored by Patrick McHardy's avatar Patrick McHardy Committed by David S. Miller

[NETLINK]: Fix module refcounting problems

Use-after-free: the struct proto_ops containing the module pointer
is freed when a socket with pid=0 is released, which besides for kernel
sockets is true for all unbound sockets.

Module refcount leak: when the kernel socket is closed before all user
sockets have been closed the proto_ops struct for this family is
replaced by the generic one and the module refcount can't be dropped.

The second problem can't be solved cleanly using module refcounting in the
generic socket code, so this patch adds explicit refcounting to
netlink_create/netlink_release.
Signed-off-by: default avatarPatrick McHardy <kaber@trash.net>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent db080529
...@@ -73,8 +73,12 @@ struct netlink_sock { ...@@ -73,8 +73,12 @@ struct netlink_sock {
struct netlink_callback *cb; struct netlink_callback *cb;
spinlock_t cb_lock; spinlock_t cb_lock;
void (*data_ready)(struct sock *sk, int bytes); void (*data_ready)(struct sock *sk, int bytes);
struct module *module;
u32 flags;
}; };
#define NETLINK_KERNEL_SOCKET 0x1
static inline struct netlink_sock *nlk_sk(struct sock *sk) static inline struct netlink_sock *nlk_sk(struct sock *sk)
{ {
return (struct netlink_sock *)sk; return (struct netlink_sock *)sk;
...@@ -97,7 +101,7 @@ struct netlink_table { ...@@ -97,7 +101,7 @@ struct netlink_table {
struct nl_pid_hash hash; struct nl_pid_hash hash;
struct hlist_head mc_list; struct hlist_head mc_list;
unsigned int nl_nonroot; unsigned int nl_nonroot;
struct proto_ops *p_ops; struct module *module;
}; };
static struct netlink_table *nl_table; static struct netlink_table *nl_table;
...@@ -338,6 +342,7 @@ static int netlink_create(struct socket *sock, int protocol) ...@@ -338,6 +342,7 @@ static int netlink_create(struct socket *sock, int protocol)
{ {
struct sock *sk; struct sock *sk;
struct netlink_sock *nlk; struct netlink_sock *nlk;
struct module *module;
sock->state = SS_UNCONNECTED; sock->state = SS_UNCONNECTED;
...@@ -347,30 +352,36 @@ static int netlink_create(struct socket *sock, int protocol) ...@@ -347,30 +352,36 @@ static int netlink_create(struct socket *sock, int protocol)
if (protocol<0 || protocol >= MAX_LINKS) if (protocol<0 || protocol >= MAX_LINKS)
return -EPROTONOSUPPORT; return -EPROTONOSUPPORT;
netlink_table_grab(); netlink_lock_table();
if (!nl_table[protocol].hash.entries) { if (!nl_table[protocol].hash.entries) {
#ifdef CONFIG_KMOD #ifdef CONFIG_KMOD
/* We do 'best effort'. If we find a matching module, /* We do 'best effort'. If we find a matching module,
* it is loaded. If not, we don't return an error to * it is loaded. If not, we don't return an error to
* allow pure userspace<->userspace communication. -HW * allow pure userspace<->userspace communication. -HW
*/ */
netlink_table_ungrab(); netlink_unlock_table();
request_module("net-pf-%d-proto-%d", PF_NETLINK, protocol); request_module("net-pf-%d-proto-%d", PF_NETLINK, protocol);
netlink_table_grab(); netlink_lock_table();
#endif #endif
} }
netlink_table_ungrab(); module = nl_table[protocol].module;
if (!try_module_get(module))
module = NULL;
netlink_unlock_table();
sock->ops = nl_table[protocol].p_ops; sock->ops = &netlink_ops;
sk = sk_alloc(PF_NETLINK, GFP_KERNEL, &netlink_proto, 1); sk = sk_alloc(PF_NETLINK, GFP_KERNEL, &netlink_proto, 1);
if (!sk) if (!sk) {
module_put(module);
return -ENOMEM; return -ENOMEM;
}
sock_init_data(sock, sk); sock_init_data(sock, sk);
nlk = nlk_sk(sk); nlk = nlk_sk(sk);
nlk->module = module;
spin_lock_init(&nlk->cb_lock); spin_lock_init(&nlk->cb_lock);
init_waitqueue_head(&nlk->wait); init_waitqueue_head(&nlk->wait);
sk->sk_destruct = netlink_sock_destruct; sk->sk_destruct = netlink_sock_destruct;
...@@ -415,19 +426,12 @@ static int netlink_release(struct socket *sock) ...@@ -415,19 +426,12 @@ static int netlink_release(struct socket *sock)
notifier_call_chain(&netlink_chain, NETLINK_URELEASE, &n); notifier_call_chain(&netlink_chain, NETLINK_URELEASE, &n);
} }
/* When this is a kernel socket, we need to remove the owner pointer, if (nlk->module)
* since we don't know whether the module will be dying at any given module_put(nlk->module);
* point - HW
*/
if (!nlk->pid) {
struct proto_ops *p_tmp;
if (nlk->flags & NETLINK_KERNEL_SOCKET) {
netlink_table_grab(); netlink_table_grab();
p_tmp = nl_table[sk->sk_protocol].p_ops; nl_table[sk->sk_protocol].module = NULL;
if (p_tmp != &netlink_ops) {
nl_table[sk->sk_protocol].p_ops = &netlink_ops;
kfree(p_tmp);
}
netlink_table_ungrab(); netlink_table_ungrab();
} }
...@@ -1060,9 +1064,9 @@ static void netlink_data_ready(struct sock *sk, int len) ...@@ -1060,9 +1064,9 @@ static void netlink_data_ready(struct sock *sk, int len)
struct sock * struct sock *
netlink_kernel_create(int unit, void (*input)(struct sock *sk, int len), struct module *module) netlink_kernel_create(int unit, void (*input)(struct sock *sk, int len), struct module *module)
{ {
struct proto_ops *p_ops;
struct socket *sock; struct socket *sock;
struct sock *sk; struct sock *sk;
struct netlink_sock *nlk;
if (!nl_table) if (!nl_table)
return NULL; return NULL;
...@@ -1070,64 +1074,32 @@ netlink_kernel_create(int unit, void (*input)(struct sock *sk, int len), struct ...@@ -1070,64 +1074,32 @@ netlink_kernel_create(int unit, void (*input)(struct sock *sk, int len), struct
if (unit<0 || unit>=MAX_LINKS) if (unit<0 || unit>=MAX_LINKS)
return NULL; return NULL;
/* Do a quick check, to make us not go down to netlink_insert()
* if protocol already has kernel socket.
*/
sk = netlink_lookup(unit, 0);
if (unlikely(sk)) {
sock_put(sk);
return NULL;
}
if (sock_create_lite(PF_NETLINK, SOCK_DGRAM, unit, &sock)) if (sock_create_lite(PF_NETLINK, SOCK_DGRAM, unit, &sock))
return NULL; return NULL;
sk = NULL; if (netlink_create(sock, unit) < 0)
if (module) {
/* Every registering protocol implemented in a module needs
* it's own p_ops, since the socket code cannot deal with
* module refcounting otherwise. -HW
*/
p_ops = kmalloc(sizeof(*p_ops), GFP_KERNEL);
if (!p_ops)
goto out_sock_release; goto out_sock_release;
memcpy(p_ops, &netlink_ops, sizeof(*p_ops));
p_ops->owner = module;
} else
p_ops = &netlink_ops;
netlink_table_grab();
nl_table[unit].p_ops = p_ops;
netlink_table_ungrab();
if (netlink_create(sock, unit) < 0) {
sk = NULL;
goto out_kfree_p_ops;
}
sk = sock->sk; sk = sock->sk;
sk->sk_data_ready = netlink_data_ready; sk->sk_data_ready = netlink_data_ready;
if (input) if (input)
nlk_sk(sk)->data_ready = input; nlk_sk(sk)->data_ready = input;
if (netlink_insert(sk, 0)) { if (netlink_insert(sk, 0))
sk = NULL; goto out_sock_release;
goto out_kfree_p_ops;
}
return sk; nlk = nlk_sk(sk);
nlk->flags |= NETLINK_KERNEL_SOCKET;
out_kfree_p_ops:
netlink_table_grab(); netlink_table_grab();
if (nl_table[unit].p_ops != &netlink_ops) { nl_table[unit].module = module;
kfree(nl_table[unit].p_ops);
nl_table[unit].p_ops = &netlink_ops;
}
netlink_table_ungrab(); netlink_table_ungrab();
return sk;
out_sock_release: out_sock_release:
sock_release(sock); sock_release(sock);
return sk; return NULL;
} }
void netlink_set_nonroot(int protocol, unsigned int flags) void netlink_set_nonroot(int protocol, unsigned int flags)
...@@ -1490,8 +1462,6 @@ static int __init netlink_proto_init(void) ...@@ -1490,8 +1462,6 @@ static int __init netlink_proto_init(void)
for (i = 0; i < MAX_LINKS; i++) { for (i = 0; i < MAX_LINKS; i++) {
struct nl_pid_hash *hash = &nl_table[i].hash; struct nl_pid_hash *hash = &nl_table[i].hash;
nl_table[i].p_ops = &netlink_ops;
hash->table = nl_pid_hash_alloc(1 * sizeof(*hash->table)); hash->table = nl_pid_hash_alloc(1 * sizeof(*hash->table));
if (!hash->table) { if (!hash->table) {
while (i-- > 0) while (i-- > 0)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment