Commit 20ae1d6a authored by Jason A. Donenfeld's avatar Jason A. Donenfeld Committed by Jakub Kicinski

wireguard: device: reset peer src endpoint when netns exits

Each peer's endpoint contains a dst_cache entry that takes a reference
to another netdev. When the containing namespace exits, we take down the
socket and prevent future sockets from being created (by setting
creating_net to NULL), which removes that potential reference on the
netns. However, it doesn't release references to the netns that a netdev
cached in dst_cache might be taking, so the netns still might fail to
exit. Since the socket is gimped anyway, we can simply clear all the
dst_caches (by way of clearing the endpoint src), which will release all
references.

However, the current dst_cache_reset function only releases those
references lazily. But it turns out that all of our usages of
wg_socket_clear_peer_endpoint_src are called from contexts that are not
exactly high-speed or bottle-necked. For example, when there's
connection difficulty, or when userspace is reconfiguring the interface.
And in particular for this patch, when the netns is exiting. So for
those cases, it makes more sense to call dst_release immediately. For
that, we add a small helper function to dst_cache.

This patch also adds a test to netns.sh from Hangbin Liu to ensure this
doesn't regress.
Tested-by: default avatarHangbin Liu <liuhangbin@gmail.com>
Reported-by: default avatarXiumei Mu <xmu@redhat.com>
Cc: Toke Høiland-Jørgensen <toke@redhat.com>
Cc: Paolo Abeni <pabeni@redhat.com>
Fixes: 900575aa ("wireguard: device: avoid circular netns references")
Signed-off-by: default avatarJason A. Donenfeld <Jason@zx2c4.com>
Signed-off-by: default avatarJakub Kicinski <kuba@kernel.org>
parent 7e938beb
...@@ -398,6 +398,7 @@ static struct rtnl_link_ops link_ops __read_mostly = { ...@@ -398,6 +398,7 @@ static struct rtnl_link_ops link_ops __read_mostly = {
static void wg_netns_pre_exit(struct net *net) static void wg_netns_pre_exit(struct net *net)
{ {
struct wg_device *wg; struct wg_device *wg;
struct wg_peer *peer;
rtnl_lock(); rtnl_lock();
list_for_each_entry(wg, &device_list, device_list) { list_for_each_entry(wg, &device_list, device_list) {
...@@ -407,6 +408,8 @@ static void wg_netns_pre_exit(struct net *net) ...@@ -407,6 +408,8 @@ static void wg_netns_pre_exit(struct net *net)
mutex_lock(&wg->device_update_lock); mutex_lock(&wg->device_update_lock);
rcu_assign_pointer(wg->creating_net, NULL); rcu_assign_pointer(wg->creating_net, NULL);
wg_socket_reinit(wg, NULL, NULL); wg_socket_reinit(wg, NULL, NULL);
list_for_each_entry(peer, &wg->peer_list, peer_list)
wg_socket_clear_peer_endpoint_src(peer);
mutex_unlock(&wg->device_update_lock); mutex_unlock(&wg->device_update_lock);
} }
} }
......
...@@ -308,7 +308,7 @@ void wg_socket_clear_peer_endpoint_src(struct wg_peer *peer) ...@@ -308,7 +308,7 @@ void wg_socket_clear_peer_endpoint_src(struct wg_peer *peer)
{ {
write_lock_bh(&peer->endpoint_lock); write_lock_bh(&peer->endpoint_lock);
memset(&peer->endpoint.src6, 0, sizeof(peer->endpoint.src6)); memset(&peer->endpoint.src6, 0, sizeof(peer->endpoint.src6));
dst_cache_reset(&peer->endpoint_cache); dst_cache_reset_now(&peer->endpoint_cache);
write_unlock_bh(&peer->endpoint_lock); write_unlock_bh(&peer->endpoint_lock);
} }
......
...@@ -79,6 +79,17 @@ static inline void dst_cache_reset(struct dst_cache *dst_cache) ...@@ -79,6 +79,17 @@ static inline void dst_cache_reset(struct dst_cache *dst_cache)
dst_cache->reset_ts = jiffies; dst_cache->reset_ts = jiffies;
} }
/**
* dst_cache_reset_now - invalidate the cache contents immediately
* @dst_cache: the cache
*
* The caller must be sure there are no concurrent users, as this frees
* all dst_cache users immediately, rather than waiting for the next
* per-cpu usage like dst_cache_reset does. Most callers should use the
* higher speed lazily-freed dst_cache_reset function instead.
*/
void dst_cache_reset_now(struct dst_cache *dst_cache);
/** /**
* dst_cache_init - initialize the cache, allocating the required storage * dst_cache_init - initialize the cache, allocating the required storage
* @dst_cache: the cache * @dst_cache: the cache
......
...@@ -162,3 +162,22 @@ void dst_cache_destroy(struct dst_cache *dst_cache) ...@@ -162,3 +162,22 @@ void dst_cache_destroy(struct dst_cache *dst_cache)
free_percpu(dst_cache->cache); free_percpu(dst_cache->cache);
} }
EXPORT_SYMBOL_GPL(dst_cache_destroy); EXPORT_SYMBOL_GPL(dst_cache_destroy);
void dst_cache_reset_now(struct dst_cache *dst_cache)
{
int i;
if (!dst_cache->cache)
return;
dst_cache->reset_ts = jiffies;
for_each_possible_cpu(i) {
struct dst_cache_pcpu *idst = per_cpu_ptr(dst_cache->cache, i);
struct dst_entry *dst = idst->dst;
idst->cookie = 0;
idst->dst = NULL;
dst_release(dst);
}
}
EXPORT_SYMBOL_GPL(dst_cache_reset_now);
...@@ -613,6 +613,28 @@ ip0 link set wg0 up ...@@ -613,6 +613,28 @@ ip0 link set wg0 up
kill $ncat_pid kill $ncat_pid
ip0 link del wg0 ip0 link del wg0
# Ensure that dst_cache references don't outlive netns lifetime
ip1 link add dev wg0 type wireguard
ip2 link add dev wg0 type wireguard
configure_peers
ip1 link add veth1 type veth peer name veth2
ip1 link set veth2 netns $netns2
ip1 addr add fd00:aa::1/64 dev veth1
ip2 addr add fd00:aa::2/64 dev veth2
ip1 link set veth1 up
ip2 link set veth2 up
waitiface $netns1 veth1
waitiface $netns2 veth2
ip1 -6 route add default dev veth1 via fd00:aa::2
ip2 -6 route add default dev veth2 via fd00:aa::1
n1 wg set wg0 peer "$pub2" endpoint [fd00:aa::2]:2
n2 wg set wg0 peer "$pub1" endpoint [fd00:aa::1]:1
n1 ping6 -c 1 fd00::2
pp ip netns delete $netns1
pp ip netns delete $netns2
pp ip netns add $netns1
pp ip netns add $netns2
# Ensure there aren't circular reference loops # Ensure there aren't circular reference loops
ip1 link add wg1 type wireguard ip1 link add wg1 type wireguard
ip2 link add wg2 type wireguard ip2 link add wg2 type wireguard
...@@ -631,7 +653,7 @@ while read -t 0.1 -r line 2>/dev/null || [[ $? -ne 142 ]]; do ...@@ -631,7 +653,7 @@ while read -t 0.1 -r line 2>/dev/null || [[ $? -ne 142 ]]; do
done < /dev/kmsg done < /dev/kmsg
alldeleted=1 alldeleted=1
for object in "${!objects[@]}"; do for object in "${!objects[@]}"; do
if [[ ${objects["$object"]} != *createddestroyed ]]; then if [[ ${objects["$object"]} != *createddestroyed && ${objects["$object"]} != *createdcreateddestroyeddestroyed ]]; then
echo "Error: $object: merely ${objects["$object"]}" >&3 echo "Error: $object: merely ${objects["$object"]}" >&3
alldeleted=0 alldeleted=0
fi fi
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment