Commit 633a6e01 authored by Puranjay Mohan's avatar Puranjay Mohan Committed by Daniel Borkmann

bpf, riscv: Implement PROBE_MEM32 pseudo instructions

Add support for [LDX | STX | ST], PROBE_MEM32, [B | H | W | DW]
instructions. They are similar to PROBE_MEM instructions with the
following differences:

- PROBE_MEM32 supports store.
- PROBE_MEM32 relies on the verifier to clear upper 32-bit of the
  src/dst register
- PROBE_MEM32 adds 64-bit kern_vm_start address (which is stored in S7
  in the prologue). Due to bpf_arena constructions such S7 + reg +
  off16 access is guaranteed to be within arena virtual range, so no
  address check at run-time.
- S11 is a free callee-saved register, so it is used to store kern_vm_start
- PROBE_MEM32 allows STX and ST. If they fault the store is a nop. When
  LDX faults the destination register is zeroed.

To support these on riscv, we do tmp = S7 + src/dst reg and then use
tmp2 as the new src/dst register. This allows us to reuse most of the
code for normal [LDX | STX | ST].
Signed-off-by: default avatarPuranjay Mohan <puranjay12@gmail.com>
Signed-off-by: default avatarDaniel Borkmann <daniel@iogearbox.net>
Tested-by: default avatarBjörn Töpel <bjorn@rivosinc.com>
Tested-by: default avatarPu Lehui <pulehui@huawei.com>
Reviewed-by: default avatarPu Lehui <pulehui@huawei.com>
Acked-by: default avatarBjörn Töpel <bjorn@kernel.org>
Link: https://lore.kernel.org/bpf/20240404114203.105970-2-puranjay12@gmail.com
parent af682b76
...@@ -81,6 +81,7 @@ struct rv_jit_context { ...@@ -81,6 +81,7 @@ struct rv_jit_context {
int nexentries; int nexentries;
unsigned long flags; unsigned long flags;
int stack_size; int stack_size;
u64 arena_vm_start;
}; };
/* Convert from ninsns to bytes. */ /* Convert from ninsns to bytes. */
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#define RV_REG_TCC RV_REG_A6 #define RV_REG_TCC RV_REG_A6
#define RV_REG_TCC_SAVED RV_REG_S6 /* Store A6 in S6 if program do calls */ #define RV_REG_TCC_SAVED RV_REG_S6 /* Store A6 in S6 if program do calls */
#define RV_REG_ARENA RV_REG_S7 /* For storing arena_vm_start */
static const int regmap[] = { static const int regmap[] = {
[BPF_REG_0] = RV_REG_A5, [BPF_REG_0] = RV_REG_A5,
...@@ -255,6 +256,10 @@ static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx) ...@@ -255,6 +256,10 @@ static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx)
emit_ld(RV_REG_S6, store_offset, RV_REG_SP, ctx); emit_ld(RV_REG_S6, store_offset, RV_REG_SP, ctx);
store_offset -= 8; store_offset -= 8;
} }
if (ctx->arena_vm_start) {
emit_ld(RV_REG_ARENA, store_offset, RV_REG_SP, ctx);
store_offset -= 8;
}
emit_addi(RV_REG_SP, RV_REG_SP, stack_adjust, ctx); emit_addi(RV_REG_SP, RV_REG_SP, stack_adjust, ctx);
/* Set return value. */ /* Set return value. */
...@@ -548,6 +553,7 @@ static void emit_atomic(u8 rd, u8 rs, s16 off, s32 imm, bool is64, ...@@ -548,6 +553,7 @@ static void emit_atomic(u8 rd, u8 rs, s16 off, s32 imm, bool is64,
#define BPF_FIXUP_OFFSET_MASK GENMASK(26, 0) #define BPF_FIXUP_OFFSET_MASK GENMASK(26, 0)
#define BPF_FIXUP_REG_MASK GENMASK(31, 27) #define BPF_FIXUP_REG_MASK GENMASK(31, 27)
#define REG_DONT_CLEAR_MARKER 0 /* RV_REG_ZERO unused in pt_regmap */
bool ex_handler_bpf(const struct exception_table_entry *ex, bool ex_handler_bpf(const struct exception_table_entry *ex,
struct pt_regs *regs) struct pt_regs *regs)
...@@ -555,7 +561,8 @@ bool ex_handler_bpf(const struct exception_table_entry *ex, ...@@ -555,7 +561,8 @@ bool ex_handler_bpf(const struct exception_table_entry *ex,
off_t offset = FIELD_GET(BPF_FIXUP_OFFSET_MASK, ex->fixup); off_t offset = FIELD_GET(BPF_FIXUP_OFFSET_MASK, ex->fixup);
int regs_offset = FIELD_GET(BPF_FIXUP_REG_MASK, ex->fixup); int regs_offset = FIELD_GET(BPF_FIXUP_REG_MASK, ex->fixup);
*(unsigned long *)((void *)regs + pt_regmap[regs_offset]) = 0; if (regs_offset != REG_DONT_CLEAR_MARKER)
*(unsigned long *)((void *)regs + pt_regmap[regs_offset]) = 0;
regs->epc = (unsigned long)&ex->fixup - offset; regs->epc = (unsigned long)&ex->fixup - offset;
return true; return true;
...@@ -572,7 +579,8 @@ static int add_exception_handler(const struct bpf_insn *insn, ...@@ -572,7 +579,8 @@ static int add_exception_handler(const struct bpf_insn *insn,
off_t fixup_offset; off_t fixup_offset;
if (!ctx->insns || !ctx->ro_insns || !ctx->prog->aux->extable || if (!ctx->insns || !ctx->ro_insns || !ctx->prog->aux->extable ||
(BPF_MODE(insn->code) != BPF_PROBE_MEM && BPF_MODE(insn->code) != BPF_PROBE_MEMSX)) (BPF_MODE(insn->code) != BPF_PROBE_MEM && BPF_MODE(insn->code) != BPF_PROBE_MEMSX &&
BPF_MODE(insn->code) != BPF_PROBE_MEM32))
return 0; return 0;
if (WARN_ON_ONCE(ctx->nexentries >= ctx->prog->aux->num_exentries)) if (WARN_ON_ONCE(ctx->nexentries >= ctx->prog->aux->num_exentries))
...@@ -1539,6 +1547,11 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx, ...@@ -1539,6 +1547,11 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
case BPF_LDX | BPF_PROBE_MEMSX | BPF_B: case BPF_LDX | BPF_PROBE_MEMSX | BPF_B:
case BPF_LDX | BPF_PROBE_MEMSX | BPF_H: case BPF_LDX | BPF_PROBE_MEMSX | BPF_H:
case BPF_LDX | BPF_PROBE_MEMSX | BPF_W: case BPF_LDX | BPF_PROBE_MEMSX | BPF_W:
/* LDX | PROBE_MEM32: dst = *(unsigned size *)(src + RV_REG_ARENA + off) */
case BPF_LDX | BPF_PROBE_MEM32 | BPF_B:
case BPF_LDX | BPF_PROBE_MEM32 | BPF_H:
case BPF_LDX | BPF_PROBE_MEM32 | BPF_W:
case BPF_LDX | BPF_PROBE_MEM32 | BPF_DW:
{ {
int insn_len, insns_start; int insn_len, insns_start;
bool sign_ext; bool sign_ext;
...@@ -1546,6 +1559,11 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx, ...@@ -1546,6 +1559,11 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
sign_ext = BPF_MODE(insn->code) == BPF_MEMSX || sign_ext = BPF_MODE(insn->code) == BPF_MEMSX ||
BPF_MODE(insn->code) == BPF_PROBE_MEMSX; BPF_MODE(insn->code) == BPF_PROBE_MEMSX;
if (BPF_MODE(insn->code) == BPF_PROBE_MEM32) {
emit_add(RV_REG_T2, rs, RV_REG_ARENA, ctx);
rs = RV_REG_T2;
}
switch (BPF_SIZE(code)) { switch (BPF_SIZE(code)) {
case BPF_B: case BPF_B:
if (is_12b_int(off)) { if (is_12b_int(off)) {
...@@ -1682,6 +1700,86 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx, ...@@ -1682,6 +1700,86 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
emit_sd(RV_REG_T2, 0, RV_REG_T1, ctx); emit_sd(RV_REG_T2, 0, RV_REG_T1, ctx);
break; break;
case BPF_ST | BPF_PROBE_MEM32 | BPF_B:
case BPF_ST | BPF_PROBE_MEM32 | BPF_H:
case BPF_ST | BPF_PROBE_MEM32 | BPF_W:
case BPF_ST | BPF_PROBE_MEM32 | BPF_DW:
{
int insn_len, insns_start;
emit_add(RV_REG_T3, rd, RV_REG_ARENA, ctx);
rd = RV_REG_T3;
/* Load imm to a register then store it */
emit_imm(RV_REG_T1, imm, ctx);
switch (BPF_SIZE(code)) {
case BPF_B:
if (is_12b_int(off)) {
insns_start = ctx->ninsns;
emit(rv_sb(rd, off, RV_REG_T1), ctx);
insn_len = ctx->ninsns - insns_start;
break;
}
emit_imm(RV_REG_T2, off, ctx);
emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
insns_start = ctx->ninsns;
emit(rv_sb(RV_REG_T2, 0, RV_REG_T1), ctx);
insn_len = ctx->ninsns - insns_start;
break;
case BPF_H:
if (is_12b_int(off)) {
insns_start = ctx->ninsns;
emit(rv_sh(rd, off, RV_REG_T1), ctx);
insn_len = ctx->ninsns - insns_start;
break;
}
emit_imm(RV_REG_T2, off, ctx);
emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
insns_start = ctx->ninsns;
emit(rv_sh(RV_REG_T2, 0, RV_REG_T1), ctx);
insn_len = ctx->ninsns - insns_start;
break;
case BPF_W:
if (is_12b_int(off)) {
insns_start = ctx->ninsns;
emit_sw(rd, off, RV_REG_T1, ctx);
insn_len = ctx->ninsns - insns_start;
break;
}
emit_imm(RV_REG_T2, off, ctx);
emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
insns_start = ctx->ninsns;
emit_sw(RV_REG_T2, 0, RV_REG_T1, ctx);
insn_len = ctx->ninsns - insns_start;
break;
case BPF_DW:
if (is_12b_int(off)) {
insns_start = ctx->ninsns;
emit_sd(rd, off, RV_REG_T1, ctx);
insn_len = ctx->ninsns - insns_start;
break;
}
emit_imm(RV_REG_T2, off, ctx);
emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
insns_start = ctx->ninsns;
emit_sd(RV_REG_T2, 0, RV_REG_T1, ctx);
insn_len = ctx->ninsns - insns_start;
break;
}
ret = add_exception_handler(insn, ctx, REG_DONT_CLEAR_MARKER,
insn_len);
if (ret)
return ret;
break;
}
/* STX: *(size *)(dst + off) = src */ /* STX: *(size *)(dst + off) = src */
case BPF_STX | BPF_MEM | BPF_B: case BPF_STX | BPF_MEM | BPF_B:
if (is_12b_int(off)) { if (is_12b_int(off)) {
...@@ -1728,6 +1826,84 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx, ...@@ -1728,6 +1826,84 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
emit_atomic(rd, rs, off, imm, emit_atomic(rd, rs, off, imm,
BPF_SIZE(code) == BPF_DW, ctx); BPF_SIZE(code) == BPF_DW, ctx);
break; break;
case BPF_STX | BPF_PROBE_MEM32 | BPF_B:
case BPF_STX | BPF_PROBE_MEM32 | BPF_H:
case BPF_STX | BPF_PROBE_MEM32 | BPF_W:
case BPF_STX | BPF_PROBE_MEM32 | BPF_DW:
{
int insn_len, insns_start;
emit_add(RV_REG_T2, rd, RV_REG_ARENA, ctx);
rd = RV_REG_T2;
switch (BPF_SIZE(code)) {
case BPF_B:
if (is_12b_int(off)) {
insns_start = ctx->ninsns;
emit(rv_sb(rd, off, rs), ctx);
insn_len = ctx->ninsns - insns_start;
break;
}
emit_imm(RV_REG_T1, off, ctx);
emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
insns_start = ctx->ninsns;
emit(rv_sb(RV_REG_T1, 0, rs), ctx);
insn_len = ctx->ninsns - insns_start;
break;
case BPF_H:
if (is_12b_int(off)) {
insns_start = ctx->ninsns;
emit(rv_sh(rd, off, rs), ctx);
insn_len = ctx->ninsns - insns_start;
break;
}
emit_imm(RV_REG_T1, off, ctx);
emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
insns_start = ctx->ninsns;
emit(rv_sh(RV_REG_T1, 0, rs), ctx);
insn_len = ctx->ninsns - insns_start;
break;
case BPF_W:
if (is_12b_int(off)) {
insns_start = ctx->ninsns;
emit_sw(rd, off, rs, ctx);
insn_len = ctx->ninsns - insns_start;
break;
}
emit_imm(RV_REG_T1, off, ctx);
emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
insns_start = ctx->ninsns;
emit_sw(RV_REG_T1, 0, rs, ctx);
insn_len = ctx->ninsns - insns_start;
break;
case BPF_DW:
if (is_12b_int(off)) {
insns_start = ctx->ninsns;
emit_sd(rd, off, rs, ctx);
insn_len = ctx->ninsns - insns_start;
break;
}
emit_imm(RV_REG_T1, off, ctx);
emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
insns_start = ctx->ninsns;
emit_sd(RV_REG_T1, 0, rs, ctx);
insn_len = ctx->ninsns - insns_start;
break;
}
ret = add_exception_handler(insn, ctx, REG_DONT_CLEAR_MARKER,
insn_len);
if (ret)
return ret;
break;
}
default: default:
pr_err("bpf-jit: unknown opcode %02x\n", code); pr_err("bpf-jit: unknown opcode %02x\n", code);
return -EINVAL; return -EINVAL;
...@@ -1759,6 +1935,8 @@ void bpf_jit_build_prologue(struct rv_jit_context *ctx, bool is_subprog) ...@@ -1759,6 +1935,8 @@ void bpf_jit_build_prologue(struct rv_jit_context *ctx, bool is_subprog)
stack_adjust += 8; stack_adjust += 8;
if (seen_reg(RV_REG_S6, ctx)) if (seen_reg(RV_REG_S6, ctx))
stack_adjust += 8; stack_adjust += 8;
if (ctx->arena_vm_start)
stack_adjust += 8;
stack_adjust = round_up(stack_adjust, 16); stack_adjust = round_up(stack_adjust, 16);
stack_adjust += bpf_stack_adjust; stack_adjust += bpf_stack_adjust;
...@@ -1810,6 +1988,10 @@ void bpf_jit_build_prologue(struct rv_jit_context *ctx, bool is_subprog) ...@@ -1810,6 +1988,10 @@ void bpf_jit_build_prologue(struct rv_jit_context *ctx, bool is_subprog)
emit_sd(RV_REG_SP, store_offset, RV_REG_S6, ctx); emit_sd(RV_REG_SP, store_offset, RV_REG_S6, ctx);
store_offset -= 8; store_offset -= 8;
} }
if (ctx->arena_vm_start) {
emit_sd(RV_REG_SP, store_offset, RV_REG_ARENA, ctx);
store_offset -= 8;
}
emit_addi(RV_REG_FP, RV_REG_SP, stack_adjust, ctx); emit_addi(RV_REG_FP, RV_REG_SP, stack_adjust, ctx);
...@@ -1823,6 +2005,9 @@ void bpf_jit_build_prologue(struct rv_jit_context *ctx, bool is_subprog) ...@@ -1823,6 +2005,9 @@ void bpf_jit_build_prologue(struct rv_jit_context *ctx, bool is_subprog)
emit_mv(RV_REG_TCC_SAVED, RV_REG_TCC, ctx); emit_mv(RV_REG_TCC_SAVED, RV_REG_TCC, ctx);
ctx->stack_size = stack_adjust; ctx->stack_size = stack_adjust;
if (ctx->arena_vm_start)
emit_imm(RV_REG_ARENA, ctx->arena_vm_start, ctx);
} }
void bpf_jit_build_epilogue(struct rv_jit_context *ctx) void bpf_jit_build_epilogue(struct rv_jit_context *ctx)
......
...@@ -80,6 +80,7 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog) ...@@ -80,6 +80,7 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
goto skip_init_ctx; goto skip_init_ctx;
} }
ctx->arena_vm_start = bpf_arena_get_kern_vm_start(prog->aux->arena);
ctx->prog = prog; ctx->prog = prog;
ctx->offset = kcalloc(prog->len, sizeof(int), GFP_KERNEL); ctx->offset = kcalloc(prog->len, sizeof(int), GFP_KERNEL);
if (!ctx->offset) { if (!ctx->offset) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment