Commit 8a74ad60 authored by Eric Dumazet's avatar Eric Dumazet Committed by David S. Miller

net: fix lock_sock_bh/unlock_sock_bh

This new sock lock primitive was introduced to speedup some user context
socket manipulation. But it is unsafe to protect two threads, one using
regular lock_sock/release_sock, one using lock_sock_bh/unlock_sock_bh

This patch changes lock_sock_bh to be careful against 'owned' state.
If owned is found to be set, we must take the slow path.
lock_sock_bh() now returns a boolean to say if the slow path was taken,
and this boolean is used at unlock_sock_bh time to call the appropriate
unlock function.

After this change, BH are either disabled or enabled during the
lock_sock_bh/unlock_sock_bh protected section. This might be misleading,
so we rename these functions to lock_sock_fast()/unlock_sock_fast().
Reported-by: default avatarAnton Blanchard <anton@samba.org>
Signed-off-by: default avatarEric Dumazet <eric.dumazet@gmail.com>
Tested-by: default avatarAnton Blanchard <anton@samba.org>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent a56635a5
...@@ -1026,15 +1026,23 @@ extern void release_sock(struct sock *sk); ...@@ -1026,15 +1026,23 @@ extern void release_sock(struct sock *sk);
SINGLE_DEPTH_NESTING) SINGLE_DEPTH_NESTING)
#define bh_unlock_sock(__sk) spin_unlock(&((__sk)->sk_lock.slock)) #define bh_unlock_sock(__sk) spin_unlock(&((__sk)->sk_lock.slock))
static inline void lock_sock_bh(struct sock *sk) extern bool lock_sock_fast(struct sock *sk);
/**
* unlock_sock_fast - complement of lock_sock_fast
* @sk: socket
* @slow: slow mode
*
* fast unlock socket for user context.
* If slow mode is on, we call regular release_sock()
*/
static inline void unlock_sock_fast(struct sock *sk, bool slow)
{ {
spin_lock_bh(&sk->sk_lock.slock); if (slow)
release_sock(sk);
else
spin_unlock_bh(&sk->sk_lock.slock);
} }
static inline void unlock_sock_bh(struct sock *sk)
{
spin_unlock_bh(&sk->sk_lock.slock);
}
extern struct sock *sk_alloc(struct net *net, int family, extern struct sock *sk_alloc(struct net *net, int family,
gfp_t priority, gfp_t priority,
......
...@@ -229,15 +229,17 @@ EXPORT_SYMBOL(skb_free_datagram); ...@@ -229,15 +229,17 @@ EXPORT_SYMBOL(skb_free_datagram);
void skb_free_datagram_locked(struct sock *sk, struct sk_buff *skb) void skb_free_datagram_locked(struct sock *sk, struct sk_buff *skb)
{ {
bool slow;
if (likely(atomic_read(&skb->users) == 1)) if (likely(atomic_read(&skb->users) == 1))
smp_rmb(); smp_rmb();
else if (likely(!atomic_dec_and_test(&skb->users))) else if (likely(!atomic_dec_and_test(&skb->users)))
return; return;
lock_sock_bh(sk); slow = lock_sock_fast(sk);
skb_orphan(skb); skb_orphan(skb);
sk_mem_reclaim_partial(sk); sk_mem_reclaim_partial(sk);
unlock_sock_bh(sk); unlock_sock_fast(sk, slow);
/* skb is now orphaned, can be freed outside of locked section */ /* skb is now orphaned, can be freed outside of locked section */
__kfree_skb(skb); __kfree_skb(skb);
......
...@@ -2007,6 +2007,39 @@ void release_sock(struct sock *sk) ...@@ -2007,6 +2007,39 @@ void release_sock(struct sock *sk)
} }
EXPORT_SYMBOL(release_sock); EXPORT_SYMBOL(release_sock);
/**
* lock_sock_fast - fast version of lock_sock
* @sk: socket
*
* This version should be used for very small section, where process wont block
* return false if fast path is taken
* sk_lock.slock locked, owned = 0, BH disabled
* return true if slow path is taken
* sk_lock.slock unlocked, owned = 1, BH enabled
*/
bool lock_sock_fast(struct sock *sk)
{
might_sleep();
spin_lock_bh(&sk->sk_lock.slock);
if (!sk->sk_lock.owned)
/*
* Note : We must disable BH
*/
return false;
__lock_sock(sk);
sk->sk_lock.owned = 1;
spin_unlock(&sk->sk_lock.slock);
/*
* The sk_lock has mutex_lock() semantics here:
*/
mutex_acquire(&sk->sk_lock.dep_map, 0, 0, _RET_IP_);
local_bh_enable();
return true;
}
EXPORT_SYMBOL(lock_sock_fast);
int sock_get_timestamp(struct sock *sk, struct timeval __user *userstamp) int sock_get_timestamp(struct sock *sk, struct timeval __user *userstamp)
{ {
struct timeval tv; struct timeval tv;
......
...@@ -1063,10 +1063,11 @@ static unsigned int first_packet_length(struct sock *sk) ...@@ -1063,10 +1063,11 @@ static unsigned int first_packet_length(struct sock *sk)
spin_unlock_bh(&rcvq->lock); spin_unlock_bh(&rcvq->lock);
if (!skb_queue_empty(&list_kill)) { if (!skb_queue_empty(&list_kill)) {
lock_sock_bh(sk); bool slow = lock_sock_fast(sk);
__skb_queue_purge(&list_kill); __skb_queue_purge(&list_kill);
sk_mem_reclaim_partial(sk); sk_mem_reclaim_partial(sk);
unlock_sock_bh(sk); unlock_sock_fast(sk, slow);
} }
return res; return res;
} }
...@@ -1123,6 +1124,7 @@ int udp_recvmsg(struct kiocb *iocb, struct sock *sk, struct msghdr *msg, ...@@ -1123,6 +1124,7 @@ int udp_recvmsg(struct kiocb *iocb, struct sock *sk, struct msghdr *msg,
int peeked; int peeked;
int err; int err;
int is_udplite = IS_UDPLITE(sk); int is_udplite = IS_UDPLITE(sk);
bool slow;
/* /*
* Check any passed addresses * Check any passed addresses
...@@ -1197,10 +1199,10 @@ int udp_recvmsg(struct kiocb *iocb, struct sock *sk, struct msghdr *msg, ...@@ -1197,10 +1199,10 @@ int udp_recvmsg(struct kiocb *iocb, struct sock *sk, struct msghdr *msg,
return err; return err;
csum_copy_err: csum_copy_err:
lock_sock_bh(sk); slow = lock_sock_fast(sk);
if (!skb_kill_datagram(sk, skb, flags)) if (!skb_kill_datagram(sk, skb, flags))
UDP_INC_STATS_USER(sock_net(sk), UDP_MIB_INERRORS, is_udplite); UDP_INC_STATS_USER(sock_net(sk), UDP_MIB_INERRORS, is_udplite);
unlock_sock_bh(sk); unlock_sock_fast(sk, slow);
if (noblock) if (noblock)
return -EAGAIN; return -EAGAIN;
...@@ -1625,9 +1627,9 @@ int udp_rcv(struct sk_buff *skb) ...@@ -1625,9 +1627,9 @@ int udp_rcv(struct sk_buff *skb)
void udp_destroy_sock(struct sock *sk) void udp_destroy_sock(struct sock *sk)
{ {
lock_sock_bh(sk); bool slow = lock_sock_fast(sk);
udp_flush_pending_frames(sk); udp_flush_pending_frames(sk);
unlock_sock_bh(sk); unlock_sock_fast(sk, slow);
} }
/* /*
......
...@@ -328,6 +328,7 @@ int udpv6_recvmsg(struct kiocb *iocb, struct sock *sk, ...@@ -328,6 +328,7 @@ int udpv6_recvmsg(struct kiocb *iocb, struct sock *sk,
int err; int err;
int is_udplite = IS_UDPLITE(sk); int is_udplite = IS_UDPLITE(sk);
int is_udp4; int is_udp4;
bool slow;
if (addr_len) if (addr_len)
*addr_len=sizeof(struct sockaddr_in6); *addr_len=sizeof(struct sockaddr_in6);
...@@ -424,7 +425,7 @@ int udpv6_recvmsg(struct kiocb *iocb, struct sock *sk, ...@@ -424,7 +425,7 @@ int udpv6_recvmsg(struct kiocb *iocb, struct sock *sk,
return err; return err;
csum_copy_err: csum_copy_err:
lock_sock_bh(sk); slow = lock_sock_fast(sk);
if (!skb_kill_datagram(sk, skb, flags)) { if (!skb_kill_datagram(sk, skb, flags)) {
if (is_udp4) if (is_udp4)
UDP_INC_STATS_USER(sock_net(sk), UDP_INC_STATS_USER(sock_net(sk),
...@@ -433,7 +434,7 @@ int udpv6_recvmsg(struct kiocb *iocb, struct sock *sk, ...@@ -433,7 +434,7 @@ int udpv6_recvmsg(struct kiocb *iocb, struct sock *sk,
UDP6_INC_STATS_USER(sock_net(sk), UDP6_INC_STATS_USER(sock_net(sk),
UDP_MIB_INERRORS, is_udplite); UDP_MIB_INERRORS, is_udplite);
} }
unlock_sock_bh(sk); unlock_sock_fast(sk, slow);
if (flags & MSG_DONTWAIT) if (flags & MSG_DONTWAIT)
return -EAGAIN; return -EAGAIN;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment