Commit 612d087d authored by Kui-Feng Lee's avatar Kui-Feng Lee Committed by Martin KaFai Lau

bpf: validate value_type

A value_type should consist of three components: refcnt, state, and data.
refcnt and state has been move to struct bpf_struct_ops_common_value to
make it easier to check the value type.
Signed-off-by: default avatarKui-Feng Lee <thinker.li@gmail.com>
Link: https://lore.kernel.org/r/20240119225005.668602-11-thinker.li@gmail.comSigned-off-by: default avatarMartin KaFai Lau <martin.lau@kernel.org>
parent e3f87fdf
...@@ -1688,6 +1688,18 @@ struct bpf_struct_ops_desc { ...@@ -1688,6 +1688,18 @@ struct bpf_struct_ops_desc {
u32 value_id; u32 value_id;
}; };
enum bpf_struct_ops_state {
BPF_STRUCT_OPS_STATE_INIT,
BPF_STRUCT_OPS_STATE_INUSE,
BPF_STRUCT_OPS_STATE_TOBEFREE,
BPF_STRUCT_OPS_STATE_READY,
};
struct bpf_struct_ops_common_value {
refcount_t refcnt;
enum bpf_struct_ops_state state;
};
#if defined(CONFIG_BPF_JIT) && defined(CONFIG_BPF_SYSCALL) #if defined(CONFIG_BPF_JIT) && defined(CONFIG_BPF_SYSCALL)
#define BPF_MODULE_OWNER ((void *)((0xeB9FUL << 2) + POISON_POINTER_DELTA)) #define BPF_MODULE_OWNER ((void *)((0xeB9FUL << 2) + POISON_POINTER_DELTA))
const struct bpf_struct_ops_desc *bpf_struct_ops_find(struct btf *btf, u32 type_id); const struct bpf_struct_ops_desc *bpf_struct_ops_find(struct btf *btf, u32 type_id);
......
...@@ -13,19 +13,8 @@ ...@@ -13,19 +13,8 @@
#include <linux/btf_ids.h> #include <linux/btf_ids.h>
#include <linux/rcupdate_wait.h> #include <linux/rcupdate_wait.h>
enum bpf_struct_ops_state {
BPF_STRUCT_OPS_STATE_INIT,
BPF_STRUCT_OPS_STATE_INUSE,
BPF_STRUCT_OPS_STATE_TOBEFREE,
BPF_STRUCT_OPS_STATE_READY,
};
#define BPF_STRUCT_OPS_COMMON_VALUE \
refcount_t refcnt; \
enum bpf_struct_ops_state state
struct bpf_struct_ops_value { struct bpf_struct_ops_value {
BPF_STRUCT_OPS_COMMON_VALUE; struct bpf_struct_ops_common_value common;
char data[] ____cacheline_aligned_in_smp; char data[] ____cacheline_aligned_in_smp;
}; };
...@@ -82,7 +71,7 @@ static DEFINE_MUTEX(update_mutex); ...@@ -82,7 +71,7 @@ static DEFINE_MUTEX(update_mutex);
extern struct bpf_struct_ops bpf_##_name; \ extern struct bpf_struct_ops bpf_##_name; \
\ \
struct bpf_struct_ops_##_name { \ struct bpf_struct_ops_##_name { \
BPF_STRUCT_OPS_COMMON_VALUE; \ struct bpf_struct_ops_common_value common; \
struct _name data ____cacheline_aligned_in_smp; \ struct _name data ____cacheline_aligned_in_smp; \
}; };
#include "bpf_struct_ops_types.h" #include "bpf_struct_ops_types.h"
...@@ -113,11 +102,49 @@ const struct bpf_prog_ops bpf_struct_ops_prog_ops = { ...@@ -113,11 +102,49 @@ const struct bpf_prog_ops bpf_struct_ops_prog_ops = {
BTF_ID_LIST(st_ops_ids) BTF_ID_LIST(st_ops_ids)
BTF_ID(struct, module) BTF_ID(struct, module)
BTF_ID(struct, bpf_struct_ops_common_value)
enum { enum {
IDX_MODULE_ID, IDX_MODULE_ID,
IDX_ST_OPS_COMMON_VALUE_ID,
}; };
extern struct btf *btf_vmlinux;
static bool is_valid_value_type(struct btf *btf, s32 value_id,
const struct btf_type *type,
const char *value_name)
{
const struct btf_type *common_value_type;
const struct btf_member *member;
const struct btf_type *vt, *mt;
vt = btf_type_by_id(btf, value_id);
if (btf_vlen(vt) != 2) {
pr_warn("The number of %s's members should be 2, but we get %d\n",
value_name, btf_vlen(vt));
return false;
}
member = btf_type_member(vt);
mt = btf_type_by_id(btf, member->type);
common_value_type = btf_type_by_id(btf_vmlinux,
st_ops_ids[IDX_ST_OPS_COMMON_VALUE_ID]);
if (mt != common_value_type) {
pr_warn("The first member of %s should be bpf_struct_ops_common_value\n",
value_name);
return false;
}
member++;
mt = btf_type_by_id(btf, member->type);
if (mt != type) {
pr_warn("The second member of %s should be %s\n",
value_name, btf_name_by_offset(btf, type->name_off));
return false;
}
return true;
}
static void bpf_struct_ops_desc_init(struct bpf_struct_ops_desc *st_ops_desc, static void bpf_struct_ops_desc_init(struct bpf_struct_ops_desc *st_ops_desc,
struct btf *btf, struct btf *btf,
struct bpf_verifier_log *log) struct bpf_verifier_log *log)
...@@ -138,14 +165,6 @@ static void bpf_struct_ops_desc_init(struct bpf_struct_ops_desc *st_ops_desc, ...@@ -138,14 +165,6 @@ static void bpf_struct_ops_desc_init(struct bpf_struct_ops_desc *st_ops_desc,
} }
sprintf(value_name, "%s%s", VALUE_PREFIX, st_ops->name); sprintf(value_name, "%s%s", VALUE_PREFIX, st_ops->name);
value_id = btf_find_by_name_kind(btf, value_name,
BTF_KIND_STRUCT);
if (value_id < 0) {
pr_warn("Cannot find struct %s in %s\n",
value_name, btf_get_name(btf));
return;
}
type_id = btf_find_by_name_kind(btf, st_ops->name, type_id = btf_find_by_name_kind(btf, st_ops->name,
BTF_KIND_STRUCT); BTF_KIND_STRUCT);
if (type_id < 0) { if (type_id < 0) {
...@@ -160,6 +179,16 @@ static void bpf_struct_ops_desc_init(struct bpf_struct_ops_desc *st_ops_desc, ...@@ -160,6 +179,16 @@ static void bpf_struct_ops_desc_init(struct bpf_struct_ops_desc *st_ops_desc,
return; return;
} }
value_id = btf_find_by_name_kind(btf, value_name,
BTF_KIND_STRUCT);
if (value_id < 0) {
pr_warn("Cannot find struct %s in %s\n",
value_name, btf_get_name(btf));
return;
}
if (!is_valid_value_type(btf, value_id, t, value_name))
return;
for_each_member(i, t, member) { for_each_member(i, t, member) {
const struct btf_type *func_proto; const struct btf_type *func_proto;
...@@ -219,8 +248,6 @@ void bpf_struct_ops_init(struct btf *btf, struct bpf_verifier_log *log) ...@@ -219,8 +248,6 @@ void bpf_struct_ops_init(struct btf *btf, struct bpf_verifier_log *log)
} }
} }
extern struct btf *btf_vmlinux;
static const struct bpf_struct_ops_desc * static const struct bpf_struct_ops_desc *
bpf_struct_ops_find_value(struct btf *btf, u32 value_id) bpf_struct_ops_find_value(struct btf *btf, u32 value_id)
{ {
...@@ -276,7 +303,7 @@ int bpf_struct_ops_map_sys_lookup_elem(struct bpf_map *map, void *key, ...@@ -276,7 +303,7 @@ int bpf_struct_ops_map_sys_lookup_elem(struct bpf_map *map, void *key,
kvalue = &st_map->kvalue; kvalue = &st_map->kvalue;
/* Pair with smp_store_release() during map_update */ /* Pair with smp_store_release() during map_update */
state = smp_load_acquire(&kvalue->state); state = smp_load_acquire(&kvalue->common.state);
if (state == BPF_STRUCT_OPS_STATE_INIT) { if (state == BPF_STRUCT_OPS_STATE_INIT) {
memset(value, 0, map->value_size); memset(value, 0, map->value_size);
return 0; return 0;
...@@ -287,7 +314,7 @@ int bpf_struct_ops_map_sys_lookup_elem(struct bpf_map *map, void *key, ...@@ -287,7 +314,7 @@ int bpf_struct_ops_map_sys_lookup_elem(struct bpf_map *map, void *key,
*/ */
uvalue = value; uvalue = value;
memcpy(uvalue, st_map->uvalue, map->value_size); memcpy(uvalue, st_map->uvalue, map->value_size);
uvalue->state = state; uvalue->common.state = state;
/* This value offers the user space a general estimate of how /* This value offers the user space a general estimate of how
* many sockets are still utilizing this struct_ops for TCP * many sockets are still utilizing this struct_ops for TCP
...@@ -295,7 +322,7 @@ int bpf_struct_ops_map_sys_lookup_elem(struct bpf_map *map, void *key, ...@@ -295,7 +322,7 @@ int bpf_struct_ops_map_sys_lookup_elem(struct bpf_map *map, void *key,
* should sufficiently meet our present goals. * should sufficiently meet our present goals.
*/ */
refcnt = atomic64_read(&map->refcnt) - atomic64_read(&map->usercnt); refcnt = atomic64_read(&map->refcnt) - atomic64_read(&map->usercnt);
refcount_set(&uvalue->refcnt, max_t(s64, refcnt, 0)); refcount_set(&uvalue->common.refcnt, max_t(s64, refcnt, 0));
return 0; return 0;
} }
...@@ -413,7 +440,7 @@ static long bpf_struct_ops_map_update_elem(struct bpf_map *map, void *key, ...@@ -413,7 +440,7 @@ static long bpf_struct_ops_map_update_elem(struct bpf_map *map, void *key,
if (err) if (err)
return err; return err;
if (uvalue->state || refcount_read(&uvalue->refcnt)) if (uvalue->common.state || refcount_read(&uvalue->common.refcnt))
return -EINVAL; return -EINVAL;
tlinks = kcalloc(BPF_TRAMP_MAX, sizeof(*tlinks), GFP_KERNEL); tlinks = kcalloc(BPF_TRAMP_MAX, sizeof(*tlinks), GFP_KERNEL);
...@@ -425,7 +452,7 @@ static long bpf_struct_ops_map_update_elem(struct bpf_map *map, void *key, ...@@ -425,7 +452,7 @@ static long bpf_struct_ops_map_update_elem(struct bpf_map *map, void *key,
mutex_lock(&st_map->lock); mutex_lock(&st_map->lock);
if (kvalue->state != BPF_STRUCT_OPS_STATE_INIT) { if (kvalue->common.state != BPF_STRUCT_OPS_STATE_INIT) {
err = -EBUSY; err = -EBUSY;
goto unlock; goto unlock;
} }
...@@ -540,7 +567,7 @@ static long bpf_struct_ops_map_update_elem(struct bpf_map *map, void *key, ...@@ -540,7 +567,7 @@ static long bpf_struct_ops_map_update_elem(struct bpf_map *map, void *key,
* *
* Pair with smp_load_acquire() during lookup_elem(). * Pair with smp_load_acquire() during lookup_elem().
*/ */
smp_store_release(&kvalue->state, BPF_STRUCT_OPS_STATE_READY); smp_store_release(&kvalue->common.state, BPF_STRUCT_OPS_STATE_READY);
goto unlock; goto unlock;
} }
...@@ -558,7 +585,7 @@ static long bpf_struct_ops_map_update_elem(struct bpf_map *map, void *key, ...@@ -558,7 +585,7 @@ static long bpf_struct_ops_map_update_elem(struct bpf_map *map, void *key,
* It ensures the above udata updates (e.g. prog->aux->id) * It ensures the above udata updates (e.g. prog->aux->id)
* can be seen once BPF_STRUCT_OPS_STATE_INUSE is set. * can be seen once BPF_STRUCT_OPS_STATE_INUSE is set.
*/ */
smp_store_release(&kvalue->state, BPF_STRUCT_OPS_STATE_INUSE); smp_store_release(&kvalue->common.state, BPF_STRUCT_OPS_STATE_INUSE);
goto unlock; goto unlock;
} }
...@@ -588,7 +615,7 @@ static long bpf_struct_ops_map_delete_elem(struct bpf_map *map, void *key) ...@@ -588,7 +615,7 @@ static long bpf_struct_ops_map_delete_elem(struct bpf_map *map, void *key)
if (st_map->map.map_flags & BPF_F_LINK) if (st_map->map.map_flags & BPF_F_LINK)
return -EOPNOTSUPP; return -EOPNOTSUPP;
prev_state = cmpxchg(&st_map->kvalue.state, prev_state = cmpxchg(&st_map->kvalue.common.state,
BPF_STRUCT_OPS_STATE_INUSE, BPF_STRUCT_OPS_STATE_INUSE,
BPF_STRUCT_OPS_STATE_TOBEFREE); BPF_STRUCT_OPS_STATE_TOBEFREE);
switch (prev_state) { switch (prev_state) {
...@@ -848,7 +875,7 @@ static bool bpf_struct_ops_valid_to_reg(struct bpf_map *map) ...@@ -848,7 +875,7 @@ static bool bpf_struct_ops_valid_to_reg(struct bpf_map *map)
return map->map_type == BPF_MAP_TYPE_STRUCT_OPS && return map->map_type == BPF_MAP_TYPE_STRUCT_OPS &&
map->map_flags & BPF_F_LINK && map->map_flags & BPF_F_LINK &&
/* Pair with smp_store_release() during map_update */ /* Pair with smp_store_release() during map_update */
smp_load_acquire(&st_map->kvalue.state) == BPF_STRUCT_OPS_STATE_READY; smp_load_acquire(&st_map->kvalue.common.state) == BPF_STRUCT_OPS_STATE_READY;
} }
static void bpf_struct_ops_map_link_dealloc(struct bpf_link *link) static void bpf_struct_ops_map_link_dealloc(struct bpf_link *link)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment