Commit da12c90e authored by Gao feng's avatar Gao feng Committed by David S. Miller

netlink: Add compare function for netlink_table

As we know, netlink sockets are private resource of
net namespace, they can communicate with each other
only when they in the same net namespace. this works
well until we try to add namespace support for other
subsystems which use netlink.

Don't like ipv4 and route table.., it is not suited to
make these subsytems belong to net namespace, Such as
audit and crypto subsystems,they are more suitable to
user namespace.

So we must have the ability to make the netlink sockets
in same user namespace can communicate with each other.

This patch adds a new function pointer "compare" for
netlink_table, we can decide if the netlink sockets can
communicate with each other through this netlink_table
self-defined compare function.

The behavior isn't changed if we don't provide the compare
function for netlink_table.
Signed-off-by: default avatarGao feng <gaofeng@cn.fujitsu.com>
Acked-by: default avatarSerge E. Hallyn <serge.hallyn@ubuntu.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 8249152c
...@@ -46,6 +46,7 @@ struct netlink_kernel_cfg { ...@@ -46,6 +46,7 @@ struct netlink_kernel_cfg {
void (*input)(struct sk_buff *skb); void (*input)(struct sk_buff *skb);
struct mutex *cb_mutex; struct mutex *cb_mutex;
void (*bind)(int group); void (*bind)(int group);
bool (*compare)(struct net *net, struct sock *sk);
}; };
extern struct sock *__netlink_kernel_create(struct net *net, int unit, extern struct sock *__netlink_kernel_create(struct net *net, int unit,
......
...@@ -858,16 +858,23 @@ netlink_unlock_table(void) ...@@ -858,16 +858,23 @@ netlink_unlock_table(void)
wake_up(&nl_table_wait); wake_up(&nl_table_wait);
} }
static bool netlink_compare(struct net *net, struct sock *sk)
{
return net_eq(sock_net(sk), net);
}
static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid) static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid)
{ {
struct nl_portid_hash *hash = &nl_table[protocol].hash; struct netlink_table *table = &nl_table[protocol];
struct nl_portid_hash *hash = &table->hash;
struct hlist_head *head; struct hlist_head *head;
struct sock *sk; struct sock *sk;
read_lock(&nl_table_lock); read_lock(&nl_table_lock);
head = nl_portid_hashfn(hash, portid); head = nl_portid_hashfn(hash, portid);
sk_for_each(sk, head) { sk_for_each(sk, head) {
if (net_eq(sock_net(sk), net) && (nlk_sk(sk)->portid == portid)) { if (table->compare(net, sk) &&
(nlk_sk(sk)->portid == portid)) {
sock_hold(sk); sock_hold(sk);
goto found; goto found;
} }
...@@ -980,7 +987,8 @@ netlink_update_listeners(struct sock *sk) ...@@ -980,7 +987,8 @@ netlink_update_listeners(struct sock *sk)
static int netlink_insert(struct sock *sk, struct net *net, u32 portid) static int netlink_insert(struct sock *sk, struct net *net, u32 portid)
{ {
struct nl_portid_hash *hash = &nl_table[sk->sk_protocol].hash; struct netlink_table *table = &nl_table[sk->sk_protocol];
struct nl_portid_hash *hash = &table->hash;
struct hlist_head *head; struct hlist_head *head;
int err = -EADDRINUSE; int err = -EADDRINUSE;
struct sock *osk; struct sock *osk;
...@@ -990,7 +998,8 @@ static int netlink_insert(struct sock *sk, struct net *net, u32 portid) ...@@ -990,7 +998,8 @@ static int netlink_insert(struct sock *sk, struct net *net, u32 portid)
head = nl_portid_hashfn(hash, portid); head = nl_portid_hashfn(hash, portid);
len = 0; len = 0;
sk_for_each(osk, head) { sk_for_each(osk, head) {
if (net_eq(sock_net(osk), net) && (nlk_sk(osk)->portid == portid)) if (table->compare(net, osk) &&
(nlk_sk(osk)->portid == portid))
break; break;
len++; len++;
} }
...@@ -1165,6 +1174,7 @@ static int netlink_release(struct socket *sock) ...@@ -1165,6 +1174,7 @@ static int netlink_release(struct socket *sock)
kfree_rcu(old, rcu); kfree_rcu(old, rcu);
nl_table[sk->sk_protocol].module = NULL; nl_table[sk->sk_protocol].module = NULL;
nl_table[sk->sk_protocol].bind = NULL; nl_table[sk->sk_protocol].bind = NULL;
nl_table[sk->sk_protocol].compare = NULL;
nl_table[sk->sk_protocol].flags = 0; nl_table[sk->sk_protocol].flags = 0;
nl_table[sk->sk_protocol].registered = 0; nl_table[sk->sk_protocol].registered = 0;
} }
...@@ -1187,7 +1197,8 @@ static int netlink_autobind(struct socket *sock) ...@@ -1187,7 +1197,8 @@ static int netlink_autobind(struct socket *sock)
{ {
struct sock *sk = sock->sk; struct sock *sk = sock->sk;
struct net *net = sock_net(sk); struct net *net = sock_net(sk);
struct nl_portid_hash *hash = &nl_table[sk->sk_protocol].hash; struct netlink_table *table = &nl_table[sk->sk_protocol];
struct nl_portid_hash *hash = &table->hash;
struct hlist_head *head; struct hlist_head *head;
struct sock *osk; struct sock *osk;
s32 portid = task_tgid_vnr(current); s32 portid = task_tgid_vnr(current);
...@@ -1199,7 +1210,7 @@ static int netlink_autobind(struct socket *sock) ...@@ -1199,7 +1210,7 @@ static int netlink_autobind(struct socket *sock)
netlink_table_grab(); netlink_table_grab();
head = nl_portid_hashfn(hash, portid); head = nl_portid_hashfn(hash, portid);
sk_for_each(osk, head) { sk_for_each(osk, head) {
if (!net_eq(sock_net(osk), net)) if (!table->compare(net, osk))
continue; continue;
if (nlk_sk(osk)->portid == portid) { if (nlk_sk(osk)->portid == portid) {
/* Bind collision, search negative portid values. */ /* Bind collision, search negative portid values. */
...@@ -2315,9 +2326,12 @@ __netlink_kernel_create(struct net *net, int unit, struct module *module, ...@@ -2315,9 +2326,12 @@ __netlink_kernel_create(struct net *net, int unit, struct module *module,
rcu_assign_pointer(nl_table[unit].listeners, listeners); rcu_assign_pointer(nl_table[unit].listeners, listeners);
nl_table[unit].cb_mutex = cb_mutex; nl_table[unit].cb_mutex = cb_mutex;
nl_table[unit].module = module; nl_table[unit].module = module;
nl_table[unit].compare = netlink_compare;
if (cfg) { if (cfg) {
nl_table[unit].bind = cfg->bind; nl_table[unit].bind = cfg->bind;
nl_table[unit].flags = cfg->flags; nl_table[unit].flags = cfg->flags;
if (cfg->compare)
nl_table[unit].compare = cfg->compare;
} }
nl_table[unit].registered = 1; nl_table[unit].registered = 1;
} else { } else {
...@@ -2740,6 +2754,7 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos) ...@@ -2740,6 +2754,7 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)
{ {
struct sock *s; struct sock *s;
struct nl_seq_iter *iter; struct nl_seq_iter *iter;
struct net *net;
int i, j; int i, j;
++*pos; ++*pos;
...@@ -2747,11 +2762,12 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos) ...@@ -2747,11 +2762,12 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)
if (v == SEQ_START_TOKEN) if (v == SEQ_START_TOKEN)
return netlink_seq_socket_idx(seq, 0); return netlink_seq_socket_idx(seq, 0);
net = seq_file_net(seq);
iter = seq->private; iter = seq->private;
s = v; s = v;
do { do {
s = sk_next(s); s = sk_next(s);
} while (s && sock_net(s) != seq_file_net(seq)); } while (s && !nl_table[s->sk_protocol].compare(net, s));
if (s) if (s)
return s; return s;
...@@ -2763,7 +2779,8 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos) ...@@ -2763,7 +2779,8 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)
for (; j <= hash->mask; j++) { for (; j <= hash->mask; j++) {
s = sk_head(&hash->table[j]); s = sk_head(&hash->table[j]);
while (s && sock_net(s) != seq_file_net(seq))
while (s && !nl_table[s->sk_protocol].compare(net, s))
s = sk_next(s); s = sk_next(s);
if (s) { if (s) {
iter->link = i; iter->link = i;
......
...@@ -73,6 +73,7 @@ struct netlink_table { ...@@ -73,6 +73,7 @@ struct netlink_table {
struct mutex *cb_mutex; struct mutex *cb_mutex;
struct module *module; struct module *module;
void (*bind)(int group); void (*bind)(int group);
bool (*compare)(struct net *net, struct sock *sock);
int registered; int registered;
}; };
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment