Commit baf606d9 authored by Marcelo Ricardo Leitner's avatar Marcelo Ricardo Leitner Committed by David S. Miller

ipv4,ipv6: grab rtnl before locking the socket

There are some setsockopt operations in ipv4 and ipv6 that are grabbing
rtnl after having grabbed the socket lock. Yet this makes it impossible
to do operations that have to lock the socket when already within a rtnl
protected scope, like ndo dev_open and dev_stop.

We normally take coarse grained locks first but setsockopt inverted that.

So this patch invert the lock logic for these operations and makes
setsockopt grab rtnl if it will be needed prior to grabbing socket lock.
Signed-off-by: default avatarMarcelo Ricardo Leitner <marcelo.leitner@gmail.com>
Acked-by: default avatarHannes Frederic Sowa <hannes@stressinduktion.org>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent fdf9ef89
...@@ -536,12 +536,25 @@ int ip_recv_error(struct sock *sk, struct msghdr *msg, int len, int *addr_len) ...@@ -536,12 +536,25 @@ int ip_recv_error(struct sock *sk, struct msghdr *msg, int len, int *addr_len)
* Socket option code for IP. This is the end of the line after any * Socket option code for IP. This is the end of the line after any
* TCP,UDP etc options on an IP socket. * TCP,UDP etc options on an IP socket.
*/ */
static bool setsockopt_needs_rtnl(int optname)
{
switch (optname) {
case IP_ADD_MEMBERSHIP:
case IP_ADD_SOURCE_MEMBERSHIP:
case IP_DROP_MEMBERSHIP:
case MCAST_JOIN_GROUP:
case MCAST_LEAVE_GROUP:
return true;
}
return false;
}
static int do_ip_setsockopt(struct sock *sk, int level, static int do_ip_setsockopt(struct sock *sk, int level,
int optname, char __user *optval, unsigned int optlen) int optname, char __user *optval, unsigned int optlen)
{ {
struct inet_sock *inet = inet_sk(sk); struct inet_sock *inet = inet_sk(sk);
int val = 0, err; int val = 0, err;
bool needs_rtnl = setsockopt_needs_rtnl(optname);
switch (optname) { switch (optname) {
case IP_PKTINFO: case IP_PKTINFO:
...@@ -584,6 +597,8 @@ static int do_ip_setsockopt(struct sock *sk, int level, ...@@ -584,6 +597,8 @@ static int do_ip_setsockopt(struct sock *sk, int level,
return ip_mroute_setsockopt(sk, optname, optval, optlen); return ip_mroute_setsockopt(sk, optname, optval, optlen);
err = 0; err = 0;
if (needs_rtnl)
rtnl_lock();
lock_sock(sk); lock_sock(sk);
switch (optname) { switch (optname) {
...@@ -846,9 +861,9 @@ static int do_ip_setsockopt(struct sock *sk, int level, ...@@ -846,9 +861,9 @@ static int do_ip_setsockopt(struct sock *sk, int level,
} }
if (optname == IP_ADD_MEMBERSHIP) if (optname == IP_ADD_MEMBERSHIP)
err = ip_mc_join_group(sk, &mreq); err = __ip_mc_join_group(sk, &mreq);
else else
err = ip_mc_leave_group(sk, &mreq); err = __ip_mc_leave_group(sk, &mreq);
break; break;
} }
case IP_MSFILTER: case IP_MSFILTER:
...@@ -913,7 +928,7 @@ static int do_ip_setsockopt(struct sock *sk, int level, ...@@ -913,7 +928,7 @@ static int do_ip_setsockopt(struct sock *sk, int level,
mreq.imr_multiaddr.s_addr = mreqs.imr_multiaddr; mreq.imr_multiaddr.s_addr = mreqs.imr_multiaddr;
mreq.imr_address.s_addr = mreqs.imr_interface; mreq.imr_address.s_addr = mreqs.imr_interface;
mreq.imr_ifindex = 0; mreq.imr_ifindex = 0;
err = ip_mc_join_group(sk, &mreq); err = __ip_mc_join_group(sk, &mreq);
if (err && err != -EADDRINUSE) if (err && err != -EADDRINUSE)
break; break;
omode = MCAST_INCLUDE; omode = MCAST_INCLUDE;
...@@ -945,9 +960,9 @@ static int do_ip_setsockopt(struct sock *sk, int level, ...@@ -945,9 +960,9 @@ static int do_ip_setsockopt(struct sock *sk, int level,
mreq.imr_ifindex = greq.gr_interface; mreq.imr_ifindex = greq.gr_interface;
if (optname == MCAST_JOIN_GROUP) if (optname == MCAST_JOIN_GROUP)
err = ip_mc_join_group(sk, &mreq); err = __ip_mc_join_group(sk, &mreq);
else else
err = ip_mc_leave_group(sk, &mreq); err = __ip_mc_leave_group(sk, &mreq);
break; break;
} }
case MCAST_JOIN_SOURCE_GROUP: case MCAST_JOIN_SOURCE_GROUP:
...@@ -990,7 +1005,7 @@ static int do_ip_setsockopt(struct sock *sk, int level, ...@@ -990,7 +1005,7 @@ static int do_ip_setsockopt(struct sock *sk, int level,
mreq.imr_multiaddr = psin->sin_addr; mreq.imr_multiaddr = psin->sin_addr;
mreq.imr_address.s_addr = 0; mreq.imr_address.s_addr = 0;
mreq.imr_ifindex = greqs.gsr_interface; mreq.imr_ifindex = greqs.gsr_interface;
err = ip_mc_join_group(sk, &mreq); err = __ip_mc_join_group(sk, &mreq);
if (err && err != -EADDRINUSE) if (err && err != -EADDRINUSE)
break; break;
greqs.gsr_interface = mreq.imr_ifindex; greqs.gsr_interface = mreq.imr_ifindex;
...@@ -1118,10 +1133,14 @@ static int do_ip_setsockopt(struct sock *sk, int level, ...@@ -1118,10 +1133,14 @@ static int do_ip_setsockopt(struct sock *sk, int level,
break; break;
} }
release_sock(sk); release_sock(sk);
if (needs_rtnl)
rtnl_unlock();
return err; return err;
e_inval: e_inval:
release_sock(sk); release_sock(sk);
if (needs_rtnl)
rtnl_unlock();
return -EINVAL; return -EINVAL;
} }
......
...@@ -117,6 +117,18 @@ struct ipv6_txoptions *ipv6_update_options(struct sock *sk, ...@@ -117,6 +117,18 @@ struct ipv6_txoptions *ipv6_update_options(struct sock *sk,
return opt; return opt;
} }
static bool setsockopt_needs_rtnl(int optname)
{
switch (optname) {
case IPV6_ADD_MEMBERSHIP:
case IPV6_DROP_MEMBERSHIP:
case MCAST_JOIN_GROUP:
case MCAST_LEAVE_GROUP:
return true;
}
return false;
}
static int do_ipv6_setsockopt(struct sock *sk, int level, int optname, static int do_ipv6_setsockopt(struct sock *sk, int level, int optname,
char __user *optval, unsigned int optlen) char __user *optval, unsigned int optlen)
{ {
...@@ -124,6 +136,7 @@ static int do_ipv6_setsockopt(struct sock *sk, int level, int optname, ...@@ -124,6 +136,7 @@ static int do_ipv6_setsockopt(struct sock *sk, int level, int optname,
struct net *net = sock_net(sk); struct net *net = sock_net(sk);
int val, valbool; int val, valbool;
int retv = -ENOPROTOOPT; int retv = -ENOPROTOOPT;
bool needs_rtnl = setsockopt_needs_rtnl(optname);
if (optval == NULL) if (optval == NULL)
val = 0; val = 0;
...@@ -140,6 +153,8 @@ static int do_ipv6_setsockopt(struct sock *sk, int level, int optname, ...@@ -140,6 +153,8 @@ static int do_ipv6_setsockopt(struct sock *sk, int level, int optname,
if (ip6_mroute_opt(optname)) if (ip6_mroute_opt(optname))
return ip6_mroute_setsockopt(sk, optname, optval, optlen); return ip6_mroute_setsockopt(sk, optname, optval, optlen);
if (needs_rtnl)
rtnl_lock();
lock_sock(sk); lock_sock(sk);
switch (optname) { switch (optname) {
...@@ -582,9 +597,9 @@ static int do_ipv6_setsockopt(struct sock *sk, int level, int optname, ...@@ -582,9 +597,9 @@ static int do_ipv6_setsockopt(struct sock *sk, int level, int optname,
break; break;
if (optname == IPV6_ADD_MEMBERSHIP) if (optname == IPV6_ADD_MEMBERSHIP)
retv = ipv6_sock_mc_join(sk, mreq.ipv6mr_ifindex, &mreq.ipv6mr_multiaddr); retv = __ipv6_sock_mc_join(sk, mreq.ipv6mr_ifindex, &mreq.ipv6mr_multiaddr);
else else
retv = ipv6_sock_mc_drop(sk, mreq.ipv6mr_ifindex, &mreq.ipv6mr_multiaddr); retv = __ipv6_sock_mc_drop(sk, mreq.ipv6mr_ifindex, &mreq.ipv6mr_multiaddr);
break; break;
} }
case IPV6_JOIN_ANYCAST: case IPV6_JOIN_ANYCAST:
...@@ -623,11 +638,11 @@ static int do_ipv6_setsockopt(struct sock *sk, int level, int optname, ...@@ -623,11 +638,11 @@ static int do_ipv6_setsockopt(struct sock *sk, int level, int optname,
} }
psin6 = (struct sockaddr_in6 *)&greq.gr_group; psin6 = (struct sockaddr_in6 *)&greq.gr_group;
if (optname == MCAST_JOIN_GROUP) if (optname == MCAST_JOIN_GROUP)
retv = ipv6_sock_mc_join(sk, greq.gr_interface, retv = __ipv6_sock_mc_join(sk, greq.gr_interface,
&psin6->sin6_addr); &psin6->sin6_addr);
else else
retv = ipv6_sock_mc_drop(sk, greq.gr_interface, retv = __ipv6_sock_mc_drop(sk, greq.gr_interface,
&psin6->sin6_addr); &psin6->sin6_addr);
break; break;
} }
case MCAST_JOIN_SOURCE_GROUP: case MCAST_JOIN_SOURCE_GROUP:
...@@ -659,8 +674,8 @@ static int do_ipv6_setsockopt(struct sock *sk, int level, int optname, ...@@ -659,8 +674,8 @@ static int do_ipv6_setsockopt(struct sock *sk, int level, int optname,
struct sockaddr_in6 *psin6; struct sockaddr_in6 *psin6;
psin6 = (struct sockaddr_in6 *)&greqs.gsr_group; psin6 = (struct sockaddr_in6 *)&greqs.gsr_group;
retv = ipv6_sock_mc_join(sk, greqs.gsr_interface, retv = __ipv6_sock_mc_join(sk, greqs.gsr_interface,
&psin6->sin6_addr); &psin6->sin6_addr);
/* prior join w/ different source is ok */ /* prior join w/ different source is ok */
if (retv && retv != -EADDRINUSE) if (retv && retv != -EADDRINUSE)
break; break;
...@@ -837,11 +852,15 @@ static int do_ipv6_setsockopt(struct sock *sk, int level, int optname, ...@@ -837,11 +852,15 @@ static int do_ipv6_setsockopt(struct sock *sk, int level, int optname,
} }
release_sock(sk); release_sock(sk);
if (needs_rtnl)
rtnl_unlock();
return retv; return retv;
e_inval: e_inval:
release_sock(sk); release_sock(sk);
if (needs_rtnl)
rtnl_unlock();
return -EINVAL; return -EINVAL;
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment