Commit d40ce48c authored by Jakub Kicinski's avatar Jakub Kicinski

Merge branch 'af_unix-replace-unix_table_lock-with-per-hash-locks'

Kuniyuki Iwashima says:

====================
af_unix: Replace unix_table_lock with per-hash locks.

The hash table of AF_UNIX sockets is protected by a single big lock,
unix_table_lock.  This series replaces it with small per-hash locks.

1st -  2nd : Misc refactoring
3rd -  8th : Separate BSD/abstract address logics
9th - 11th : Prep to save a hash in each socket
12th       : Replace the big lock
13th       : Speed up autobind()

Note to maintainers:
The 12th patch adds two kinds of Sparse warnings on patchwork:

  about unix_table_double_lock/unlock()
    We can avoid this by adding two apparent acquires/releases annotations,
    but there are the same kinds of warnings about unix_state_double_lock().

  about unix_next_socket() and unix_seq_stop() (/proc/net/unix)
    This is because Sparse does not understand logic in unix_next_socket(),
    which leaves a spin lock held until it returns NULL.
    Also, tcp_seq_stop() causes a warning for the same reason.

These warnings seem reasonable, but let me know if there is any better way.
Please see [0] for details.

[0]: https://lore.kernel.org/netdev/20211117001611.74123-1-kuniyu@amazon.co.jp/
====================

Link: https://lore.kernel.org/r/20211124021431.48956-1-kuniyu@amazon.co.jpSigned-off-by: default avatarJakub Kicinski <kuba@kernel.org>
parents 442b03c3 9acbc584
...@@ -20,13 +20,12 @@ struct sock *unix_peer_get(struct sock *sk); ...@@ -20,13 +20,12 @@ struct sock *unix_peer_get(struct sock *sk);
#define UNIX_HASH_BITS 8 #define UNIX_HASH_BITS 8
extern unsigned int unix_tot_inflight; extern unsigned int unix_tot_inflight;
extern spinlock_t unix_table_lock; extern spinlock_t unix_table_locks[2 * UNIX_HASH_SIZE];
extern struct hlist_head unix_socket_table[2 * UNIX_HASH_SIZE]; extern struct hlist_head unix_socket_table[2 * UNIX_HASH_SIZE];
struct unix_address { struct unix_address {
refcount_t refcnt; refcount_t refcnt;
int len; int len;
unsigned int hash;
struct sockaddr_un name[]; struct sockaddr_un name[];
}; };
......
...@@ -117,24 +117,64 @@ ...@@ -117,24 +117,64 @@
#include "scm.h" #include "scm.h"
spinlock_t unix_table_locks[2 * UNIX_HASH_SIZE];
EXPORT_SYMBOL_GPL(unix_table_locks);
struct hlist_head unix_socket_table[2 * UNIX_HASH_SIZE]; struct hlist_head unix_socket_table[2 * UNIX_HASH_SIZE];
EXPORT_SYMBOL_GPL(unix_socket_table); EXPORT_SYMBOL_GPL(unix_socket_table);
DEFINE_SPINLOCK(unix_table_lock);
EXPORT_SYMBOL_GPL(unix_table_lock);
static atomic_long_t unix_nr_socks; static atomic_long_t unix_nr_socks;
/* SMP locking strategy:
* hash table is protected with spinlock unix_table_locks
* each socket state is protected by separate spin lock.
*/
static struct hlist_head *unix_sockets_unbound(void *addr) static unsigned int unix_unbound_hash(struct sock *sk)
{ {
unsigned long hash = (unsigned long)addr; unsigned long hash = (unsigned long)sk;
hash ^= hash >> 16; hash ^= hash >> 16;
hash ^= hash >> 8; hash ^= hash >> 8;
hash %= UNIX_HASH_SIZE; hash ^= sk->sk_type;
return &unix_socket_table[UNIX_HASH_SIZE + hash];
return UNIX_HASH_SIZE + (hash & (UNIX_HASH_SIZE - 1));
}
static unsigned int unix_bsd_hash(struct inode *i)
{
return i->i_ino & (UNIX_HASH_SIZE - 1);
}
static unsigned int unix_abstract_hash(struct sockaddr_un *sunaddr,
int addr_len, int type)
{
__wsum csum = csum_partial(sunaddr, addr_len, 0);
unsigned int hash;
hash = (__force unsigned int)csum_fold(csum);
hash ^= hash >> 8;
hash ^= type;
return hash & (UNIX_HASH_SIZE - 1);
}
static void unix_table_double_lock(unsigned int hash1, unsigned int hash2)
{
/* hash1 and hash2 is never the same because
* one is between 0 and UNIX_HASH_SIZE - 1, and
* another is between UNIX_HASH_SIZE and UNIX_HASH_SIZE * 2.
*/
if (hash1 > hash2)
swap(hash1, hash2);
spin_lock(&unix_table_locks[hash1]);
spin_lock_nested(&unix_table_locks[hash2], SINGLE_DEPTH_NESTING);
} }
#define UNIX_ABSTRACT(sk) (unix_sk(sk)->addr->hash < UNIX_HASH_SIZE) static void unix_table_double_unlock(unsigned int hash1, unsigned int hash2)
{
spin_unlock(&unix_table_locks[hash1]);
spin_unlock(&unix_table_locks[hash2]);
}
#ifdef CONFIG_SECURITY_NETWORK #ifdef CONFIG_SECURITY_NETWORK
static void unix_get_secdata(struct scm_cookie *scm, struct sk_buff *skb) static void unix_get_secdata(struct scm_cookie *scm, struct sk_buff *skb)
...@@ -164,20 +204,6 @@ static inline bool unix_secdata_eq(struct scm_cookie *scm, struct sk_buff *skb) ...@@ -164,20 +204,6 @@ static inline bool unix_secdata_eq(struct scm_cookie *scm, struct sk_buff *skb)
} }
#endif /* CONFIG_SECURITY_NETWORK */ #endif /* CONFIG_SECURITY_NETWORK */
/*
* SMP locking strategy:
* hash table is protected with spinlock unix_table_lock
* each socket state is protected by separate spin lock.
*/
static inline unsigned int unix_hash_fold(__wsum n)
{
unsigned int hash = (__force unsigned int)csum_fold(n);
hash ^= hash>>8;
return hash&(UNIX_HASH_SIZE-1);
}
#define unix_peer(sk) (unix_sk(sk)->peer) #define unix_peer(sk) (unix_sk(sk)->peer)
static inline int unix_our_peer(struct sock *sk, struct sock *osk) static inline int unix_our_peer(struct sock *sk, struct sock *osk)
...@@ -214,6 +240,22 @@ struct sock *unix_peer_get(struct sock *s) ...@@ -214,6 +240,22 @@ struct sock *unix_peer_get(struct sock *s)
} }
EXPORT_SYMBOL_GPL(unix_peer_get); EXPORT_SYMBOL_GPL(unix_peer_get);
static struct unix_address *unix_create_addr(struct sockaddr_un *sunaddr,
int addr_len)
{
struct unix_address *addr;
addr = kmalloc(sizeof(*addr) + addr_len, GFP_KERNEL);
if (!addr)
return NULL;
refcount_set(&addr->refcnt, 1);
addr->len = addr_len;
memcpy(addr->name, sunaddr, addr_len);
return addr;
}
static inline void unix_release_addr(struct unix_address *addr) static inline void unix_release_addr(struct unix_address *addr)
{ {
if (refcount_dec_and_test(&addr->refcnt)) if (refcount_dec_and_test(&addr->refcnt))
...@@ -227,29 +269,29 @@ static inline void unix_release_addr(struct unix_address *addr) ...@@ -227,29 +269,29 @@ static inline void unix_release_addr(struct unix_address *addr)
* - if started by zero, it is abstract name. * - if started by zero, it is abstract name.
*/ */
static int unix_mkname(struct sockaddr_un *sunaddr, int len, unsigned int *hashp) static int unix_validate_addr(struct sockaddr_un *sunaddr, int addr_len)
{ {
*hashp = 0; if (addr_len <= offsetof(struct sockaddr_un, sun_path) ||
addr_len > sizeof(*sunaddr))
if (len <= sizeof(short) || len > sizeof(*sunaddr))
return -EINVAL; return -EINVAL;
if (!sunaddr || sunaddr->sun_family != AF_UNIX)
if (sunaddr->sun_family != AF_UNIX)
return -EINVAL; return -EINVAL;
if (sunaddr->sun_path[0]) {
/* return 0;
* This may look like an off by one error but it is a bit more }
static void unix_mkname_bsd(struct sockaddr_un *sunaddr, int addr_len)
{
/* This may look like an off by one error but it is a bit more
* subtle. 108 is the longest valid AF_UNIX path for a binding. * subtle. 108 is the longest valid AF_UNIX path for a binding.
* sun_path[108] doesn't as such exist. However in kernel space * sun_path[108] doesn't as such exist. However in kernel space
* we are guaranteed that it is a valid memory location in our * we are guaranteed that it is a valid memory location in our
* kernel address buffer. * kernel address buffer because syscall functions always pass
* a pointer of struct sockaddr_storage which has a bigger buffer
* than 108.
*/ */
((char *)sunaddr)[len] = 0; ((char *)sunaddr)[addr_len] = 0;
len = strlen(sunaddr->sun_path)+1+sizeof(short);
return len;
}
*hashp = unix_hash_fold(csum_partial(sunaddr, len, 0));
return len;
} }
static void __unix_remove_socket(struct sock *sk) static void __unix_remove_socket(struct sock *sk)
...@@ -257,32 +299,34 @@ static void __unix_remove_socket(struct sock *sk) ...@@ -257,32 +299,34 @@ static void __unix_remove_socket(struct sock *sk)
sk_del_node_init(sk); sk_del_node_init(sk);
} }
static void __unix_insert_socket(struct hlist_head *list, struct sock *sk) static void __unix_insert_socket(struct sock *sk)
{ {
WARN_ON(!sk_unhashed(sk)); WARN_ON(!sk_unhashed(sk));
sk_add_node(sk, list); sk_add_node(sk, &unix_socket_table[sk->sk_hash]);
} }
static void __unix_set_addr(struct sock *sk, struct unix_address *addr, static void __unix_set_addr_hash(struct sock *sk, struct unix_address *addr,
unsigned hash) unsigned int hash)
{ {
__unix_remove_socket(sk); __unix_remove_socket(sk);
smp_store_release(&unix_sk(sk)->addr, addr); smp_store_release(&unix_sk(sk)->addr, addr);
__unix_insert_socket(&unix_socket_table[hash], sk);
sk->sk_hash = hash;
__unix_insert_socket(sk);
} }
static inline void unix_remove_socket(struct sock *sk) static void unix_remove_socket(struct sock *sk)
{ {
spin_lock(&unix_table_lock); spin_lock(&unix_table_locks[sk->sk_hash]);
__unix_remove_socket(sk); __unix_remove_socket(sk);
spin_unlock(&unix_table_lock); spin_unlock(&unix_table_locks[sk->sk_hash]);
} }
static inline void unix_insert_socket(struct hlist_head *list, struct sock *sk) static void unix_insert_unbound_socket(struct sock *sk)
{ {
spin_lock(&unix_table_lock); spin_lock(&unix_table_locks[sk->sk_hash]);
__unix_insert_socket(list, sk); __unix_insert_socket(sk);
spin_unlock(&unix_table_lock); spin_unlock(&unix_table_locks[sk->sk_hash]);
} }
static struct sock *__unix_find_socket_byname(struct net *net, static struct sock *__unix_find_socket_byname(struct net *net,
...@@ -310,32 +354,31 @@ static inline struct sock *unix_find_socket_byname(struct net *net, ...@@ -310,32 +354,31 @@ static inline struct sock *unix_find_socket_byname(struct net *net,
{ {
struct sock *s; struct sock *s;
spin_lock(&unix_table_lock); spin_lock(&unix_table_locks[hash]);
s = __unix_find_socket_byname(net, sunname, len, hash); s = __unix_find_socket_byname(net, sunname, len, hash);
if (s) if (s)
sock_hold(s); sock_hold(s);
spin_unlock(&unix_table_lock); spin_unlock(&unix_table_locks[hash]);
return s; return s;
} }
static struct sock *unix_find_socket_byinode(struct inode *i) static struct sock *unix_find_socket_byinode(struct inode *i)
{ {
unsigned int hash = unix_bsd_hash(i);
struct sock *s; struct sock *s;
spin_lock(&unix_table_lock); spin_lock(&unix_table_locks[hash]);
sk_for_each(s, sk_for_each(s, &unix_socket_table[hash]) {
&unix_socket_table[i->i_ino & (UNIX_HASH_SIZE - 1)]) {
struct dentry *dentry = unix_sk(s)->path.dentry; struct dentry *dentry = unix_sk(s)->path.dentry;
if (dentry && d_backing_inode(dentry) == i) { if (dentry && d_backing_inode(dentry) == i) {
sock_hold(s); sock_hold(s);
goto found; spin_unlock(&unix_table_locks[hash]);
return s;
} }
} }
s = NULL; spin_unlock(&unix_table_locks[hash]);
found: return NULL;
spin_unlock(&unix_table_lock);
return s;
} }
/* Support code for asymmetrically connected dgram sockets /* Support code for asymmetrically connected dgram sockets
...@@ -870,6 +913,7 @@ static struct sock *unix_create1(struct net *net, struct socket *sock, int kern, ...@@ -870,6 +913,7 @@ static struct sock *unix_create1(struct net *net, struct socket *sock, int kern,
sock_init_data(sock, sk); sock_init_data(sock, sk);
sk->sk_hash = unix_unbound_hash(sk);
sk->sk_allocation = GFP_KERNEL_ACCOUNT; sk->sk_allocation = GFP_KERNEL_ACCOUNT;
sk->sk_write_space = unix_write_space; sk->sk_write_space = unix_write_space;
sk->sk_max_ack_backlog = net->unx.sysctl_max_dgram_qlen; sk->sk_max_ack_backlog = net->unx.sysctl_max_dgram_qlen;
...@@ -885,7 +929,7 @@ static struct sock *unix_create1(struct net *net, struct socket *sock, int kern, ...@@ -885,7 +929,7 @@ static struct sock *unix_create1(struct net *net, struct socket *sock, int kern,
init_waitqueue_head(&u->peer_wait); init_waitqueue_head(&u->peer_wait);
init_waitqueue_func_entry(&u->peer_wake, unix_dgram_peer_wake_relay); init_waitqueue_func_entry(&u->peer_wake, unix_dgram_peer_wake_relay);
memset(&u->scm_stat, 0, sizeof(struct scm_stat)); memset(&u->scm_stat, 0, sizeof(struct scm_stat));
unix_insert_socket(unix_sockets_unbound(sk), sk); unix_insert_unbound_socket(sk);
sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
...@@ -948,15 +992,90 @@ static int unix_release(struct socket *sock) ...@@ -948,15 +992,90 @@ static int unix_release(struct socket *sock)
return 0; return 0;
} }
static int unix_autobind(struct socket *sock) static struct sock *unix_find_bsd(struct net *net, struct sockaddr_un *sunaddr,
int addr_len, int type)
{ {
struct sock *sk = sock->sk; struct inode *inode;
struct net *net = sock_net(sk); struct path path;
struct sock *sk;
int err;
unix_mkname_bsd(sunaddr, addr_len);
err = kern_path(sunaddr->sun_path, LOOKUP_FOLLOW, &path);
if (err)
goto fail;
err = path_permission(&path, MAY_WRITE);
if (err)
goto path_put;
err = -ECONNREFUSED;
inode = d_backing_inode(path.dentry);
if (!S_ISSOCK(inode->i_mode))
goto path_put;
sk = unix_find_socket_byinode(inode);
if (!sk)
goto path_put;
err = -EPROTOTYPE;
if (sk->sk_type == type)
touch_atime(&path);
else
goto sock_put;
path_put(&path);
return sk;
sock_put:
sock_put(sk);
path_put:
path_put(&path);
fail:
return ERR_PTR(err);
}
static struct sock *unix_find_abstract(struct net *net,
struct sockaddr_un *sunaddr,
int addr_len, int type)
{
unsigned int hash = unix_abstract_hash(sunaddr, addr_len, type);
struct dentry *dentry;
struct sock *sk;
sk = unix_find_socket_byname(net, sunaddr, addr_len, hash);
if (!sk)
return ERR_PTR(-ECONNREFUSED);
dentry = unix_sk(sk)->path.dentry;
if (dentry)
touch_atime(&unix_sk(sk)->path);
return sk;
}
static struct sock *unix_find_other(struct net *net,
struct sockaddr_un *sunaddr,
int addr_len, int type)
{
struct sock *sk;
if (sunaddr->sun_path[0])
sk = unix_find_bsd(net, sunaddr, addr_len, type);
else
sk = unix_find_abstract(net, sunaddr, addr_len, type);
return sk;
}
static int unix_autobind(struct sock *sk)
{
unsigned int new_hash, old_hash = sk->sk_hash;
struct unix_sock *u = unix_sk(sk); struct unix_sock *u = unix_sk(sk);
static u32 ordernum = 1;
struct unix_address *addr; struct unix_address *addr;
u32 lastnum, ordernum;
int err; int err;
unsigned int retries = 0;
err = mutex_lock_interruptible(&u->bindlock); err = mutex_lock_interruptible(&u->bindlock);
if (err) if (err)
...@@ -966,141 +1085,103 @@ static int unix_autobind(struct socket *sock) ...@@ -966,141 +1085,103 @@ static int unix_autobind(struct socket *sock)
goto out; goto out;
err = -ENOMEM; err = -ENOMEM;
addr = kzalloc(sizeof(*addr) + sizeof(short) + 16, GFP_KERNEL); addr = kzalloc(sizeof(*addr) +
offsetof(struct sockaddr_un, sun_path) + 16, GFP_KERNEL);
if (!addr) if (!addr)
goto out; goto out;
addr->len = offsetof(struct sockaddr_un, sun_path) + 6;
addr->name->sun_family = AF_UNIX; addr->name->sun_family = AF_UNIX;
refcount_set(&addr->refcnt, 1); refcount_set(&addr->refcnt, 1);
ordernum = prandom_u32();
lastnum = ordernum & 0xFFFFF;
retry: retry:
addr->len = sprintf(addr->name->sun_path+1, "%05x", ordernum) + 1 + sizeof(short); ordernum = (ordernum + 1) & 0xFFFFF;
addr->hash = unix_hash_fold(csum_partial(addr->name, addr->len, 0)); sprintf(addr->name->sun_path + 1, "%05x", ordernum);
addr->hash ^= sk->sk_type;
spin_lock(&unix_table_lock); new_hash = unix_abstract_hash(addr->name, addr->len, sk->sk_type);
ordernum = (ordernum+1)&0xFFFFF; unix_table_double_lock(old_hash, new_hash);
if (__unix_find_socket_byname(net, addr->name, addr->len, addr->hash)) { if (__unix_find_socket_byname(sock_net(sk), addr->name, addr->len,
spin_unlock(&unix_table_lock); new_hash)) {
/* unix_table_double_unlock(old_hash, new_hash);
* __unix_find_socket_byname() may take long time if many names
/* __unix_find_socket_byname() may take long time if many names
* are already in use. * are already in use.
*/ */
cond_resched(); cond_resched();
if (ordernum == lastnum) {
/* Give up if all names seems to be in use. */ /* Give up if all names seems to be in use. */
if (retries++ == 0xFFFFF) {
err = -ENOSPC; err = -ENOSPC;
kfree(addr); unix_release_addr(addr);
goto out; goto out;
} }
goto retry; goto retry;
} }
__unix_set_addr(sk, addr, addr->hash); __unix_set_addr_hash(sk, addr, new_hash);
spin_unlock(&unix_table_lock); unix_table_double_unlock(old_hash, new_hash);
err = 0; err = 0;
out: mutex_unlock(&u->bindlock); out: mutex_unlock(&u->bindlock);
return err; return err;
} }
static struct sock *unix_find_other(struct net *net, static int unix_bind_bsd(struct sock *sk, struct sockaddr_un *sunaddr,
struct sockaddr_un *sunname, int len, int addr_len)
int type, unsigned int hash, int *error)
{ {
struct sock *u;
struct path path;
int err = 0;
if (sunname->sun_path[0]) {
struct inode *inode;
err = kern_path(sunname->sun_path, LOOKUP_FOLLOW, &path);
if (err)
goto fail;
inode = d_backing_inode(path.dentry);
err = path_permission(&path, MAY_WRITE);
if (err)
goto put_fail;
err = -ECONNREFUSED;
if (!S_ISSOCK(inode->i_mode))
goto put_fail;
u = unix_find_socket_byinode(inode);
if (!u)
goto put_fail;
if (u->sk_type == type)
touch_atime(&path);
path_put(&path);
err = -EPROTOTYPE;
if (u->sk_type != type) {
sock_put(u);
goto fail;
}
} else {
err = -ECONNREFUSED;
u = unix_find_socket_byname(net, sunname, len, type ^ hash);
if (u) {
struct dentry *dentry;
dentry = unix_sk(u)->path.dentry;
if (dentry)
touch_atime(&unix_sk(u)->path);
} else
goto fail;
}
return u;
put_fail:
path_put(&path);
fail:
*error = err;
return NULL;
}
static int unix_bind_bsd(struct sock *sk, struct unix_address *addr)
{
struct unix_sock *u = unix_sk(sk);
umode_t mode = S_IFSOCK | umode_t mode = S_IFSOCK |
(SOCK_INODE(sk->sk_socket)->i_mode & ~current_umask()); (SOCK_INODE(sk->sk_socket)->i_mode & ~current_umask());
unsigned int new_hash, old_hash = sk->sk_hash;
struct unix_sock *u = unix_sk(sk);
struct user_namespace *ns; // barf... struct user_namespace *ns; // barf...
struct path parent; struct unix_address *addr;
struct dentry *dentry; struct dentry *dentry;
unsigned int hash; struct path parent;
int err; int err;
unix_mkname_bsd(sunaddr, addr_len);
addr_len = strlen(sunaddr->sun_path) +
offsetof(struct sockaddr_un, sun_path) + 1;
addr = unix_create_addr(sunaddr, addr_len);
if (!addr)
return -ENOMEM;
/* /*
* Get the parent directory, calculate the hash for last * Get the parent directory, calculate the hash for last
* component. * component.
*/ */
dentry = kern_path_create(AT_FDCWD, addr->name->sun_path, &parent, 0); dentry = kern_path_create(AT_FDCWD, addr->name->sun_path, &parent, 0);
if (IS_ERR(dentry)) if (IS_ERR(dentry)) {
return PTR_ERR(dentry); err = PTR_ERR(dentry);
ns = mnt_user_ns(parent.mnt); goto out;
}
/* /*
* All right, let's create it. * All right, let's create it.
*/ */
ns = mnt_user_ns(parent.mnt);
err = security_path_mknod(&parent, dentry, mode, 0); err = security_path_mknod(&parent, dentry, mode, 0);
if (!err) if (!err)
err = vfs_mknod(ns, d_inode(parent.dentry), dentry, mode, 0); err = vfs_mknod(ns, d_inode(parent.dentry), dentry, mode, 0);
if (err) if (err)
goto out; goto out_path;
err = mutex_lock_interruptible(&u->bindlock); err = mutex_lock_interruptible(&u->bindlock);
if (err) if (err)
goto out_unlink; goto out_unlink;
if (u->addr) if (u->addr)
goto out_unlock; goto out_unlock;
addr->hash = UNIX_HASH_SIZE; new_hash = unix_bsd_hash(d_backing_inode(dentry));
hash = d_backing_inode(dentry)->i_ino & (UNIX_HASH_SIZE - 1); unix_table_double_lock(old_hash, new_hash);
spin_lock(&unix_table_lock);
u->path.mnt = mntget(parent.mnt); u->path.mnt = mntget(parent.mnt);
u->path.dentry = dget(dentry); u->path.dentry = dget(dentry);
__unix_set_addr(sk, addr, hash); __unix_set_addr_hash(sk, addr, new_hash);
spin_unlock(&unix_table_lock); unix_table_double_unlock(old_hash, new_hash);
mutex_unlock(&u->bindlock); mutex_unlock(&u->bindlock);
done_path_create(&parent, dentry); done_path_create(&parent, dentry);
return 0; return 0;
...@@ -1111,74 +1192,76 @@ static int unix_bind_bsd(struct sock *sk, struct unix_address *addr) ...@@ -1111,74 +1192,76 @@ static int unix_bind_bsd(struct sock *sk, struct unix_address *addr)
out_unlink: out_unlink:
/* failed after successful mknod? unlink what we'd created... */ /* failed after successful mknod? unlink what we'd created... */
vfs_unlink(ns, d_inode(parent.dentry), dentry, NULL); vfs_unlink(ns, d_inode(parent.dentry), dentry, NULL);
out: out_path:
done_path_create(&parent, dentry); done_path_create(&parent, dentry);
return err; out:
unix_release_addr(addr);
return err == -EEXIST ? -EADDRINUSE : err;
} }
static int unix_bind_abstract(struct sock *sk, struct unix_address *addr) static int unix_bind_abstract(struct sock *sk, struct sockaddr_un *sunaddr,
int addr_len)
{ {
unsigned int new_hash, old_hash = sk->sk_hash;
struct unix_sock *u = unix_sk(sk); struct unix_sock *u = unix_sk(sk);
struct unix_address *addr;
int err; int err;
addr = unix_create_addr(sunaddr, addr_len);
if (!addr)
return -ENOMEM;
err = mutex_lock_interruptible(&u->bindlock); err = mutex_lock_interruptible(&u->bindlock);
if (err) if (err)
return err; goto out;
if (u->addr) { if (u->addr) {
mutex_unlock(&u->bindlock); err = -EINVAL;
return -EINVAL; goto out_mutex;
} }
spin_lock(&unix_table_lock); new_hash = unix_abstract_hash(addr->name, addr->len, sk->sk_type);
unix_table_double_lock(old_hash, new_hash);
if (__unix_find_socket_byname(sock_net(sk), addr->name, addr->len, if (__unix_find_socket_byname(sock_net(sk), addr->name, addr->len,
addr->hash)) { new_hash))
spin_unlock(&unix_table_lock); goto out_spin;
mutex_unlock(&u->bindlock);
return -EADDRINUSE; __unix_set_addr_hash(sk, addr, new_hash);
} unix_table_double_unlock(old_hash, new_hash);
__unix_set_addr(sk, addr, addr->hash);
spin_unlock(&unix_table_lock);
mutex_unlock(&u->bindlock); mutex_unlock(&u->bindlock);
return 0; return 0;
out_spin:
unix_table_double_unlock(old_hash, new_hash);
err = -EADDRINUSE;
out_mutex:
mutex_unlock(&u->bindlock);
out:
unix_release_addr(addr);
return err;
} }
static int unix_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) static int unix_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
{ {
struct sock *sk = sock->sk;
struct sockaddr_un *sunaddr = (struct sockaddr_un *)uaddr; struct sockaddr_un *sunaddr = (struct sockaddr_un *)uaddr;
char *sun_path = sunaddr->sun_path; struct sock *sk = sock->sk;
int err; int err;
unsigned int hash;
struct unix_address *addr;
if (addr_len < offsetofend(struct sockaddr_un, sun_family) ||
sunaddr->sun_family != AF_UNIX)
return -EINVAL;
if (addr_len == sizeof(short)) if (addr_len == offsetof(struct sockaddr_un, sun_path) &&
return unix_autobind(sock); sunaddr->sun_family == AF_UNIX)
return unix_autobind(sk);
err = unix_mkname(sunaddr, addr_len, &hash); err = unix_validate_addr(sunaddr, addr_len);
if (err < 0) if (err)
return err; return err;
addr_len = err;
addr = kmalloc(sizeof(*addr)+addr_len, GFP_KERNEL);
if (!addr)
return -ENOMEM;
memcpy(addr->name, sunaddr, addr_len); if (sunaddr->sun_path[0])
addr->len = addr_len; err = unix_bind_bsd(sk, sunaddr, addr_len);
addr->hash = hash ^ sk->sk_type;
refcount_set(&addr->refcnt, 1);
if (sun_path[0])
err = unix_bind_bsd(sk, addr);
else else
err = unix_bind_abstract(sk, addr); err = unix_bind_abstract(sk, sunaddr, addr_len);
if (err)
unix_release_addr(addr); return err;
return err == -EEXIST ? -EADDRINUSE : err;
} }
static void unix_state_double_lock(struct sock *sk1, struct sock *sk2) static void unix_state_double_lock(struct sock *sk1, struct sock *sk2)
...@@ -1213,7 +1296,6 @@ static int unix_dgram_connect(struct socket *sock, struct sockaddr *addr, ...@@ -1213,7 +1296,6 @@ static int unix_dgram_connect(struct socket *sock, struct sockaddr *addr,
struct net *net = sock_net(sk); struct net *net = sock_net(sk);
struct sockaddr_un *sunaddr = (struct sockaddr_un *)addr; struct sockaddr_un *sunaddr = (struct sockaddr_un *)addr;
struct sock *other; struct sock *other;
unsigned int hash;
int err; int err;
err = -EINVAL; err = -EINVAL;
...@@ -1221,19 +1303,23 @@ static int unix_dgram_connect(struct socket *sock, struct sockaddr *addr, ...@@ -1221,19 +1303,23 @@ static int unix_dgram_connect(struct socket *sock, struct sockaddr *addr,
goto out; goto out;
if (addr->sa_family != AF_UNSPEC) { if (addr->sa_family != AF_UNSPEC) {
err = unix_mkname(sunaddr, alen, &hash); err = unix_validate_addr(sunaddr, alen);
if (err < 0) if (err)
goto out; goto out;
alen = err;
if (test_bit(SOCK_PASSCRED, &sock->flags) && if (test_bit(SOCK_PASSCRED, &sock->flags) &&
!unix_sk(sk)->addr && (err = unix_autobind(sock)) != 0) !unix_sk(sk)->addr) {
err = unix_autobind(sk);
if (err)
goto out; goto out;
}
restart: restart:
other = unix_find_other(net, sunaddr, alen, sock->type, hash, &err); other = unix_find_other(net, sunaddr, alen, sock->type);
if (!other) if (IS_ERR(other)) {
err = PTR_ERR(other);
goto out; goto out;
}
unix_state_double_lock(sk, other); unix_state_double_lock(sk, other);
...@@ -1323,19 +1409,19 @@ static int unix_stream_connect(struct socket *sock, struct sockaddr *uaddr, ...@@ -1323,19 +1409,19 @@ static int unix_stream_connect(struct socket *sock, struct sockaddr *uaddr,
struct sock *newsk = NULL; struct sock *newsk = NULL;
struct sock *other = NULL; struct sock *other = NULL;
struct sk_buff *skb = NULL; struct sk_buff *skb = NULL;
unsigned int hash;
int st; int st;
int err; int err;
long timeo; long timeo;
err = unix_mkname(sunaddr, addr_len, &hash); err = unix_validate_addr(sunaddr, addr_len);
if (err < 0) if (err)
goto out; goto out;
addr_len = err;
if (test_bit(SOCK_PASSCRED, &sock->flags) && !u->addr && if (test_bit(SOCK_PASSCRED, &sock->flags) && !u->addr) {
(err = unix_autobind(sock)) != 0) err = unix_autobind(sk);
if (err)
goto out; goto out;
}
timeo = sock_sndtimeo(sk, flags & O_NONBLOCK); timeo = sock_sndtimeo(sk, flags & O_NONBLOCK);
...@@ -1361,9 +1447,12 @@ static int unix_stream_connect(struct socket *sock, struct sockaddr *uaddr, ...@@ -1361,9 +1447,12 @@ static int unix_stream_connect(struct socket *sock, struct sockaddr *uaddr,
restart: restart:
/* Find listening sock. */ /* Find listening sock. */
other = unix_find_other(net, sunaddr, addr_len, sk->sk_type, hash, &err); other = unix_find_other(net, sunaddr, addr_len, sk->sk_type);
if (!other) if (IS_ERR(other)) {
err = PTR_ERR(other);
other = NULL;
goto out; goto out;
}
/* Latch state of peer */ /* Latch state of peer */
unix_state_lock(other); unix_state_lock(other);
...@@ -1451,9 +1540,9 @@ static int unix_stream_connect(struct socket *sock, struct sockaddr *uaddr, ...@@ -1451,9 +1540,9 @@ static int unix_stream_connect(struct socket *sock, struct sockaddr *uaddr,
* *
* The contents of *(otheru->addr) and otheru->path * The contents of *(otheru->addr) and otheru->path
* are seen fully set up here, since we have found * are seen fully set up here, since we have found
* otheru in hash under unix_table_lock. Insertion * otheru in hash under unix_table_locks. Insertion
* into the hash chain we'd found it in had been done * into the hash chain we'd found it in had been done
* in an earlier critical area protected by unix_table_lock, * in an earlier critical area protected by unix_table_locks,
* the same one where we'd set *(otheru->addr) contents, * the same one where we'd set *(otheru->addr) contents,
* as well as otheru->path and otheru->addr itself. * as well as otheru->path and otheru->addr itself.
* *
...@@ -1600,7 +1689,7 @@ static int unix_getname(struct socket *sock, struct sockaddr *uaddr, int peer) ...@@ -1600,7 +1689,7 @@ static int unix_getname(struct socket *sock, struct sockaddr *uaddr, int peer)
if (!addr) { if (!addr) {
sunaddr->sun_family = AF_UNIX; sunaddr->sun_family = AF_UNIX;
sunaddr->sun_path[0] = 0; sunaddr->sun_path[0] = 0;
err = sizeof(short); err = offsetof(struct sockaddr_un, sun_path);
} else { } else {
err = addr->len; err = addr->len;
memcpy(sunaddr, addr->name, addr->len); memcpy(sunaddr, addr->name, addr->len);
...@@ -1756,9 +1845,7 @@ static int unix_dgram_sendmsg(struct socket *sock, struct msghdr *msg, ...@@ -1756,9 +1845,7 @@ static int unix_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
struct unix_sock *u = unix_sk(sk); struct unix_sock *u = unix_sk(sk);
DECLARE_SOCKADDR(struct sockaddr_un *, sunaddr, msg->msg_name); DECLARE_SOCKADDR(struct sockaddr_un *, sunaddr, msg->msg_name);
struct sock *other = NULL; struct sock *other = NULL;
int namelen = 0; /* fake GCC */
int err; int err;
unsigned int hash;
struct sk_buff *skb; struct sk_buff *skb;
long timeo; long timeo;
struct scm_cookie scm; struct scm_cookie scm;
...@@ -1775,10 +1862,9 @@ static int unix_dgram_sendmsg(struct socket *sock, struct msghdr *msg, ...@@ -1775,10 +1862,9 @@ static int unix_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
goto out; goto out;
if (msg->msg_namelen) { if (msg->msg_namelen) {
err = unix_mkname(sunaddr, msg->msg_namelen, &hash); err = unix_validate_addr(sunaddr, msg->msg_namelen);
if (err < 0) if (err)
goto out; goto out;
namelen = err;
} else { } else {
sunaddr = NULL; sunaddr = NULL;
err = -ENOTCONN; err = -ENOTCONN;
...@@ -1787,9 +1873,11 @@ static int unix_dgram_sendmsg(struct socket *sock, struct msghdr *msg, ...@@ -1787,9 +1873,11 @@ static int unix_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
goto out; goto out;
} }
if (test_bit(SOCK_PASSCRED, &sock->flags) && !u->addr if (test_bit(SOCK_PASSCRED, &sock->flags) && !u->addr) {
&& (err = unix_autobind(sock)) != 0) err = unix_autobind(sk);
if (err)
goto out; goto out;
}
err = -EMSGSIZE; err = -EMSGSIZE;
if (len > sk->sk_sndbuf - 32) if (len > sk->sk_sndbuf - 32)
...@@ -1829,11 +1917,14 @@ static int unix_dgram_sendmsg(struct socket *sock, struct msghdr *msg, ...@@ -1829,11 +1917,14 @@ static int unix_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
if (sunaddr == NULL) if (sunaddr == NULL)
goto out_free; goto out_free;
other = unix_find_other(net, sunaddr, namelen, sk->sk_type, other = unix_find_other(net, sunaddr, msg->msg_namelen,
hash, &err); sk->sk_type);
if (other == NULL) if (IS_ERR(other)) {
err = PTR_ERR(other);
other = NULL;
goto out_free; goto out_free;
} }
}
if (sk_filter(other, skb) < 0) { if (sk_filter(other, skb) < 0) {
/* Toss the packet but do not return any error to the sender */ /* Toss the packet but do not return any error to the sender */
...@@ -3128,7 +3219,7 @@ static __poll_t unix_dgram_poll(struct file *file, struct socket *sock, ...@@ -3128,7 +3219,7 @@ static __poll_t unix_dgram_poll(struct file *file, struct socket *sock,
#define BUCKET_SPACE (BITS_PER_LONG - (UNIX_HASH_BITS + 1) - 1) #define BUCKET_SPACE (BITS_PER_LONG - (UNIX_HASH_BITS + 1) - 1)
#define get_bucket(x) ((x) >> BUCKET_SPACE) #define get_bucket(x) ((x) >> BUCKET_SPACE)
#define get_offset(x) ((x) & ((1L << BUCKET_SPACE) - 1)) #define get_offset(x) ((x) & ((1UL << BUCKET_SPACE) - 1))
#define set_bucket_offset(b, o) ((b) << BUCKET_SPACE | (o)) #define set_bucket_offset(b, o) ((b) << BUCKET_SPACE | (o))
static struct sock *unix_from_bucket(struct seq_file *seq, loff_t *pos) static struct sock *unix_from_bucket(struct seq_file *seq, loff_t *pos)
...@@ -3152,7 +3243,7 @@ static struct sock *unix_next_socket(struct seq_file *seq, ...@@ -3152,7 +3243,7 @@ static struct sock *unix_next_socket(struct seq_file *seq,
struct sock *sk, struct sock *sk,
loff_t *pos) loff_t *pos)
{ {
unsigned long bucket; unsigned long bucket = get_bucket(*pos);
while (sk > (struct sock *)SEQ_START_TOKEN) { while (sk > (struct sock *)SEQ_START_TOKEN) {
sk = sk_next(sk); sk = sk_next(sk);
...@@ -3163,12 +3254,13 @@ static struct sock *unix_next_socket(struct seq_file *seq, ...@@ -3163,12 +3254,13 @@ static struct sock *unix_next_socket(struct seq_file *seq,
} }
do { do {
spin_lock(&unix_table_locks[bucket]);
sk = unix_from_bucket(seq, pos); sk = unix_from_bucket(seq, pos);
if (sk) if (sk)
return sk; return sk;
next_bucket: next_bucket:
bucket = get_bucket(*pos) + 1; spin_unlock(&unix_table_locks[bucket++]);
*pos = set_bucket_offset(bucket, 1); *pos = set_bucket_offset(bucket, 1);
} while (bucket < ARRAY_SIZE(unix_socket_table)); } while (bucket < ARRAY_SIZE(unix_socket_table));
...@@ -3176,10 +3268,7 @@ static struct sock *unix_next_socket(struct seq_file *seq, ...@@ -3176,10 +3268,7 @@ static struct sock *unix_next_socket(struct seq_file *seq,
} }
static void *unix_seq_start(struct seq_file *seq, loff_t *pos) static void *unix_seq_start(struct seq_file *seq, loff_t *pos)
__acquires(unix_table_lock)
{ {
spin_lock(&unix_table_lock);
if (!*pos) if (!*pos)
return SEQ_START_TOKEN; return SEQ_START_TOKEN;
...@@ -3196,9 +3285,11 @@ static void *unix_seq_next(struct seq_file *seq, void *v, loff_t *pos) ...@@ -3196,9 +3285,11 @@ static void *unix_seq_next(struct seq_file *seq, void *v, loff_t *pos)
} }
static void unix_seq_stop(struct seq_file *seq, void *v) static void unix_seq_stop(struct seq_file *seq, void *v)
__releases(unix_table_lock)
{ {
spin_unlock(&unix_table_lock); struct sock *sk = v;
if (sk)
spin_unlock(&unix_table_locks[sk->sk_hash]);
} }
static int unix_seq_show(struct seq_file *seq, void *v) static int unix_seq_show(struct seq_file *seq, void *v)
...@@ -3223,15 +3314,16 @@ static int unix_seq_show(struct seq_file *seq, void *v) ...@@ -3223,15 +3314,16 @@ static int unix_seq_show(struct seq_file *seq, void *v)
(s->sk_state == TCP_ESTABLISHED ? SS_CONNECTING : SS_DISCONNECTING), (s->sk_state == TCP_ESTABLISHED ? SS_CONNECTING : SS_DISCONNECTING),
sock_i_ino(s)); sock_i_ino(s));
if (u->addr) { // under unix_table_lock here if (u->addr) { // under unix_table_locks here
int i, len; int i, len;
seq_putc(seq, ' '); seq_putc(seq, ' ');
i = 0; i = 0;
len = u->addr->len - sizeof(short); len = u->addr->len -
if (!UNIX_ABSTRACT(s)) offsetof(struct sockaddr_un, sun_path);
if (u->addr->name->sun_path[0]) {
len--; len--;
else { } else {
seq_putc(seq, '@'); seq_putc(seq, '@');
i++; i++;
} }
...@@ -3381,10 +3473,13 @@ static void __init bpf_iter_register(void) ...@@ -3381,10 +3473,13 @@ static void __init bpf_iter_register(void)
static int __init af_unix_init(void) static int __init af_unix_init(void)
{ {
int rc = -1; int i, rc = -1;
BUILD_BUG_ON(sizeof(struct unix_skb_parms) > sizeof_field(struct sk_buff, cb)); BUILD_BUG_ON(sizeof(struct unix_skb_parms) > sizeof_field(struct sk_buff, cb));
for (i = 0; i < 2 * UNIX_HASH_SIZE; i++)
spin_lock_init(&unix_table_locks[i]);
rc = proto_register(&unix_dgram_proto, 1); rc = proto_register(&unix_dgram_proto, 1);
if (rc != 0) { if (rc != 0) {
pr_crit("%s: Cannot create unix_sock SLAB cache!\n", __func__); pr_crit("%s: Cannot create unix_sock SLAB cache!\n", __func__);
......
...@@ -13,13 +13,14 @@ ...@@ -13,13 +13,14 @@
static int sk_diag_dump_name(struct sock *sk, struct sk_buff *nlskb) static int sk_diag_dump_name(struct sock *sk, struct sk_buff *nlskb)
{ {
/* might or might not have unix_table_lock */ /* might or might not have unix_table_locks */
struct unix_address *addr = smp_load_acquire(&unix_sk(sk)->addr); struct unix_address *addr = smp_load_acquire(&unix_sk(sk)->addr);
if (!addr) if (!addr)
return 0; return 0;
return nla_put(nlskb, UNIX_DIAG_NAME, addr->len - sizeof(short), return nla_put(nlskb, UNIX_DIAG_NAME,
addr->len - offsetof(struct sockaddr_un, sun_path),
addr->name->sun_path); addr->name->sun_path);
} }
...@@ -203,13 +204,13 @@ static int unix_diag_dump(struct sk_buff *skb, struct netlink_callback *cb) ...@@ -203,13 +204,13 @@ static int unix_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
s_slot = cb->args[0]; s_slot = cb->args[0];
num = s_num = cb->args[1]; num = s_num = cb->args[1];
spin_lock(&unix_table_lock);
for (slot = s_slot; for (slot = s_slot;
slot < ARRAY_SIZE(unix_socket_table); slot < ARRAY_SIZE(unix_socket_table);
s_num = 0, slot++) { s_num = 0, slot++) {
struct sock *sk; struct sock *sk;
num = 0; num = 0;
spin_lock(&unix_table_locks[slot]);
sk_for_each(sk, &unix_socket_table[slot]) { sk_for_each(sk, &unix_socket_table[slot]) {
if (!net_eq(sock_net(sk), net)) if (!net_eq(sock_net(sk), net))
continue; continue;
...@@ -220,14 +221,16 @@ static int unix_diag_dump(struct sk_buff *skb, struct netlink_callback *cb) ...@@ -220,14 +221,16 @@ static int unix_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
if (sk_diag_dump(sk, skb, req, if (sk_diag_dump(sk, skb, req,
NETLINK_CB(cb->skb).portid, NETLINK_CB(cb->skb).portid,
cb->nlh->nlmsg_seq, cb->nlh->nlmsg_seq,
NLM_F_MULTI) < 0) NLM_F_MULTI) < 0) {
spin_unlock(&unix_table_locks[slot]);
goto done; goto done;
}
next: next:
num++; num++;
} }
spin_unlock(&unix_table_locks[slot]);
} }
done: done:
spin_unlock(&unix_table_lock);
cb->args[0] = slot; cb->args[0] = slot;
cb->args[1] = num; cb->args[1] = num;
...@@ -236,21 +239,19 @@ static int unix_diag_dump(struct sk_buff *skb, struct netlink_callback *cb) ...@@ -236,21 +239,19 @@ static int unix_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
static struct sock *unix_lookup_by_ino(unsigned int ino) static struct sock *unix_lookup_by_ino(unsigned int ino)
{ {
int i;
struct sock *sk; struct sock *sk;
int i;
spin_lock(&unix_table_lock);
for (i = 0; i < ARRAY_SIZE(unix_socket_table); i++) { for (i = 0; i < ARRAY_SIZE(unix_socket_table); i++) {
spin_lock(&unix_table_locks[i]);
sk_for_each(sk, &unix_socket_table[i]) sk_for_each(sk, &unix_socket_table[i])
if (ino == sock_i_ino(sk)) { if (ino == sock_i_ino(sk)) {
sock_hold(sk); sock_hold(sk);
spin_unlock(&unix_table_lock); spin_unlock(&unix_table_locks[i]);
return sk; return sk;
} }
spin_unlock(&unix_table_locks[i]);
} }
spin_unlock(&unix_table_lock);
return NULL; return NULL;
} }
......
...@@ -49,7 +49,7 @@ int dump_unix(struct bpf_iter__unix *ctx) ...@@ -49,7 +49,7 @@ int dump_unix(struct bpf_iter__unix *ctx)
sock_i_ino(sk)); sock_i_ino(sk));
if (unix_sk->addr) { if (unix_sk->addr) {
if (!UNIX_ABSTRACT(unix_sk)) { if (unix_sk->addr->name->sun_path[0]) {
BPF_SEQ_PRINTF(seq, " %s", unix_sk->addr->name->sun_path); BPF_SEQ_PRINTF(seq, " %s", unix_sk->addr->name->sun_path);
} else { } else {
/* The name of the abstract UNIX domain socket starts /* The name of the abstract UNIX domain socket starts
......
...@@ -6,8 +6,6 @@ ...@@ -6,8 +6,6 @@
#define AF_INET6 10 #define AF_INET6 10
#define __SO_ACCEPTCON (1 << 16) #define __SO_ACCEPTCON (1 << 16)
#define UNIX_HASH_SIZE 256
#define UNIX_ABSTRACT(unix_sk) (unix_sk->addr->hash < UNIX_HASH_SIZE)
#define SOL_TCP 6 #define SOL_TCP 6
#define TCP_CONGESTION 13 #define TCP_CONGESTION 13
......
...@@ -23,7 +23,7 @@ int BPF_PROG(unix_listen, struct socket *sock, int backlog) ...@@ -23,7 +23,7 @@ int BPF_PROG(unix_listen, struct socket *sock, int backlog)
if (!unix_sk) if (!unix_sk)
return 0; return 0;
if (!UNIX_ABSTRACT(unix_sk)) if (unix_sk->addr->name->sun_path[0])
return 0; return 0;
len = unix_sk->addr->len - sizeof(short); len = unix_sk->addr->len - sizeof(short);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment