Commit 7dcade39 authored by David S. Miller's avatar David S. Miller

Merge branch 'net_get_random_once'

Hannes Frederic Sowa says:

====================
This series implements support for delaying the initialization of secret
keys, e.g. used for hashing, for as long as possible. This functionality
is implemented by a new macro, net_get_random_bytes.

I already used it to protect the socket hashes, the syncookie secret
(most important) and the tcp_fastopen secrets.

Changelog:
v2) Use static_keys in net_get_random_once to have as minimal impact to
    the fast-path as possible.
v3) added patch "static_key: WARN on usage before jump_label_init was called":
    Patch "x86/jump_label: expect default_nop if static_key gets enabled
    on boot-up" relaxes the checks for using static_key primitives before
    jump_label_init. So tighten them first.
v4) Update changelog on the patch "static_key: WARN on usage before
    jump_label_init was called"

Included patches:
 ipv4: split inet_ehashfn to hash functions per compilation unit
 ipv6: split inet6_ehashfn to hash functions per compilation unit
 static_key: WARN on usage before jump_label_init was called
 x86/jump_label: expect default_nop if static_key gets enabled on boot-up
 net: introduce new macro net_get_random_once
 inet: split syncookie keys for ipv4 and ipv6 and initialize with net_get_random_once
 inet: convert inet_ehash_secret and ipv6_hash_secret to net_get_random_once
 tcp: switch tcp_fastopen key generation to net_get_random_once
 net: switch net_secret key generation to net_get_random_once
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 53481da3 e34c9a69
...@@ -42,15 +42,27 @@ static void __jump_label_transform(struct jump_entry *entry, ...@@ -42,15 +42,27 @@ static void __jump_label_transform(struct jump_entry *entry,
int init) int init)
{ {
union jump_code_union code; union jump_code_union code;
const unsigned char default_nop[] = { STATIC_KEY_INIT_NOP };
const unsigned char *ideal_nop = ideal_nops[NOP_ATOMIC5]; const unsigned char *ideal_nop = ideal_nops[NOP_ATOMIC5];
if (type == JUMP_LABEL_ENABLE) { if (type == JUMP_LABEL_ENABLE) {
if (init) {
/*
* Jump label is enabled for the first time.
* So we expect a default_nop...
*/
if (unlikely(memcmp((void *)entry->code, default_nop, 5)
!= 0))
bug_at((void *)entry->code, __LINE__);
} else {
/* /*
* We are enabling this jump label. If it is not a nop * ...otherwise expect an ideal_nop. Otherwise
* then something must have gone wrong. * something went horribly wrong.
*/ */
if (unlikely(memcmp((void *)entry->code, ideal_nop, 5) != 0)) if (unlikely(memcmp((void *)entry->code, ideal_nop, 5)
!= 0))
bug_at((void *)entry->code, __LINE__); bug_at((void *)entry->code, __LINE__);
}
code.jump = 0xe9; code.jump = 0xe9;
code.offset = entry->target - code.offset = entry->target -
...@@ -63,7 +75,6 @@ static void __jump_label_transform(struct jump_entry *entry, ...@@ -63,7 +75,6 @@ static void __jump_label_transform(struct jump_entry *entry,
* are converting the default nop to the ideal nop. * are converting the default nop to the ideal nop.
*/ */
if (init) { if (init) {
const unsigned char default_nop[] = { STATIC_KEY_INIT_NOP };
if (unlikely(memcmp((void *)entry->code, default_nop, 5) != 0)) if (unlikely(memcmp((void *)entry->code, default_nop, 5) != 0))
bug_at((void *)entry->code, __LINE__); bug_at((void *)entry->code, __LINE__);
} else { } else {
......
...@@ -48,6 +48,13 @@ ...@@ -48,6 +48,13 @@
#include <linux/types.h> #include <linux/types.h>
#include <linux/compiler.h> #include <linux/compiler.h>
#include <linux/bug.h>
extern bool static_key_initialized;
#define STATIC_KEY_CHECK_USE() WARN(!static_key_initialized, \
"%s used before call to jump_label_init", \
__func__)
#if defined(CC_HAVE_ASM_GOTO) && defined(CONFIG_JUMP_LABEL) #if defined(CC_HAVE_ASM_GOTO) && defined(CONFIG_JUMP_LABEL)
...@@ -128,6 +135,7 @@ struct static_key { ...@@ -128,6 +135,7 @@ struct static_key {
static __always_inline void jump_label_init(void) static __always_inline void jump_label_init(void)
{ {
static_key_initialized = true;
} }
static __always_inline bool static_key_false(struct static_key *key) static __always_inline bool static_key_false(struct static_key *key)
...@@ -146,11 +154,13 @@ static __always_inline bool static_key_true(struct static_key *key) ...@@ -146,11 +154,13 @@ static __always_inline bool static_key_true(struct static_key *key)
static inline void static_key_slow_inc(struct static_key *key) static inline void static_key_slow_inc(struct static_key *key)
{ {
STATIC_KEY_CHECK_USE();
atomic_inc(&key->enabled); atomic_inc(&key->enabled);
} }
static inline void static_key_slow_dec(struct static_key *key) static inline void static_key_slow_dec(struct static_key *key)
{ {
STATIC_KEY_CHECK_USE();
atomic_dec(&key->enabled); atomic_dec(&key->enabled);
} }
......
...@@ -23,12 +23,14 @@ struct static_key_deferred { ...@@ -23,12 +23,14 @@ struct static_key_deferred {
}; };
static inline void static_key_slow_dec_deferred(struct static_key_deferred *key) static inline void static_key_slow_dec_deferred(struct static_key_deferred *key)
{ {
STATIC_KEY_CHECK_USE();
static_key_slow_dec(&key->key); static_key_slow_dec(&key->key);
} }
static inline void static inline void
jump_label_rate_limit(struct static_key_deferred *key, jump_label_rate_limit(struct static_key_deferred *key,
unsigned long rl) unsigned long rl)
{ {
STATIC_KEY_CHECK_USE();
} }
#endif /* HAVE_JUMP_LABEL */ #endif /* HAVE_JUMP_LABEL */
#endif /* _LINUX_JUMP_LABEL_RATELIMIT_H */ #endif /* _LINUX_JUMP_LABEL_RATELIMIT_H */
...@@ -239,6 +239,31 @@ do { \ ...@@ -239,6 +239,31 @@ do { \
#define net_random() prandom_u32() #define net_random() prandom_u32()
#define net_srandom(seed) prandom_seed((__force u32)(seed)) #define net_srandom(seed) prandom_seed((__force u32)(seed))
bool __net_get_random_once(void *buf, int nbytes, bool *done,
struct static_key *done_key);
#ifdef HAVE_JUMP_LABEL
#define ___NET_RANDOM_STATIC_KEY_INIT ((struct static_key) \
{ .enabled = ATOMIC_INIT(0), .entries = (void *)1 })
#else /* !HAVE_JUMP_LABEL */
#define ___NET_RANDOM_STATIC_KEY_INIT STATIC_KEY_INIT_FALSE
#endif /* HAVE_JUMP_LABEL */
/* BE CAREFUL: this function is not interrupt safe */
#define net_get_random_once(buf, nbytes) \
({ \
bool ___ret = false; \
static bool ___done = false; \
static struct static_key ___done_key = \
___NET_RANDOM_STATIC_KEY_INIT; \
if (!static_key_true(&___done_key)) \
___ret = __net_get_random_once(buf, \
nbytes, \
&___done, \
&___done_key); \
___ret; \
})
int kernel_sendmsg(struct socket *sock, struct msghdr *msg, struct kvec *vec, int kernel_sendmsg(struct socket *sock, struct msghdr *msg, struct kvec *vec,
size_t num, size_t len); size_t num, size_t len);
int kernel_recvmsg(struct socket *sock, struct msghdr *msg, struct kvec *vec, int kernel_recvmsg(struct socket *sock, struct msghdr *msg, struct kvec *vec,
......
...@@ -28,28 +28,14 @@ ...@@ -28,28 +28,14 @@
struct inet_hashinfo; struct inet_hashinfo;
static inline unsigned int inet6_ehashfn(struct net *net, static inline unsigned int __inet6_ehashfn(const u32 lhash,
const struct in6_addr *laddr, const u16 lport, const u16 lport,
const struct in6_addr *faddr, const __be16 fport) const u32 fhash,
const __be16 fport,
const u32 initval)
{ {
u32 ports = (((u32)lport) << 16) | (__force u32)fport; const u32 ports = (((u32)lport) << 16) | (__force u32)fport;
return jhash_3words(lhash, fhash, ports, initval);
return jhash_3words((__force u32)laddr->s6_addr32[3],
ipv6_addr_jhash(faddr),
ports,
inet_ehash_secret + net_hash_mix(net));
}
static inline int inet6_sk_ehashfn(const struct sock *sk)
{
const struct inet_sock *inet = inet_sk(sk);
const struct in6_addr *laddr = &sk->sk_v6_rcv_saddr;
const struct in6_addr *faddr = &sk->sk_v6_daddr;
const __u16 lport = inet->inet_num;
const __be16 fport = inet->inet_dport;
struct net *net = sock_net(sk);
return inet6_ehashfn(net, laddr, lport, faddr, fport);
} }
int __inet6_hash(struct sock *sk, struct inet_timewait_sock *twp); int __inet6_hash(struct sock *sk, struct inet_timewait_sock *twp);
......
...@@ -204,30 +204,16 @@ static inline void inet_sk_copy_descendant(struct sock *sk_to, ...@@ -204,30 +204,16 @@ static inline void inet_sk_copy_descendant(struct sock *sk_to,
int inet_sk_rebuild_header(struct sock *sk); int inet_sk_rebuild_header(struct sock *sk);
extern u32 inet_ehash_secret; static inline unsigned int __inet_ehashfn(const __be32 laddr,
extern u32 ipv6_hash_secret; const __u16 lport,
void build_ehash_secret(void); const __be32 faddr,
const __be16 fport,
static inline unsigned int inet_ehashfn(struct net *net, u32 initval)
const __be32 laddr, const __u16 lport,
const __be32 faddr, const __be16 fport)
{ {
return jhash_3words((__force __u32) laddr, return jhash_3words((__force __u32) laddr,
(__force __u32) faddr, (__force __u32) faddr,
((__u32) lport) << 16 | (__force __u32)fport, ((__u32) lport) << 16 | (__force __u32)fport,
inet_ehash_secret + net_hash_mix(net)); initval);
}
static inline int inet_sk_ehashfn(const struct sock *sk)
{
const struct inet_sock *inet = inet_sk(sk);
const __be32 laddr = inet->inet_rcv_saddr;
const __u16 lport = inet->inet_num;
const __be32 faddr = inet->inet_daddr;
const __be16 fport = inet->inet_dport;
struct net *net = sock_net(sk);
return inet_ehashfn(net, laddr, lport, faddr, fport);
} }
static inline struct request_sock *inet_reqsk_alloc(struct request_sock_ops *ops) static inline struct request_sock *inet_reqsk_alloc(struct request_sock_ops *ops)
......
...@@ -539,14 +539,14 @@ static inline u32 ipv6_addr_hash(const struct in6_addr *a) ...@@ -539,14 +539,14 @@ static inline u32 ipv6_addr_hash(const struct in6_addr *a)
} }
/* more secured version of ipv6_addr_hash() */ /* more secured version of ipv6_addr_hash() */
static inline u32 ipv6_addr_jhash(const struct in6_addr *a) static inline u32 __ipv6_addr_jhash(const struct in6_addr *a, const u32 initval)
{ {
u32 v = (__force u32)a->s6_addr32[0] ^ (__force u32)a->s6_addr32[1]; u32 v = (__force u32)a->s6_addr32[0] ^ (__force u32)a->s6_addr32[1];
return jhash_3words(v, return jhash_3words(v,
(__force u32)a->s6_addr32[2], (__force u32)a->s6_addr32[2],
(__force u32)a->s6_addr32[3], (__force u32)a->s6_addr32[3],
ipv6_hash_secret); initval);
} }
static inline bool ipv6_addr_loopback(const struct in6_addr *a) static inline bool ipv6_addr_loopback(const struct in6_addr *a)
......
...@@ -475,7 +475,6 @@ int tcp_send_rcvq(struct sock *sk, struct msghdr *msg, size_t size); ...@@ -475,7 +475,6 @@ int tcp_send_rcvq(struct sock *sk, struct msghdr *msg, size_t size);
void inet_sk_rx_dst_set(struct sock *sk, const struct sk_buff *skb); void inet_sk_rx_dst_set(struct sock *sk, const struct sk_buff *skb);
/* From syncookies.c */ /* From syncookies.c */
extern __u32 syncookie_secret[2][16-4+SHA_DIGEST_WORDS];
int __cookie_v4_check(const struct iphdr *iph, const struct tcphdr *th, int __cookie_v4_check(const struct iphdr *iph, const struct tcphdr *th,
u32 cookie); u32 cookie);
struct sock *cookie_v4_check(struct sock *sk, struct sk_buff *skb, struct sock *cookie_v4_check(struct sock *sk, struct sk_buff *skb,
...@@ -1323,7 +1322,7 @@ extern struct tcp_fastopen_context __rcu *tcp_fastopen_ctx; ...@@ -1323,7 +1322,7 @@ extern struct tcp_fastopen_context __rcu *tcp_fastopen_ctx;
int tcp_fastopen_reset_cipher(void *key, unsigned int len); int tcp_fastopen_reset_cipher(void *key, unsigned int len);
void tcp_fastopen_cookie_gen(__be32 src, __be32 dst, void tcp_fastopen_cookie_gen(__be32 src, __be32 dst,
struct tcp_fastopen_cookie *foc); struct tcp_fastopen_cookie *foc);
void tcp_fastopen_init_key_once(bool publish);
#define TCP_FASTOPEN_KEY_LENGTH 16 #define TCP_FASTOPEN_KEY_LENGTH 16
/* Fastopen key context */ /* Fastopen key context */
......
...@@ -135,6 +135,13 @@ static char *static_command_line; ...@@ -135,6 +135,13 @@ static char *static_command_line;
static char *execute_command; static char *execute_command;
static char *ramdisk_execute_command; static char *ramdisk_execute_command;
/*
* Used to generate warnings if static_key manipulation functions are used
* before jump_label_init is called.
*/
bool static_key_initialized __read_mostly = false;
EXPORT_SYMBOL_GPL(static_key_initialized);
/* /*
* If set, this is an indication to the drivers that reset the underlying * If set, this is an indication to the drivers that reset the underlying
* device before going ahead with the initialization otherwise driver might * device before going ahead with the initialization otherwise driver might
......
...@@ -58,6 +58,7 @@ static void jump_label_update(struct static_key *key, int enable); ...@@ -58,6 +58,7 @@ static void jump_label_update(struct static_key *key, int enable);
void static_key_slow_inc(struct static_key *key) void static_key_slow_inc(struct static_key *key)
{ {
STATIC_KEY_CHECK_USE();
if (atomic_inc_not_zero(&key->enabled)) if (atomic_inc_not_zero(&key->enabled))
return; return;
...@@ -103,12 +104,14 @@ static void jump_label_update_timeout(struct work_struct *work) ...@@ -103,12 +104,14 @@ static void jump_label_update_timeout(struct work_struct *work)
void static_key_slow_dec(struct static_key *key) void static_key_slow_dec(struct static_key *key)
{ {
STATIC_KEY_CHECK_USE();
__static_key_slow_dec(key, 0, NULL); __static_key_slow_dec(key, 0, NULL);
} }
EXPORT_SYMBOL_GPL(static_key_slow_dec); EXPORT_SYMBOL_GPL(static_key_slow_dec);
void static_key_slow_dec_deferred(struct static_key_deferred *key) void static_key_slow_dec_deferred(struct static_key_deferred *key)
{ {
STATIC_KEY_CHECK_USE();
__static_key_slow_dec(&key->key, key->timeout, &key->work); __static_key_slow_dec(&key->key, key->timeout, &key->work);
} }
EXPORT_SYMBOL_GPL(static_key_slow_dec_deferred); EXPORT_SYMBOL_GPL(static_key_slow_dec_deferred);
...@@ -116,6 +119,7 @@ EXPORT_SYMBOL_GPL(static_key_slow_dec_deferred); ...@@ -116,6 +119,7 @@ EXPORT_SYMBOL_GPL(static_key_slow_dec_deferred);
void jump_label_rate_limit(struct static_key_deferred *key, void jump_label_rate_limit(struct static_key_deferred *key,
unsigned long rl) unsigned long rl)
{ {
STATIC_KEY_CHECK_USE();
key->timeout = rl; key->timeout = rl;
INIT_DELAYED_WORK(&key->work, jump_label_update_timeout); INIT_DELAYED_WORK(&key->work, jump_label_update_timeout);
} }
...@@ -212,6 +216,7 @@ void __init jump_label_init(void) ...@@ -212,6 +216,7 @@ void __init jump_label_init(void)
key->next = NULL; key->next = NULL;
#endif #endif
} }
static_key_initialized = true;
jump_label_unlock(); jump_label_unlock();
} }
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <linux/hrtimer.h> #include <linux/hrtimer.h>
#include <linux/ktime.h> #include <linux/ktime.h>
#include <linux/string.h> #include <linux/string.h>
#include <linux/net.h>
#include <net/secure_seq.h> #include <net/secure_seq.h>
...@@ -16,18 +17,7 @@ static u32 net_secret[NET_SECRET_SIZE] ____cacheline_aligned; ...@@ -16,18 +17,7 @@ static u32 net_secret[NET_SECRET_SIZE] ____cacheline_aligned;
static void net_secret_init(void) static void net_secret_init(void)
{ {
u32 tmp; net_get_random_once(net_secret, sizeof(net_secret));
int i;
if (likely(net_secret[0]))
return;
for (i = NET_SECRET_SIZE; i > 0;) {
do {
get_random_bytes(&tmp, sizeof(tmp));
} while (!tmp);
cmpxchg(&net_secret[--i], 0, tmp);
}
} }
#ifdef CONFIG_INET #ifdef CONFIG_INET
......
...@@ -338,3 +338,51 @@ void inet_proto_csum_replace16(__sum16 *sum, struct sk_buff *skb, ...@@ -338,3 +338,51 @@ void inet_proto_csum_replace16(__sum16 *sum, struct sk_buff *skb,
csum_unfold(*sum))); csum_unfold(*sum)));
} }
EXPORT_SYMBOL(inet_proto_csum_replace16); EXPORT_SYMBOL(inet_proto_csum_replace16);
struct __net_random_once_work {
struct work_struct work;
struct static_key *key;
};
static void __net_random_once_deferred(struct work_struct *w)
{
struct __net_random_once_work *work =
container_of(w, struct __net_random_once_work, work);
if (!static_key_enabled(work->key))
static_key_slow_inc(work->key);
kfree(work);
}
static void __net_random_once_disable_jump(struct static_key *key)
{
struct __net_random_once_work *w;
w = kmalloc(sizeof(*w), GFP_ATOMIC);
if (!w)
return;
INIT_WORK(&w->work, __net_random_once_deferred);
w->key = key;
schedule_work(&w->work);
}
bool __net_get_random_once(void *buf, int nbytes, bool *done,
struct static_key *done_key)
{
static DEFINE_SPINLOCK(lock);
spin_lock_bh(&lock);
if (*done) {
spin_unlock_bh(&lock);
return false;
}
get_random_bytes(buf, nbytes);
*done = true;
spin_unlock_bh(&lock);
__net_random_once_disable_jump(done_key);
return true;
}
EXPORT_SYMBOL(__net_get_random_once);
...@@ -245,29 +245,6 @@ int inet_listen(struct socket *sock, int backlog) ...@@ -245,29 +245,6 @@ int inet_listen(struct socket *sock, int backlog)
} }
EXPORT_SYMBOL(inet_listen); EXPORT_SYMBOL(inet_listen);
u32 inet_ehash_secret __read_mostly;
EXPORT_SYMBOL(inet_ehash_secret);
u32 ipv6_hash_secret __read_mostly;
EXPORT_SYMBOL(ipv6_hash_secret);
/*
* inet_ehash_secret must be set exactly once, and to a non nul value
* ipv6_hash_secret must be set exactly once.
*/
void build_ehash_secret(void)
{
u32 rnd;
do {
get_random_bytes(&rnd, sizeof(rnd));
} while (rnd == 0);
if (cmpxchg(&inet_ehash_secret, 0, rnd) == 0)
get_random_bytes(&ipv6_hash_secret, sizeof(ipv6_hash_secret));
}
EXPORT_SYMBOL(build_ehash_secret);
/* /*
* Create an inet socket. * Create an inet socket.
*/ */
...@@ -284,10 +261,6 @@ static int inet_create(struct net *net, struct socket *sock, int protocol, ...@@ -284,10 +261,6 @@ static int inet_create(struct net *net, struct socket *sock, int protocol,
int try_loading_module = 0; int try_loading_module = 0;
int err; int err;
if (unlikely(!inet_ehash_secret))
if (sock->type != SOCK_RAW && sock->type != SOCK_DGRAM)
build_ehash_secret();
sock->state = SS_UNCONNECTED; sock->state = SS_UNCONNECTED;
/* Look for the requested type/protocol pair. */ /* Look for the requested type/protocol pair. */
......
...@@ -24,6 +24,31 @@ ...@@ -24,6 +24,31 @@
#include <net/secure_seq.h> #include <net/secure_seq.h>
#include <net/ip.h> #include <net/ip.h>
static unsigned int inet_ehashfn(struct net *net, const __be32 laddr,
const __u16 lport, const __be32 faddr,
const __be16 fport)
{
static u32 inet_ehash_secret __read_mostly;
net_get_random_once(&inet_ehash_secret, sizeof(inet_ehash_secret));
return __inet_ehashfn(laddr, lport, faddr, fport,
inet_ehash_secret + net_hash_mix(net));
}
static unsigned int inet_sk_ehashfn(const struct sock *sk)
{
const struct inet_sock *inet = inet_sk(sk);
const __be32 laddr = inet->inet_rcv_saddr;
const __u16 lport = inet->inet_num;
const __be32 faddr = inet->inet_daddr;
const __be16 fport = inet->inet_dport;
struct net *net = sock_net(sk);
return inet_ehashfn(net, laddr, lport, faddr, fport);
}
/* /*
* Allocate and initialize a new local port bind bucket. * Allocate and initialize a new local port bind bucket.
* The bindhash mutex for snum's hash chain must be held here. * The bindhash mutex for snum's hash chain must be held here.
......
...@@ -25,15 +25,7 @@ ...@@ -25,15 +25,7 @@
extern int sysctl_tcp_syncookies; extern int sysctl_tcp_syncookies;
__u32 syncookie_secret[2][16-4+SHA_DIGEST_WORDS]; static u32 syncookie_secret[2][16-4+SHA_DIGEST_WORDS];
EXPORT_SYMBOL(syncookie_secret);
static __init int init_syncookies(void)
{
get_random_bytes(syncookie_secret, sizeof(syncookie_secret));
return 0;
}
__initcall(init_syncookies);
#define COOKIEBITS 24 /* Upper bits store count */ #define COOKIEBITS 24 /* Upper bits store count */
#define COOKIEMASK (((__u32)1 << COOKIEBITS) - 1) #define COOKIEMASK (((__u32)1 << COOKIEBITS) - 1)
...@@ -44,8 +36,11 @@ static DEFINE_PER_CPU(__u32 [16 + 5 + SHA_WORKSPACE_WORDS], ...@@ -44,8 +36,11 @@ static DEFINE_PER_CPU(__u32 [16 + 5 + SHA_WORKSPACE_WORDS],
static u32 cookie_hash(__be32 saddr, __be32 daddr, __be16 sport, __be16 dport, static u32 cookie_hash(__be32 saddr, __be32 daddr, __be16 sport, __be16 dport,
u32 count, int c) u32 count, int c)
{ {
__u32 *tmp = __get_cpu_var(ipv4_cookie_scratch); __u32 *tmp;
net_get_random_once(syncookie_secret, sizeof(syncookie_secret));
tmp = __get_cpu_var(ipv4_cookie_scratch);
memcpy(tmp + 4, syncookie_secret[c], sizeof(syncookie_secret[c])); memcpy(tmp + 4, syncookie_secret[c], sizeof(syncookie_secret[c]));
tmp[0] = (__force u32)saddr; tmp[0] = (__force u32)saddr;
tmp[1] = (__force u32)daddr; tmp[1] = (__force u32)daddr;
......
...@@ -274,6 +274,11 @@ static int proc_tcp_fastopen_key(struct ctl_table *ctl, int write, ...@@ -274,6 +274,11 @@ static int proc_tcp_fastopen_key(struct ctl_table *ctl, int write,
ret = -EINVAL; ret = -EINVAL;
goto bad_key; goto bad_key;
} }
/* Generate a dummy secret but don't publish it. This
* is needed so we don't regenerate a new key on the
* first invocation of tcp_fastopen_cookie_gen
*/
tcp_fastopen_init_key_once(false);
tcp_fastopen_reset_cipher(user_key, TCP_FASTOPEN_KEY_LENGTH); tcp_fastopen_reset_cipher(user_key, TCP_FASTOPEN_KEY_LENGTH);
} }
......
...@@ -14,6 +14,20 @@ struct tcp_fastopen_context __rcu *tcp_fastopen_ctx; ...@@ -14,6 +14,20 @@ struct tcp_fastopen_context __rcu *tcp_fastopen_ctx;
static DEFINE_SPINLOCK(tcp_fastopen_ctx_lock); static DEFINE_SPINLOCK(tcp_fastopen_ctx_lock);
void tcp_fastopen_init_key_once(bool publish)
{
static u8 key[TCP_FASTOPEN_KEY_LENGTH];
/* tcp_fastopen_reset_cipher publishes the new context
* atomically, so we allow this race happening here.
*
* All call sites of tcp_fastopen_cookie_gen also check
* for a valid cookie, so this is an acceptable risk.
*/
if (net_get_random_once(key, sizeof(key)) && publish)
tcp_fastopen_reset_cipher(key, sizeof(key));
}
static void tcp_fastopen_ctx_free(struct rcu_head *head) static void tcp_fastopen_ctx_free(struct rcu_head *head)
{ {
struct tcp_fastopen_context *ctx = struct tcp_fastopen_context *ctx =
...@@ -70,6 +84,8 @@ void tcp_fastopen_cookie_gen(__be32 src, __be32 dst, ...@@ -70,6 +84,8 @@ void tcp_fastopen_cookie_gen(__be32 src, __be32 dst,
__be32 path[4] = { src, dst, 0, 0 }; __be32 path[4] = { src, dst, 0, 0 };
struct tcp_fastopen_context *ctx; struct tcp_fastopen_context *ctx;
tcp_fastopen_init_key_once(true);
rcu_read_lock(); rcu_read_lock();
ctx = rcu_dereference(tcp_fastopen_ctx); ctx = rcu_dereference(tcp_fastopen_ctx);
if (ctx) { if (ctx) {
...@@ -78,14 +94,3 @@ void tcp_fastopen_cookie_gen(__be32 src, __be32 dst, ...@@ -78,14 +94,3 @@ void tcp_fastopen_cookie_gen(__be32 src, __be32 dst,
} }
rcu_read_unlock(); rcu_read_unlock();
} }
static int __init tcp_fastopen_init(void)
{
__u8 key[TCP_FASTOPEN_KEY_LENGTH];
get_random_bytes(key, sizeof(key));
tcp_fastopen_reset_cipher(key, sizeof(key));
return 0;
}
late_initcall(tcp_fastopen_init);
...@@ -407,6 +407,18 @@ static inline int compute_score2(struct sock *sk, struct net *net, ...@@ -407,6 +407,18 @@ static inline int compute_score2(struct sock *sk, struct net *net,
return score; return score;
} }
static unsigned int udp_ehashfn(struct net *net, const __be32 laddr,
const __u16 lport, const __be32 faddr,
const __be16 fport)
{
static u32 udp_ehash_secret __read_mostly;
net_get_random_once(&udp_ehash_secret, sizeof(udp_ehash_secret));
return __inet_ehashfn(laddr, lport, faddr, fport,
udp_ehash_secret + net_hash_mix(net));
}
/* called with read_rcu_lock() */ /* called with read_rcu_lock() */
static struct sock *udp4_lib_lookup2(struct net *net, static struct sock *udp4_lib_lookup2(struct net *net,
...@@ -430,7 +442,7 @@ static struct sock *udp4_lib_lookup2(struct net *net, ...@@ -430,7 +442,7 @@ static struct sock *udp4_lib_lookup2(struct net *net,
badness = score; badness = score;
reuseport = sk->sk_reuseport; reuseport = sk->sk_reuseport;
if (reuseport) { if (reuseport) {
hash = inet_ehashfn(net, daddr, hnum, hash = udp_ehashfn(net, daddr, hnum,
saddr, sport); saddr, sport);
matches = 1; matches = 1;
} }
...@@ -511,7 +523,7 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr, ...@@ -511,7 +523,7 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,
badness = score; badness = score;
reuseport = sk->sk_reuseport; reuseport = sk->sk_reuseport;
if (reuseport) { if (reuseport) {
hash = inet_ehashfn(net, daddr, hnum, hash = udp_ehashfn(net, daddr, hnum,
saddr, sport); saddr, sport);
matches = 1; matches = 1;
} }
......
...@@ -110,11 +110,6 @@ static int inet6_create(struct net *net, struct socket *sock, int protocol, ...@@ -110,11 +110,6 @@ static int inet6_create(struct net *net, struct socket *sock, int protocol,
int try_loading_module = 0; int try_loading_module = 0;
int err; int err;
if (sock->type != SOCK_RAW &&
sock->type != SOCK_DGRAM &&
!inet_ehash_secret)
build_ehash_secret();
/* Look for the requested type/protocol pair. */ /* Look for the requested type/protocol pair. */
lookup_protocol: lookup_protocol:
err = -ESOCKTNOSUPPORT; err = -ESOCKTNOSUPPORT;
......
...@@ -23,6 +23,39 @@ ...@@ -23,6 +23,39 @@
#include <net/secure_seq.h> #include <net/secure_seq.h>
#include <net/ip.h> #include <net/ip.h>
static unsigned int inet6_ehashfn(struct net *net,
const struct in6_addr *laddr,
const u16 lport,
const struct in6_addr *faddr,
const __be16 fport)
{
static u32 inet6_ehash_secret __read_mostly;
static u32 ipv6_hash_secret __read_mostly;
u32 lhash, fhash;
net_get_random_once(&inet6_ehash_secret, sizeof(inet6_ehash_secret));
net_get_random_once(&ipv6_hash_secret, sizeof(ipv6_hash_secret));
lhash = (__force u32)laddr->s6_addr32[3];
fhash = __ipv6_addr_jhash(faddr, ipv6_hash_secret);
return __inet6_ehashfn(lhash, lport, fhash, fport,
inet6_ehash_secret + net_hash_mix(net));
}
static int inet6_sk_ehashfn(const struct sock *sk)
{
const struct inet_sock *inet = inet_sk(sk);
const struct in6_addr *laddr = &sk->sk_v6_rcv_saddr;
const struct in6_addr *faddr = &sk->sk_v6_daddr;
const __u16 lport = inet->inet_num;
const __be16 fport = inet->inet_dport;
struct net *net = sock_net(sk);
return inet6_ehashfn(net, laddr, lport, faddr, fport);
}
int __inet6_hash(struct sock *sk, struct inet_timewait_sock *tw) int __inet6_hash(struct sock *sk, struct inet_timewait_sock *tw)
{ {
struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo; struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo;
......
...@@ -24,6 +24,8 @@ ...@@ -24,6 +24,8 @@
#define COOKIEBITS 24 /* Upper bits store count */ #define COOKIEBITS 24 /* Upper bits store count */
#define COOKIEMASK (((__u32)1 << COOKIEBITS) - 1) #define COOKIEMASK (((__u32)1 << COOKIEBITS) - 1)
static u32 syncookie6_secret[2][16-4+SHA_DIGEST_WORDS];
/* RFC 2460, Section 8.3: /* RFC 2460, Section 8.3:
* [ipv6 tcp] MSS must be computed as the maximum packet size minus 60 [..] * [ipv6 tcp] MSS must be computed as the maximum packet size minus 60 [..]
* *
...@@ -61,14 +63,18 @@ static DEFINE_PER_CPU(__u32 [16 + 5 + SHA_WORKSPACE_WORDS], ...@@ -61,14 +63,18 @@ static DEFINE_PER_CPU(__u32 [16 + 5 + SHA_WORKSPACE_WORDS],
static u32 cookie_hash(const struct in6_addr *saddr, const struct in6_addr *daddr, static u32 cookie_hash(const struct in6_addr *saddr, const struct in6_addr *daddr,
__be16 sport, __be16 dport, u32 count, int c) __be16 sport, __be16 dport, u32 count, int c)
{ {
__u32 *tmp = __get_cpu_var(ipv6_cookie_scratch); __u32 *tmp;
net_get_random_once(syncookie6_secret, sizeof(syncookie6_secret));
tmp = __get_cpu_var(ipv6_cookie_scratch);
/* /*
* we have 320 bits of information to hash, copy in the remaining * we have 320 bits of information to hash, copy in the remaining
* 192 bits required for sha_transform, from the syncookie_secret * 192 bits required for sha_transform, from the syncookie6_secret
* and overwrite the digest with the secret * and overwrite the digest with the secret
*/ */
memcpy(tmp + 10, syncookie_secret[c], 44); memcpy(tmp + 10, syncookie6_secret[c], 44);
memcpy(tmp, saddr, 16); memcpy(tmp, saddr, 16);
memcpy(tmp + 4, daddr, 16); memcpy(tmp + 4, daddr, 16);
tmp[8] = ((__force u32)sport << 16) + (__force u32)dport; tmp[8] = ((__force u32)sport << 16) + (__force u32)dport;
......
...@@ -53,6 +53,29 @@ ...@@ -53,6 +53,29 @@
#include <trace/events/skb.h> #include <trace/events/skb.h>
#include "udp_impl.h" #include "udp_impl.h"
static unsigned int udp6_ehashfn(struct net *net,
const struct in6_addr *laddr,
const u16 lport,
const struct in6_addr *faddr,
const __be16 fport)
{
static u32 udp6_ehash_secret __read_mostly;
static u32 udp_ipv6_hash_secret __read_mostly;
u32 lhash, fhash;
net_get_random_once(&udp6_ehash_secret,
sizeof(udp6_ehash_secret));
net_get_random_once(&udp_ipv6_hash_secret,
sizeof(udp_ipv6_hash_secret));
lhash = (__force u32)laddr->s6_addr32[3];
fhash = __ipv6_addr_jhash(faddr, udp_ipv6_hash_secret);
return __inet6_ehashfn(lhash, lport, fhash, fport,
udp_ipv6_hash_secret + net_hash_mix(net));
}
int ipv6_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2) int ipv6_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2)
{ {
const struct in6_addr *sk2_rcv_saddr6 = inet6_rcv_saddr(sk2); const struct in6_addr *sk2_rcv_saddr6 = inet6_rcv_saddr(sk2);
...@@ -214,7 +237,7 @@ static struct sock *udp6_lib_lookup2(struct net *net, ...@@ -214,7 +237,7 @@ static struct sock *udp6_lib_lookup2(struct net *net,
badness = score; badness = score;
reuseport = sk->sk_reuseport; reuseport = sk->sk_reuseport;
if (reuseport) { if (reuseport) {
hash = inet6_ehashfn(net, daddr, hnum, hash = udp6_ehashfn(net, daddr, hnum,
saddr, sport); saddr, sport);
matches = 1; matches = 1;
} else if (score == SCORE2_MAX) } else if (score == SCORE2_MAX)
...@@ -295,7 +318,7 @@ struct sock *__udp6_lib_lookup(struct net *net, ...@@ -295,7 +318,7 @@ struct sock *__udp6_lib_lookup(struct net *net,
badness = score; badness = score;
reuseport = sk->sk_reuseport; reuseport = sk->sk_reuseport;
if (reuseport) { if (reuseport) {
hash = inet6_ehashfn(net, daddr, hnum, hash = udp6_ehashfn(net, daddr, hnum,
saddr, sport); saddr, sport);
matches = 1; matches = 1;
} }
......
...@@ -51,10 +51,16 @@ static struct kmem_cache *rds_conn_slab; ...@@ -51,10 +51,16 @@ static struct kmem_cache *rds_conn_slab;
static struct hlist_head *rds_conn_bucket(__be32 laddr, __be32 faddr) static struct hlist_head *rds_conn_bucket(__be32 laddr, __be32 faddr)
{ {
static u32 rds_hash_secret __read_mostly;
unsigned long hash;
net_get_random_once(&rds_hash_secret, sizeof(rds_hash_secret));
/* Pass NULL, don't need struct net for hash */ /* Pass NULL, don't need struct net for hash */
unsigned long hash = inet_ehashfn(NULL, hash = __inet_ehashfn(be32_to_cpu(laddr), 0,
be32_to_cpu(laddr), 0, be32_to_cpu(faddr), 0,
be32_to_cpu(faddr), 0); rds_hash_secret);
return &rds_conn_hash[hash & RDS_CONNECTION_HASH_MASK]; return &rds_conn_hash[hash & RDS_CONNECTION_HASH_MASK];
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment