Commit 440ffcdd authored by Jakub Kicinski's avatar Jakub Kicinski

Merge https://git.kernel.org/pub/scm/linux/kernel/git/bpf/bpf

Daniel Borkmann says:

====================
pull-request: bpf 2021-10-26

We've added 12 non-merge commits during the last 7 day(s) which contain
a total of 23 files changed, 118 insertions(+), 98 deletions(-).

The main changes are:

1) Fix potential race window in BPF tail call compatibility check, from Toke Høiland-Jørgensen.

2) Fix memory leak in cgroup fs due to missing cgroup_bpf_offline(), from Quanyang Wang.

3) Fix file descriptor reference counting in generic_map_update_batch(), from Xu Kuohai.

4) Fix bpf_jit_limit knob to the max supported limit by the arch's JIT, from Lorenz Bauer.

5) Fix BPF sockmap ->poll callbacks for UDP and AF_UNIX sockets, from Cong Wang and Yucong Sun.

6) Fix BPF sockmap concurrency issue in TCP on non-blocking sendmsg calls, from Liu Jian.

7) Fix build failure of INODE_STORAGE and TASK_STORAGE maps on !CONFIG_NET, from Tejun Heo.

* https://git.kernel.org/pub/scm/linux/kernel/git/bpf/bpf:
  bpf: Fix potential race in tail call compatibility check
  bpf: Move BPF_MAP_TYPE for INODE_STORAGE and TASK_STORAGE outside of CONFIG_NET
  selftests/bpf: Use recv_timeout() instead of retries
  net: Implement ->sock_is_readable() for UDP and AF_UNIX
  skmsg: Extract and reuse sk_msg_is_readable()
  net: Rename ->stream_memory_read to ->sock_is_readable
  tcp_bpf: Fix one concurrency problem in the tcp_bpf_send_verdict function
  cgroup: Fix memory leak caused by missing cgroup_bpf_offline
  bpf: Fix error usage of map_fd and fdget() in generic_map_update_batch()
  bpf: Prevent increasing bpf_jit_limit above max
  bpf: Define bpf_jit_alloc_exec_limit for arm64 JIT
  bpf: Define bpf_jit_alloc_exec_limit for riscv JIT
====================

Link: https://lore.kernel.org/r/20211026201920.11296-1-daniel@iogearbox.netSigned-off-by: default avatarJakub Kicinski <kuba@kernel.org>
parents 19fa0887 54713c85
...@@ -1136,6 +1136,11 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog) ...@@ -1136,6 +1136,11 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
return prog; return prog;
} }
u64 bpf_jit_alloc_exec_limit(void)
{
return BPF_JIT_REGION_SIZE;
}
void *bpf_jit_alloc_exec(unsigned long size) void *bpf_jit_alloc_exec(unsigned long size)
{ {
return __vmalloc_node_range(size, PAGE_SIZE, BPF_JIT_REGION_START, return __vmalloc_node_range(size, PAGE_SIZE, BPF_JIT_REGION_START,
......
...@@ -166,6 +166,11 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog) ...@@ -166,6 +166,11 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
return prog; return prog;
} }
u64 bpf_jit_alloc_exec_limit(void)
{
return BPF_JIT_REGION_SIZE;
}
void *bpf_jit_alloc_exec(unsigned long size) void *bpf_jit_alloc_exec(unsigned long size)
{ {
return __vmalloc_node_range(size, PAGE_SIZE, BPF_JIT_REGION_START, return __vmalloc_node_range(size, PAGE_SIZE, BPF_JIT_REGION_START,
......
...@@ -929,8 +929,11 @@ struct bpf_array_aux { ...@@ -929,8 +929,11 @@ struct bpf_array_aux {
* stored in the map to make sure that all callers and callees have * stored in the map to make sure that all callers and callees have
* the same prog type and JITed flag. * the same prog type and JITed flag.
*/ */
enum bpf_prog_type type; struct {
bool jited; spinlock_t lock;
enum bpf_prog_type type;
bool jited;
} owner;
/* Programs with direct jumps into programs part of this array. */ /* Programs with direct jumps into programs part of this array. */
struct list_head poke_progs; struct list_head poke_progs;
struct bpf_map *map; struct bpf_map *map;
......
...@@ -101,14 +101,14 @@ BPF_MAP_TYPE(BPF_MAP_TYPE_STACK_TRACE, stack_trace_map_ops) ...@@ -101,14 +101,14 @@ BPF_MAP_TYPE(BPF_MAP_TYPE_STACK_TRACE, stack_trace_map_ops)
#endif #endif
BPF_MAP_TYPE(BPF_MAP_TYPE_ARRAY_OF_MAPS, array_of_maps_map_ops) BPF_MAP_TYPE(BPF_MAP_TYPE_ARRAY_OF_MAPS, array_of_maps_map_ops)
BPF_MAP_TYPE(BPF_MAP_TYPE_HASH_OF_MAPS, htab_of_maps_map_ops) BPF_MAP_TYPE(BPF_MAP_TYPE_HASH_OF_MAPS, htab_of_maps_map_ops)
#ifdef CONFIG_NET
BPF_MAP_TYPE(BPF_MAP_TYPE_DEVMAP, dev_map_ops)
BPF_MAP_TYPE(BPF_MAP_TYPE_DEVMAP_HASH, dev_map_hash_ops)
BPF_MAP_TYPE(BPF_MAP_TYPE_SK_STORAGE, sk_storage_map_ops)
#ifdef CONFIG_BPF_LSM #ifdef CONFIG_BPF_LSM
BPF_MAP_TYPE(BPF_MAP_TYPE_INODE_STORAGE, inode_storage_map_ops) BPF_MAP_TYPE(BPF_MAP_TYPE_INODE_STORAGE, inode_storage_map_ops)
#endif #endif
BPF_MAP_TYPE(BPF_MAP_TYPE_TASK_STORAGE, task_storage_map_ops) BPF_MAP_TYPE(BPF_MAP_TYPE_TASK_STORAGE, task_storage_map_ops)
#ifdef CONFIG_NET
BPF_MAP_TYPE(BPF_MAP_TYPE_DEVMAP, dev_map_ops)
BPF_MAP_TYPE(BPF_MAP_TYPE_DEVMAP_HASH, dev_map_hash_ops)
BPF_MAP_TYPE(BPF_MAP_TYPE_SK_STORAGE, sk_storage_map_ops)
BPF_MAP_TYPE(BPF_MAP_TYPE_CPUMAP, cpu_map_ops) BPF_MAP_TYPE(BPF_MAP_TYPE_CPUMAP, cpu_map_ops)
#if defined(CONFIG_XDP_SOCKETS) #if defined(CONFIG_XDP_SOCKETS)
BPF_MAP_TYPE(BPF_MAP_TYPE_XSKMAP, xsk_map_ops) BPF_MAP_TYPE(BPF_MAP_TYPE_XSKMAP, xsk_map_ops)
......
...@@ -1051,6 +1051,7 @@ extern int bpf_jit_enable; ...@@ -1051,6 +1051,7 @@ extern int bpf_jit_enable;
extern int bpf_jit_harden; extern int bpf_jit_harden;
extern int bpf_jit_kallsyms; extern int bpf_jit_kallsyms;
extern long bpf_jit_limit; extern long bpf_jit_limit;
extern long bpf_jit_limit_max;
typedef void (*bpf_jit_fill_hole_t)(void *area, unsigned int size); typedef void (*bpf_jit_fill_hole_t)(void *area, unsigned int size);
......
...@@ -128,6 +128,7 @@ int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from, ...@@ -128,6 +128,7 @@ int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
struct sk_msg *msg, u32 bytes); struct sk_msg *msg, u32 bytes);
int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg, int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg,
int len, int flags); int len, int flags);
bool sk_msg_is_readable(struct sock *sk);
static inline void sk_msg_check_to_free(struct sk_msg *msg, u32 i, u32 bytes) static inline void sk_msg_check_to_free(struct sk_msg *msg, u32 i, u32 bytes)
{ {
......
...@@ -1208,7 +1208,7 @@ struct proto { ...@@ -1208,7 +1208,7 @@ struct proto {
#endif #endif
bool (*stream_memory_free)(const struct sock *sk, int wake); bool (*stream_memory_free)(const struct sock *sk, int wake);
bool (*stream_memory_read)(const struct sock *sk); bool (*sock_is_readable)(struct sock *sk);
/* Memory pressure */ /* Memory pressure */
void (*enter_memory_pressure)(struct sock *sk); void (*enter_memory_pressure)(struct sock *sk);
void (*leave_memory_pressure)(struct sock *sk); void (*leave_memory_pressure)(struct sock *sk);
...@@ -2820,4 +2820,10 @@ void sock_set_sndtimeo(struct sock *sk, s64 secs); ...@@ -2820,4 +2820,10 @@ void sock_set_sndtimeo(struct sock *sk, s64 secs);
int sock_bind_add(struct sock *sk, struct sockaddr *addr, int addr_len); int sock_bind_add(struct sock *sk, struct sockaddr *addr, int addr_len);
static inline bool sk_is_readable(struct sock *sk)
{
if (sk->sk_prot->sock_is_readable)
return sk->sk_prot->sock_is_readable(sk);
return false;
}
#endif /* _SOCK_H */ #endif /* _SOCK_H */
...@@ -375,7 +375,7 @@ void tls_sw_release_resources_rx(struct sock *sk); ...@@ -375,7 +375,7 @@ void tls_sw_release_resources_rx(struct sock *sk);
void tls_sw_free_ctx_rx(struct tls_context *tls_ctx); void tls_sw_free_ctx_rx(struct tls_context *tls_ctx);
int tls_sw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int tls_sw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
int nonblock, int flags, int *addr_len); int nonblock, int flags, int *addr_len);
bool tls_sw_stream_read(const struct sock *sk); bool tls_sw_sock_is_readable(struct sock *sk);
ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
struct pipe_inode_info *pipe, struct pipe_inode_info *pipe,
size_t len, unsigned int flags); size_t len, unsigned int flags);
......
...@@ -1072,6 +1072,7 @@ static struct bpf_map *prog_array_map_alloc(union bpf_attr *attr) ...@@ -1072,6 +1072,7 @@ static struct bpf_map *prog_array_map_alloc(union bpf_attr *attr)
INIT_WORK(&aux->work, prog_array_map_clear_deferred); INIT_WORK(&aux->work, prog_array_map_clear_deferred);
INIT_LIST_HEAD(&aux->poke_progs); INIT_LIST_HEAD(&aux->poke_progs);
mutex_init(&aux->poke_mutex); mutex_init(&aux->poke_mutex);
spin_lock_init(&aux->owner.lock);
map = array_map_alloc(attr); map = array_map_alloc(attr);
if (IS_ERR(map)) { if (IS_ERR(map)) {
......
...@@ -524,6 +524,7 @@ int bpf_jit_enable __read_mostly = IS_BUILTIN(CONFIG_BPF_JIT_DEFAULT_ON); ...@@ -524,6 +524,7 @@ int bpf_jit_enable __read_mostly = IS_BUILTIN(CONFIG_BPF_JIT_DEFAULT_ON);
int bpf_jit_kallsyms __read_mostly = IS_BUILTIN(CONFIG_BPF_JIT_DEFAULT_ON); int bpf_jit_kallsyms __read_mostly = IS_BUILTIN(CONFIG_BPF_JIT_DEFAULT_ON);
int bpf_jit_harden __read_mostly; int bpf_jit_harden __read_mostly;
long bpf_jit_limit __read_mostly; long bpf_jit_limit __read_mostly;
long bpf_jit_limit_max __read_mostly;
static void static void
bpf_prog_ksym_set_addr(struct bpf_prog *prog) bpf_prog_ksym_set_addr(struct bpf_prog *prog)
...@@ -817,7 +818,8 @@ u64 __weak bpf_jit_alloc_exec_limit(void) ...@@ -817,7 +818,8 @@ u64 __weak bpf_jit_alloc_exec_limit(void)
static int __init bpf_jit_charge_init(void) static int __init bpf_jit_charge_init(void)
{ {
/* Only used as heuristic here to derive limit. */ /* Only used as heuristic here to derive limit. */
bpf_jit_limit = min_t(u64, round_up(bpf_jit_alloc_exec_limit() >> 2, bpf_jit_limit_max = bpf_jit_alloc_exec_limit();
bpf_jit_limit = min_t(u64, round_up(bpf_jit_limit_max >> 2,
PAGE_SIZE), LONG_MAX); PAGE_SIZE), LONG_MAX);
return 0; return 0;
} }
...@@ -1821,20 +1823,26 @@ static unsigned int __bpf_prog_ret0_warn(const void *ctx, ...@@ -1821,20 +1823,26 @@ static unsigned int __bpf_prog_ret0_warn(const void *ctx,
bool bpf_prog_array_compatible(struct bpf_array *array, bool bpf_prog_array_compatible(struct bpf_array *array,
const struct bpf_prog *fp) const struct bpf_prog *fp)
{ {
bool ret;
if (fp->kprobe_override) if (fp->kprobe_override)
return false; return false;
if (!array->aux->type) { spin_lock(&array->aux->owner.lock);
if (!array->aux->owner.type) {
/* There's no owner yet where we could check for /* There's no owner yet where we could check for
* compatibility. * compatibility.
*/ */
array->aux->type = fp->type; array->aux->owner.type = fp->type;
array->aux->jited = fp->jited; array->aux->owner.jited = fp->jited;
return true; ret = true;
} else {
ret = array->aux->owner.type == fp->type &&
array->aux->owner.jited == fp->jited;
} }
spin_unlock(&array->aux->owner.lock);
return array->aux->type == fp->type && return ret;
array->aux->jited == fp->jited;
} }
static int bpf_check_tail_call(const struct bpf_prog *fp) static int bpf_check_tail_call(const struct bpf_prog *fp)
......
...@@ -543,8 +543,10 @@ static void bpf_map_show_fdinfo(struct seq_file *m, struct file *filp) ...@@ -543,8 +543,10 @@ static void bpf_map_show_fdinfo(struct seq_file *m, struct file *filp)
if (map->map_type == BPF_MAP_TYPE_PROG_ARRAY) { if (map->map_type == BPF_MAP_TYPE_PROG_ARRAY) {
array = container_of(map, struct bpf_array, map); array = container_of(map, struct bpf_array, map);
type = array->aux->type; spin_lock(&array->aux->owner.lock);
jited = array->aux->jited; type = array->aux->owner.type;
jited = array->aux->owner.jited;
spin_unlock(&array->aux->owner.lock);
} }
seq_printf(m, seq_printf(m,
...@@ -1337,12 +1339,11 @@ int generic_map_update_batch(struct bpf_map *map, ...@@ -1337,12 +1339,11 @@ int generic_map_update_batch(struct bpf_map *map,
void __user *values = u64_to_user_ptr(attr->batch.values); void __user *values = u64_to_user_ptr(attr->batch.values);
void __user *keys = u64_to_user_ptr(attr->batch.keys); void __user *keys = u64_to_user_ptr(attr->batch.keys);
u32 value_size, cp, max_count; u32 value_size, cp, max_count;
int ufd = attr->map_fd; int ufd = attr->batch.map_fd;
void *key, *value; void *key, *value;
struct fd f; struct fd f;
int err = 0; int err = 0;
f = fdget(ufd);
if (attr->batch.elem_flags & ~BPF_F_LOCK) if (attr->batch.elem_flags & ~BPF_F_LOCK)
return -EINVAL; return -EINVAL;
...@@ -1367,6 +1368,7 @@ int generic_map_update_batch(struct bpf_map *map, ...@@ -1367,6 +1368,7 @@ int generic_map_update_batch(struct bpf_map *map,
return -ENOMEM; return -ENOMEM;
} }
f = fdget(ufd); /* bpf_map_do_batch() guarantees ufd is valid */
for (cp = 0; cp < max_count; cp++) { for (cp = 0; cp < max_count; cp++) {
err = -EFAULT; err = -EFAULT;
if (copy_from_user(key, keys + cp * map->key_size, if (copy_from_user(key, keys + cp * map->key_size,
...@@ -1386,6 +1388,7 @@ int generic_map_update_batch(struct bpf_map *map, ...@@ -1386,6 +1388,7 @@ int generic_map_update_batch(struct bpf_map *map,
kvfree(value); kvfree(value);
kvfree(key); kvfree(key);
fdput(f);
return err; return err;
} }
......
...@@ -2187,8 +2187,10 @@ static void cgroup_kill_sb(struct super_block *sb) ...@@ -2187,8 +2187,10 @@ static void cgroup_kill_sb(struct super_block *sb)
* And don't kill the default root. * And don't kill the default root.
*/ */
if (list_empty(&root->cgrp.self.children) && root != &cgrp_dfl_root && if (list_empty(&root->cgrp.self.children) && root != &cgrp_dfl_root &&
!percpu_ref_is_dying(&root->cgrp.self.refcnt)) !percpu_ref_is_dying(&root->cgrp.self.refcnt)) {
cgroup_bpf_offline(&root->cgrp);
percpu_ref_kill(&root->cgrp.self.refcnt); percpu_ref_kill(&root->cgrp.self.refcnt);
}
cgroup_put(&root->cgrp); cgroup_put(&root->cgrp);
kernfs_kill_sb(sb); kernfs_kill_sb(sb);
} }
......
...@@ -474,6 +474,20 @@ int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg, ...@@ -474,6 +474,20 @@ int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg,
} }
EXPORT_SYMBOL_GPL(sk_msg_recvmsg); EXPORT_SYMBOL_GPL(sk_msg_recvmsg);
bool sk_msg_is_readable(struct sock *sk)
{
struct sk_psock *psock;
bool empty = true;
rcu_read_lock();
psock = sk_psock(sk);
if (likely(psock))
empty = list_empty(&psock->ingress_msg);
rcu_read_unlock();
return !empty;
}
EXPORT_SYMBOL_GPL(sk_msg_is_readable);
static struct sk_msg *sk_psock_create_ingress_msg(struct sock *sk, static struct sk_msg *sk_psock_create_ingress_msg(struct sock *sk,
struct sk_buff *skb) struct sk_buff *skb)
{ {
......
...@@ -419,7 +419,7 @@ static struct ctl_table net_core_table[] = { ...@@ -419,7 +419,7 @@ static struct ctl_table net_core_table[] = {
.mode = 0600, .mode = 0600,
.proc_handler = proc_dolongvec_minmax_bpf_restricted, .proc_handler = proc_dolongvec_minmax_bpf_restricted,
.extra1 = &long_one, .extra1 = &long_one,
.extra2 = &long_max, .extra2 = &bpf_jit_limit_max,
}, },
#endif #endif
{ {
......
...@@ -486,10 +486,7 @@ static bool tcp_stream_is_readable(struct sock *sk, int target) ...@@ -486,10 +486,7 @@ static bool tcp_stream_is_readable(struct sock *sk, int target)
{ {
if (tcp_epollin_ready(sk, target)) if (tcp_epollin_ready(sk, target))
return true; return true;
return sk_is_readable(sk);
if (sk->sk_prot->stream_memory_read)
return sk->sk_prot->stream_memory_read(sk);
return false;
} }
/* /*
......
...@@ -150,19 +150,6 @@ int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, ...@@ -150,19 +150,6 @@ int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg,
EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir); EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir);
#ifdef CONFIG_BPF_SYSCALL #ifdef CONFIG_BPF_SYSCALL
static bool tcp_bpf_stream_read(const struct sock *sk)
{
struct sk_psock *psock;
bool empty = true;
rcu_read_lock();
psock = sk_psock(sk);
if (likely(psock))
empty = list_empty(&psock->ingress_msg);
rcu_read_unlock();
return !empty;
}
static int tcp_msg_wait_data(struct sock *sk, struct sk_psock *psock, static int tcp_msg_wait_data(struct sock *sk, struct sk_psock *psock,
long timeo) long timeo)
{ {
...@@ -232,6 +219,7 @@ static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock, ...@@ -232,6 +219,7 @@ static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock,
bool cork = false, enospc = sk_msg_full(msg); bool cork = false, enospc = sk_msg_full(msg);
struct sock *sk_redir; struct sock *sk_redir;
u32 tosend, delta = 0; u32 tosend, delta = 0;
u32 eval = __SK_NONE;
int ret; int ret;
more_data: more_data:
...@@ -275,13 +263,24 @@ static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock, ...@@ -275,13 +263,24 @@ static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock,
case __SK_REDIRECT: case __SK_REDIRECT:
sk_redir = psock->sk_redir; sk_redir = psock->sk_redir;
sk_msg_apply_bytes(psock, tosend); sk_msg_apply_bytes(psock, tosend);
if (!psock->apply_bytes) {
/* Clean up before releasing the sock lock. */
eval = psock->eval;
psock->eval = __SK_NONE;
psock->sk_redir = NULL;
}
if (psock->cork) { if (psock->cork) {
cork = true; cork = true;
psock->cork = NULL; psock->cork = NULL;
} }
sk_msg_return(sk, msg, tosend); sk_msg_return(sk, msg, tosend);
release_sock(sk); release_sock(sk);
ret = tcp_bpf_sendmsg_redir(sk_redir, msg, tosend, flags); ret = tcp_bpf_sendmsg_redir(sk_redir, msg, tosend, flags);
if (eval == __SK_REDIRECT)
sock_put(sk_redir);
lock_sock(sk); lock_sock(sk);
if (unlikely(ret < 0)) { if (unlikely(ret < 0)) {
int free = sk_msg_free_nocharge(sk, msg); int free = sk_msg_free_nocharge(sk, msg);
...@@ -479,7 +478,7 @@ static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS], ...@@ -479,7 +478,7 @@ static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS],
prot[TCP_BPF_BASE].unhash = sock_map_unhash; prot[TCP_BPF_BASE].unhash = sock_map_unhash;
prot[TCP_BPF_BASE].close = sock_map_close; prot[TCP_BPF_BASE].close = sock_map_close;
prot[TCP_BPF_BASE].recvmsg = tcp_bpf_recvmsg; prot[TCP_BPF_BASE].recvmsg = tcp_bpf_recvmsg;
prot[TCP_BPF_BASE].stream_memory_read = tcp_bpf_stream_read; prot[TCP_BPF_BASE].sock_is_readable = sk_msg_is_readable;
prot[TCP_BPF_TX] = prot[TCP_BPF_BASE]; prot[TCP_BPF_TX] = prot[TCP_BPF_BASE];
prot[TCP_BPF_TX].sendmsg = tcp_bpf_sendmsg; prot[TCP_BPF_TX].sendmsg = tcp_bpf_sendmsg;
......
...@@ -2867,6 +2867,9 @@ __poll_t udp_poll(struct file *file, struct socket *sock, poll_table *wait) ...@@ -2867,6 +2867,9 @@ __poll_t udp_poll(struct file *file, struct socket *sock, poll_table *wait)
!(sk->sk_shutdown & RCV_SHUTDOWN) && first_packet_length(sk) == -1) !(sk->sk_shutdown & RCV_SHUTDOWN) && first_packet_length(sk) == -1)
mask &= ~(EPOLLIN | EPOLLRDNORM); mask &= ~(EPOLLIN | EPOLLRDNORM);
/* psock ingress_msg queue should not contain any bad checksum frames */
if (sk_is_readable(sk))
mask |= EPOLLIN | EPOLLRDNORM;
return mask; return mask;
} }
......
...@@ -114,6 +114,7 @@ static void udp_bpf_rebuild_protos(struct proto *prot, const struct proto *base) ...@@ -114,6 +114,7 @@ static void udp_bpf_rebuild_protos(struct proto *prot, const struct proto *base)
*prot = *base; *prot = *base;
prot->close = sock_map_close; prot->close = sock_map_close;
prot->recvmsg = udp_bpf_recvmsg; prot->recvmsg = udp_bpf_recvmsg;
prot->sock_is_readable = sk_msg_is_readable;
} }
static void udp_bpf_check_v6_needs_rebuild(struct proto *ops) static void udp_bpf_check_v6_needs_rebuild(struct proto *ops)
......
...@@ -681,12 +681,12 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], ...@@ -681,12 +681,12 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE]; prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE];
prot[TLS_BASE][TLS_SW].recvmsg = tls_sw_recvmsg; prot[TLS_BASE][TLS_SW].recvmsg = tls_sw_recvmsg;
prot[TLS_BASE][TLS_SW].stream_memory_read = tls_sw_stream_read; prot[TLS_BASE][TLS_SW].sock_is_readable = tls_sw_sock_is_readable;
prot[TLS_BASE][TLS_SW].close = tls_sk_proto_close; prot[TLS_BASE][TLS_SW].close = tls_sk_proto_close;
prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE]; prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE];
prot[TLS_SW][TLS_SW].recvmsg = tls_sw_recvmsg; prot[TLS_SW][TLS_SW].recvmsg = tls_sw_recvmsg;
prot[TLS_SW][TLS_SW].stream_memory_read = tls_sw_stream_read; prot[TLS_SW][TLS_SW].sock_is_readable = tls_sw_sock_is_readable;
prot[TLS_SW][TLS_SW].close = tls_sk_proto_close; prot[TLS_SW][TLS_SW].close = tls_sk_proto_close;
#ifdef CONFIG_TLS_DEVICE #ifdef CONFIG_TLS_DEVICE
......
...@@ -2026,7 +2026,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, ...@@ -2026,7 +2026,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
return copied ? : err; return copied ? : err;
} }
bool tls_sw_stream_read(const struct sock *sk) bool tls_sw_sock_is_readable(struct sock *sk)
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
......
...@@ -3052,6 +3052,8 @@ static __poll_t unix_poll(struct file *file, struct socket *sock, poll_table *wa ...@@ -3052,6 +3052,8 @@ static __poll_t unix_poll(struct file *file, struct socket *sock, poll_table *wa
/* readable? */ /* readable? */
if (!skb_queue_empty_lockless(&sk->sk_receive_queue)) if (!skb_queue_empty_lockless(&sk->sk_receive_queue))
mask |= EPOLLIN | EPOLLRDNORM; mask |= EPOLLIN | EPOLLRDNORM;
if (sk_is_readable(sk))
mask |= EPOLLIN | EPOLLRDNORM;
/* Connection-based need to check for termination and startup */ /* Connection-based need to check for termination and startup */
if ((sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET) && if ((sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET) &&
...@@ -3091,6 +3093,8 @@ static __poll_t unix_dgram_poll(struct file *file, struct socket *sock, ...@@ -3091,6 +3093,8 @@ static __poll_t unix_dgram_poll(struct file *file, struct socket *sock,
/* readable? */ /* readable? */
if (!skb_queue_empty_lockless(&sk->sk_receive_queue)) if (!skb_queue_empty_lockless(&sk->sk_receive_queue))
mask |= EPOLLIN | EPOLLRDNORM; mask |= EPOLLIN | EPOLLRDNORM;
if (sk_is_readable(sk))
mask |= EPOLLIN | EPOLLRDNORM;
/* Connection-based need to check for termination and startup */ /* Connection-based need to check for termination and startup */
if (sk->sk_type == SOCK_SEQPACKET) { if (sk->sk_type == SOCK_SEQPACKET) {
......
...@@ -102,6 +102,7 @@ static void unix_dgram_bpf_rebuild_protos(struct proto *prot, const struct proto ...@@ -102,6 +102,7 @@ static void unix_dgram_bpf_rebuild_protos(struct proto *prot, const struct proto
*prot = *base; *prot = *base;
prot->close = sock_map_close; prot->close = sock_map_close;
prot->recvmsg = unix_bpf_recvmsg; prot->recvmsg = unix_bpf_recvmsg;
prot->sock_is_readable = sk_msg_is_readable;
} }
static void unix_stream_bpf_rebuild_protos(struct proto *prot, static void unix_stream_bpf_rebuild_protos(struct proto *prot,
...@@ -110,6 +111,7 @@ static void unix_stream_bpf_rebuild_protos(struct proto *prot, ...@@ -110,6 +111,7 @@ static void unix_stream_bpf_rebuild_protos(struct proto *prot,
*prot = *base; *prot = *base;
prot->close = sock_map_close; prot->close = sock_map_close;
prot->recvmsg = unix_bpf_recvmsg; prot->recvmsg = unix_bpf_recvmsg;
prot->sock_is_readable = sk_msg_is_readable;
prot->unhash = sock_map_unhash; prot->unhash = sock_map_unhash;
} }
......
...@@ -949,7 +949,6 @@ static void redir_to_connected(int family, int sotype, int sock_mapfd, ...@@ -949,7 +949,6 @@ static void redir_to_connected(int family, int sotype, int sock_mapfd,
int err, n; int err, n;
u32 key; u32 key;
char b; char b;
int retries = 100;
zero_verdict_count(verd_mapfd); zero_verdict_count(verd_mapfd);
...@@ -1002,17 +1001,11 @@ static void redir_to_connected(int family, int sotype, int sock_mapfd, ...@@ -1002,17 +1001,11 @@ static void redir_to_connected(int family, int sotype, int sock_mapfd,
goto close_peer1; goto close_peer1;
if (pass != 1) if (pass != 1)
FAIL("%s: want pass count 1, have %d", log_prefix, pass); FAIL("%s: want pass count 1, have %d", log_prefix, pass);
again: n = recv_timeout(c0, &b, 1, 0, IO_TIMEOUT_SEC);
n = read(c0, &b, 1); if (n < 0)
if (n < 0) { FAIL_ERRNO("%s: recv_timeout", log_prefix);
if (errno == EAGAIN && retries--) {
usleep(1000);
goto again;
}
FAIL_ERRNO("%s: read", log_prefix);
}
if (n == 0) if (n == 0)
FAIL("%s: incomplete read", log_prefix); FAIL("%s: incomplete recv", log_prefix);
close_peer1: close_peer1:
xclose(p1); xclose(p1);
...@@ -1571,7 +1564,6 @@ static void unix_redir_to_connected(int sotype, int sock_mapfd, ...@@ -1571,7 +1564,6 @@ static void unix_redir_to_connected(int sotype, int sock_mapfd,
const char *log_prefix = redir_mode_str(mode); const char *log_prefix = redir_mode_str(mode);
int c0, c1, p0, p1; int c0, c1, p0, p1;
unsigned int pass; unsigned int pass;
int retries = 100;
int err, n; int err, n;
int sfd[2]; int sfd[2];
u32 key; u32 key;
...@@ -1606,17 +1598,11 @@ static void unix_redir_to_connected(int sotype, int sock_mapfd, ...@@ -1606,17 +1598,11 @@ static void unix_redir_to_connected(int sotype, int sock_mapfd,
if (pass != 1) if (pass != 1)
FAIL("%s: want pass count 1, have %d", log_prefix, pass); FAIL("%s: want pass count 1, have %d", log_prefix, pass);
again: n = recv_timeout(mode == REDIR_INGRESS ? p0 : c0, &b, 1, 0, IO_TIMEOUT_SEC);
n = read(mode == REDIR_INGRESS ? p0 : c0, &b, 1); if (n < 0)
if (n < 0) { FAIL_ERRNO("%s: recv_timeout", log_prefix);
if (errno == EAGAIN && retries--) {
usleep(1000);
goto again;
}
FAIL_ERRNO("%s: read", log_prefix);
}
if (n == 0) if (n == 0)
FAIL("%s: incomplete read", log_prefix); FAIL("%s: incomplete recv", log_prefix);
close: close:
xclose(c1); xclose(c1);
...@@ -1748,7 +1734,6 @@ static void udp_redir_to_connected(int family, int sock_mapfd, int verd_mapfd, ...@@ -1748,7 +1734,6 @@ static void udp_redir_to_connected(int family, int sock_mapfd, int verd_mapfd,
const char *log_prefix = redir_mode_str(mode); const char *log_prefix = redir_mode_str(mode);
int c0, c1, p0, p1; int c0, c1, p0, p1;
unsigned int pass; unsigned int pass;
int retries = 100;
int err, n; int err, n;
u32 key; u32 key;
char b; char b;
...@@ -1781,17 +1766,11 @@ static void udp_redir_to_connected(int family, int sock_mapfd, int verd_mapfd, ...@@ -1781,17 +1766,11 @@ static void udp_redir_to_connected(int family, int sock_mapfd, int verd_mapfd,
if (pass != 1) if (pass != 1)
FAIL("%s: want pass count 1, have %d", log_prefix, pass); FAIL("%s: want pass count 1, have %d", log_prefix, pass);
again: n = recv_timeout(mode == REDIR_INGRESS ? p0 : c0, &b, 1, 0, IO_TIMEOUT_SEC);
n = read(mode == REDIR_INGRESS ? p0 : c0, &b, 1); if (n < 0)
if (n < 0) { FAIL_ERRNO("%s: recv_timeout", log_prefix);
if (errno == EAGAIN && retries--) {
usleep(1000);
goto again;
}
FAIL_ERRNO("%s: read", log_prefix);
}
if (n == 0) if (n == 0)
FAIL("%s: incomplete read", log_prefix); FAIL("%s: incomplete recv", log_prefix);
close_cli1: close_cli1:
xclose(c1); xclose(c1);
...@@ -1841,7 +1820,6 @@ static void inet_unix_redir_to_connected(int family, int type, int sock_mapfd, ...@@ -1841,7 +1820,6 @@ static void inet_unix_redir_to_connected(int family, int type, int sock_mapfd,
const char *log_prefix = redir_mode_str(mode); const char *log_prefix = redir_mode_str(mode);
int c0, c1, p0, p1; int c0, c1, p0, p1;
unsigned int pass; unsigned int pass;
int retries = 100;
int err, n; int err, n;
int sfd[2]; int sfd[2];
u32 key; u32 key;
...@@ -1876,17 +1854,11 @@ static void inet_unix_redir_to_connected(int family, int type, int sock_mapfd, ...@@ -1876,17 +1854,11 @@ static void inet_unix_redir_to_connected(int family, int type, int sock_mapfd,
if (pass != 1) if (pass != 1)
FAIL("%s: want pass count 1, have %d", log_prefix, pass); FAIL("%s: want pass count 1, have %d", log_prefix, pass);
again: n = recv_timeout(mode == REDIR_INGRESS ? p0 : c0, &b, 1, 0, IO_TIMEOUT_SEC);
n = read(mode == REDIR_INGRESS ? p0 : c0, &b, 1); if (n < 0)
if (n < 0) { FAIL_ERRNO("%s: recv_timeout", log_prefix);
if (errno == EAGAIN && retries--) {
usleep(1000);
goto again;
}
FAIL_ERRNO("%s: read", log_prefix);
}
if (n == 0) if (n == 0)
FAIL("%s: incomplete read", log_prefix); FAIL("%s: incomplete recv", log_prefix);
close_cli1: close_cli1:
xclose(c1); xclose(c1);
...@@ -1932,7 +1904,6 @@ static void unix_inet_redir_to_connected(int family, int type, int sock_mapfd, ...@@ -1932,7 +1904,6 @@ static void unix_inet_redir_to_connected(int family, int type, int sock_mapfd,
int sfd[2]; int sfd[2];
u32 key; u32 key;
char b; char b;
int retries = 100;
zero_verdict_count(verd_mapfd); zero_verdict_count(verd_mapfd);
...@@ -1963,17 +1934,11 @@ static void unix_inet_redir_to_connected(int family, int type, int sock_mapfd, ...@@ -1963,17 +1934,11 @@ static void unix_inet_redir_to_connected(int family, int type, int sock_mapfd,
if (pass != 1) if (pass != 1)
FAIL("%s: want pass count 1, have %d", log_prefix, pass); FAIL("%s: want pass count 1, have %d", log_prefix, pass);
again: n = recv_timeout(mode == REDIR_INGRESS ? p0 : c0, &b, 1, 0, IO_TIMEOUT_SEC);
n = read(mode == REDIR_INGRESS ? p0 : c0, &b, 1); if (n < 0)
if (n < 0) { FAIL_ERRNO("%s: recv_timeout", log_prefix);
if (errno == EAGAIN && retries--) {
usleep(1000);
goto again;
}
FAIL_ERRNO("%s: read", log_prefix);
}
if (n == 0) if (n == 0)
FAIL("%s: incomplete read", log_prefix); FAIL("%s: incomplete recv", log_prefix);
close: close:
xclose(c1); xclose(c1);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment