Commit 4f520900 authored by Richard Guy Briggs's avatar Richard Guy Briggs Committed by David S. Miller

netlink: have netlink per-protocol bind function return an error code.

Have the netlink per-protocol optional bind function return an int error code
rather than void to signal a failure.

This will enable netlink protocols to perform extra checks including
capabilities and permissions verifications when updating memberships in
multicast groups.

In netlink_bind() and netlink_setsockopt() the call to the per-protocol bind
function was moved above the multicast group update to prevent any access to
the multicast socket groups before checking with the per-protocol bind
function.  This will enable the per-protocol bind function to be used to check
permissions which could be denied before making them available, and to avoid
the messy job of undoing the addition should the per-protocol bind function
fail.

The netfilter subsystem seems to be the only one currently using the
per-protocol bind function.
Signed-off-by: default avatarRichard Guy Briggs <rgb@redhat.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent bfe4bc71
...@@ -45,7 +45,8 @@ struct netlink_kernel_cfg { ...@@ -45,7 +45,8 @@ struct netlink_kernel_cfg {
unsigned int flags; unsigned int flags;
void (*input)(struct sk_buff *skb); void (*input)(struct sk_buff *skb);
struct mutex *cb_mutex; struct mutex *cb_mutex;
void (*bind)(int group); int (*bind)(int group);
void (*unbind)(int group);
bool (*compare)(struct net *net, struct sock *sk); bool (*compare)(struct net *net, struct sock *sk);
}; };
......
...@@ -400,7 +400,7 @@ static void nfnetlink_rcv(struct sk_buff *skb) ...@@ -400,7 +400,7 @@ static void nfnetlink_rcv(struct sk_buff *skb)
} }
#ifdef CONFIG_MODULES #ifdef CONFIG_MODULES
static void nfnetlink_bind(int group) static int nfnetlink_bind(int group)
{ {
const struct nfnetlink_subsystem *ss; const struct nfnetlink_subsystem *ss;
int type = nfnl_group2type[group]; int type = nfnl_group2type[group];
...@@ -410,6 +410,7 @@ static void nfnetlink_bind(int group) ...@@ -410,6 +410,7 @@ static void nfnetlink_bind(int group)
rcu_read_unlock(); rcu_read_unlock();
if (!ss) if (!ss)
request_module("nfnetlink-subsys-%d", type); request_module("nfnetlink-subsys-%d", type);
return 0;
} }
#endif #endif
......
...@@ -1206,7 +1206,8 @@ static int netlink_create(struct net *net, struct socket *sock, int protocol, ...@@ -1206,7 +1206,8 @@ static int netlink_create(struct net *net, struct socket *sock, int protocol,
struct module *module = NULL; struct module *module = NULL;
struct mutex *cb_mutex; struct mutex *cb_mutex;
struct netlink_sock *nlk; struct netlink_sock *nlk;
void (*bind)(int group); int (*bind)(int group);
void (*unbind)(int group);
int err = 0; int err = 0;
sock->state = SS_UNCONNECTED; sock->state = SS_UNCONNECTED;
...@@ -1232,6 +1233,7 @@ static int netlink_create(struct net *net, struct socket *sock, int protocol, ...@@ -1232,6 +1233,7 @@ static int netlink_create(struct net *net, struct socket *sock, int protocol,
err = -EPROTONOSUPPORT; err = -EPROTONOSUPPORT;
cb_mutex = nl_table[protocol].cb_mutex; cb_mutex = nl_table[protocol].cb_mutex;
bind = nl_table[protocol].bind; bind = nl_table[protocol].bind;
unbind = nl_table[protocol].unbind;
netlink_unlock_table(); netlink_unlock_table();
if (err < 0) if (err < 0)
...@@ -1248,6 +1250,7 @@ static int netlink_create(struct net *net, struct socket *sock, int protocol, ...@@ -1248,6 +1250,7 @@ static int netlink_create(struct net *net, struct socket *sock, int protocol,
nlk = nlk_sk(sock->sk); nlk = nlk_sk(sock->sk);
nlk->module = module; nlk->module = module;
nlk->netlink_bind = bind; nlk->netlink_bind = bind;
nlk->netlink_unbind = unbind;
out: out:
return err; return err;
...@@ -1301,6 +1304,7 @@ static int netlink_release(struct socket *sock) ...@@ -1301,6 +1304,7 @@ static int netlink_release(struct socket *sock)
kfree_rcu(old, rcu); kfree_rcu(old, rcu);
nl_table[sk->sk_protocol].module = NULL; nl_table[sk->sk_protocol].module = NULL;
nl_table[sk->sk_protocol].bind = NULL; nl_table[sk->sk_protocol].bind = NULL;
nl_table[sk->sk_protocol].unbind = NULL;
nl_table[sk->sk_protocol].flags = 0; nl_table[sk->sk_protocol].flags = 0;
nl_table[sk->sk_protocol].registered = 0; nl_table[sk->sk_protocol].registered = 0;
} }
...@@ -1411,6 +1415,19 @@ static int netlink_realloc_groups(struct sock *sk) ...@@ -1411,6 +1415,19 @@ static int netlink_realloc_groups(struct sock *sk)
return err; return err;
} }
static void netlink_unbind(int group, long unsigned int groups,
struct netlink_sock *nlk)
{
int undo;
if (!nlk->netlink_unbind)
return;
for (undo = 0; undo < group; undo++)
if (test_bit(group, &groups))
nlk->netlink_unbind(undo);
}
static int netlink_bind(struct socket *sock, struct sockaddr *addr, static int netlink_bind(struct socket *sock, struct sockaddr *addr,
int addr_len) int addr_len)
{ {
...@@ -1419,6 +1436,7 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr, ...@@ -1419,6 +1436,7 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr,
struct netlink_sock *nlk = nlk_sk(sk); struct netlink_sock *nlk = nlk_sk(sk);
struct sockaddr_nl *nladdr = (struct sockaddr_nl *)addr; struct sockaddr_nl *nladdr = (struct sockaddr_nl *)addr;
int err; int err;
long unsigned int groups = nladdr->nl_groups;
if (addr_len < sizeof(struct sockaddr_nl)) if (addr_len < sizeof(struct sockaddr_nl))
return -EINVAL; return -EINVAL;
...@@ -1427,7 +1445,7 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr, ...@@ -1427,7 +1445,7 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr,
return -EINVAL; return -EINVAL;
/* Only superuser is allowed to listen multicasts */ /* Only superuser is allowed to listen multicasts */
if (nladdr->nl_groups) { if (groups) {
if (!netlink_capable(sock, NL_CFG_F_NONROOT_RECV)) if (!netlink_capable(sock, NL_CFG_F_NONROOT_RECV))
return -EPERM; return -EPERM;
err = netlink_realloc_groups(sk); err = netlink_realloc_groups(sk);
...@@ -1435,37 +1453,45 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr, ...@@ -1435,37 +1453,45 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr,
return err; return err;
} }
if (nlk->portid) { if (nlk->portid)
if (nladdr->nl_pid != nlk->portid) if (nladdr->nl_pid != nlk->portid)
return -EINVAL; return -EINVAL;
} else {
if (nlk->netlink_bind && groups) {
int group;
for (group = 0; group < nlk->ngroups; group++) {
if (!test_bit(group, &groups))
continue;
err = nlk->netlink_bind(group);
if (!err)
continue;
netlink_unbind(group, groups, nlk);
return err;
}
}
if (!nlk->portid) {
err = nladdr->nl_pid ? err = nladdr->nl_pid ?
netlink_insert(sk, net, nladdr->nl_pid) : netlink_insert(sk, net, nladdr->nl_pid) :
netlink_autobind(sock); netlink_autobind(sock);
if (err) if (err) {
netlink_unbind(nlk->ngroups - 1, groups, nlk);
return err; return err;
} }
}
if (!nladdr->nl_groups && (nlk->groups == NULL || !(u32)nlk->groups[0])) if (!groups && (nlk->groups == NULL || !(u32)nlk->groups[0]))
return 0; return 0;
netlink_table_grab(); netlink_table_grab();
netlink_update_subscriptions(sk, nlk->subscriptions + netlink_update_subscriptions(sk, nlk->subscriptions +
hweight32(nladdr->nl_groups) - hweight32(groups) -
hweight32(nlk->groups[0])); hweight32(nlk->groups[0]));
nlk->groups[0] = (nlk->groups[0] & ~0xffffffffUL) | nladdr->nl_groups; nlk->groups[0] = (nlk->groups[0] & ~0xffffffffUL) | groups;
netlink_update_listeners(sk); netlink_update_listeners(sk);
netlink_table_ungrab(); netlink_table_ungrab();
if (nlk->netlink_bind && nlk->groups[0]) {
int i;
for (i = 0; i < nlk->ngroups; i++) {
if (test_bit(i, nlk->groups))
nlk->netlink_bind(i);
}
}
return 0; return 0;
} }
...@@ -2103,14 +2129,16 @@ static int netlink_setsockopt(struct socket *sock, int level, int optname, ...@@ -2103,14 +2129,16 @@ static int netlink_setsockopt(struct socket *sock, int level, int optname,
return err; return err;
if (!val || val - 1 >= nlk->ngroups) if (!val || val - 1 >= nlk->ngroups)
return -EINVAL; return -EINVAL;
if (nlk->netlink_bind) {
err = nlk->netlink_bind(val);
if (err)
return err;
}
netlink_table_grab(); netlink_table_grab();
netlink_update_socket_mc(nlk, val, netlink_update_socket_mc(nlk, val,
optname == NETLINK_ADD_MEMBERSHIP); optname == NETLINK_ADD_MEMBERSHIP);
netlink_table_ungrab(); netlink_table_ungrab();
if (nlk->netlink_bind)
nlk->netlink_bind(val);
err = 0; err = 0;
break; break;
} }
......
...@@ -38,7 +38,8 @@ struct netlink_sock { ...@@ -38,7 +38,8 @@ struct netlink_sock {
struct mutex *cb_mutex; struct mutex *cb_mutex;
struct mutex cb_def_mutex; struct mutex cb_def_mutex;
void (*netlink_rcv)(struct sk_buff *skb); void (*netlink_rcv)(struct sk_buff *skb);
void (*netlink_bind)(int group); int (*netlink_bind)(int group);
void (*netlink_unbind)(int group);
struct module *module; struct module *module;
#ifdef CONFIG_NETLINK_MMAP #ifdef CONFIG_NETLINK_MMAP
struct mutex pg_vec_lock; struct mutex pg_vec_lock;
...@@ -74,7 +75,8 @@ struct netlink_table { ...@@ -74,7 +75,8 @@ struct netlink_table {
unsigned int groups; unsigned int groups;
struct mutex *cb_mutex; struct mutex *cb_mutex;
struct module *module; struct module *module;
void (*bind)(int group); int (*bind)(int group);
void (*unbind)(int group);
bool (*compare)(struct net *net, struct sock *sock); bool (*compare)(struct net *net, struct sock *sock);
int registered; int registered;
}; };
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment