Commit 32c1da70 authored by Stephen Hemminger's avatar Stephen Hemminger Committed by David S. Miller

[UDP]: Randomize port selection.

This patch causes UDP port allocation to be randomized like TCP.
The earlier code would always choose same port (ie first empty list).
Signed-off-by: default avatarStephen Hemminger <shemminger@linux-foundation.org>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 356f89e1
...@@ -113,9 +113,8 @@ DEFINE_SNMP_STAT(struct udp_mib, udp_statistics) __read_mostly; ...@@ -113,9 +113,8 @@ DEFINE_SNMP_STAT(struct udp_mib, udp_statistics) __read_mostly;
struct hlist_head udp_hash[UDP_HTABLE_SIZE]; struct hlist_head udp_hash[UDP_HTABLE_SIZE];
DEFINE_RWLOCK(udp_hash_lock); DEFINE_RWLOCK(udp_hash_lock);
static int udp_port_rover; static inline int __udp_lib_lport_inuse(__u16 num,
const struct hlist_head udptable[])
static inline int __udp_lib_lport_inuse(__u16 num, struct hlist_head udptable[])
{ {
struct sock *sk; struct sock *sk;
struct hlist_node *node; struct hlist_node *node;
...@@ -132,11 +131,10 @@ static inline int __udp_lib_lport_inuse(__u16 num, struct hlist_head udptable[]) ...@@ -132,11 +131,10 @@ static inline int __udp_lib_lport_inuse(__u16 num, struct hlist_head udptable[])
* @sk: socket struct in question * @sk: socket struct in question
* @snum: port number to look up * @snum: port number to look up
* @udptable: hash list table, must be of UDP_HTABLE_SIZE * @udptable: hash list table, must be of UDP_HTABLE_SIZE
* @port_rover: pointer to record of last unallocated port
* @saddr_comp: AF-dependent comparison of bound local IP addresses * @saddr_comp: AF-dependent comparison of bound local IP addresses
*/ */
int __udp_lib_get_port(struct sock *sk, unsigned short snum, int __udp_lib_get_port(struct sock *sk, unsigned short snum,
struct hlist_head udptable[], int *port_rover, struct hlist_head udptable[],
int (*saddr_comp)(const struct sock *sk1, int (*saddr_comp)(const struct sock *sk1,
const struct sock *sk2 ) ) const struct sock *sk2 ) )
{ {
...@@ -146,49 +144,56 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum, ...@@ -146,49 +144,56 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum,
int error = 1; int error = 1;
write_lock_bh(&udp_hash_lock); write_lock_bh(&udp_hash_lock);
if (snum == 0) {
int best_size_so_far, best, result, i; if (!snum) {
int i;
if (*port_rover > sysctl_local_port_range[1] || int low = sysctl_local_port_range[0];
*port_rover < sysctl_local_port_range[0]) int high = sysctl_local_port_range[1];
*port_rover = sysctl_local_port_range[0]; unsigned rover, best, best_size_so_far;
best_size_so_far = 32767;
best = result = *port_rover; best_size_so_far = UINT_MAX;
for (i = 0; i < UDP_HTABLE_SIZE; i++, result++) { best = rover = net_random() % (high - low) + low;
int size;
/* 1st pass: look for empty (or shortest) hash chain */
head = &udptable[result & (UDP_HTABLE_SIZE - 1)]; for (i = 0; i < UDP_HTABLE_SIZE; i++) {
if (hlist_empty(head)) { int size = 0;
if (result > sysctl_local_port_range[1])
result = sysctl_local_port_range[0] + head = &udptable[rover & (UDP_HTABLE_SIZE - 1)];
((result - sysctl_local_port_range[0]) & if (hlist_empty(head))
(UDP_HTABLE_SIZE - 1));
goto gotit; goto gotit;
}
size = 0;
sk_for_each(sk2, node, head) { sk_for_each(sk2, node, head) {
if (++size >= best_size_so_far) if (++size >= best_size_so_far)
goto next; goto next;
} }
best_size_so_far = size; best_size_so_far = size;
best = result; best = rover;
next: next:
; /* fold back if end of range */
if (++rover > high)
rover = low + ((rover - low)
& (UDP_HTABLE_SIZE - 1));
} }
result = best;
for (i = 0; i < (1 << 16) / UDP_HTABLE_SIZE; /* 2nd pass: find hole in shortest hash chain */
i++, result += UDP_HTABLE_SIZE) { rover = best;
if (result > sysctl_local_port_range[1]) for (i = 0; i < (1 << 16) / UDP_HTABLE_SIZE; i++) {
result = sysctl_local_port_range[0] if (! __udp_lib_lport_inuse(rover, udptable))
+ ((result - sysctl_local_port_range[0]) & goto gotit;
(UDP_HTABLE_SIZE - 1)); rover += UDP_HTABLE_SIZE;
if (! __udp_lib_lport_inuse(result, udptable)) if (rover > high)
break; rover = low + ((rover - low)
& (UDP_HTABLE_SIZE - 1));
} }
if (i >= (1 << 16) / UDP_HTABLE_SIZE)
goto fail;
/* All ports in use! */
goto fail;
gotit: gotit:
*port_rover = snum = result; snum = rover;
} else { } else {
head = &udptable[snum & (UDP_HTABLE_SIZE - 1)]; head = &udptable[snum & (UDP_HTABLE_SIZE - 1)];
...@@ -201,6 +206,7 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum, ...@@ -201,6 +206,7 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum,
(*saddr_comp)(sk, sk2) ) (*saddr_comp)(sk, sk2) )
goto fail; goto fail;
} }
inet_sk(sk)->num = snum; inet_sk(sk)->num = snum;
sk->sk_hash = snum; sk->sk_hash = snum;
if (sk_unhashed(sk)) { if (sk_unhashed(sk)) {
...@@ -217,7 +223,7 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum, ...@@ -217,7 +223,7 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum,
int udp_get_port(struct sock *sk, unsigned short snum, int udp_get_port(struct sock *sk, unsigned short snum,
int (*scmp)(const struct sock *, const struct sock *)) int (*scmp)(const struct sock *, const struct sock *))
{ {
return __udp_lib_get_port(sk, snum, udp_hash, &udp_port_rover, scmp); return __udp_lib_get_port(sk, snum, udp_hash, scmp);
} }
int ipv4_rcv_saddr_equal(const struct sock *sk1, const struct sock *sk2) int ipv4_rcv_saddr_equal(const struct sock *sk1, const struct sock *sk2)
......
...@@ -9,7 +9,7 @@ extern int __udp4_lib_rcv(struct sk_buff *, struct hlist_head [], int ); ...@@ -9,7 +9,7 @@ extern int __udp4_lib_rcv(struct sk_buff *, struct hlist_head [], int );
extern void __udp4_lib_err(struct sk_buff *, u32, struct hlist_head []); extern void __udp4_lib_err(struct sk_buff *, u32, struct hlist_head []);
extern int __udp_lib_get_port(struct sock *sk, unsigned short snum, extern int __udp_lib_get_port(struct sock *sk, unsigned short snum,
struct hlist_head udptable[], int *port_rover, struct hlist_head udptable[],
int (*)(const struct sock*,const struct sock*)); int (*)(const struct sock*,const struct sock*));
extern int ipv4_rcv_saddr_equal(const struct sock *, const struct sock *); extern int ipv4_rcv_saddr_equal(const struct sock *, const struct sock *);
......
...@@ -16,12 +16,11 @@ ...@@ -16,12 +16,11 @@
DEFINE_SNMP_STAT(struct udp_mib, udplite_statistics) __read_mostly; DEFINE_SNMP_STAT(struct udp_mib, udplite_statistics) __read_mostly;
struct hlist_head udplite_hash[UDP_HTABLE_SIZE]; struct hlist_head udplite_hash[UDP_HTABLE_SIZE];
static int udplite_port_rover;
int udplite_get_port(struct sock *sk, unsigned short p, int udplite_get_port(struct sock *sk, unsigned short p,
int (*c)(const struct sock *, const struct sock *)) int (*c)(const struct sock *, const struct sock *))
{ {
return __udp_lib_get_port(sk, p, udplite_hash, &udplite_port_rover, c); return __udp_lib_get_port(sk, p, udplite_hash, c);
} }
static int udplite_v4_get_port(struct sock *sk, unsigned short snum) static int udplite_v4_get_port(struct sock *sk, unsigned short snum)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment