Commit 9f2470fb authored by Cong Wang's avatar Cong Wang Committed by Daniel Borkmann

skmsg: Improve udp_bpf_recvmsg() accuracy

I tried to reuse sk_msg_wait_data() for different protocols,
but it turns out it can not be simply reused. For example,
UDP actually uses two queues to receive skb:
udp_sk(sk)->reader_queue and sk->sk_receive_queue. So we have
to check both of them to know whether we have received any
packet.

Also, UDP does not lock the sock during BH Rx path, it makes
no sense for its ->recvmsg() to lock the sock. It is always
possible for ->recvmsg() to be called before packets actually
arrive in the receive queue, we just use best effort to make
it accurate here.

Fixes: 1f5be6b3 ("udp: Implement udp_bpf_recvmsg() for sockmap")
Signed-off-by: default avatarCong Wang <cong.wang@bytedance.com>
Signed-off-by: default avatarDaniel Borkmann <daniel@iogearbox.net>
Acked-by: default avatarJohn Fastabend <john.fastabend@gmail.com>
Acked-by: default avatarJakub Sitnicki <jakub@cloudflare.com>
Link: https://lore.kernel.org/bpf/20210615021342.7416-2-xiyou.wangcong@gmail.com
parent 61e8aeda
...@@ -126,8 +126,6 @@ int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from, ...@@ -126,8 +126,6 @@ int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
struct sk_msg *msg, u32 bytes); struct sk_msg *msg, u32 bytes);
int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from, int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
struct sk_msg *msg, u32 bytes); struct sk_msg *msg, u32 bytes);
int sk_msg_wait_data(struct sock *sk, struct sk_psock *psock, int flags,
long timeo, int *err);
int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg, int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg,
int len, int flags); int len, int flags);
......
...@@ -399,29 +399,6 @@ int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from, ...@@ -399,29 +399,6 @@ int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
} }
EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter); EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter);
int sk_msg_wait_data(struct sock *sk, struct sk_psock *psock, int flags,
long timeo, int *err)
{
DEFINE_WAIT_FUNC(wait, woken_wake_function);
int ret = 0;
if (sk->sk_shutdown & RCV_SHUTDOWN)
return 1;
if (!timeo)
return ret;
add_wait_queue(sk_sleep(sk), &wait);
sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
ret = sk_wait_event(sk, &timeo,
!list_empty(&psock->ingress_msg) ||
!skb_queue_empty(&sk->sk_receive_queue), &wait);
sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
remove_wait_queue(sk_sleep(sk), &wait);
return ret;
}
EXPORT_SYMBOL_GPL(sk_msg_wait_data);
/* Receive sk_msg from psock->ingress_msg to @msg. */ /* Receive sk_msg from psock->ingress_msg to @msg. */
int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg, int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg,
int len, int flags) int len, int flags)
......
...@@ -163,6 +163,28 @@ static bool tcp_bpf_stream_read(const struct sock *sk) ...@@ -163,6 +163,28 @@ static bool tcp_bpf_stream_read(const struct sock *sk)
return !empty; return !empty;
} }
static int tcp_msg_wait_data(struct sock *sk, struct sk_psock *psock, int flags,
long timeo, int *err)
{
DEFINE_WAIT_FUNC(wait, woken_wake_function);
int ret = 0;
if (sk->sk_shutdown & RCV_SHUTDOWN)
return 1;
if (!timeo)
return ret;
add_wait_queue(sk_sleep(sk), &wait);
sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
ret = sk_wait_event(sk, &timeo,
!list_empty(&psock->ingress_msg) ||
!skb_queue_empty(&sk->sk_receive_queue), &wait);
sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
remove_wait_queue(sk_sleep(sk), &wait);
return ret;
}
static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
int nonblock, int flags, int *addr_len) int nonblock, int flags, int *addr_len)
{ {
...@@ -188,7 +210,7 @@ static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, ...@@ -188,7 +210,7 @@ static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
long timeo; long timeo;
timeo = sock_rcvtimeo(sk, nonblock); timeo = sock_rcvtimeo(sk, nonblock);
data = sk_msg_wait_data(sk, psock, flags, timeo, &err); data = tcp_msg_wait_data(sk, psock, flags, timeo, &err);
if (data) { if (data) {
if (!sk_psock_queue_empty(psock)) if (!sk_psock_queue_empty(psock))
goto msg_bytes_ready; goto msg_bytes_ready;
......
...@@ -21,6 +21,45 @@ static int sk_udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, ...@@ -21,6 +21,45 @@ static int sk_udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
return udp_prot.recvmsg(sk, msg, len, noblock, flags, addr_len); return udp_prot.recvmsg(sk, msg, len, noblock, flags, addr_len);
} }
static bool udp_sk_has_data(struct sock *sk)
{
return !skb_queue_empty(&udp_sk(sk)->reader_queue) ||
!skb_queue_empty(&sk->sk_receive_queue);
}
static bool psock_has_data(struct sk_psock *psock)
{
return !skb_queue_empty(&psock->ingress_skb) ||
!sk_psock_queue_empty(psock);
}
#define udp_msg_has_data(__sk, __psock) \
({ udp_sk_has_data(__sk) || psock_has_data(__psock); })
static int udp_msg_wait_data(struct sock *sk, struct sk_psock *psock, int flags,
long timeo, int *err)
{
DEFINE_WAIT_FUNC(wait, woken_wake_function);
int ret = 0;
if (sk->sk_shutdown & RCV_SHUTDOWN)
return 1;
if (!timeo)
return ret;
add_wait_queue(sk_sleep(sk), &wait);
sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
ret = udp_msg_has_data(sk, psock);
if (!ret) {
wait_woken(&wait, TASK_INTERRUPTIBLE, timeo);
ret = udp_msg_has_data(sk, psock);
}
sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
remove_wait_queue(sk_sleep(sk), &wait);
return ret;
}
static int udp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, static int udp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
int nonblock, int flags, int *addr_len) int nonblock, int flags, int *addr_len)
{ {
...@@ -34,8 +73,7 @@ static int udp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, ...@@ -34,8 +73,7 @@ static int udp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
if (unlikely(!psock)) if (unlikely(!psock))
return sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len); return sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
lock_sock(sk); if (!psock_has_data(psock)) {
if (sk_psock_queue_empty(psock)) {
ret = sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len); ret = sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
goto out; goto out;
} }
...@@ -47,9 +85,9 @@ static int udp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, ...@@ -47,9 +85,9 @@ static int udp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
long timeo; long timeo;
timeo = sock_rcvtimeo(sk, nonblock); timeo = sock_rcvtimeo(sk, nonblock);
data = sk_msg_wait_data(sk, psock, flags, timeo, &err); data = udp_msg_wait_data(sk, psock, flags, timeo, &err);
if (data) { if (data) {
if (!sk_psock_queue_empty(psock)) if (psock_has_data(psock))
goto msg_bytes_ready; goto msg_bytes_ready;
ret = sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len); ret = sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
goto out; goto out;
...@@ -62,7 +100,6 @@ static int udp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, ...@@ -62,7 +100,6 @@ static int udp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
} }
ret = copied; ret = copied;
out: out:
release_sock(sk);
sk_psock_put(sk, psock); sk_psock_put(sk, psock);
return ret; return ret;
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment