Commit 8934ce2f authored by John Fastabend's avatar John Fastabend Committed by Daniel Borkmann

bpf: sockmap redirect ingress support

Add support for the BPF_F_INGRESS flag in sk_msg redirect helper.
To do this add a scatterlist ring for receiving socks to check
before calling into regular recvmsg call path. Additionally, because
the poll wakeup logic only checked the skb recv queue we need to
add a hook in TCP stack (similar to write side) so that we have
a way to wake up polling socks when a scatterlist is redirected
to that sock.

After this all that is needed is for the redirect helper to
push the scatterlist into the psock receive queue.
Signed-off-by: default avatarJohn Fastabend <john.fastabend@gmail.com>
Signed-off-by: default avatarDaniel Borkmann <daniel@iogearbox.net>
parent 22527437
...@@ -521,6 +521,7 @@ struct sk_msg_buff { ...@@ -521,6 +521,7 @@ struct sk_msg_buff {
__u32 key; __u32 key;
__u32 flags; __u32 flags;
struct bpf_map *map; struct bpf_map *map;
struct list_head list;
}; };
/* Compute the linear packet data range [data, data_end) which /* Compute the linear packet data range [data, data_end) which
......
...@@ -1085,6 +1085,7 @@ struct proto { ...@@ -1085,6 +1085,7 @@ struct proto {
#endif #endif
bool (*stream_memory_free)(const struct sock *sk); bool (*stream_memory_free)(const struct sock *sk);
bool (*stream_memory_read)(const struct sock *sk);
/* Memory pressure */ /* Memory pressure */
void (*enter_memory_pressure)(struct sock *sk); void (*enter_memory_pressure)(struct sock *sk);
void (*leave_memory_pressure)(struct sock *sk); void (*leave_memory_pressure)(struct sock *sk);
......
...@@ -41,6 +41,8 @@ ...@@ -41,6 +41,8 @@
#include <linux/mm.h> #include <linux/mm.h>
#include <net/strparser.h> #include <net/strparser.h>
#include <net/tcp.h> #include <net/tcp.h>
#include <linux/ptr_ring.h>
#include <net/inet_common.h>
#define SOCK_CREATE_FLAG_MASK \ #define SOCK_CREATE_FLAG_MASK \
(BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY) (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
...@@ -82,6 +84,7 @@ struct smap_psock { ...@@ -82,6 +84,7 @@ struct smap_psock {
int sg_size; int sg_size;
int eval; int eval;
struct sk_msg_buff *cork; struct sk_msg_buff *cork;
struct list_head ingress;
struct strparser strp; struct strparser strp;
struct bpf_prog *bpf_tx_msg; struct bpf_prog *bpf_tx_msg;
...@@ -103,6 +106,8 @@ struct smap_psock { ...@@ -103,6 +106,8 @@ struct smap_psock {
}; };
static void smap_release_sock(struct smap_psock *psock, struct sock *sock); static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
int nonblock, int flags, int *addr_len);
static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size); static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
static int bpf_tcp_sendpage(struct sock *sk, struct page *page, static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
int offset, size_t size, int flags); int offset, size_t size, int flags);
...@@ -112,6 +117,21 @@ static inline struct smap_psock *smap_psock_sk(const struct sock *sk) ...@@ -112,6 +117,21 @@ static inline struct smap_psock *smap_psock_sk(const struct sock *sk)
return rcu_dereference_sk_user_data(sk); return rcu_dereference_sk_user_data(sk);
} }
static bool bpf_tcp_stream_read(const struct sock *sk)
{
struct smap_psock *psock;
bool empty = true;
rcu_read_lock();
psock = smap_psock_sk(sk);
if (unlikely(!psock))
goto out;
empty = list_empty(&psock->ingress);
out:
rcu_read_unlock();
return !empty;
}
static struct proto tcp_bpf_proto; static struct proto tcp_bpf_proto;
static int bpf_tcp_init(struct sock *sk) static int bpf_tcp_init(struct sock *sk)
{ {
...@@ -135,6 +155,8 @@ static int bpf_tcp_init(struct sock *sk) ...@@ -135,6 +155,8 @@ static int bpf_tcp_init(struct sock *sk)
if (psock->bpf_tx_msg) { if (psock->bpf_tx_msg) {
tcp_bpf_proto.sendmsg = bpf_tcp_sendmsg; tcp_bpf_proto.sendmsg = bpf_tcp_sendmsg;
tcp_bpf_proto.sendpage = bpf_tcp_sendpage; tcp_bpf_proto.sendpage = bpf_tcp_sendpage;
tcp_bpf_proto.recvmsg = bpf_tcp_recvmsg;
tcp_bpf_proto.stream_memory_read = bpf_tcp_stream_read;
} }
sk->sk_prot = &tcp_bpf_proto; sk->sk_prot = &tcp_bpf_proto;
...@@ -170,6 +192,7 @@ static void bpf_tcp_close(struct sock *sk, long timeout) ...@@ -170,6 +192,7 @@ static void bpf_tcp_close(struct sock *sk, long timeout)
{ {
void (*close_fun)(struct sock *sk, long timeout); void (*close_fun)(struct sock *sk, long timeout);
struct smap_psock_map_entry *e, *tmp; struct smap_psock_map_entry *e, *tmp;
struct sk_msg_buff *md, *mtmp;
struct smap_psock *psock; struct smap_psock *psock;
struct sock *osk; struct sock *osk;
...@@ -188,6 +211,12 @@ static void bpf_tcp_close(struct sock *sk, long timeout) ...@@ -188,6 +211,12 @@ static void bpf_tcp_close(struct sock *sk, long timeout)
close_fun = psock->save_close; close_fun = psock->save_close;
write_lock_bh(&sk->sk_callback_lock); write_lock_bh(&sk->sk_callback_lock);
list_for_each_entry_safe(md, mtmp, &psock->ingress, list) {
list_del(&md->list);
free_start_sg(psock->sock, md);
kfree(md);
}
list_for_each_entry_safe(e, tmp, &psock->maps, list) { list_for_each_entry_safe(e, tmp, &psock->maps, list) {
osk = cmpxchg(e->entry, sk, NULL); osk = cmpxchg(e->entry, sk, NULL);
if (osk == sk) { if (osk == sk) {
...@@ -468,6 +497,72 @@ static unsigned int smap_do_tx_msg(struct sock *sk, ...@@ -468,6 +497,72 @@ static unsigned int smap_do_tx_msg(struct sock *sk,
return _rc; return _rc;
} }
static int bpf_tcp_ingress(struct sock *sk, int apply_bytes,
struct smap_psock *psock,
struct sk_msg_buff *md, int flags)
{
bool apply = apply_bytes;
size_t size, copied = 0;
struct sk_msg_buff *r;
int err = 0, i;
r = kzalloc(sizeof(struct sk_msg_buff), __GFP_NOWARN | GFP_KERNEL);
if (unlikely(!r))
return -ENOMEM;
lock_sock(sk);
r->sg_start = md->sg_start;
i = md->sg_start;
do {
r->sg_data[i] = md->sg_data[i];
size = (apply && apply_bytes < md->sg_data[i].length) ?
apply_bytes : md->sg_data[i].length;
if (!sk_wmem_schedule(sk, size)) {
if (!copied)
err = -ENOMEM;
break;
}
sk_mem_charge(sk, size);
r->sg_data[i].length = size;
md->sg_data[i].length -= size;
md->sg_data[i].offset += size;
copied += size;
if (md->sg_data[i].length) {
get_page(sg_page(&r->sg_data[i]));
r->sg_end = (i + 1) == MAX_SKB_FRAGS ? 0 : i + 1;
} else {
i++;
if (i == MAX_SKB_FRAGS)
i = 0;
r->sg_end = i;
}
if (apply) {
apply_bytes -= size;
if (!apply_bytes)
break;
}
} while (i != md->sg_end);
md->sg_start = i;
if (!err) {
list_add_tail(&r->list, &psock->ingress);
sk->sk_data_ready(sk);
} else {
free_start_sg(sk, r);
kfree(r);
}
release_sock(sk);
return err;
}
static int bpf_tcp_sendmsg_do_redirect(struct sock *sk, int send, static int bpf_tcp_sendmsg_do_redirect(struct sock *sk, int send,
struct sk_msg_buff *md, struct sk_msg_buff *md,
int flags) int flags)
...@@ -475,6 +570,7 @@ static int bpf_tcp_sendmsg_do_redirect(struct sock *sk, int send, ...@@ -475,6 +570,7 @@ static int bpf_tcp_sendmsg_do_redirect(struct sock *sk, int send,
struct smap_psock *psock; struct smap_psock *psock;
struct scatterlist *sg; struct scatterlist *sg;
int i, err, free = 0; int i, err, free = 0;
bool ingress = !!(md->flags & BPF_F_INGRESS);
sg = md->sg_data; sg = md->sg_data;
...@@ -487,9 +583,14 @@ static int bpf_tcp_sendmsg_do_redirect(struct sock *sk, int send, ...@@ -487,9 +583,14 @@ static int bpf_tcp_sendmsg_do_redirect(struct sock *sk, int send,
goto out_rcu; goto out_rcu;
rcu_read_unlock(); rcu_read_unlock();
if (ingress) {
err = bpf_tcp_ingress(sk, send, psock, md, flags);
} else {
lock_sock(sk); lock_sock(sk);
err = bpf_tcp_push(sk, send, md, flags, false); err = bpf_tcp_push(sk, send, md, flags, false);
release_sock(sk); release_sock(sk);
}
smap_release_sock(psock, sk); smap_release_sock(psock, sk);
if (unlikely(err)) if (unlikely(err))
goto out; goto out;
...@@ -623,6 +724,89 @@ static int bpf_exec_tx_verdict(struct smap_psock *psock, ...@@ -623,6 +724,89 @@ static int bpf_exec_tx_verdict(struct smap_psock *psock,
return err; return err;
} }
static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
int nonblock, int flags, int *addr_len)
{
struct iov_iter *iter = &msg->msg_iter;
struct smap_psock *psock;
int copied = 0;
if (unlikely(flags & MSG_ERRQUEUE))
return inet_recv_error(sk, msg, len, addr_len);
rcu_read_lock();
psock = smap_psock_sk(sk);
if (unlikely(!psock))
goto out;
if (unlikely(!refcount_inc_not_zero(&psock->refcnt)))
goto out;
rcu_read_unlock();
if (!skb_queue_empty(&sk->sk_receive_queue))
return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
lock_sock(sk);
while (copied != len) {
struct scatterlist *sg;
struct sk_msg_buff *md;
int i;
md = list_first_entry_or_null(&psock->ingress,
struct sk_msg_buff, list);
if (unlikely(!md))
break;
i = md->sg_start;
do {
struct page *page;
int n, copy;
sg = &md->sg_data[i];
copy = sg->length;
page = sg_page(sg);
if (copied + copy > len)
copy = len - copied;
n = copy_page_to_iter(page, sg->offset, copy, iter);
if (n != copy) {
md->sg_start = i;
release_sock(sk);
smap_release_sock(psock, sk);
return -EFAULT;
}
copied += copy;
sg->offset += copy;
sg->length -= copy;
sk_mem_uncharge(sk, copy);
if (!sg->length) {
i++;
if (i == MAX_SKB_FRAGS)
i = 0;
put_page(page);
}
if (copied == len)
break;
} while (i != md->sg_end);
md->sg_start = i;
if (!sg->length && md->sg_start == md->sg_end) {
list_del(&md->list);
kfree(md);
}
}
release_sock(sk);
smap_release_sock(psock, sk);
return copied;
out:
rcu_read_unlock();
return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
}
static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
{ {
int flags = msg->msg_flags | MSG_NO_SHARED_FRAGS; int flags = msg->msg_flags | MSG_NO_SHARED_FRAGS;
...@@ -1107,6 +1291,7 @@ static void sock_map_remove_complete(struct bpf_stab *stab) ...@@ -1107,6 +1291,7 @@ static void sock_map_remove_complete(struct bpf_stab *stab)
static void smap_gc_work(struct work_struct *w) static void smap_gc_work(struct work_struct *w)
{ {
struct smap_psock_map_entry *e, *tmp; struct smap_psock_map_entry *e, *tmp;
struct sk_msg_buff *md, *mtmp;
struct smap_psock *psock; struct smap_psock *psock;
psock = container_of(w, struct smap_psock, gc_work); psock = container_of(w, struct smap_psock, gc_work);
...@@ -1131,6 +1316,12 @@ static void smap_gc_work(struct work_struct *w) ...@@ -1131,6 +1316,12 @@ static void smap_gc_work(struct work_struct *w)
kfree(psock->cork); kfree(psock->cork);
} }
list_for_each_entry_safe(md, mtmp, &psock->ingress, list) {
list_del(&md->list);
free_start_sg(psock->sock, md);
kfree(md);
}
list_for_each_entry_safe(e, tmp, &psock->maps, list) { list_for_each_entry_safe(e, tmp, &psock->maps, list) {
list_del(&e->list); list_del(&e->list);
kfree(e); kfree(e);
...@@ -1160,6 +1351,7 @@ static struct smap_psock *smap_init_psock(struct sock *sock, ...@@ -1160,6 +1351,7 @@ static struct smap_psock *smap_init_psock(struct sock *sock,
INIT_WORK(&psock->tx_work, smap_tx_work); INIT_WORK(&psock->tx_work, smap_tx_work);
INIT_WORK(&psock->gc_work, smap_gc_work); INIT_WORK(&psock->gc_work, smap_gc_work);
INIT_LIST_HEAD(&psock->maps); INIT_LIST_HEAD(&psock->maps);
INIT_LIST_HEAD(&psock->ingress);
refcount_set(&psock->refcnt, 1); refcount_set(&psock->refcnt, 1);
rcu_assign_sk_user_data(sock, psock); rcu_assign_sk_user_data(sock, psock);
......
...@@ -1894,7 +1894,7 @@ BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg_buff *, msg, ...@@ -1894,7 +1894,7 @@ BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg_buff *, msg,
struct bpf_map *, map, u32, key, u64, flags) struct bpf_map *, map, u32, key, u64, flags)
{ {
/* If user passes invalid input drop the packet. */ /* If user passes invalid input drop the packet. */
if (unlikely(flags)) if (unlikely(flags & ~(BPF_F_INGRESS)))
return SK_DROP; return SK_DROP;
msg->key = key; msg->key = key;
......
...@@ -485,6 +485,14 @@ static void tcp_tx_timestamp(struct sock *sk, u16 tsflags) ...@@ -485,6 +485,14 @@ static void tcp_tx_timestamp(struct sock *sk, u16 tsflags)
} }
} }
static inline bool tcp_stream_is_readable(const struct tcp_sock *tp,
int target, struct sock *sk)
{
return (tp->rcv_nxt - tp->copied_seq >= target) ||
(sk->sk_prot->stream_memory_read ?
sk->sk_prot->stream_memory_read(sk) : false);
}
/* /*
* Wait for a TCP event. * Wait for a TCP event.
* *
...@@ -554,7 +562,7 @@ __poll_t tcp_poll(struct file *file, struct socket *sock, poll_table *wait) ...@@ -554,7 +562,7 @@ __poll_t tcp_poll(struct file *file, struct socket *sock, poll_table *wait)
tp->urg_data) tp->urg_data)
target++; target++;
if (tp->rcv_nxt - tp->copied_seq >= target) if (tcp_stream_is_readable(tp, target, sk))
mask |= EPOLLIN | EPOLLRDNORM; mask |= EPOLLIN | EPOLLRDNORM;
if (!(sk->sk_shutdown & SEND_SHUTDOWN)) { if (!(sk->sk_shutdown & SEND_SHUTDOWN)) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment