Commit 965b57b4 authored by Cong Wang's avatar Cong Wang Committed by Daniel Borkmann

net: Introduce a new proto_ops ->read_skb()

Currently both splice() and sockmap use ->read_sock() to
read skb from receive queue, but for sockmap we only read
one entire skb at a time, so ->read_sock() is too conservative
to use. Introduce a new proto_ops ->read_skb() which supports
this sematic, with this we can finally pass the ownership of
skb to recv actors.

For non-TCP protocols, all ->read_sock() can be simply
converted to ->read_skb().
Signed-off-by: default avatarCong Wang <cong.wang@bytedance.com>
Signed-off-by: default avatarDaniel Borkmann <daniel@iogearbox.net>
Reviewed-by: default avatarJohn Fastabend <john.fastabend@gmail.com>
Link: https://lore.kernel.org/bpf/20220615162014.89193-3-xiyou.wangcong@gmail.com
parent 04919bed
...@@ -152,6 +152,8 @@ struct module; ...@@ -152,6 +152,8 @@ struct module;
struct sk_buff; struct sk_buff;
typedef int (*sk_read_actor_t)(read_descriptor_t *, struct sk_buff *, typedef int (*sk_read_actor_t)(read_descriptor_t *, struct sk_buff *,
unsigned int, size_t); unsigned int, size_t);
typedef int (*skb_read_actor_t)(struct sock *, struct sk_buff *);
struct proto_ops { struct proto_ops {
int family; int family;
...@@ -214,6 +216,8 @@ struct proto_ops { ...@@ -214,6 +216,8 @@ struct proto_ops {
*/ */
int (*read_sock)(struct sock *sk, read_descriptor_t *desc, int (*read_sock)(struct sock *sk, read_descriptor_t *desc,
sk_read_actor_t recv_actor); sk_read_actor_t recv_actor);
/* This is different from read_sock(), it reads an entire skb at a time. */
int (*read_skb)(struct sock *sk, skb_read_actor_t recv_actor);
int (*sendpage_locked)(struct sock *sk, struct page *page, int (*sendpage_locked)(struct sock *sk, struct page *page,
int offset, size_t size, int flags); int offset, size_t size, int flags);
int (*sendmsg_locked)(struct sock *sk, struct msghdr *msg, int (*sendmsg_locked)(struct sock *sk, struct msghdr *msg,
......
...@@ -672,8 +672,7 @@ void tcp_get_info(struct sock *, struct tcp_info *); ...@@ -672,8 +672,7 @@ void tcp_get_info(struct sock *, struct tcp_info *);
/* Read 'sendfile()'-style from a TCP socket */ /* Read 'sendfile()'-style from a TCP socket */
int tcp_read_sock(struct sock *sk, read_descriptor_t *desc, int tcp_read_sock(struct sock *sk, read_descriptor_t *desc,
sk_read_actor_t recv_actor); sk_read_actor_t recv_actor);
int tcp_read_skb(struct sock *sk, read_descriptor_t *desc, int tcp_read_skb(struct sock *sk, skb_read_actor_t recv_actor);
sk_read_actor_t recv_actor);
void tcp_initialize_rcv_mss(struct sock *sk); void tcp_initialize_rcv_mss(struct sock *sk);
......
...@@ -306,8 +306,7 @@ struct sock *__udp6_lib_lookup(struct net *net, ...@@ -306,8 +306,7 @@ struct sock *__udp6_lib_lookup(struct net *net,
struct sk_buff *skb); struct sk_buff *skb);
struct sock *udp6_lib_lookup_skb(const struct sk_buff *skb, struct sock *udp6_lib_lookup_skb(const struct sk_buff *skb,
__be16 sport, __be16 dport); __be16 sport, __be16 dport);
int udp_read_sock(struct sock *sk, read_descriptor_t *desc, int udp_read_skb(struct sock *sk, skb_read_actor_t recv_actor);
sk_read_actor_t recv_actor);
/* UDP uses skb->dev_scratch to cache as much information as possible and avoid /* UDP uses skb->dev_scratch to cache as much information as possible and avoid
* possibly multiple cache miss on dequeue() * possibly multiple cache miss on dequeue()
......
...@@ -1160,21 +1160,17 @@ static void sk_psock_done_strp(struct sk_psock *psock) ...@@ -1160,21 +1160,17 @@ static void sk_psock_done_strp(struct sk_psock *psock)
} }
#endif /* CONFIG_BPF_STREAM_PARSER */ #endif /* CONFIG_BPF_STREAM_PARSER */
static int sk_psock_verdict_recv(read_descriptor_t *desc, struct sk_buff *skb, static int sk_psock_verdict_recv(struct sock *sk, struct sk_buff *skb)
unsigned int offset, size_t orig_len)
{ {
struct sock *sk = (struct sock *)desc->arg.data;
struct sk_psock *psock; struct sk_psock *psock;
struct bpf_prog *prog; struct bpf_prog *prog;
int ret = __SK_DROP; int ret = __SK_DROP;
int len = orig_len; int len = skb->len;
/* clone here so sk_eat_skb() in tcp_read_sock does not drop our data */ /* clone here so sk_eat_skb() in tcp_read_sock does not drop our data */
skb = skb_clone(skb, GFP_ATOMIC); skb = skb_clone(skb, GFP_ATOMIC);
if (!skb) { if (!skb)
desc->error = -ENOMEM;
return 0; return 0;
}
rcu_read_lock(); rcu_read_lock();
psock = sk_psock(sk); psock = sk_psock(sk);
...@@ -1204,16 +1200,10 @@ static int sk_psock_verdict_recv(read_descriptor_t *desc, struct sk_buff *skb, ...@@ -1204,16 +1200,10 @@ static int sk_psock_verdict_recv(read_descriptor_t *desc, struct sk_buff *skb,
static void sk_psock_verdict_data_ready(struct sock *sk) static void sk_psock_verdict_data_ready(struct sock *sk)
{ {
struct socket *sock = sk->sk_socket; struct socket *sock = sk->sk_socket;
read_descriptor_t desc;
if (unlikely(!sock || !sock->ops || !sock->ops->read_sock)) if (unlikely(!sock || !sock->ops || !sock->ops->read_skb))
return; return;
sock->ops->read_skb(sk, sk_psock_verdict_recv);
desc.arg.data = sk;
desc.error = 0;
desc.count = 1;
sock->ops->read_sock(sk, &desc, sk_psock_verdict_recv);
} }
void sk_psock_start_verdict(struct sock *sk, struct sk_psock *psock) void sk_psock_start_verdict(struct sock *sk, struct sk_psock *psock)
......
...@@ -1040,6 +1040,7 @@ const struct proto_ops inet_stream_ops = { ...@@ -1040,6 +1040,7 @@ const struct proto_ops inet_stream_ops = {
.sendpage = inet_sendpage, .sendpage = inet_sendpage,
.splice_read = tcp_splice_read, .splice_read = tcp_splice_read,
.read_sock = tcp_read_sock, .read_sock = tcp_read_sock,
.read_skb = tcp_read_skb,
.sendmsg_locked = tcp_sendmsg_locked, .sendmsg_locked = tcp_sendmsg_locked,
.sendpage_locked = tcp_sendpage_locked, .sendpage_locked = tcp_sendpage_locked,
.peek_len = tcp_peek_len, .peek_len = tcp_peek_len,
...@@ -1067,7 +1068,7 @@ const struct proto_ops inet_dgram_ops = { ...@@ -1067,7 +1068,7 @@ const struct proto_ops inet_dgram_ops = {
.setsockopt = sock_common_setsockopt, .setsockopt = sock_common_setsockopt,
.getsockopt = sock_common_getsockopt, .getsockopt = sock_common_getsockopt,
.sendmsg = inet_sendmsg, .sendmsg = inet_sendmsg,
.read_sock = udp_read_sock, .read_skb = udp_read_skb,
.recvmsg = inet_recvmsg, .recvmsg = inet_recvmsg,
.mmap = sock_no_mmap, .mmap = sock_no_mmap,
.sendpage = inet_sendpage, .sendpage = inet_sendpage,
......
...@@ -1734,8 +1734,7 @@ int tcp_read_sock(struct sock *sk, read_descriptor_t *desc, ...@@ -1734,8 +1734,7 @@ int tcp_read_sock(struct sock *sk, read_descriptor_t *desc,
} }
EXPORT_SYMBOL(tcp_read_sock); EXPORT_SYMBOL(tcp_read_sock);
int tcp_read_skb(struct sock *sk, read_descriptor_t *desc, int tcp_read_skb(struct sock *sk, skb_read_actor_t recv_actor)
sk_read_actor_t recv_actor)
{ {
struct tcp_sock *tp = tcp_sk(sk); struct tcp_sock *tp = tcp_sk(sk);
u32 seq = tp->copied_seq; u32 seq = tp->copied_seq;
...@@ -1750,7 +1749,7 @@ int tcp_read_skb(struct sock *sk, read_descriptor_t *desc, ...@@ -1750,7 +1749,7 @@ int tcp_read_skb(struct sock *sk, read_descriptor_t *desc,
int used; int used;
__skb_unlink(skb, &sk->sk_receive_queue); __skb_unlink(skb, &sk->sk_receive_queue);
used = recv_actor(desc, skb, 0, skb->len); used = recv_actor(sk, skb);
if (used <= 0) { if (used <= 0) {
if (!copied) if (!copied)
copied = used; copied = used;
...@@ -1765,9 +1764,7 @@ int tcp_read_skb(struct sock *sk, read_descriptor_t *desc, ...@@ -1765,9 +1764,7 @@ int tcp_read_skb(struct sock *sk, read_descriptor_t *desc,
break; break;
} }
consume_skb(skb); consume_skb(skb);
if (!desc->count) break;
break;
WRITE_ONCE(tp->copied_seq, seq);
} }
WRITE_ONCE(tp->copied_seq, seq); WRITE_ONCE(tp->copied_seq, seq);
......
...@@ -1797,8 +1797,7 @@ struct sk_buff *__skb_recv_udp(struct sock *sk, unsigned int flags, ...@@ -1797,8 +1797,7 @@ struct sk_buff *__skb_recv_udp(struct sock *sk, unsigned int flags,
} }
EXPORT_SYMBOL(__skb_recv_udp); EXPORT_SYMBOL(__skb_recv_udp);
int udp_read_sock(struct sock *sk, read_descriptor_t *desc, int udp_read_skb(struct sock *sk, skb_read_actor_t recv_actor)
sk_read_actor_t recv_actor)
{ {
int copied = 0; int copied = 0;
...@@ -1820,7 +1819,7 @@ int udp_read_sock(struct sock *sk, read_descriptor_t *desc, ...@@ -1820,7 +1819,7 @@ int udp_read_sock(struct sock *sk, read_descriptor_t *desc,
continue; continue;
} }
used = recv_actor(desc, skb, 0, skb->len); used = recv_actor(sk, skb);
if (used <= 0) { if (used <= 0) {
if (!copied) if (!copied)
copied = used; copied = used;
...@@ -1831,13 +1830,12 @@ int udp_read_sock(struct sock *sk, read_descriptor_t *desc, ...@@ -1831,13 +1830,12 @@ int udp_read_sock(struct sock *sk, read_descriptor_t *desc,
} }
kfree_skb(skb); kfree_skb(skb);
if (!desc->count) break;
break;
} }
return copied; return copied;
} }
EXPORT_SYMBOL(udp_read_sock); EXPORT_SYMBOL(udp_read_skb);
/* /*
* This should be easy, if there is something there we * This should be easy, if there is something there we
......
...@@ -702,6 +702,7 @@ const struct proto_ops inet6_stream_ops = { ...@@ -702,6 +702,7 @@ const struct proto_ops inet6_stream_ops = {
.sendpage_locked = tcp_sendpage_locked, .sendpage_locked = tcp_sendpage_locked,
.splice_read = tcp_splice_read, .splice_read = tcp_splice_read,
.read_sock = tcp_read_sock, .read_sock = tcp_read_sock,
.read_skb = tcp_read_skb,
.peek_len = tcp_peek_len, .peek_len = tcp_peek_len,
#ifdef CONFIG_COMPAT #ifdef CONFIG_COMPAT
.compat_ioctl = inet6_compat_ioctl, .compat_ioctl = inet6_compat_ioctl,
...@@ -727,7 +728,7 @@ const struct proto_ops inet6_dgram_ops = { ...@@ -727,7 +728,7 @@ const struct proto_ops inet6_dgram_ops = {
.getsockopt = sock_common_getsockopt, /* ok */ .getsockopt = sock_common_getsockopt, /* ok */
.sendmsg = inet6_sendmsg, /* retpoline's sake */ .sendmsg = inet6_sendmsg, /* retpoline's sake */
.recvmsg = inet6_recvmsg, /* retpoline's sake */ .recvmsg = inet6_recvmsg, /* retpoline's sake */
.read_sock = udp_read_sock, .read_skb = udp_read_skb,
.mmap = sock_no_mmap, .mmap = sock_no_mmap,
.sendpage = sock_no_sendpage, .sendpage = sock_no_sendpage,
.set_peek_off = sk_set_peek_off, .set_peek_off = sk_set_peek_off,
......
...@@ -741,10 +741,8 @@ static ssize_t unix_stream_splice_read(struct socket *, loff_t *ppos, ...@@ -741,10 +741,8 @@ static ssize_t unix_stream_splice_read(struct socket *, loff_t *ppos,
unsigned int flags); unsigned int flags);
static int unix_dgram_sendmsg(struct socket *, struct msghdr *, size_t); static int unix_dgram_sendmsg(struct socket *, struct msghdr *, size_t);
static int unix_dgram_recvmsg(struct socket *, struct msghdr *, size_t, int); static int unix_dgram_recvmsg(struct socket *, struct msghdr *, size_t, int);
static int unix_read_sock(struct sock *sk, read_descriptor_t *desc, static int unix_read_skb(struct sock *sk, skb_read_actor_t recv_actor);
sk_read_actor_t recv_actor); static int unix_stream_read_skb(struct sock *sk, skb_read_actor_t recv_actor);
static int unix_stream_read_sock(struct sock *sk, read_descriptor_t *desc,
sk_read_actor_t recv_actor);
static int unix_dgram_connect(struct socket *, struct sockaddr *, static int unix_dgram_connect(struct socket *, struct sockaddr *,
int, int); int, int);
static int unix_seqpacket_sendmsg(struct socket *, struct msghdr *, size_t); static int unix_seqpacket_sendmsg(struct socket *, struct msghdr *, size_t);
...@@ -798,7 +796,7 @@ static const struct proto_ops unix_stream_ops = { ...@@ -798,7 +796,7 @@ static const struct proto_ops unix_stream_ops = {
.shutdown = unix_shutdown, .shutdown = unix_shutdown,
.sendmsg = unix_stream_sendmsg, .sendmsg = unix_stream_sendmsg,
.recvmsg = unix_stream_recvmsg, .recvmsg = unix_stream_recvmsg,
.read_sock = unix_stream_read_sock, .read_skb = unix_stream_read_skb,
.mmap = sock_no_mmap, .mmap = sock_no_mmap,
.sendpage = unix_stream_sendpage, .sendpage = unix_stream_sendpage,
.splice_read = unix_stream_splice_read, .splice_read = unix_stream_splice_read,
...@@ -823,7 +821,7 @@ static const struct proto_ops unix_dgram_ops = { ...@@ -823,7 +821,7 @@ static const struct proto_ops unix_dgram_ops = {
.listen = sock_no_listen, .listen = sock_no_listen,
.shutdown = unix_shutdown, .shutdown = unix_shutdown,
.sendmsg = unix_dgram_sendmsg, .sendmsg = unix_dgram_sendmsg,
.read_sock = unix_read_sock, .read_skb = unix_read_skb,
.recvmsg = unix_dgram_recvmsg, .recvmsg = unix_dgram_recvmsg,
.mmap = sock_no_mmap, .mmap = sock_no_mmap,
.sendpage = sock_no_sendpage, .sendpage = sock_no_sendpage,
...@@ -2487,8 +2485,7 @@ static int unix_dgram_recvmsg(struct socket *sock, struct msghdr *msg, size_t si ...@@ -2487,8 +2485,7 @@ static int unix_dgram_recvmsg(struct socket *sock, struct msghdr *msg, size_t si
return __unix_dgram_recvmsg(sk, msg, size, flags); return __unix_dgram_recvmsg(sk, msg, size, flags);
} }
static int unix_read_sock(struct sock *sk, read_descriptor_t *desc, static int unix_read_skb(struct sock *sk, skb_read_actor_t recv_actor)
sk_read_actor_t recv_actor)
{ {
int copied = 0; int copied = 0;
...@@ -2503,7 +2500,7 @@ static int unix_read_sock(struct sock *sk, read_descriptor_t *desc, ...@@ -2503,7 +2500,7 @@ static int unix_read_sock(struct sock *sk, read_descriptor_t *desc,
if (!skb) if (!skb)
return err; return err;
used = recv_actor(desc, skb, 0, skb->len); used = recv_actor(sk, skb);
if (used <= 0) { if (used <= 0) {
if (!copied) if (!copied)
copied = used; copied = used;
...@@ -2514,8 +2511,7 @@ static int unix_read_sock(struct sock *sk, read_descriptor_t *desc, ...@@ -2514,8 +2511,7 @@ static int unix_read_sock(struct sock *sk, read_descriptor_t *desc,
} }
kfree_skb(skb); kfree_skb(skb);
if (!desc->count) break;
break;
} }
return copied; return copied;
...@@ -2650,13 +2646,12 @@ static struct sk_buff *manage_oob(struct sk_buff *skb, struct sock *sk, ...@@ -2650,13 +2646,12 @@ static struct sk_buff *manage_oob(struct sk_buff *skb, struct sock *sk,
} }
#endif #endif
static int unix_stream_read_sock(struct sock *sk, read_descriptor_t *desc, static int unix_stream_read_skb(struct sock *sk, skb_read_actor_t recv_actor)
sk_read_actor_t recv_actor)
{ {
if (unlikely(sk->sk_state != TCP_ESTABLISHED)) if (unlikely(sk->sk_state != TCP_ESTABLISHED))
return -ENOTCONN; return -ENOTCONN;
return unix_read_sock(sk, desc, recv_actor); return unix_read_skb(sk, recv_actor);
} }
static int unix_stream_read_generic(struct unix_stream_read_state *state, static int unix_stream_read_generic(struct unix_stream_read_state *state,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment