Commit 7e7c714a authored by Paolo Abeni's avatar Paolo Abeni

Merge branch 'af_unix-remove-spin_lock_nested-and-convert-to-lock_cmp_fn'

Kuniyuki Iwashima says:

====================
af_unix: Remove spin_lock_nested() and convert to lock_cmp_fn.

This series removes spin_lock_nested() in AF_UNIX and instead
defines the locking orders as functions tied to each lock by
lockdep_set_lock_cmp_fn().

When the defined function returns a negative value, lockdep
considers it will not cause deadlock.  (See ->cmp_fn() in
check_deadlock() and check_prev_add().)

When we cannot define the total ordering, we return -1 for
the allowed ordering and otherwise 0 as undefined. [0]

[0]: https://lore.kernel.org/netdev/thzkgbuwuo3knevpipu4rzsh5qgmwhklihypdgziiruabvh46f@uwdkpcfxgloo/

Changes:
  v4:
    * Patch 4
      * Make unix_state_lock_cmp_fn() symmetric.

  v3: https://lore.kernel.org/netdev/20240614200715.93150-1-kuniyu@amazon.com/
    * Patch 3
      * Cache sk->sk_state
      * s/unix_state_lock()/unix_state_unlock()/
    * Patch 8
      * Add embryo -> listener locking order

  v2: https://lore.kernel.org/netdev/20240611222905.34695-1-kuniyu@amazon.com/
   * Patch 1 & 2
      * Use (((l) > (r)) - ((l) < (r))) for comparison

  v1: https://lore.kernel.org/netdev/20240610223501.73191-1-kuniyu@amazon.com/
====================

Link: https://lore.kernel.org/r/20240620205623.60139-1-kuniyu@amazon.comSigned-off-by: default avatarPaolo Abeni <pabeni@redhat.com>
parents bf2468f9 22e5751b
...@@ -96,20 +96,6 @@ struct unix_sock { ...@@ -96,20 +96,6 @@ struct unix_sock {
#define unix_state_lock(s) spin_lock(&unix_sk(s)->lock) #define unix_state_lock(s) spin_lock(&unix_sk(s)->lock)
#define unix_state_unlock(s) spin_unlock(&unix_sk(s)->lock) #define unix_state_unlock(s) spin_unlock(&unix_sk(s)->lock)
enum unix_socket_lock_class {
U_LOCK_NORMAL,
U_LOCK_SECOND, /* for double locking, see unix_state_double_lock(). */
U_LOCK_DIAG, /* used while dumping icons, see sk_diag_dump_icons(). */
U_LOCK_GC_LISTENER, /* used for listening socket while determining gc
* candidates to close a small race window.
*/
};
static inline void unix_state_lock_nested(struct sock *sk,
enum unix_socket_lock_class subclass)
{
spin_lock_nested(&unix_sk(sk)->lock, subclass);
}
#define peer_wait peer_wq.wait #define peer_wait peer_wq.wait
......
...@@ -126,6 +126,81 @@ static spinlock_t bsd_socket_locks[UNIX_HASH_SIZE / 2]; ...@@ -126,6 +126,81 @@ static spinlock_t bsd_socket_locks[UNIX_HASH_SIZE / 2];
* hash table is protected with spinlock. * hash table is protected with spinlock.
* each socket state is protected by separate spinlock. * each socket state is protected by separate spinlock.
*/ */
#ifdef CONFIG_PROVE_LOCKING
#define cmp_ptr(l, r) (((l) > (r)) - ((l) < (r)))
static int unix_table_lock_cmp_fn(const struct lockdep_map *a,
const struct lockdep_map *b)
{
return cmp_ptr(a, b);
}
static int unix_state_lock_cmp_fn(const struct lockdep_map *_a,
const struct lockdep_map *_b)
{
const struct unix_sock *a, *b;
a = container_of(_a, struct unix_sock, lock.dep_map);
b = container_of(_b, struct unix_sock, lock.dep_map);
if (a->sk.sk_state == TCP_LISTEN) {
/* unix_stream_connect(): Before the 2nd unix_state_lock(),
*
* 1. a is TCP_LISTEN.
* 2. b is not a.
* 3. concurrent connect(b -> a) must fail.
*
* Except for 2. & 3., the b's state can be any possible
* value due to concurrent connect() or listen().
*
* 2. is detected in debug_spin_lock_before(), and 3. cannot
* be expressed as lock_cmp_fn.
*/
switch (b->sk.sk_state) {
case TCP_CLOSE:
case TCP_ESTABLISHED:
case TCP_LISTEN:
return -1;
default:
/* Invalid case. */
return 0;
}
}
/* Should never happen. Just to be symmetric. */
if (b->sk.sk_state == TCP_LISTEN) {
switch (b->sk.sk_state) {
case TCP_CLOSE:
case TCP_ESTABLISHED:
return 1;
default:
return 0;
}
}
/* unix_state_double_lock(): ascending address order. */
return cmp_ptr(a, b);
}
static int unix_recvq_lock_cmp_fn(const struct lockdep_map *_a,
const struct lockdep_map *_b)
{
const struct sock *a, *b;
a = container_of(_a, struct sock, sk_receive_queue.lock.dep_map);
b = container_of(_b, struct sock, sk_receive_queue.lock.dep_map);
/* unix_collect_skb(): listener -> embryo order. */
if (a->sk_state == TCP_LISTEN && unix_sk(b)->listener == a)
return -1;
/* Should never happen. Just to be symmetric. */
if (b->sk_state == TCP_LISTEN && unix_sk(a)->listener == b)
return 1;
return 0;
}
#endif
static unsigned int unix_unbound_hash(struct sock *sk) static unsigned int unix_unbound_hash(struct sock *sk)
{ {
...@@ -168,7 +243,7 @@ static void unix_table_double_lock(struct net *net, ...@@ -168,7 +243,7 @@ static void unix_table_double_lock(struct net *net,
swap(hash1, hash2); swap(hash1, hash2);
spin_lock(&net->unx.table.locks[hash1]); spin_lock(&net->unx.table.locks[hash1]);
spin_lock_nested(&net->unx.table.locks[hash2], SINGLE_DEPTH_NESTING); spin_lock(&net->unx.table.locks[hash2]);
} }
static void unix_table_double_unlock(struct net *net, static void unix_table_double_unlock(struct net *net,
...@@ -675,6 +750,12 @@ static void unix_release_sock(struct sock *sk, int embrion) ...@@ -675,6 +750,12 @@ static void unix_release_sock(struct sock *sk, int embrion)
} }
static void init_peercred(struct sock *sk) static void init_peercred(struct sock *sk)
{
sk->sk_peer_pid = get_pid(task_tgid(current));
sk->sk_peer_cred = get_current_cred();
}
static void update_peercred(struct sock *sk)
{ {
const struct cred *old_cred; const struct cred *old_cred;
struct pid *old_pid; struct pid *old_pid;
...@@ -682,8 +763,7 @@ static void init_peercred(struct sock *sk) ...@@ -682,8 +763,7 @@ static void init_peercred(struct sock *sk)
spin_lock(&sk->sk_peer_lock); spin_lock(&sk->sk_peer_lock);
old_pid = sk->sk_peer_pid; old_pid = sk->sk_peer_pid;
old_cred = sk->sk_peer_cred; old_cred = sk->sk_peer_cred;
sk->sk_peer_pid = get_pid(task_tgid(current)); init_peercred(sk);
sk->sk_peer_cred = get_current_cred();
spin_unlock(&sk->sk_peer_lock); spin_unlock(&sk->sk_peer_lock);
put_pid(old_pid); put_pid(old_pid);
...@@ -692,26 +772,12 @@ static void init_peercred(struct sock *sk) ...@@ -692,26 +772,12 @@ static void init_peercred(struct sock *sk)
static void copy_peercred(struct sock *sk, struct sock *peersk) static void copy_peercred(struct sock *sk, struct sock *peersk)
{ {
const struct cred *old_cred; lockdep_assert_held(&unix_sk(peersk)->lock);
struct pid *old_pid;
if (sk < peersk) { spin_lock(&sk->sk_peer_lock);
spin_lock(&sk->sk_peer_lock); sk->sk_peer_pid = get_pid(peersk->sk_peer_pid);
spin_lock_nested(&peersk->sk_peer_lock, SINGLE_DEPTH_NESTING);
} else {
spin_lock(&peersk->sk_peer_lock);
spin_lock_nested(&sk->sk_peer_lock, SINGLE_DEPTH_NESTING);
}
old_pid = sk->sk_peer_pid;
old_cred = sk->sk_peer_cred;
sk->sk_peer_pid = get_pid(peersk->sk_peer_pid);
sk->sk_peer_cred = get_cred(peersk->sk_peer_cred); sk->sk_peer_cred = get_cred(peersk->sk_peer_cred);
spin_unlock(&sk->sk_peer_lock); spin_unlock(&sk->sk_peer_lock);
spin_unlock(&peersk->sk_peer_lock);
put_pid(old_pid);
put_cred(old_cred);
} }
static int unix_listen(struct socket *sock, int backlog) static int unix_listen(struct socket *sock, int backlog)
...@@ -735,7 +801,7 @@ static int unix_listen(struct socket *sock, int backlog) ...@@ -735,7 +801,7 @@ static int unix_listen(struct socket *sock, int backlog)
WRITE_ONCE(sk->sk_state, TCP_LISTEN); WRITE_ONCE(sk->sk_state, TCP_LISTEN);
/* set credentials so connect can copy them */ /* set credentials so connect can copy them */
init_peercred(sk); update_peercred(sk);
err = 0; err = 0;
out_unlock: out_unlock:
...@@ -972,12 +1038,15 @@ static struct sock *unix_create1(struct net *net, struct socket *sock, int kern, ...@@ -972,12 +1038,15 @@ static struct sock *unix_create1(struct net *net, struct socket *sock, int kern,
sk->sk_write_space = unix_write_space; sk->sk_write_space = unix_write_space;
sk->sk_max_ack_backlog = READ_ONCE(net->unx.sysctl_max_dgram_qlen); sk->sk_max_ack_backlog = READ_ONCE(net->unx.sysctl_max_dgram_qlen);
sk->sk_destruct = unix_sock_destructor; sk->sk_destruct = unix_sock_destructor;
lock_set_cmp_fn(&sk->sk_receive_queue.lock, unix_recvq_lock_cmp_fn, NULL);
u = unix_sk(sk); u = unix_sk(sk);
u->listener = NULL; u->listener = NULL;
u->vertex = NULL; u->vertex = NULL;
u->path.dentry = NULL; u->path.dentry = NULL;
u->path.mnt = NULL; u->path.mnt = NULL;
spin_lock_init(&u->lock); spin_lock_init(&u->lock);
lock_set_cmp_fn(&u->lock, unix_state_lock_cmp_fn, NULL);
mutex_init(&u->iolock); /* single task reading lock */ mutex_init(&u->iolock); /* single task reading lock */
mutex_init(&u->bindlock); /* single task binding lock */ mutex_init(&u->bindlock); /* single task binding lock */
init_waitqueue_head(&u->peer_wait); init_waitqueue_head(&u->peer_wait);
...@@ -1326,11 +1395,12 @@ static void unix_state_double_lock(struct sock *sk1, struct sock *sk2) ...@@ -1326,11 +1395,12 @@ static void unix_state_double_lock(struct sock *sk1, struct sock *sk2)
unix_state_lock(sk1); unix_state_lock(sk1);
return; return;
} }
if (sk1 > sk2) if (sk1 > sk2)
swap(sk1, sk2); swap(sk1, sk2);
unix_state_lock(sk1); unix_state_lock(sk1);
unix_state_lock_nested(sk2, U_LOCK_SECOND); unix_state_lock(sk2);
} }
static void unix_state_double_unlock(struct sock *sk1, struct sock *sk2) static void unix_state_double_unlock(struct sock *sk1, struct sock *sk2)
...@@ -1473,6 +1543,7 @@ static int unix_stream_connect(struct socket *sock, struct sockaddr *uaddr, ...@@ -1473,6 +1543,7 @@ static int unix_stream_connect(struct socket *sock, struct sockaddr *uaddr,
struct unix_sock *u = unix_sk(sk), *newu, *otheru; struct unix_sock *u = unix_sk(sk), *newu, *otheru;
struct net *net = sock_net(sk); struct net *net = sock_net(sk);
struct sk_buff *skb = NULL; struct sk_buff *skb = NULL;
unsigned char state;
long timeo; long timeo;
int err; int err;
...@@ -1523,7 +1594,6 @@ static int unix_stream_connect(struct socket *sock, struct sockaddr *uaddr, ...@@ -1523,7 +1594,6 @@ static int unix_stream_connect(struct socket *sock, struct sockaddr *uaddr,
goto out; goto out;
} }
/* Latch state of peer */
unix_state_lock(other); unix_state_lock(other);
/* Apparently VFS overslept socket death. Retry. */ /* Apparently VFS overslept socket death. Retry. */
...@@ -1553,37 +1623,21 @@ static int unix_stream_connect(struct socket *sock, struct sockaddr *uaddr, ...@@ -1553,37 +1623,21 @@ static int unix_stream_connect(struct socket *sock, struct sockaddr *uaddr,
goto restart; goto restart;
} }
/* Latch our state. /* self connect and simultaneous connect are eliminated
* by rejecting TCP_LISTEN socket to avoid deadlock.
It is tricky place. We need to grab our state lock and cannot
drop lock on peer. It is dangerous because deadlock is
possible. Connect to self case and simultaneous
attempt to connect are eliminated by checking socket
state. other is TCP_LISTEN, if sk is TCP_LISTEN we
check this before attempt to grab lock.
Well, and we have to recheck the state after socket locked.
*/ */
switch (READ_ONCE(sk->sk_state)) { state = READ_ONCE(sk->sk_state);
case TCP_CLOSE: if (unlikely(state != TCP_CLOSE)) {
/* This is ok... continue with connect */ err = state == TCP_ESTABLISHED ? -EISCONN : -EINVAL;
break;
case TCP_ESTABLISHED:
/* Socket is already connected */
err = -EISCONN;
goto out_unlock;
default:
err = -EINVAL;
goto out_unlock; goto out_unlock;
} }
unix_state_lock_nested(sk, U_LOCK_SECOND); unix_state_lock(sk);
if (sk->sk_state != TCP_CLOSE) { if (unlikely(sk->sk_state != TCP_CLOSE)) {
err = sk->sk_state == TCP_ESTABLISHED ? -EISCONN : -EINVAL;
unix_state_unlock(sk); unix_state_unlock(sk);
unix_state_unlock(other); goto out_unlock;
sock_put(other);
goto restart;
} }
err = security_unix_stream_connect(sk, other, newsk); err = security_unix_stream_connect(sk, other, newsk);
...@@ -3578,6 +3632,7 @@ static int __net_init unix_net_init(struct net *net) ...@@ -3578,6 +3632,7 @@ static int __net_init unix_net_init(struct net *net)
for (i = 0; i < UNIX_HASH_SIZE; i++) { for (i = 0; i < UNIX_HASH_SIZE; i++) {
spin_lock_init(&net->unx.table.locks[i]); spin_lock_init(&net->unx.table.locks[i]);
lock_set_cmp_fn(&net->unx.table.locks[i], unix_table_lock_cmp_fn, NULL);
INIT_HLIST_HEAD(&net->unx.table.buckets[i]); INIT_HLIST_HEAD(&net->unx.table.buckets[i]);
} }
......
...@@ -47,9 +47,7 @@ static int sk_diag_dump_peer(struct sock *sk, struct sk_buff *nlskb) ...@@ -47,9 +47,7 @@ static int sk_diag_dump_peer(struct sock *sk, struct sk_buff *nlskb)
peer = unix_peer_get(sk); peer = unix_peer_get(sk);
if (peer) { if (peer) {
unix_state_lock(peer);
ino = sock_i_ino(peer); ino = sock_i_ino(peer);
unix_state_unlock(peer);
sock_put(peer); sock_put(peer);
return nla_put_u32(nlskb, UNIX_DIAG_PEER, ino); return nla_put_u32(nlskb, UNIX_DIAG_PEER, ino);
...@@ -75,20 +73,9 @@ static int sk_diag_dump_icons(struct sock *sk, struct sk_buff *nlskb) ...@@ -75,20 +73,9 @@ static int sk_diag_dump_icons(struct sock *sk, struct sk_buff *nlskb)
buf = nla_data(attr); buf = nla_data(attr);
i = 0; i = 0;
skb_queue_walk(&sk->sk_receive_queue, skb) { skb_queue_walk(&sk->sk_receive_queue, skb)
struct sock *req, *peer; buf[i++] = sock_i_ino(unix_peer(skb->sk));
req = skb->sk;
/*
* The state lock is outer for the same sk's
* queue lock. With the other's queue locked it's
* OK to lock the state.
*/
unix_state_lock_nested(req, U_LOCK_DIAG);
peer = unix_sk(req)->peer;
buf[i++] = (peer ? sock_i_ino(peer) : 0);
unix_state_unlock(req);
}
spin_unlock(&sk->sk_receive_queue.lock); spin_unlock(&sk->sk_receive_queue.lock);
} }
...@@ -180,22 +167,6 @@ static int sk_diag_fill(struct sock *sk, struct sk_buff *skb, struct unix_diag_r ...@@ -180,22 +167,6 @@ static int sk_diag_fill(struct sock *sk, struct sk_buff *skb, struct unix_diag_r
return -EMSGSIZE; return -EMSGSIZE;
} }
static int sk_diag_dump(struct sock *sk, struct sk_buff *skb, struct unix_diag_req *req,
struct user_namespace *user_ns,
u32 portid, u32 seq, u32 flags)
{
int sk_ino;
unix_state_lock(sk);
sk_ino = sock_i_ino(sk);
unix_state_unlock(sk);
if (!sk_ino)
return 0;
return sk_diag_fill(sk, skb, req, user_ns, portid, seq, flags, sk_ino);
}
static int unix_diag_dump(struct sk_buff *skb, struct netlink_callback *cb) static int unix_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
{ {
struct net *net = sock_net(skb->sk); struct net *net = sock_net(skb->sk);
...@@ -213,14 +184,22 @@ static int unix_diag_dump(struct sk_buff *skb, struct netlink_callback *cb) ...@@ -213,14 +184,22 @@ static int unix_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
num = 0; num = 0;
spin_lock(&net->unx.table.locks[slot]); spin_lock(&net->unx.table.locks[slot]);
sk_for_each(sk, &net->unx.table.buckets[slot]) { sk_for_each(sk, &net->unx.table.buckets[slot]) {
int sk_ino;
if (num < s_num) if (num < s_num)
goto next; goto next;
if (!(req->udiag_states & (1 << READ_ONCE(sk->sk_state)))) if (!(req->udiag_states & (1 << READ_ONCE(sk->sk_state))))
goto next; goto next;
if (sk_diag_dump(sk, skb, req, sk_user_ns(skb->sk),
sk_ino = sock_i_ino(sk);
if (!sk_ino)
goto next;
if (sk_diag_fill(sk, skb, req, sk_user_ns(skb->sk),
NETLINK_CB(cb->skb).portid, NETLINK_CB(cb->skb).portid,
cb->nlh->nlmsg_seq, cb->nlh->nlmsg_seq,
NLM_F_MULTI) < 0) { NLM_F_MULTI, sk_ino) < 0) {
spin_unlock(&net->unx.table.locks[slot]); spin_unlock(&net->unx.table.locks[slot]);
goto done; goto done;
} }
......
...@@ -337,11 +337,6 @@ static bool unix_vertex_dead(struct unix_vertex *vertex) ...@@ -337,11 +337,6 @@ static bool unix_vertex_dead(struct unix_vertex *vertex)
return true; return true;
} }
enum unix_recv_queue_lock_class {
U_RECVQ_LOCK_NORMAL,
U_RECVQ_LOCK_EMBRYO,
};
static void unix_collect_queue(struct unix_sock *u, struct sk_buff_head *hitlist) static void unix_collect_queue(struct unix_sock *u, struct sk_buff_head *hitlist)
{ {
skb_queue_splice_init(&u->sk.sk_receive_queue, hitlist); skb_queue_splice_init(&u->sk.sk_receive_queue, hitlist);
...@@ -375,8 +370,7 @@ static void unix_collect_skb(struct list_head *scc, struct sk_buff_head *hitlist ...@@ -375,8 +370,7 @@ static void unix_collect_skb(struct list_head *scc, struct sk_buff_head *hitlist
skb_queue_walk(queue, skb) { skb_queue_walk(queue, skb) {
struct sk_buff_head *embryo_queue = &skb->sk->sk_receive_queue; struct sk_buff_head *embryo_queue = &skb->sk->sk_receive_queue;
/* listener -> embryo order, the inversion never happens. */ spin_lock(&embryo_queue->lock);
spin_lock_nested(&embryo_queue->lock, U_RECVQ_LOCK_EMBRYO);
unix_collect_queue(unix_sk(skb->sk), hitlist); unix_collect_queue(unix_sk(skb->sk), hitlist);
spin_unlock(&embryo_queue->lock); spin_unlock(&embryo_queue->lock);
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment