Commit 6fd815bb authored by David S. Miller's avatar David S. Miller

Merge branch 'wireguard-fixes'

Jason A. Donenfeld says:

====================
wireguard fixes for 5.13-rc5

Here are bug fixes to WireGuard for 5.13-rc5:

1-2,6) These are small, trivial tweaks to our test harness.

3) Linus thinks -O3 is still dangerous to enable. The code gen wasn't so
   much different with -O2 either.

4) We were accidentally calling synchronize_rcu instead of
   synchronize_net while holding the rtnl_lock, resulting in some rather
   large stalls that hit production machines.

5) Peer allocation was wasting literally hundreds of megabytes on real
   world deployments, due to oddly sized large objects not fitting
   nicely into a kmalloc slab.

7-9) We move from an insanely expensive O(n) algorithm to a fast O(1)
     algorithm, and cleanup a massive memory leak in the process, in
     which allowed ips churn would leave danging nodes hanging around
     without cleanup until the interface was removed. The O(1) algorithm
     eliminates packet stalls and high latency issues, in addition to
     bringing operations that took as much as 10 minutes down to less
     than a second.
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 579028de bf7b042d
ccflags-y := -O3 ccflags-y := -D'pr_fmt(fmt)=KBUILD_MODNAME ": " fmt'
ccflags-y += -D'pr_fmt(fmt)=KBUILD_MODNAME ": " fmt'
ccflags-$(CONFIG_WIREGUARD_DEBUG) += -DDEBUG ccflags-$(CONFIG_WIREGUARD_DEBUG) += -DDEBUG
wireguard-y := main.o wireguard-y := main.o
wireguard-y += noise.o wireguard-y += noise.o
......
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
#include "allowedips.h" #include "allowedips.h"
#include "peer.h" #include "peer.h"
static struct kmem_cache *node_cache;
static void swap_endian(u8 *dst, const u8 *src, u8 bits) static void swap_endian(u8 *dst, const u8 *src, u8 bits)
{ {
if (bits == 32) { if (bits == 32) {
...@@ -28,8 +30,11 @@ static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src, ...@@ -28,8 +30,11 @@ static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src,
node->bitlen = bits; node->bitlen = bits;
memcpy(node->bits, src, bits / 8U); memcpy(node->bits, src, bits / 8U);
} }
#define CHOOSE_NODE(parent, key) \
parent->bit[(key[parent->bit_at_a] >> parent->bit_at_b) & 1] static inline u8 choose(struct allowedips_node *node, const u8 *key)
{
return (key[node->bit_at_a] >> node->bit_at_b) & 1;
}
static void push_rcu(struct allowedips_node **stack, static void push_rcu(struct allowedips_node **stack,
struct allowedips_node __rcu *p, unsigned int *len) struct allowedips_node __rcu *p, unsigned int *len)
...@@ -40,6 +45,11 @@ static void push_rcu(struct allowedips_node **stack, ...@@ -40,6 +45,11 @@ static void push_rcu(struct allowedips_node **stack,
} }
} }
static void node_free_rcu(struct rcu_head *rcu)
{
kmem_cache_free(node_cache, container_of(rcu, struct allowedips_node, rcu));
}
static void root_free_rcu(struct rcu_head *rcu) static void root_free_rcu(struct rcu_head *rcu)
{ {
struct allowedips_node *node, *stack[128] = { struct allowedips_node *node, *stack[128] = {
...@@ -49,7 +59,7 @@ static void root_free_rcu(struct rcu_head *rcu) ...@@ -49,7 +59,7 @@ static void root_free_rcu(struct rcu_head *rcu)
while (len > 0 && (node = stack[--len])) { while (len > 0 && (node = stack[--len])) {
push_rcu(stack, node->bit[0], &len); push_rcu(stack, node->bit[0], &len);
push_rcu(stack, node->bit[1], &len); push_rcu(stack, node->bit[1], &len);
kfree(node); kmem_cache_free(node_cache, node);
} }
} }
...@@ -66,60 +76,6 @@ static void root_remove_peer_lists(struct allowedips_node *root) ...@@ -66,60 +76,6 @@ static void root_remove_peer_lists(struct allowedips_node *root)
} }
} }
static void walk_remove_by_peer(struct allowedips_node __rcu **top,
struct wg_peer *peer, struct mutex *lock)
{
#define REF(p) rcu_access_pointer(p)
#define DEREF(p) rcu_dereference_protected(*(p), lockdep_is_held(lock))
#define PUSH(p) ({ \
WARN_ON(IS_ENABLED(DEBUG) && len >= 128); \
stack[len++] = p; \
})
struct allowedips_node __rcu **stack[128], **nptr;
struct allowedips_node *node, *prev;
unsigned int len;
if (unlikely(!peer || !REF(*top)))
return;
for (prev = NULL, len = 0, PUSH(top); len > 0; prev = node) {
nptr = stack[len - 1];
node = DEREF(nptr);
if (!node) {
--len;
continue;
}
if (!prev || REF(prev->bit[0]) == node ||
REF(prev->bit[1]) == node) {
if (REF(node->bit[0]))
PUSH(&node->bit[0]);
else if (REF(node->bit[1]))
PUSH(&node->bit[1]);
} else if (REF(node->bit[0]) == prev) {
if (REF(node->bit[1]))
PUSH(&node->bit[1]);
} else {
if (rcu_dereference_protected(node->peer,
lockdep_is_held(lock)) == peer) {
RCU_INIT_POINTER(node->peer, NULL);
list_del_init(&node->peer_list);
if (!node->bit[0] || !node->bit[1]) {
rcu_assign_pointer(*nptr, DEREF(
&node->bit[!REF(node->bit[0])]));
kfree_rcu(node, rcu);
node = DEREF(nptr);
}
}
--len;
}
}
#undef REF
#undef DEREF
#undef PUSH
}
static unsigned int fls128(u64 a, u64 b) static unsigned int fls128(u64 a, u64 b)
{ {
return a ? fls64(a) + 64U : fls64(b); return a ? fls64(a) + 64U : fls64(b);
...@@ -159,7 +115,7 @@ static struct allowedips_node *find_node(struct allowedips_node *trie, u8 bits, ...@@ -159,7 +115,7 @@ static struct allowedips_node *find_node(struct allowedips_node *trie, u8 bits,
found = node; found = node;
if (node->cidr == bits) if (node->cidr == bits)
break; break;
node = rcu_dereference_bh(CHOOSE_NODE(node, key)); node = rcu_dereference_bh(node->bit[choose(node, key)]);
} }
return found; return found;
} }
...@@ -191,8 +147,7 @@ static bool node_placement(struct allowedips_node __rcu *trie, const u8 *key, ...@@ -191,8 +147,7 @@ static bool node_placement(struct allowedips_node __rcu *trie, const u8 *key,
u8 cidr, u8 bits, struct allowedips_node **rnode, u8 cidr, u8 bits, struct allowedips_node **rnode,
struct mutex *lock) struct mutex *lock)
{ {
struct allowedips_node *node = rcu_dereference_protected(trie, struct allowedips_node *node = rcu_dereference_protected(trie, lockdep_is_held(lock));
lockdep_is_held(lock));
struct allowedips_node *parent = NULL; struct allowedips_node *parent = NULL;
bool exact = false; bool exact = false;
...@@ -202,13 +157,24 @@ static bool node_placement(struct allowedips_node __rcu *trie, const u8 *key, ...@@ -202,13 +157,24 @@ static bool node_placement(struct allowedips_node __rcu *trie, const u8 *key,
exact = true; exact = true;
break; break;
} }
node = rcu_dereference_protected(CHOOSE_NODE(parent, key), node = rcu_dereference_protected(parent->bit[choose(parent, key)], lockdep_is_held(lock));
lockdep_is_held(lock));
} }
*rnode = parent; *rnode = parent;
return exact; return exact;
} }
static inline void connect_node(struct allowedips_node **parent, u8 bit, struct allowedips_node *node)
{
node->parent_bit_packed = (unsigned long)parent | bit;
rcu_assign_pointer(*parent, node);
}
static inline void choose_and_connect_node(struct allowedips_node *parent, struct allowedips_node *node)
{
u8 bit = choose(parent, node->bits);
connect_node(&parent->bit[bit], bit, node);
}
static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
u8 cidr, struct wg_peer *peer, struct mutex *lock) u8 cidr, struct wg_peer *peer, struct mutex *lock)
{ {
...@@ -218,13 +184,13 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, ...@@ -218,13 +184,13 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
return -EINVAL; return -EINVAL;
if (!rcu_access_pointer(*trie)) { if (!rcu_access_pointer(*trie)) {
node = kzalloc(sizeof(*node), GFP_KERNEL); node = kmem_cache_zalloc(node_cache, GFP_KERNEL);
if (unlikely(!node)) if (unlikely(!node))
return -ENOMEM; return -ENOMEM;
RCU_INIT_POINTER(node->peer, peer); RCU_INIT_POINTER(node->peer, peer);
list_add_tail(&node->peer_list, &peer->allowedips_list); list_add_tail(&node->peer_list, &peer->allowedips_list);
copy_and_assign_cidr(node, key, cidr, bits); copy_and_assign_cidr(node, key, cidr, bits);
rcu_assign_pointer(*trie, node); connect_node(trie, 2, node);
return 0; return 0;
} }
if (node_placement(*trie, key, cidr, bits, &node, lock)) { if (node_placement(*trie, key, cidr, bits, &node, lock)) {
...@@ -233,7 +199,7 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, ...@@ -233,7 +199,7 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
return 0; return 0;
} }
newnode = kzalloc(sizeof(*newnode), GFP_KERNEL); newnode = kmem_cache_zalloc(node_cache, GFP_KERNEL);
if (unlikely(!newnode)) if (unlikely(!newnode))
return -ENOMEM; return -ENOMEM;
RCU_INIT_POINTER(newnode->peer, peer); RCU_INIT_POINTER(newnode->peer, peer);
...@@ -243,10 +209,10 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, ...@@ -243,10 +209,10 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
if (!node) { if (!node) {
down = rcu_dereference_protected(*trie, lockdep_is_held(lock)); down = rcu_dereference_protected(*trie, lockdep_is_held(lock));
} else { } else {
down = rcu_dereference_protected(CHOOSE_NODE(node, key), const u8 bit = choose(node, key);
lockdep_is_held(lock)); down = rcu_dereference_protected(node->bit[bit], lockdep_is_held(lock));
if (!down) { if (!down) {
rcu_assign_pointer(CHOOSE_NODE(node, key), newnode); connect_node(&node->bit[bit], bit, newnode);
return 0; return 0;
} }
} }
...@@ -254,30 +220,29 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, ...@@ -254,30 +220,29 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
parent = node; parent = node;
if (newnode->cidr == cidr) { if (newnode->cidr == cidr) {
rcu_assign_pointer(CHOOSE_NODE(newnode, down->bits), down); choose_and_connect_node(newnode, down);
if (!parent) if (!parent)
rcu_assign_pointer(*trie, newnode); connect_node(trie, 2, newnode);
else else
rcu_assign_pointer(CHOOSE_NODE(parent, newnode->bits), choose_and_connect_node(parent, newnode);
newnode); return 0;
} else { }
node = kzalloc(sizeof(*node), GFP_KERNEL);
node = kmem_cache_zalloc(node_cache, GFP_KERNEL);
if (unlikely(!node)) { if (unlikely(!node)) {
list_del(&newnode->peer_list); list_del(&newnode->peer_list);
kfree(newnode); kmem_cache_free(node_cache, newnode);
return -ENOMEM; return -ENOMEM;
} }
INIT_LIST_HEAD(&node->peer_list); INIT_LIST_HEAD(&node->peer_list);
copy_and_assign_cidr(node, newnode->bits, cidr, bits); copy_and_assign_cidr(node, newnode->bits, cidr, bits);
rcu_assign_pointer(CHOOSE_NODE(node, down->bits), down); choose_and_connect_node(node, down);
rcu_assign_pointer(CHOOSE_NODE(node, newnode->bits), newnode); choose_and_connect_node(node, newnode);
if (!parent) if (!parent)
rcu_assign_pointer(*trie, node); connect_node(trie, 2, node);
else else
rcu_assign_pointer(CHOOSE_NODE(parent, node->bits), choose_and_connect_node(parent, node);
node);
}
return 0; return 0;
} }
...@@ -335,9 +300,41 @@ int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip, ...@@ -335,9 +300,41 @@ int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip,
void wg_allowedips_remove_by_peer(struct allowedips *table, void wg_allowedips_remove_by_peer(struct allowedips *table,
struct wg_peer *peer, struct mutex *lock) struct wg_peer *peer, struct mutex *lock)
{ {
struct allowedips_node *node, *child, **parent_bit, *parent, *tmp;
bool free_parent;
if (list_empty(&peer->allowedips_list))
return;
++table->seq; ++table->seq;
walk_remove_by_peer(&table->root4, peer, lock); list_for_each_entry_safe(node, tmp, &peer->allowedips_list, peer_list) {
walk_remove_by_peer(&table->root6, peer, lock); list_del_init(&node->peer_list);
RCU_INIT_POINTER(node->peer, NULL);
if (node->bit[0] && node->bit[1])
continue;
child = rcu_dereference_protected(node->bit[!rcu_access_pointer(node->bit[0])],
lockdep_is_held(lock));
if (child)
child->parent_bit_packed = node->parent_bit_packed;
parent_bit = (struct allowedips_node **)(node->parent_bit_packed & ~3UL);
*parent_bit = child;
parent = (void *)parent_bit -
offsetof(struct allowedips_node, bit[node->parent_bit_packed & 1]);
free_parent = !rcu_access_pointer(node->bit[0]) &&
!rcu_access_pointer(node->bit[1]) &&
(node->parent_bit_packed & 3) <= 1 &&
!rcu_access_pointer(parent->peer);
if (free_parent)
child = rcu_dereference_protected(
parent->bit[!(node->parent_bit_packed & 1)],
lockdep_is_held(lock));
call_rcu(&node->rcu, node_free_rcu);
if (!free_parent)
continue;
if (child)
child->parent_bit_packed = parent->parent_bit_packed;
*(struct allowedips_node **)(parent->parent_bit_packed & ~3UL) = child;
call_rcu(&parent->rcu, node_free_rcu);
}
} }
int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr) int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr)
...@@ -374,4 +371,16 @@ struct wg_peer *wg_allowedips_lookup_src(struct allowedips *table, ...@@ -374,4 +371,16 @@ struct wg_peer *wg_allowedips_lookup_src(struct allowedips *table,
return NULL; return NULL;
} }
int __init wg_allowedips_slab_init(void)
{
node_cache = KMEM_CACHE(allowedips_node, 0);
return node_cache ? 0 : -ENOMEM;
}
void wg_allowedips_slab_uninit(void)
{
rcu_barrier();
kmem_cache_destroy(node_cache);
}
#include "selftest/allowedips.c" #include "selftest/allowedips.c"
...@@ -15,14 +15,11 @@ struct wg_peer; ...@@ -15,14 +15,11 @@ struct wg_peer;
struct allowedips_node { struct allowedips_node {
struct wg_peer __rcu *peer; struct wg_peer __rcu *peer;
struct allowedips_node __rcu *bit[2]; struct allowedips_node __rcu *bit[2];
/* While it may seem scandalous that we waste space for v4,
* we're alloc'ing to the nearest power of 2 anyway, so this
* doesn't actually make a difference.
*/
u8 bits[16] __aligned(__alignof(u64));
u8 cidr, bit_at_a, bit_at_b, bitlen; u8 cidr, bit_at_a, bit_at_b, bitlen;
u8 bits[16] __aligned(__alignof(u64));
/* Keep rarely used list at bottom to be beyond cache line. */ /* Keep rarely used members at bottom to be beyond cache line. */
unsigned long parent_bit_packed;
union { union {
struct list_head peer_list; struct list_head peer_list;
struct rcu_head rcu; struct rcu_head rcu;
...@@ -33,7 +30,7 @@ struct allowedips { ...@@ -33,7 +30,7 @@ struct allowedips {
struct allowedips_node __rcu *root4; struct allowedips_node __rcu *root4;
struct allowedips_node __rcu *root6; struct allowedips_node __rcu *root6;
u64 seq; u64 seq;
}; } __aligned(4); /* We pack the lower 2 bits of &root, but m68k only gives 16-bit alignment. */
void wg_allowedips_init(struct allowedips *table); void wg_allowedips_init(struct allowedips *table);
void wg_allowedips_free(struct allowedips *table, struct mutex *mutex); void wg_allowedips_free(struct allowedips *table, struct mutex *mutex);
...@@ -56,4 +53,7 @@ struct wg_peer *wg_allowedips_lookup_src(struct allowedips *table, ...@@ -56,4 +53,7 @@ struct wg_peer *wg_allowedips_lookup_src(struct allowedips *table,
bool wg_allowedips_selftest(void); bool wg_allowedips_selftest(void);
#endif #endif
int wg_allowedips_slab_init(void);
void wg_allowedips_slab_uninit(void);
#endif /* _WG_ALLOWEDIPS_H */ #endif /* _WG_ALLOWEDIPS_H */
...@@ -21,13 +21,22 @@ static int __init mod_init(void) ...@@ -21,13 +21,22 @@ static int __init mod_init(void)
{ {
int ret; int ret;
ret = wg_allowedips_slab_init();
if (ret < 0)
goto err_allowedips;
#ifdef DEBUG #ifdef DEBUG
ret = -ENOTRECOVERABLE;
if (!wg_allowedips_selftest() || !wg_packet_counter_selftest() || if (!wg_allowedips_selftest() || !wg_packet_counter_selftest() ||
!wg_ratelimiter_selftest()) !wg_ratelimiter_selftest())
return -ENOTRECOVERABLE; goto err_peer;
#endif #endif
wg_noise_init(); wg_noise_init();
ret = wg_peer_init();
if (ret < 0)
goto err_peer;
ret = wg_device_init(); ret = wg_device_init();
if (ret < 0) if (ret < 0)
goto err_device; goto err_device;
...@@ -44,6 +53,10 @@ static int __init mod_init(void) ...@@ -44,6 +53,10 @@ static int __init mod_init(void)
err_netlink: err_netlink:
wg_device_uninit(); wg_device_uninit();
err_device: err_device:
wg_peer_uninit();
err_peer:
wg_allowedips_slab_uninit();
err_allowedips:
return ret; return ret;
} }
...@@ -51,6 +64,8 @@ static void __exit mod_exit(void) ...@@ -51,6 +64,8 @@ static void __exit mod_exit(void)
{ {
wg_genetlink_uninit(); wg_genetlink_uninit();
wg_device_uninit(); wg_device_uninit();
wg_peer_uninit();
wg_allowedips_slab_uninit();
} }
module_init(mod_init); module_init(mod_init);
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include <linux/rcupdate.h> #include <linux/rcupdate.h>
#include <linux/list.h> #include <linux/list.h>
static struct kmem_cache *peer_cache;
static atomic64_t peer_counter = ATOMIC64_INIT(0); static atomic64_t peer_counter = ATOMIC64_INIT(0);
struct wg_peer *wg_peer_create(struct wg_device *wg, struct wg_peer *wg_peer_create(struct wg_device *wg,
...@@ -29,10 +30,10 @@ struct wg_peer *wg_peer_create(struct wg_device *wg, ...@@ -29,10 +30,10 @@ struct wg_peer *wg_peer_create(struct wg_device *wg,
if (wg->num_peers >= MAX_PEERS_PER_DEVICE) if (wg->num_peers >= MAX_PEERS_PER_DEVICE)
return ERR_PTR(ret); return ERR_PTR(ret);
peer = kzalloc(sizeof(*peer), GFP_KERNEL); peer = kmem_cache_zalloc(peer_cache, GFP_KERNEL);
if (unlikely(!peer)) if (unlikely(!peer))
return ERR_PTR(ret); return ERR_PTR(ret);
if (dst_cache_init(&peer->endpoint_cache, GFP_KERNEL)) if (unlikely(dst_cache_init(&peer->endpoint_cache, GFP_KERNEL)))
goto err; goto err;
peer->device = wg; peer->device = wg;
...@@ -64,7 +65,7 @@ struct wg_peer *wg_peer_create(struct wg_device *wg, ...@@ -64,7 +65,7 @@ struct wg_peer *wg_peer_create(struct wg_device *wg,
return peer; return peer;
err: err:
kfree(peer); kmem_cache_free(peer_cache, peer);
return ERR_PTR(ret); return ERR_PTR(ret);
} }
...@@ -88,7 +89,7 @@ static void peer_make_dead(struct wg_peer *peer) ...@@ -88,7 +89,7 @@ static void peer_make_dead(struct wg_peer *peer)
/* Mark as dead, so that we don't allow jumping contexts after. */ /* Mark as dead, so that we don't allow jumping contexts after. */
WRITE_ONCE(peer->is_dead, true); WRITE_ONCE(peer->is_dead, true);
/* The caller must now synchronize_rcu() for this to take effect. */ /* The caller must now synchronize_net() for this to take effect. */
} }
static void peer_remove_after_dead(struct wg_peer *peer) static void peer_remove_after_dead(struct wg_peer *peer)
...@@ -160,7 +161,7 @@ void wg_peer_remove(struct wg_peer *peer) ...@@ -160,7 +161,7 @@ void wg_peer_remove(struct wg_peer *peer)
lockdep_assert_held(&peer->device->device_update_lock); lockdep_assert_held(&peer->device->device_update_lock);
peer_make_dead(peer); peer_make_dead(peer);
synchronize_rcu(); synchronize_net();
peer_remove_after_dead(peer); peer_remove_after_dead(peer);
} }
...@@ -178,7 +179,7 @@ void wg_peer_remove_all(struct wg_device *wg) ...@@ -178,7 +179,7 @@ void wg_peer_remove_all(struct wg_device *wg)
peer_make_dead(peer); peer_make_dead(peer);
list_add_tail(&peer->peer_list, &dead_peers); list_add_tail(&peer->peer_list, &dead_peers);
} }
synchronize_rcu(); synchronize_net();
list_for_each_entry_safe(peer, temp, &dead_peers, peer_list) list_for_each_entry_safe(peer, temp, &dead_peers, peer_list)
peer_remove_after_dead(peer); peer_remove_after_dead(peer);
} }
...@@ -193,7 +194,8 @@ static void rcu_release(struct rcu_head *rcu) ...@@ -193,7 +194,8 @@ static void rcu_release(struct rcu_head *rcu)
/* The final zeroing takes care of clearing any remaining handshake key /* The final zeroing takes care of clearing any remaining handshake key
* material and other potentially sensitive information. * material and other potentially sensitive information.
*/ */
kfree_sensitive(peer); memzero_explicit(peer, sizeof(*peer));
kmem_cache_free(peer_cache, peer);
} }
static void kref_release(struct kref *refcount) static void kref_release(struct kref *refcount)
...@@ -225,3 +227,14 @@ void wg_peer_put(struct wg_peer *peer) ...@@ -225,3 +227,14 @@ void wg_peer_put(struct wg_peer *peer)
return; return;
kref_put(&peer->refcount, kref_release); kref_put(&peer->refcount, kref_release);
} }
int __init wg_peer_init(void)
{
peer_cache = KMEM_CACHE(wg_peer, 0);
return peer_cache ? 0 : -ENOMEM;
}
void wg_peer_uninit(void)
{
kmem_cache_destroy(peer_cache);
}
...@@ -80,4 +80,7 @@ void wg_peer_put(struct wg_peer *peer); ...@@ -80,4 +80,7 @@ void wg_peer_put(struct wg_peer *peer);
void wg_peer_remove(struct wg_peer *peer); void wg_peer_remove(struct wg_peer *peer);
void wg_peer_remove_all(struct wg_device *wg); void wg_peer_remove_all(struct wg_device *wg);
int wg_peer_init(void);
void wg_peer_uninit(void);
#endif /* _WG_PEER_H */ #endif /* _WG_PEER_H */
...@@ -19,32 +19,22 @@ ...@@ -19,32 +19,22 @@
#include <linux/siphash.h> #include <linux/siphash.h>
static __init void swap_endian_and_apply_cidr(u8 *dst, const u8 *src, u8 bits,
u8 cidr)
{
swap_endian(dst, src, bits);
memset(dst + (cidr + 7) / 8, 0, bits / 8 - (cidr + 7) / 8);
if (cidr)
dst[(cidr + 7) / 8 - 1] &= ~0U << ((8 - (cidr % 8)) % 8);
}
static __init void print_node(struct allowedips_node *node, u8 bits) static __init void print_node(struct allowedips_node *node, u8 bits)
{ {
char *fmt_connection = KERN_DEBUG "\t\"%p/%d\" -> \"%p/%d\";\n"; char *fmt_connection = KERN_DEBUG "\t\"%p/%d\" -> \"%p/%d\";\n";
char *fmt_declaration = KERN_DEBUG char *fmt_declaration = KERN_DEBUG "\t\"%p/%d\"[style=%s, color=\"#%06x\"];\n";
"\t\"%p/%d\"[style=%s, color=\"#%06x\"];\n"; u8 ip1[16], ip2[16], cidr1, cidr2;
char *style = "dotted"; char *style = "dotted";
u8 ip1[16], ip2[16];
u32 color = 0; u32 color = 0;
if (node == NULL)
return;
if (bits == 32) { if (bits == 32) {
fmt_connection = KERN_DEBUG "\t\"%pI4/%d\" -> \"%pI4/%d\";\n"; fmt_connection = KERN_DEBUG "\t\"%pI4/%d\" -> \"%pI4/%d\";\n";
fmt_declaration = KERN_DEBUG fmt_declaration = KERN_DEBUG "\t\"%pI4/%d\"[style=%s, color=\"#%06x\"];\n";
"\t\"%pI4/%d\"[style=%s, color=\"#%06x\"];\n";
} else if (bits == 128) { } else if (bits == 128) {
fmt_connection = KERN_DEBUG "\t\"%pI6/%d\" -> \"%pI6/%d\";\n"; fmt_connection = KERN_DEBUG "\t\"%pI6/%d\" -> \"%pI6/%d\";\n";
fmt_declaration = KERN_DEBUG fmt_declaration = KERN_DEBUG "\t\"%pI6/%d\"[style=%s, color=\"#%06x\"];\n";
"\t\"%pI6/%d\"[style=%s, color=\"#%06x\"];\n";
} }
if (node->peer) { if (node->peer) {
hsiphash_key_t key = { { 0 } }; hsiphash_key_t key = { { 0 } };
...@@ -55,24 +45,20 @@ static __init void print_node(struct allowedips_node *node, u8 bits) ...@@ -55,24 +45,20 @@ static __init void print_node(struct allowedips_node *node, u8 bits)
hsiphash_1u32(0xabad1dea, &key) % 200; hsiphash_1u32(0xabad1dea, &key) % 200;
style = "bold"; style = "bold";
} }
swap_endian_and_apply_cidr(ip1, node->bits, bits, node->cidr); wg_allowedips_read_node(node, ip1, &cidr1);
printk(fmt_declaration, ip1, node->cidr, style, color); printk(fmt_declaration, ip1, cidr1, style, color);
if (node->bit[0]) { if (node->bit[0]) {
swap_endian_and_apply_cidr(ip2, wg_allowedips_read_node(rcu_dereference_raw(node->bit[0]), ip2, &cidr2);
rcu_dereference_raw(node->bit[0])->bits, bits, printk(fmt_connection, ip1, cidr1, ip2, cidr2);
node->cidr);
printk(fmt_connection, ip1, node->cidr, ip2,
rcu_dereference_raw(node->bit[0])->cidr);
print_node(rcu_dereference_raw(node->bit[0]), bits);
} }
if (node->bit[1]) { if (node->bit[1]) {
swap_endian_and_apply_cidr(ip2, wg_allowedips_read_node(rcu_dereference_raw(node->bit[1]), ip2, &cidr2);
rcu_dereference_raw(node->bit[1])->bits, printk(fmt_connection, ip1, cidr1, ip2, cidr2);
bits, node->cidr);
printk(fmt_connection, ip1, node->cidr, ip2,
rcu_dereference_raw(node->bit[1])->cidr);
print_node(rcu_dereference_raw(node->bit[1]), bits);
} }
if (node->bit[0])
print_node(rcu_dereference_raw(node->bit[0]), bits);
if (node->bit[1])
print_node(rcu_dereference_raw(node->bit[1]), bits);
} }
static __init void print_tree(struct allowedips_node __rcu *top, u8 bits) static __init void print_tree(struct allowedips_node __rcu *top, u8 bits)
...@@ -121,8 +107,8 @@ static __init inline union nf_inet_addr horrible_cidr_to_mask(u8 cidr) ...@@ -121,8 +107,8 @@ static __init inline union nf_inet_addr horrible_cidr_to_mask(u8 cidr)
{ {
union nf_inet_addr mask; union nf_inet_addr mask;
memset(&mask, 0x00, 128 / 8); memset(&mask, 0, sizeof(mask));
memset(&mask, 0xff, cidr / 8); memset(&mask.all, 0xff, cidr / 8);
if (cidr % 32) if (cidr % 32)
mask.all[cidr / 32] = (__force u32)htonl( mask.all[cidr / 32] = (__force u32)htonl(
(0xFFFFFFFFUL << (32 - (cidr % 32))) & 0xFFFFFFFFUL); (0xFFFFFFFFUL << (32 - (cidr % 32))) & 0xFFFFFFFFUL);
...@@ -149,42 +135,36 @@ horrible_mask_self(struct horrible_allowedips_node *node) ...@@ -149,42 +135,36 @@ horrible_mask_self(struct horrible_allowedips_node *node)
} }
static __init inline bool static __init inline bool
horrible_match_v4(const struct horrible_allowedips_node *node, horrible_match_v4(const struct horrible_allowedips_node *node, struct in_addr *ip)
struct in_addr *ip)
{ {
return (ip->s_addr & node->mask.ip) == node->ip.ip; return (ip->s_addr & node->mask.ip) == node->ip.ip;
} }
static __init inline bool static __init inline bool
horrible_match_v6(const struct horrible_allowedips_node *node, horrible_match_v6(const struct horrible_allowedips_node *node, struct in6_addr *ip)
struct in6_addr *ip)
{ {
return (ip->in6_u.u6_addr32[0] & node->mask.ip6[0]) == return (ip->in6_u.u6_addr32[0] & node->mask.ip6[0]) == node->ip.ip6[0] &&
node->ip.ip6[0] && (ip->in6_u.u6_addr32[1] & node->mask.ip6[1]) == node->ip.ip6[1] &&
(ip->in6_u.u6_addr32[1] & node->mask.ip6[1]) == (ip->in6_u.u6_addr32[2] & node->mask.ip6[2]) == node->ip.ip6[2] &&
node->ip.ip6[1] &&
(ip->in6_u.u6_addr32[2] & node->mask.ip6[2]) ==
node->ip.ip6[2] &&
(ip->in6_u.u6_addr32[3] & node->mask.ip6[3]) == node->ip.ip6[3]; (ip->in6_u.u6_addr32[3] & node->mask.ip6[3]) == node->ip.ip6[3];
} }
static __init void static __init void
horrible_insert_ordered(struct horrible_allowedips *table, horrible_insert_ordered(struct horrible_allowedips *table, struct horrible_allowedips_node *node)
struct horrible_allowedips_node *node)
{ {
struct horrible_allowedips_node *other = NULL, *where = NULL; struct horrible_allowedips_node *other = NULL, *where = NULL;
u8 my_cidr = horrible_mask_to_cidr(node->mask); u8 my_cidr = horrible_mask_to_cidr(node->mask);
hlist_for_each_entry(other, &table->head, table) { hlist_for_each_entry(other, &table->head, table) {
if (!memcmp(&other->mask, &node->mask, if (other->ip_version == node->ip_version &&
sizeof(union nf_inet_addr)) && !memcmp(&other->mask, &node->mask, sizeof(union nf_inet_addr)) &&
!memcmp(&other->ip, &node->ip, !memcmp(&other->ip, &node->ip, sizeof(union nf_inet_addr))) {
sizeof(union nf_inet_addr)) &&
other->ip_version == node->ip_version) {
other->value = node->value; other->value = node->value;
kfree(node); kfree(node);
return; return;
} }
}
hlist_for_each_entry(other, &table->head, table) {
where = other; where = other;
if (horrible_mask_to_cidr(other->mask) <= my_cidr) if (horrible_mask_to_cidr(other->mask) <= my_cidr)
break; break;
...@@ -201,8 +181,7 @@ static __init int ...@@ -201,8 +181,7 @@ static __init int
horrible_allowedips_insert_v4(struct horrible_allowedips *table, horrible_allowedips_insert_v4(struct horrible_allowedips *table,
struct in_addr *ip, u8 cidr, void *value) struct in_addr *ip, u8 cidr, void *value)
{ {
struct horrible_allowedips_node *node = kzalloc(sizeof(*node), struct horrible_allowedips_node *node = kzalloc(sizeof(*node), GFP_KERNEL);
GFP_KERNEL);
if (unlikely(!node)) if (unlikely(!node))
return -ENOMEM; return -ENOMEM;
...@@ -219,8 +198,7 @@ static __init int ...@@ -219,8 +198,7 @@ static __init int
horrible_allowedips_insert_v6(struct horrible_allowedips *table, horrible_allowedips_insert_v6(struct horrible_allowedips *table,
struct in6_addr *ip, u8 cidr, void *value) struct in6_addr *ip, u8 cidr, void *value)
{ {
struct horrible_allowedips_node *node = kzalloc(sizeof(*node), struct horrible_allowedips_node *node = kzalloc(sizeof(*node), GFP_KERNEL);
GFP_KERNEL);
if (unlikely(!node)) if (unlikely(!node))
return -ENOMEM; return -ENOMEM;
...@@ -234,39 +212,43 @@ horrible_allowedips_insert_v6(struct horrible_allowedips *table, ...@@ -234,39 +212,43 @@ horrible_allowedips_insert_v6(struct horrible_allowedips *table,
} }
static __init void * static __init void *
horrible_allowedips_lookup_v4(struct horrible_allowedips *table, horrible_allowedips_lookup_v4(struct horrible_allowedips *table, struct in_addr *ip)
struct in_addr *ip)
{ {
struct horrible_allowedips_node *node; struct horrible_allowedips_node *node;
void *ret = NULL;
hlist_for_each_entry(node, &table->head, table) { hlist_for_each_entry(node, &table->head, table) {
if (node->ip_version != 4) if (node->ip_version == 4 && horrible_match_v4(node, ip))
continue; return node->value;
if (horrible_match_v4(node, ip)) {
ret = node->value;
break;
}
} }
return ret; return NULL;
} }
static __init void * static __init void *
horrible_allowedips_lookup_v6(struct horrible_allowedips *table, horrible_allowedips_lookup_v6(struct horrible_allowedips *table, struct in6_addr *ip)
struct in6_addr *ip)
{ {
struct horrible_allowedips_node *node; struct horrible_allowedips_node *node;
void *ret = NULL;
hlist_for_each_entry(node, &table->head, table) { hlist_for_each_entry(node, &table->head, table) {
if (node->ip_version != 6) if (node->ip_version == 6 && horrible_match_v6(node, ip))
continue; return node->value;
if (horrible_match_v6(node, ip)) {
ret = node->value;
break;
} }
return NULL;
}
static __init void
horrible_allowedips_remove_by_value(struct horrible_allowedips *table, void *value)
{
struct horrible_allowedips_node *node;
struct hlist_node *h;
hlist_for_each_entry_safe(node, h, &table->head, table) {
if (node->value != value)
continue;
hlist_del(&node->table);
kfree(node);
} }
return ret;
} }
static __init bool randomized_test(void) static __init bool randomized_test(void)
...@@ -296,6 +278,7 @@ static __init bool randomized_test(void) ...@@ -296,6 +278,7 @@ static __init bool randomized_test(void)
goto free; goto free;
} }
kref_init(&peers[i]->refcount); kref_init(&peers[i]->refcount);
INIT_LIST_HEAD(&peers[i]->allowedips_list);
} }
mutex_lock(&mutex); mutex_lock(&mutex);
...@@ -333,7 +316,7 @@ static __init bool randomized_test(void) ...@@ -333,7 +316,7 @@ static __init bool randomized_test(void)
if (wg_allowedips_insert_v4(&t, if (wg_allowedips_insert_v4(&t,
(struct in_addr *)mutated, (struct in_addr *)mutated,
cidr, peer, &mutex) < 0) { cidr, peer, &mutex) < 0) {
pr_err("allowedips random malloc: FAIL\n"); pr_err("allowedips random self-test malloc: FAIL\n");
goto free_locked; goto free_locked;
} }
if (horrible_allowedips_insert_v4(&h, if (horrible_allowedips_insert_v4(&h,
...@@ -396,23 +379,33 @@ static __init bool randomized_test(void) ...@@ -396,23 +379,33 @@ static __init bool randomized_test(void)
print_tree(t.root6, 128); print_tree(t.root6, 128);
} }
for (j = 0;; ++j) {
for (i = 0; i < NUM_QUERIES; ++i) { for (i = 0; i < NUM_QUERIES; ++i) {
prandom_bytes(ip, 4); prandom_bytes(ip, 4);
if (lookup(t.root4, 32, ip) != if (lookup(t.root4, 32, ip) != horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) {
horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) { horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip);
pr_err("allowedips random self-test: FAIL\n"); pr_err("allowedips random v4 self-test: FAIL\n");
goto free; goto free;
} }
}
for (i = 0; i < NUM_QUERIES; ++i) {
prandom_bytes(ip, 16); prandom_bytes(ip, 16);
if (lookup(t.root6, 128, ip) != if (lookup(t.root6, 128, ip) != horrible_allowedips_lookup_v6(&h, (struct in6_addr *)ip)) {
horrible_allowedips_lookup_v6(&h, (struct in6_addr *)ip)) { pr_err("allowedips random v6 self-test: FAIL\n");
pr_err("allowedips random self-test: FAIL\n");
goto free; goto free;
} }
} }
if (j >= NUM_PEERS)
break;
mutex_lock(&mutex);
wg_allowedips_remove_by_peer(&t, peers[j], &mutex);
mutex_unlock(&mutex);
horrible_allowedips_remove_by_value(&h, peers[j]);
}
if (t.root4 || t.root6) {
pr_err("allowedips random self-test removal: FAIL\n");
goto free;
}
ret = true; ret = true;
free: free:
......
...@@ -430,7 +430,7 @@ void wg_socket_reinit(struct wg_device *wg, struct sock *new4, ...@@ -430,7 +430,7 @@ void wg_socket_reinit(struct wg_device *wg, struct sock *new4,
if (new4) if (new4)
wg->incoming_port = ntohs(inet_sk(new4)->inet_sport); wg->incoming_port = ntohs(inet_sk(new4)->inet_sport);
mutex_unlock(&wg->socket_update_lock); mutex_unlock(&wg->socket_update_lock);
synchronize_rcu(); synchronize_net();
sock_free(old4); sock_free(old4);
sock_free(old6); sock_free(old6);
} }
...@@ -363,6 +363,7 @@ ip1 -6 rule add table main suppress_prefixlength 0 ...@@ -363,6 +363,7 @@ ip1 -6 rule add table main suppress_prefixlength 0
ip1 -4 route add default dev wg0 table 51820 ip1 -4 route add default dev wg0 table 51820
ip1 -4 rule add not fwmark 51820 table 51820 ip1 -4 rule add not fwmark 51820 table 51820
ip1 -4 rule add table main suppress_prefixlength 0 ip1 -4 rule add table main suppress_prefixlength 0
n1 bash -c 'printf 0 > /proc/sys/net/ipv4/conf/vethc/rp_filter'
# Flood the pings instead of sending just one, to trigger routing table reference counting bugs. # Flood the pings instead of sending just one, to trigger routing table reference counting bugs.
n1 ping -W 1 -c 100 -f 192.168.99.7 n1 ping -W 1 -c 100 -f 192.168.99.7
n1 ping -W 1 -c 100 -f abab::1111 n1 ping -W 1 -c 100 -f abab::1111
......
...@@ -19,7 +19,6 @@ CONFIG_NETFILTER_XTABLES=y ...@@ -19,7 +19,6 @@ CONFIG_NETFILTER_XTABLES=y
CONFIG_NETFILTER_XT_NAT=y CONFIG_NETFILTER_XT_NAT=y
CONFIG_NETFILTER_XT_MATCH_LENGTH=y CONFIG_NETFILTER_XT_MATCH_LENGTH=y
CONFIG_NETFILTER_XT_MARK=y CONFIG_NETFILTER_XT_MARK=y
CONFIG_NF_CONNTRACK_IPV4=y
CONFIG_NF_NAT_IPV4=y CONFIG_NF_NAT_IPV4=y
CONFIG_IP_NF_IPTABLES=y CONFIG_IP_NF_IPTABLES=y
CONFIG_IP_NF_FILTER=y CONFIG_IP_NF_FILTER=y
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment