Commit 904709f6 authored by Alexei Starovoitov's avatar Alexei Starovoitov

Merge branch 'bpf: Enable bpf_sk_storage for FENTRY/FEXIT/RAW_TP'

Martin KaFai says:

====================

This set is to allow the FENTRY/FEXIT/RAW_TP tracing program to use
bpf_sk_storage.  The first two patches are a cleanup.  The last patch is
tests.  Patch 3 has the required kernel changes to
enable bpf_sk_storage for FENTRY/FEXIT/RAW_TP.

Please see individual patch for details.

v2:
- Rename some of the function prefix from sk_storage to bpf_sk_storage
- Use prefix check instead of substr check
====================
Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
parents 0a58a65c 53632e11
......@@ -20,6 +20,8 @@ void bpf_sk_storage_free(struct sock *sk);
extern const struct bpf_func_proto bpf_sk_storage_get_proto;
extern const struct bpf_func_proto bpf_sk_storage_delete_proto;
extern const struct bpf_func_proto bpf_sk_storage_get_tracing_proto;
extern const struct bpf_func_proto bpf_sk_storage_delete_tracing_proto;
struct bpf_local_storage_elem;
struct bpf_sk_storage_diag;
......
......@@ -16,6 +16,7 @@
#include <linux/syscalls.h>
#include <linux/error-injection.h>
#include <linux/btf_ids.h>
#include <net/bpf_sk_storage.h>
#include <uapi/linux/bpf.h>
#include <uapi/linux/btf.h>
......@@ -1735,6 +1736,10 @@ tracing_prog_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
return &bpf_skc_to_tcp_request_sock_proto;
case BPF_FUNC_skc_to_udp6_sock:
return &bpf_skc_to_udp6_sock_proto;
case BPF_FUNC_sk_storage_get:
return &bpf_sk_storage_get_tracing_proto;
case BPF_FUNC_sk_storage_delete:
return &bpf_sk_storage_delete_tracing_proto;
#endif
case BPF_FUNC_seq_printf:
return prog->expected_attach_type == BPF_TRACE_ITER ?
......
......@@ -6,6 +6,7 @@
#include <linux/types.h>
#include <linux/spinlock.h>
#include <linux/bpf.h>
#include <linux/btf.h>
#include <linux/btf_ids.h>
#include <linux/bpf_local_storage.h>
#include <net/bpf_sk_storage.h>
......@@ -15,20 +16,8 @@
DEFINE_BPF_STORAGE_CACHE(sk_cache);
static int omem_charge(struct sock *sk, unsigned int size)
{
/* same check as in sock_kmalloc() */
if (size <= sysctl_optmem_max &&
atomic_read(&sk->sk_omem_alloc) + size < sysctl_optmem_max) {
atomic_add(size, &sk->sk_omem_alloc);
return 0;
}
return -ENOMEM;
}
static struct bpf_local_storage_data *
sk_storage_lookup(struct sock *sk, struct bpf_map *map, bool cacheit_lockit)
bpf_sk_storage_lookup(struct sock *sk, struct bpf_map *map, bool cacheit_lockit)
{
struct bpf_local_storage *sk_storage;
struct bpf_local_storage_map *smap;
......@@ -41,11 +30,11 @@ sk_storage_lookup(struct sock *sk, struct bpf_map *map, bool cacheit_lockit)
return bpf_local_storage_lookup(sk_storage, smap, cacheit_lockit);
}
static int sk_storage_delete(struct sock *sk, struct bpf_map *map)
static int bpf_sk_storage_del(struct sock *sk, struct bpf_map *map)
{
struct bpf_local_storage_data *sdata;
sdata = sk_storage_lookup(sk, map, false);
sdata = bpf_sk_storage_lookup(sk, map, false);
if (!sdata)
return -ENOENT;
......@@ -94,7 +83,7 @@ void bpf_sk_storage_free(struct sock *sk)
kfree_rcu(sk_storage, rcu);
}
static void sk_storage_map_free(struct bpf_map *map)
static void bpf_sk_storage_map_free(struct bpf_map *map)
{
struct bpf_local_storage_map *smap;
......@@ -103,7 +92,7 @@ static void sk_storage_map_free(struct bpf_map *map)
bpf_local_storage_map_free(smap);
}
static struct bpf_map *sk_storage_map_alloc(union bpf_attr *attr)
static struct bpf_map *bpf_sk_storage_map_alloc(union bpf_attr *attr)
{
struct bpf_local_storage_map *smap;
......@@ -130,7 +119,7 @@ static void *bpf_fd_sk_storage_lookup_elem(struct bpf_map *map, void *key)
fd = *(int *)key;
sock = sockfd_lookup(fd, &err);
if (sock) {
sdata = sk_storage_lookup(sock->sk, map, true);
sdata = bpf_sk_storage_lookup(sock->sk, map, true);
sockfd_put(sock);
return sdata ? sdata->data : NULL;
}
......@@ -166,7 +155,7 @@ static int bpf_fd_sk_storage_delete_elem(struct bpf_map *map, void *key)
fd = *(int *)key;
sock = sockfd_lookup(fd, &err);
if (sock) {
err = sk_storage_delete(sock->sk, map);
err = bpf_sk_storage_del(sock->sk, map);
sockfd_put(sock);
return err;
}
......@@ -272,7 +261,7 @@ BPF_CALL_4(bpf_sk_storage_get, struct bpf_map *, map, struct sock *, sk,
if (!sk || !sk_fullsock(sk) || flags > BPF_SK_STORAGE_GET_F_CREATE)
return (unsigned long)NULL;
sdata = sk_storage_lookup(sk, map, true);
sdata = bpf_sk_storage_lookup(sk, map, true);
if (sdata)
return (unsigned long)sdata->data;
......@@ -305,7 +294,7 @@ BPF_CALL_2(bpf_sk_storage_delete, struct bpf_map *, map, struct sock *, sk)
if (refcount_inc_not_zero(&sk->sk_refcnt)) {
int err;
err = sk_storage_delete(sk, map);
err = bpf_sk_storage_del(sk, map);
sock_put(sk);
return err;
}
......@@ -313,13 +302,22 @@ BPF_CALL_2(bpf_sk_storage_delete, struct bpf_map *, map, struct sock *, sk)
return -ENOENT;
}
static int sk_storage_charge(struct bpf_local_storage_map *smap,
static int bpf_sk_storage_charge(struct bpf_local_storage_map *smap,
void *owner, u32 size)
{
return omem_charge(owner, size);
struct sock *sk = (struct sock *)owner;
/* same check as in sock_kmalloc() */
if (size <= sysctl_optmem_max &&
atomic_read(&sk->sk_omem_alloc) + size < sysctl_optmem_max) {
atomic_add(size, &sk->sk_omem_alloc);
return 0;
}
return -ENOMEM;
}
static void sk_storage_uncharge(struct bpf_local_storage_map *smap,
static void bpf_sk_storage_uncharge(struct bpf_local_storage_map *smap,
void *owner, u32 size)
{
struct sock *sk = owner;
......@@ -328,7 +326,7 @@ static void sk_storage_uncharge(struct bpf_local_storage_map *smap,
}
static struct bpf_local_storage __rcu **
sk_storage_ptr(void *owner)
bpf_sk_storage_ptr(void *owner)
{
struct sock *sk = owner;
......@@ -339,8 +337,8 @@ static int sk_storage_map_btf_id;
const struct bpf_map_ops sk_storage_map_ops = {
.map_meta_equal = bpf_map_meta_equal,
.map_alloc_check = bpf_local_storage_map_alloc_check,
.map_alloc = sk_storage_map_alloc,
.map_free = sk_storage_map_free,
.map_alloc = bpf_sk_storage_map_alloc,
.map_free = bpf_sk_storage_map_free,
.map_get_next_key = notsupp_get_next_key,
.map_lookup_elem = bpf_fd_sk_storage_lookup_elem,
.map_update_elem = bpf_fd_sk_storage_update_elem,
......@@ -348,9 +346,9 @@ const struct bpf_map_ops sk_storage_map_ops = {
.map_check_btf = bpf_local_storage_map_check_btf,
.map_btf_name = "bpf_local_storage_map",
.map_btf_id = &sk_storage_map_btf_id,
.map_local_storage_charge = sk_storage_charge,
.map_local_storage_uncharge = sk_storage_uncharge,
.map_owner_storage_ptr = sk_storage_ptr,
.map_local_storage_charge = bpf_sk_storage_charge,
.map_local_storage_uncharge = bpf_sk_storage_uncharge,
.map_owner_storage_ptr = bpf_sk_storage_ptr,
};
const struct bpf_func_proto bpf_sk_storage_get_proto = {
......@@ -381,6 +379,79 @@ const struct bpf_func_proto bpf_sk_storage_delete_proto = {
.arg2_type = ARG_PTR_TO_BTF_ID_SOCK_COMMON,
};
static bool bpf_sk_storage_tracing_allowed(const struct bpf_prog *prog)
{
const struct btf *btf_vmlinux;
const struct btf_type *t;
const char *tname;
u32 btf_id;
if (prog->aux->dst_prog)
return false;
/* Ensure the tracing program is not tracing
* any bpf_sk_storage*() function and also
* use the bpf_sk_storage_(get|delete) helper.
*/
switch (prog->expected_attach_type) {
case BPF_TRACE_RAW_TP:
/* bpf_sk_storage has no trace point */
return true;
case BPF_TRACE_FENTRY:
case BPF_TRACE_FEXIT:
btf_vmlinux = bpf_get_btf_vmlinux();
btf_id = prog->aux->attach_btf_id;
t = btf_type_by_id(btf_vmlinux, btf_id);
tname = btf_name_by_offset(btf_vmlinux, t->name_off);
return !!strncmp(tname, "bpf_sk_storage",
strlen("bpf_sk_storage"));
default:
return false;
}
return false;
}
BPF_CALL_4(bpf_sk_storage_get_tracing, struct bpf_map *, map, struct sock *, sk,
void *, value, u64, flags)
{
if (!in_serving_softirq() && !in_task())
return (unsigned long)NULL;
return (unsigned long)____bpf_sk_storage_get(map, sk, value, flags);
}
BPF_CALL_2(bpf_sk_storage_delete_tracing, struct bpf_map *, map,
struct sock *, sk)
{
if (!in_serving_softirq() && !in_task())
return -EPERM;
return ____bpf_sk_storage_delete(map, sk);
}
const struct bpf_func_proto bpf_sk_storage_get_tracing_proto = {
.func = bpf_sk_storage_get_tracing,
.gpl_only = false,
.ret_type = RET_PTR_TO_MAP_VALUE_OR_NULL,
.arg1_type = ARG_CONST_MAP_PTR,
.arg2_type = ARG_PTR_TO_BTF_ID,
.arg2_btf_id = &btf_sock_ids[BTF_SOCK_TYPE_SOCK_COMMON],
.arg3_type = ARG_PTR_TO_MAP_VALUE_OR_NULL,
.arg4_type = ARG_ANYTHING,
.allowed = bpf_sk_storage_tracing_allowed,
};
const struct bpf_func_proto bpf_sk_storage_delete_tracing_proto = {
.func = bpf_sk_storage_delete_tracing,
.gpl_only = false,
.ret_type = RET_INTEGER,
.arg1_type = ARG_CONST_MAP_PTR,
.arg2_type = ARG_PTR_TO_BTF_ID,
.arg2_btf_id = &btf_sock_ids[BTF_SOCK_TYPE_SOCK_COMMON],
.allowed = bpf_sk_storage_tracing_allowed,
};
struct bpf_sk_storage_diag {
u32 nr_maps;
struct bpf_map *maps[];
......
// SPDX-License-Identifier: GPL-2.0
/* Copyright (c) 2020 Facebook */
#include <sys/types.h>
#include <bpf/bpf.h>
#include <bpf/libbpf.h>
#include "test_progs.h"
#include "network_helpers.h"
#include "test_sk_storage_trace_itself.skel.h"
#include "test_sk_storage_tracing.skel.h"
#define LO_ADDR6 "::1"
#define TEST_COMM "test_progs"
struct sk_stg {
__u32 pid;
__u32 last_notclose_state;
char comm[16];
};
static struct test_sk_storage_tracing *skel;
static __u32 duration;
static pid_t my_pid;
static int check_sk_stg(int sk_fd, __u32 expected_state)
{
struct sk_stg sk_stg;
int err;
err = bpf_map_lookup_elem(bpf_map__fd(skel->maps.sk_stg_map), &sk_fd,
&sk_stg);
if (!ASSERT_OK(err, "map_lookup(sk_stg_map)"))
return -1;
if (!ASSERT_EQ(sk_stg.last_notclose_state, expected_state,
"last_notclose_state"))
return -1;
if (!ASSERT_EQ(sk_stg.pid, my_pid, "pid"))
return -1;
if (!ASSERT_STREQ(sk_stg.comm, skel->bss->task_comm, "task_comm"))
return -1;
return 0;
}
static void do_test(void)
{
int listen_fd = -1, passive_fd = -1, active_fd = -1, value = 1, err;
char abyte;
listen_fd = start_server(AF_INET6, SOCK_STREAM, LO_ADDR6, 0, 0);
if (CHECK(listen_fd == -1, "start_server",
"listen_fd:%d errno:%d\n", listen_fd, errno))
return;
active_fd = connect_to_fd(listen_fd, 0);
if (CHECK(active_fd == -1, "connect_to_fd", "active_fd:%d errno:%d\n",
active_fd, errno))
goto out;
err = bpf_map_update_elem(bpf_map__fd(skel->maps.del_sk_stg_map),
&active_fd, &value, 0);
if (!ASSERT_OK(err, "map_update(del_sk_stg_map)"))
goto out;
passive_fd = accept(listen_fd, NULL, 0);
if (CHECK(passive_fd == -1, "accept", "passive_fd:%d errno:%d\n",
passive_fd, errno))
goto out;
shutdown(active_fd, SHUT_WR);
err = read(passive_fd, &abyte, 1);
if (!ASSERT_OK(err, "read(passive_fd)"))
goto out;
shutdown(passive_fd, SHUT_WR);
err = read(active_fd, &abyte, 1);
if (!ASSERT_OK(err, "read(active_fd)"))
goto out;
err = bpf_map_lookup_elem(bpf_map__fd(skel->maps.del_sk_stg_map),
&active_fd, &value);
if (!ASSERT_ERR(err, "map_lookup(del_sk_stg_map)"))
goto out;
err = check_sk_stg(listen_fd, BPF_TCP_LISTEN);
if (!ASSERT_OK(err, "listen_fd sk_stg"))
goto out;
err = check_sk_stg(active_fd, BPF_TCP_FIN_WAIT2);
if (!ASSERT_OK(err, "active_fd sk_stg"))
goto out;
err = check_sk_stg(passive_fd, BPF_TCP_LAST_ACK);
ASSERT_OK(err, "passive_fd sk_stg");
out:
if (active_fd != -1)
close(active_fd);
if (passive_fd != -1)
close(passive_fd);
if (listen_fd != -1)
close(listen_fd);
}
void test_sk_storage_tracing(void)
{
struct test_sk_storage_trace_itself *skel_itself;
int err;
my_pid = getpid();
skel_itself = test_sk_storage_trace_itself__open_and_load();
if (!ASSERT_NULL(skel_itself, "test_sk_storage_trace_itself")) {
test_sk_storage_trace_itself__destroy(skel_itself);
return;
}
skel = test_sk_storage_tracing__open_and_load();
if (!ASSERT_OK_PTR(skel, "test_sk_storage_tracing"))
return;
err = test_sk_storage_tracing__attach(skel);
if (!ASSERT_OK(err, "test_sk_storage_tracing__attach")) {
test_sk_storage_tracing__destroy(skel);
return;
}
do_test();
test_sk_storage_tracing__destroy(skel);
}
// SPDX-License-Identifier: GPL-2.0
/* Copyright (c) 2020 Facebook */
#include <vmlinux.h>
#include <bpf/bpf_tracing.h>
#include <bpf/bpf_helpers.h>
struct {
__uint(type, BPF_MAP_TYPE_SK_STORAGE);
__uint(map_flags, BPF_F_NO_PREALLOC);
__type(key, int);
__type(value, int);
} sk_stg_map SEC(".maps");
SEC("fentry/bpf_sk_storage_free")
int BPF_PROG(trace_bpf_sk_storage_free, struct sock *sk)
{
int *value;
value = bpf_sk_storage_get(&sk_stg_map, sk, 0,
BPF_SK_STORAGE_GET_F_CREATE);
if (value)
*value = 1;
return 0;
}
char _license[] SEC("license") = "GPL";
// SPDX-License-Identifier: GPL-2.0
/* Copyright (c) 2020 Facebook */
#include <vmlinux.h>
#include <bpf/bpf_tracing.h>
#include <bpf/bpf_core_read.h>
#include <bpf/bpf_helpers.h>
struct sk_stg {
__u32 pid;
__u32 last_notclose_state;
char comm[16];
};
struct {
__uint(type, BPF_MAP_TYPE_SK_STORAGE);
__uint(map_flags, BPF_F_NO_PREALLOC);
__type(key, int);
__type(value, struct sk_stg);
} sk_stg_map SEC(".maps");
/* Testing delete */
struct {
__uint(type, BPF_MAP_TYPE_SK_STORAGE);
__uint(map_flags, BPF_F_NO_PREALLOC);
__type(key, int);
__type(value, int);
} del_sk_stg_map SEC(".maps");
char task_comm[16] = "";
SEC("tp_btf/inet_sock_set_state")
int BPF_PROG(trace_inet_sock_set_state, struct sock *sk, int oldstate,
int newstate)
{
struct sk_stg *stg;
if (newstate == BPF_TCP_CLOSE)
return 0;
stg = bpf_sk_storage_get(&sk_stg_map, sk, 0,
BPF_SK_STORAGE_GET_F_CREATE);
if (!stg)
return 0;
stg->last_notclose_state = newstate;
bpf_sk_storage_delete(&del_sk_stg_map, sk);
return 0;
}
static void set_task_info(struct sock *sk)
{
struct task_struct *task;
struct sk_stg *stg;
stg = bpf_sk_storage_get(&sk_stg_map, sk, 0,
BPF_SK_STORAGE_GET_F_CREATE);
if (!stg)
return;
stg->pid = bpf_get_current_pid_tgid();
task = (struct task_struct *)bpf_get_current_task();
bpf_core_read_str(&stg->comm, sizeof(stg->comm), &task->comm);
bpf_core_read_str(&task_comm, sizeof(task_comm), &task->comm);
}
SEC("fentry/inet_csk_listen_start")
int BPF_PROG(trace_inet_csk_listen_start, struct sock *sk, int backlog)
{
set_task_info(sk);
return 0;
}
SEC("fentry/tcp_connect")
int BPF_PROG(trace_tcp_connect, struct sock *sk)
{
set_task_info(sk);
return 0;
}
SEC("fexit/inet_csk_accept")
int BPF_PROG(inet_csk_accept, struct sock *sk, int flags, int *err, bool kern,
struct sock *accepted_sk)
{
set_task_info(accepted_sk);
return 0;
}
char _license[] SEC("license") = "GPL";
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment