Commit 023e2cfa authored by Johannes Berg's avatar Johannes Berg Committed by David S. Miller

netlink/genetlink: pass network namespace to bind/unbind

Netlink families can exist in multiple namespaces, and for the most
part multicast subscriptions are per network namespace. Thus it only
makes sense to have bind/unbind notifications per network namespace.

To achieve this, pass the network namespace of a given client socket
to the bind/unbind functions.

Also do this in generic netlink, and there also make sure that any
bind for multicast groups that only exist in init_net is rejected.
This isn't really a problem if it is accepted since a client in a
different namespace will never receive any notifications from such
a group, but it can confuse the family if not rejected (it's also
possible to silently (without telling the family) accept it, but it
would also have to be ignored on unbind so families that take any
kind of action on bind/unbind won't do unnecessary work for invalid
clients like that.
Signed-off-by: default avatarJohannes Berg <johannes.berg@intel.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent eb69c5bf
...@@ -46,8 +46,8 @@ struct netlink_kernel_cfg { ...@@ -46,8 +46,8 @@ struct netlink_kernel_cfg {
unsigned int flags; unsigned int flags;
void (*input)(struct sk_buff *skb); void (*input)(struct sk_buff *skb);
struct mutex *cb_mutex; struct mutex *cb_mutex;
int (*bind)(int group); int (*bind)(struct net *net, int group);
void (*unbind)(int group); void (*unbind)(struct net *net, int group);
bool (*compare)(struct net *net, struct sock *sk); bool (*compare)(struct net *net, struct sock *sk);
}; };
......
...@@ -56,8 +56,8 @@ struct genl_family { ...@@ -56,8 +56,8 @@ struct genl_family {
void (*post_doit)(const struct genl_ops *ops, void (*post_doit)(const struct genl_ops *ops,
struct sk_buff *skb, struct sk_buff *skb,
struct genl_info *info); struct genl_info *info);
int (*mcast_bind)(int group); int (*mcast_bind)(struct net *net, int group);
void (*mcast_unbind)(int group); void (*mcast_unbind)(struct net *net, int group);
struct nlattr ** attrbuf; /* private */ struct nlattr ** attrbuf; /* private */
const struct genl_ops * ops; /* private */ const struct genl_ops * ops; /* private */
const struct genl_multicast_group *mcgrps; /* private */ const struct genl_multicast_group *mcgrps; /* private */
......
...@@ -1100,7 +1100,7 @@ static void audit_receive(struct sk_buff *skb) ...@@ -1100,7 +1100,7 @@ static void audit_receive(struct sk_buff *skb)
} }
/* Run custom bind function on netlink socket group connect or bind requests. */ /* Run custom bind function on netlink socket group connect or bind requests. */
static int audit_bind(int group) static int audit_bind(struct net *net, int group)
{ {
if (!capable(CAP_AUDIT_READ)) if (!capable(CAP_AUDIT_READ))
return -EPERM; return -EPERM;
......
...@@ -463,7 +463,7 @@ static void nfnetlink_rcv(struct sk_buff *skb) ...@@ -463,7 +463,7 @@ static void nfnetlink_rcv(struct sk_buff *skb)
} }
#ifdef CONFIG_MODULES #ifdef CONFIG_MODULES
static int nfnetlink_bind(int group) static int nfnetlink_bind(struct net *net, int group)
{ {
const struct nfnetlink_subsystem *ss; const struct nfnetlink_subsystem *ss;
int type; int type;
......
...@@ -1141,8 +1141,8 @@ static int netlink_create(struct net *net, struct socket *sock, int protocol, ...@@ -1141,8 +1141,8 @@ static int netlink_create(struct net *net, struct socket *sock, int protocol,
struct module *module = NULL; struct module *module = NULL;
struct mutex *cb_mutex; struct mutex *cb_mutex;
struct netlink_sock *nlk; struct netlink_sock *nlk;
int (*bind)(int group); int (*bind)(struct net *net, int group);
void (*unbind)(int group); void (*unbind)(struct net *net, int group);
int err = 0; int err = 0;
sock->state = SS_UNCONNECTED; sock->state = SS_UNCONNECTED;
...@@ -1251,7 +1251,7 @@ static int netlink_release(struct socket *sock) ...@@ -1251,7 +1251,7 @@ static int netlink_release(struct socket *sock)
for (i = 0; i < nlk->ngroups; i++) for (i = 0; i < nlk->ngroups; i++)
if (test_bit(i, nlk->groups)) if (test_bit(i, nlk->groups))
nlk->netlink_unbind(i + 1); nlk->netlink_unbind(sock_net(sk), i + 1);
} }
kfree(nlk->groups); kfree(nlk->groups);
nlk->groups = NULL; nlk->groups = NULL;
...@@ -1418,8 +1418,9 @@ static int netlink_realloc_groups(struct sock *sk) ...@@ -1418,8 +1418,9 @@ static int netlink_realloc_groups(struct sock *sk)
} }
static void netlink_undo_bind(int group, long unsigned int groups, static void netlink_undo_bind(int group, long unsigned int groups,
struct netlink_sock *nlk) struct sock *sk)
{ {
struct netlink_sock *nlk = nlk_sk(sk);
int undo; int undo;
if (!nlk->netlink_unbind) if (!nlk->netlink_unbind)
...@@ -1427,7 +1428,7 @@ static void netlink_undo_bind(int group, long unsigned int groups, ...@@ -1427,7 +1428,7 @@ static void netlink_undo_bind(int group, long unsigned int groups,
for (undo = 0; undo < group; undo++) for (undo = 0; undo < group; undo++)
if (test_bit(undo, &groups)) if (test_bit(undo, &groups))
nlk->netlink_unbind(undo); nlk->netlink_unbind(sock_net(sk), undo);
} }
static int netlink_bind(struct socket *sock, struct sockaddr *addr, static int netlink_bind(struct socket *sock, struct sockaddr *addr,
...@@ -1465,10 +1466,10 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr, ...@@ -1465,10 +1466,10 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr,
for (group = 0; group < nlk->ngroups; group++) { for (group = 0; group < nlk->ngroups; group++) {
if (!test_bit(group, &groups)) if (!test_bit(group, &groups))
continue; continue;
err = nlk->netlink_bind(group); err = nlk->netlink_bind(net, group);
if (!err) if (!err)
continue; continue;
netlink_undo_bind(group, groups, nlk); netlink_undo_bind(group, groups, sk);
return err; return err;
} }
} }
...@@ -1478,7 +1479,7 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr, ...@@ -1478,7 +1479,7 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr,
netlink_insert(sk, net, nladdr->nl_pid) : netlink_insert(sk, net, nladdr->nl_pid) :
netlink_autobind(sock); netlink_autobind(sock);
if (err) { if (err) {
netlink_undo_bind(nlk->ngroups, groups, nlk); netlink_undo_bind(nlk->ngroups, groups, sk);
return err; return err;
} }
} }
...@@ -2129,7 +2130,7 @@ static int netlink_setsockopt(struct socket *sock, int level, int optname, ...@@ -2129,7 +2130,7 @@ static int netlink_setsockopt(struct socket *sock, int level, int optname,
if (!val || val - 1 >= nlk->ngroups) if (!val || val - 1 >= nlk->ngroups)
return -EINVAL; return -EINVAL;
if (optname == NETLINK_ADD_MEMBERSHIP && nlk->netlink_bind) { if (optname == NETLINK_ADD_MEMBERSHIP && nlk->netlink_bind) {
err = nlk->netlink_bind(val); err = nlk->netlink_bind(sock_net(sk), val);
if (err) if (err)
return err; return err;
} }
...@@ -2138,7 +2139,7 @@ static int netlink_setsockopt(struct socket *sock, int level, int optname, ...@@ -2138,7 +2139,7 @@ static int netlink_setsockopt(struct socket *sock, int level, int optname,
optname == NETLINK_ADD_MEMBERSHIP); optname == NETLINK_ADD_MEMBERSHIP);
netlink_table_ungrab(); netlink_table_ungrab();
if (optname == NETLINK_DROP_MEMBERSHIP && nlk->netlink_unbind) if (optname == NETLINK_DROP_MEMBERSHIP && nlk->netlink_unbind)
nlk->netlink_unbind(val); nlk->netlink_unbind(sock_net(sk), val);
err = 0; err = 0;
break; break;
......
...@@ -39,8 +39,8 @@ struct netlink_sock { ...@@ -39,8 +39,8 @@ struct netlink_sock {
struct mutex *cb_mutex; struct mutex *cb_mutex;
struct mutex cb_def_mutex; struct mutex cb_def_mutex;
void (*netlink_rcv)(struct sk_buff *skb); void (*netlink_rcv)(struct sk_buff *skb);
int (*netlink_bind)(int group); int (*netlink_bind)(struct net *net, int group);
void (*netlink_unbind)(int group); void (*netlink_unbind)(struct net *net, int group);
struct module *module; struct module *module;
#ifdef CONFIG_NETLINK_MMAP #ifdef CONFIG_NETLINK_MMAP
struct mutex pg_vec_lock; struct mutex pg_vec_lock;
...@@ -65,8 +65,8 @@ struct netlink_table { ...@@ -65,8 +65,8 @@ struct netlink_table {
unsigned int groups; unsigned int groups;
struct mutex *cb_mutex; struct mutex *cb_mutex;
struct module *module; struct module *module;
int (*bind)(int group); int (*bind)(struct net *net, int group);
void (*unbind)(int group); void (*unbind)(struct net *net, int group);
bool (*compare)(struct net *net, struct sock *sock); bool (*compare)(struct net *net, struct sock *sock);
int registered; int registered;
}; };
......
...@@ -983,7 +983,7 @@ static struct genl_multicast_group genl_ctrl_groups[] = { ...@@ -983,7 +983,7 @@ static struct genl_multicast_group genl_ctrl_groups[] = {
{ .name = "notify", }, { .name = "notify", },
}; };
static int genl_bind(int group) static int genl_bind(struct net *net, int group)
{ {
int i, err; int i, err;
bool found = false; bool found = false;
...@@ -997,8 +997,10 @@ static int genl_bind(int group) ...@@ -997,8 +997,10 @@ static int genl_bind(int group)
group < f->mcgrp_offset + f->n_mcgrps) { group < f->mcgrp_offset + f->n_mcgrps) {
int fam_grp = group - f->mcgrp_offset; int fam_grp = group - f->mcgrp_offset;
if (f->mcast_bind) if (!f->netnsok && net != &init_net)
err = f->mcast_bind(fam_grp); err = -ENOENT;
else if (f->mcast_bind)
err = f->mcast_bind(net, fam_grp);
else else
err = 0; err = 0;
found = true; found = true;
...@@ -1014,7 +1016,7 @@ static int genl_bind(int group) ...@@ -1014,7 +1016,7 @@ static int genl_bind(int group)
return err; return err;
} }
static void genl_unbind(int group) static void genl_unbind(struct net *net, int group)
{ {
int i; int i;
bool found = false; bool found = false;
...@@ -1029,7 +1031,7 @@ static void genl_unbind(int group) ...@@ -1029,7 +1031,7 @@ static void genl_unbind(int group)
int fam_grp = group - f->mcgrp_offset; int fam_grp = group - f->mcgrp_offset;
if (f->mcast_unbind) if (f->mcast_unbind)
f->mcast_unbind(fam_grp); f->mcast_unbind(net, fam_grp);
found = true; found = true;
break; break;
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment