Commit 81a0b954 authored by Alexei Starovoitov's avatar Alexei Starovoitov Committed by Andrii Nakryiko

Merge branch 'bpf-fix-tailcall-hierarchy'

Leon Hwang says:

====================
bpf: Fix tailcall hierarchy

This patchset fixes a tailcall hierarchy issue.

The issue is confirmed in the discussions of
"bpf, x64: Fix tailcall infinite loop" [0].

The issue has been resolved on both x86_64 and arm64 [1].

I provide a long commit message in the "bpf, x64: Fix tailcall hierarchy"
patch to describe how the issue happens and how this patchset resolves the
issue in details.

How does this patchset resolve the issue?

In short, it stores tail_call_cnt on the stack of main prog, and propagates
tail_call_cnt_ptr to its subprogs.

First, at the prologue of main prog, it initializes tail_call_cnt and
prepares tail_call_cnt_ptr. And at the prologue of subprog, it reuses
the tail_call_cnt_ptr from caller.

Then, when a tailcall happens, it increments tail_call_cnt by its pointer.

v5 -> v6:
  * Address comments from Eduard:
    * Add JITed dumping along annotating comments
    * Rewrite two selftests with RUN_TESTS macro.

v4 -> v5:
  * Solution changes from tailcall run ctx to tail_call_cnt and its pointer.
    It's because v4 solution is unable to handle the case that there is no
    tailcall in subprog but there is tailcall in EXT prog which attaches to
    the subprog.

v3 -> v4:
  * Solution changes from per-task tail_call_cnt to tailcall run ctx.
    As for per-cpu/per-task solution, there is a case it is unable to handle [2].

v2 -> v3:
  * Solution changes from percpu tail_call_cnt to tail_call_cnt at task_struct.

v1 -> v2:
  * Solution changes from extra run-time call insn to percpu tail_call_cnt.
  * Address comments from Alexei:
    * Use percpu tail_call_cnt.
    * Use asm to make sure no callee saved registers are touched.

RFC v2 -> v1:
  * Solution changes from propagating tail_call_cnt with its pointer to extra
    run-time call insn.
  * Address comments from Maciej:
    * Replace all memcpy(prog, x86_nops[5], X86_PATCH_SIZE) with
        emit_nops(&prog, X86_PATCH_SIZE)

RFC v1 -> RFC v2:
  * Address comments from Stanislav:
    * Separate moving emit_nops() as first patch.

Links:
[0] https://lore.kernel.org/bpf/6203dd01-789d-f02c-5293-def4c1b18aef@gmail.com/
[1] https://github.com/kernel-patches/bpf/pull/7350/checks
[2] https://lore.kernel.org/bpf/CAADnVQK1qF+uBjwom2s2W-yEmgd_3rGi5Nr+KiV3cW0T+UPPfA@mail.gmail.com/
====================

Link: https://lore.kernel.org/r/20240714123902.32305-1-hffilwlqm@gmail.comSigned-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
Signed-off-by: default avatarAndrii Nakryiko <andrii@kernel.org>
parents bde0c5a7 b83b936f
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
#define TMP_REG_1 (MAX_BPF_JIT_REG + 0) #define TMP_REG_1 (MAX_BPF_JIT_REG + 0)
#define TMP_REG_2 (MAX_BPF_JIT_REG + 1) #define TMP_REG_2 (MAX_BPF_JIT_REG + 1)
#define TCALL_CNT (MAX_BPF_JIT_REG + 2) #define TCCNT_PTR (MAX_BPF_JIT_REG + 2)
#define TMP_REG_3 (MAX_BPF_JIT_REG + 3) #define TMP_REG_3 (MAX_BPF_JIT_REG + 3)
#define FP_BOTTOM (MAX_BPF_JIT_REG + 4) #define FP_BOTTOM (MAX_BPF_JIT_REG + 4)
#define ARENA_VM_START (MAX_BPF_JIT_REG + 5) #define ARENA_VM_START (MAX_BPF_JIT_REG + 5)
...@@ -63,8 +63,8 @@ static const int bpf2a64[] = { ...@@ -63,8 +63,8 @@ static const int bpf2a64[] = {
[TMP_REG_1] = A64_R(10), [TMP_REG_1] = A64_R(10),
[TMP_REG_2] = A64_R(11), [TMP_REG_2] = A64_R(11),
[TMP_REG_3] = A64_R(12), [TMP_REG_3] = A64_R(12),
/* tail_call_cnt */ /* tail_call_cnt_ptr */
[TCALL_CNT] = A64_R(26), [TCCNT_PTR] = A64_R(26),
/* temporary register for blinding constants */ /* temporary register for blinding constants */
[BPF_REG_AX] = A64_R(9), [BPF_REG_AX] = A64_R(9),
[FP_BOTTOM] = A64_R(27), [FP_BOTTOM] = A64_R(27),
...@@ -282,13 +282,35 @@ static bool is_lsi_offset(int offset, int scale) ...@@ -282,13 +282,35 @@ static bool is_lsi_offset(int offset, int scale)
* mov x29, sp * mov x29, sp
* stp x19, x20, [sp, #-16]! * stp x19, x20, [sp, #-16]!
* stp x21, x22, [sp, #-16]! * stp x21, x22, [sp, #-16]!
* stp x25, x26, [sp, #-16]! * stp x26, x25, [sp, #-16]!
* stp x26, x25, [sp, #-16]!
* stp x27, x28, [sp, #-16]! * stp x27, x28, [sp, #-16]!
* mov x25, sp * mov x25, sp
* mov tcc, #0 * mov tcc, #0
* // PROLOGUE_OFFSET * // PROLOGUE_OFFSET
*/ */
static void prepare_bpf_tail_call_cnt(struct jit_ctx *ctx)
{
const struct bpf_prog *prog = ctx->prog;
const bool is_main_prog = !bpf_is_subprog(prog);
const u8 ptr = bpf2a64[TCCNT_PTR];
const u8 fp = bpf2a64[BPF_REG_FP];
const u8 tcc = ptr;
emit(A64_PUSH(ptr, fp, A64_SP), ctx);
if (is_main_prog) {
/* Initialize tail_call_cnt. */
emit(A64_MOVZ(1, tcc, 0, 0), ctx);
emit(A64_PUSH(tcc, fp, A64_SP), ctx);
emit(A64_MOV(1, ptr, A64_SP), ctx);
} else {
emit(A64_PUSH(ptr, fp, A64_SP), ctx);
emit(A64_NOP, ctx);
emit(A64_NOP, ctx);
}
}
#define BTI_INSNS (IS_ENABLED(CONFIG_ARM64_BTI_KERNEL) ? 1 : 0) #define BTI_INSNS (IS_ENABLED(CONFIG_ARM64_BTI_KERNEL) ? 1 : 0)
#define PAC_INSNS (IS_ENABLED(CONFIG_ARM64_PTR_AUTH_KERNEL) ? 1 : 0) #define PAC_INSNS (IS_ENABLED(CONFIG_ARM64_PTR_AUTH_KERNEL) ? 1 : 0)
...@@ -296,7 +318,7 @@ static bool is_lsi_offset(int offset, int scale) ...@@ -296,7 +318,7 @@ static bool is_lsi_offset(int offset, int scale)
#define POKE_OFFSET (BTI_INSNS + 1) #define POKE_OFFSET (BTI_INSNS + 1)
/* Tail call offset to jump into */ /* Tail call offset to jump into */
#define PROLOGUE_OFFSET (BTI_INSNS + 2 + PAC_INSNS + 8) #define PROLOGUE_OFFSET (BTI_INSNS + 2 + PAC_INSNS + 10)
static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf, static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf,
bool is_exception_cb, u64 arena_vm_start) bool is_exception_cb, u64 arena_vm_start)
...@@ -308,7 +330,6 @@ static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf, ...@@ -308,7 +330,6 @@ static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf,
const u8 r8 = bpf2a64[BPF_REG_8]; const u8 r8 = bpf2a64[BPF_REG_8];
const u8 r9 = bpf2a64[BPF_REG_9]; const u8 r9 = bpf2a64[BPF_REG_9];
const u8 fp = bpf2a64[BPF_REG_FP]; const u8 fp = bpf2a64[BPF_REG_FP];
const u8 tcc = bpf2a64[TCALL_CNT];
const u8 fpb = bpf2a64[FP_BOTTOM]; const u8 fpb = bpf2a64[FP_BOTTOM];
const u8 arena_vm_base = bpf2a64[ARENA_VM_START]; const u8 arena_vm_base = bpf2a64[ARENA_VM_START];
const int idx0 = ctx->idx; const int idx0 = ctx->idx;
...@@ -359,7 +380,7 @@ static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf, ...@@ -359,7 +380,7 @@ static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf,
/* Save callee-saved registers */ /* Save callee-saved registers */
emit(A64_PUSH(r6, r7, A64_SP), ctx); emit(A64_PUSH(r6, r7, A64_SP), ctx);
emit(A64_PUSH(r8, r9, A64_SP), ctx); emit(A64_PUSH(r8, r9, A64_SP), ctx);
emit(A64_PUSH(fp, tcc, A64_SP), ctx); prepare_bpf_tail_call_cnt(ctx);
emit(A64_PUSH(fpb, A64_R(28), A64_SP), ctx); emit(A64_PUSH(fpb, A64_R(28), A64_SP), ctx);
} else { } else {
/* /*
...@@ -372,18 +393,15 @@ static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf, ...@@ -372,18 +393,15 @@ static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf,
* callee-saved registers. The exception callback will not push * callee-saved registers. The exception callback will not push
* anything and re-use the main program's stack. * anything and re-use the main program's stack.
* *
* 10 registers are on the stack * 12 registers are on the stack
*/ */
emit(A64_SUB_I(1, A64_SP, A64_FP, 80), ctx); emit(A64_SUB_I(1, A64_SP, A64_FP, 96), ctx);
} }
/* Set up BPF prog stack base register */ /* Set up BPF prog stack base register */
emit(A64_MOV(1, fp, A64_SP), ctx); emit(A64_MOV(1, fp, A64_SP), ctx);
if (!ebpf_from_cbpf && is_main_prog) { if (!ebpf_from_cbpf && is_main_prog) {
/* Initialize tail_call_cnt */
emit(A64_MOVZ(1, tcc, 0, 0), ctx);
cur_offset = ctx->idx - idx0; cur_offset = ctx->idx - idx0;
if (cur_offset != PROLOGUE_OFFSET) { if (cur_offset != PROLOGUE_OFFSET) {
pr_err_once("PROLOGUE_OFFSET = %d, expected %d!\n", pr_err_once("PROLOGUE_OFFSET = %d, expected %d!\n",
...@@ -432,7 +450,8 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx) ...@@ -432,7 +450,8 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx)
const u8 tmp = bpf2a64[TMP_REG_1]; const u8 tmp = bpf2a64[TMP_REG_1];
const u8 prg = bpf2a64[TMP_REG_2]; const u8 prg = bpf2a64[TMP_REG_2];
const u8 tcc = bpf2a64[TCALL_CNT]; const u8 tcc = bpf2a64[TMP_REG_3];
const u8 ptr = bpf2a64[TCCNT_PTR];
const int idx0 = ctx->idx; const int idx0 = ctx->idx;
#define cur_offset (ctx->idx - idx0) #define cur_offset (ctx->idx - idx0)
#define jmp_offset (out_offset - (cur_offset)) #define jmp_offset (out_offset - (cur_offset))
...@@ -449,11 +468,12 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx) ...@@ -449,11 +468,12 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx)
emit(A64_B_(A64_COND_CS, jmp_offset), ctx); emit(A64_B_(A64_COND_CS, jmp_offset), ctx);
/* /*
* if (tail_call_cnt >= MAX_TAIL_CALL_CNT) * if ((*tail_call_cnt_ptr) >= MAX_TAIL_CALL_CNT)
* goto out; * goto out;
* tail_call_cnt++; * (*tail_call_cnt_ptr)++;
*/ */
emit_a64_mov_i64(tmp, MAX_TAIL_CALL_CNT, ctx); emit_a64_mov_i64(tmp, MAX_TAIL_CALL_CNT, ctx);
emit(A64_LDR64I(tcc, ptr, 0), ctx);
emit(A64_CMP(1, tcc, tmp), ctx); emit(A64_CMP(1, tcc, tmp), ctx);
emit(A64_B_(A64_COND_CS, jmp_offset), ctx); emit(A64_B_(A64_COND_CS, jmp_offset), ctx);
emit(A64_ADD_I(1, tcc, tcc, 1), ctx); emit(A64_ADD_I(1, tcc, tcc, 1), ctx);
...@@ -469,6 +489,9 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx) ...@@ -469,6 +489,9 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx)
emit(A64_LDR64(prg, tmp, prg), ctx); emit(A64_LDR64(prg, tmp, prg), ctx);
emit(A64_CBZ(1, prg, jmp_offset), ctx); emit(A64_CBZ(1, prg, jmp_offset), ctx);
/* Update tail_call_cnt if the slot is populated. */
emit(A64_STR64I(tcc, ptr, 0), ctx);
/* goto *(prog->bpf_func + prologue_offset); */ /* goto *(prog->bpf_func + prologue_offset); */
off = offsetof(struct bpf_prog, bpf_func); off = offsetof(struct bpf_prog, bpf_func);
emit_a64_mov_i64(tmp, off, ctx); emit_a64_mov_i64(tmp, off, ctx);
...@@ -721,6 +744,7 @@ static void build_epilogue(struct jit_ctx *ctx, bool is_exception_cb) ...@@ -721,6 +744,7 @@ static void build_epilogue(struct jit_ctx *ctx, bool is_exception_cb)
const u8 r8 = bpf2a64[BPF_REG_8]; const u8 r8 = bpf2a64[BPF_REG_8];
const u8 r9 = bpf2a64[BPF_REG_9]; const u8 r9 = bpf2a64[BPF_REG_9];
const u8 fp = bpf2a64[BPF_REG_FP]; const u8 fp = bpf2a64[BPF_REG_FP];
const u8 ptr = bpf2a64[TCCNT_PTR];
const u8 fpb = bpf2a64[FP_BOTTOM]; const u8 fpb = bpf2a64[FP_BOTTOM];
/* We're done with BPF stack */ /* We're done with BPF stack */
...@@ -738,7 +762,8 @@ static void build_epilogue(struct jit_ctx *ctx, bool is_exception_cb) ...@@ -738,7 +762,8 @@ static void build_epilogue(struct jit_ctx *ctx, bool is_exception_cb)
/* Restore x27 and x28 */ /* Restore x27 and x28 */
emit(A64_POP(fpb, A64_R(28), A64_SP), ctx); emit(A64_POP(fpb, A64_R(28), A64_SP), ctx);
/* Restore fs (x25) and x26 */ /* Restore fs (x25) and x26 */
emit(A64_POP(fp, A64_R(26), A64_SP), ctx); emit(A64_POP(ptr, fp, A64_SP), ctx);
emit(A64_POP(ptr, fp, A64_SP), ctx);
/* Restore callee-saved register */ /* Restore callee-saved register */
emit(A64_POP(r8, r9, A64_SP), ctx); emit(A64_POP(r8, r9, A64_SP), ctx);
......
...@@ -273,7 +273,7 @@ struct jit_context { ...@@ -273,7 +273,7 @@ struct jit_context {
/* Number of bytes emit_patch() needs to generate instructions */ /* Number of bytes emit_patch() needs to generate instructions */
#define X86_PATCH_SIZE 5 #define X86_PATCH_SIZE 5
/* Number of bytes that will be skipped on tailcall */ /* Number of bytes that will be skipped on tailcall */
#define X86_TAIL_CALL_OFFSET (11 + ENDBR_INSN_SIZE) #define X86_TAIL_CALL_OFFSET (12 + ENDBR_INSN_SIZE)
static void push_r12(u8 **pprog) static void push_r12(u8 **pprog)
{ {
...@@ -403,6 +403,37 @@ static void emit_cfi(u8 **pprog, u32 hash) ...@@ -403,6 +403,37 @@ static void emit_cfi(u8 **pprog, u32 hash)
*pprog = prog; *pprog = prog;
} }
static void emit_prologue_tail_call(u8 **pprog, bool is_subprog)
{
u8 *prog = *pprog;
if (!is_subprog) {
/* cmp rax, MAX_TAIL_CALL_CNT */
EMIT4(0x48, 0x83, 0xF8, MAX_TAIL_CALL_CNT);
EMIT2(X86_JA, 6); /* ja 6 */
/* rax is tail_call_cnt if <= MAX_TAIL_CALL_CNT.
* case1: entry of main prog.
* case2: tail callee of main prog.
*/
EMIT1(0x50); /* push rax */
/* Make rax as tail_call_cnt_ptr. */
EMIT3(0x48, 0x89, 0xE0); /* mov rax, rsp */
EMIT2(0xEB, 1); /* jmp 1 */
/* rax is tail_call_cnt_ptr if > MAX_TAIL_CALL_CNT.
* case: tail callee of subprog.
*/
EMIT1(0x50); /* push rax */
/* push tail_call_cnt_ptr */
EMIT1(0x50); /* push rax */
} else { /* is_subprog */
/* rax is tail_call_cnt_ptr. */
EMIT1(0x50); /* push rax */
EMIT1(0x50); /* push rax */
}
*pprog = prog;
}
/* /*
* Emit x86-64 prologue code for BPF program. * Emit x86-64 prologue code for BPF program.
* bpf_tail_call helper will skip the first X86_TAIL_CALL_OFFSET bytes * bpf_tail_call helper will skip the first X86_TAIL_CALL_OFFSET bytes
...@@ -424,10 +455,10 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf, ...@@ -424,10 +455,10 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf,
/* When it's the entry of the whole tailcall context, /* When it's the entry of the whole tailcall context,
* zeroing rax means initialising tail_call_cnt. * zeroing rax means initialising tail_call_cnt.
*/ */
EMIT2(0x31, 0xC0); /* xor eax, eax */ EMIT3(0x48, 0x31, 0xC0); /* xor rax, rax */
else else
/* Keep the same instruction layout. */ /* Keep the same instruction layout. */
EMIT2(0x66, 0x90); /* nop2 */ emit_nops(&prog, 3); /* nop3 */
} }
/* Exception callback receives FP as third parameter */ /* Exception callback receives FP as third parameter */
if (is_exception_cb) { if (is_exception_cb) {
...@@ -453,7 +484,7 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf, ...@@ -453,7 +484,7 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf,
if (stack_depth) if (stack_depth)
EMIT3_off32(0x48, 0x81, 0xEC, round_up(stack_depth, 8)); EMIT3_off32(0x48, 0x81, 0xEC, round_up(stack_depth, 8));
if (tail_call_reachable) if (tail_call_reachable)
EMIT1(0x50); /* push rax */ emit_prologue_tail_call(&prog, is_subprog);
*pprog = prog; *pprog = prog;
} }
...@@ -589,13 +620,15 @@ static void emit_return(u8 **pprog, u8 *ip) ...@@ -589,13 +620,15 @@ static void emit_return(u8 **pprog, u8 *ip)
*pprog = prog; *pprog = prog;
} }
#define BPF_TAIL_CALL_CNT_PTR_STACK_OFF(stack) (-16 - round_up(stack, 8))
/* /*
* Generate the following code: * Generate the following code:
* *
* ... bpf_tail_call(void *ctx, struct bpf_array *array, u64 index) ... * ... bpf_tail_call(void *ctx, struct bpf_array *array, u64 index) ...
* if (index >= array->map.max_entries) * if (index >= array->map.max_entries)
* goto out; * goto out;
* if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT) * if ((*tcc_ptr)++ >= MAX_TAIL_CALL_CNT)
* goto out; * goto out;
* prog = array->ptrs[index]; * prog = array->ptrs[index];
* if (prog == NULL) * if (prog == NULL)
...@@ -608,7 +641,7 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog, ...@@ -608,7 +641,7 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog,
u32 stack_depth, u8 *ip, u32 stack_depth, u8 *ip,
struct jit_context *ctx) struct jit_context *ctx)
{ {
int tcc_off = -4 - round_up(stack_depth, 8); int tcc_ptr_off = BPF_TAIL_CALL_CNT_PTR_STACK_OFF(stack_depth);
u8 *prog = *pprog, *start = *pprog; u8 *prog = *pprog, *start = *pprog;
int offset; int offset;
...@@ -630,16 +663,14 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog, ...@@ -630,16 +663,14 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog,
EMIT2(X86_JBE, offset); /* jbe out */ EMIT2(X86_JBE, offset); /* jbe out */
/* /*
* if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT) * if ((*tcc_ptr)++ >= MAX_TAIL_CALL_CNT)
* goto out; * goto out;
*/ */
EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */ EMIT3_off32(0x48, 0x8B, 0x85, tcc_ptr_off); /* mov rax, qword ptr [rbp - tcc_ptr_off] */
EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */ EMIT4(0x48, 0x83, 0x38, MAX_TAIL_CALL_CNT); /* cmp qword ptr [rax], MAX_TAIL_CALL_CNT */
offset = ctx->tail_call_indirect_label - (prog + 2 - start); offset = ctx->tail_call_indirect_label - (prog + 2 - start);
EMIT2(X86_JAE, offset); /* jae out */ EMIT2(X86_JAE, offset); /* jae out */
EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */
EMIT2_off32(0x89, 0x85, tcc_off); /* mov dword ptr [rbp - tcc_off], eax */
/* prog = array->ptrs[index]; */ /* prog = array->ptrs[index]; */
EMIT4_off32(0x48, 0x8B, 0x8C, 0xD6, /* mov rcx, [rsi + rdx * 8 + offsetof(...)] */ EMIT4_off32(0x48, 0x8B, 0x8C, 0xD6, /* mov rcx, [rsi + rdx * 8 + offsetof(...)] */
...@@ -654,6 +685,9 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog, ...@@ -654,6 +685,9 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog,
offset = ctx->tail_call_indirect_label - (prog + 2 - start); offset = ctx->tail_call_indirect_label - (prog + 2 - start);
EMIT2(X86_JE, offset); /* je out */ EMIT2(X86_JE, offset); /* je out */
/* Inc tail_call_cnt if the slot is populated. */
EMIT4(0x48, 0x83, 0x00, 0x01); /* add qword ptr [rax], 1 */
if (bpf_prog->aux->exception_boundary) { if (bpf_prog->aux->exception_boundary) {
pop_callee_regs(&prog, all_callee_regs_used); pop_callee_regs(&prog, all_callee_regs_used);
pop_r12(&prog); pop_r12(&prog);
...@@ -663,6 +697,11 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog, ...@@ -663,6 +697,11 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog,
pop_r12(&prog); pop_r12(&prog);
} }
/* Pop tail_call_cnt_ptr. */
EMIT1(0x58); /* pop rax */
/* Pop tail_call_cnt, if it's main prog.
* Pop tail_call_cnt_ptr, if it's subprog.
*/
EMIT1(0x58); /* pop rax */ EMIT1(0x58); /* pop rax */
if (stack_depth) if (stack_depth)
EMIT3_off32(0x48, 0x81, 0xC4, /* add rsp, sd */ EMIT3_off32(0x48, 0x81, 0xC4, /* add rsp, sd */
...@@ -691,21 +730,19 @@ static void emit_bpf_tail_call_direct(struct bpf_prog *bpf_prog, ...@@ -691,21 +730,19 @@ static void emit_bpf_tail_call_direct(struct bpf_prog *bpf_prog,
bool *callee_regs_used, u32 stack_depth, bool *callee_regs_used, u32 stack_depth,
struct jit_context *ctx) struct jit_context *ctx)
{ {
int tcc_off = -4 - round_up(stack_depth, 8); int tcc_ptr_off = BPF_TAIL_CALL_CNT_PTR_STACK_OFF(stack_depth);
u8 *prog = *pprog, *start = *pprog; u8 *prog = *pprog, *start = *pprog;
int offset; int offset;
/* /*
* if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT) * if ((*tcc_ptr)++ >= MAX_TAIL_CALL_CNT)
* goto out; * goto out;
*/ */
EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */ EMIT3_off32(0x48, 0x8B, 0x85, tcc_ptr_off); /* mov rax, qword ptr [rbp - tcc_ptr_off] */
EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */ EMIT4(0x48, 0x83, 0x38, MAX_TAIL_CALL_CNT); /* cmp qword ptr [rax], MAX_TAIL_CALL_CNT */
offset = ctx->tail_call_direct_label - (prog + 2 - start); offset = ctx->tail_call_direct_label - (prog + 2 - start);
EMIT2(X86_JAE, offset); /* jae out */ EMIT2(X86_JAE, offset); /* jae out */
EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */
EMIT2_off32(0x89, 0x85, tcc_off); /* mov dword ptr [rbp - tcc_off], eax */
poke->tailcall_bypass = ip + (prog - start); poke->tailcall_bypass = ip + (prog - start);
poke->adj_off = X86_TAIL_CALL_OFFSET; poke->adj_off = X86_TAIL_CALL_OFFSET;
...@@ -715,6 +752,9 @@ static void emit_bpf_tail_call_direct(struct bpf_prog *bpf_prog, ...@@ -715,6 +752,9 @@ static void emit_bpf_tail_call_direct(struct bpf_prog *bpf_prog,
emit_jump(&prog, (u8 *)poke->tailcall_target + X86_PATCH_SIZE, emit_jump(&prog, (u8 *)poke->tailcall_target + X86_PATCH_SIZE,
poke->tailcall_bypass); poke->tailcall_bypass);
/* Inc tail_call_cnt if the slot is populated. */
EMIT4(0x48, 0x83, 0x00, 0x01); /* add qword ptr [rax], 1 */
if (bpf_prog->aux->exception_boundary) { if (bpf_prog->aux->exception_boundary) {
pop_callee_regs(&prog, all_callee_regs_used); pop_callee_regs(&prog, all_callee_regs_used);
pop_r12(&prog); pop_r12(&prog);
...@@ -724,6 +764,11 @@ static void emit_bpf_tail_call_direct(struct bpf_prog *bpf_prog, ...@@ -724,6 +764,11 @@ static void emit_bpf_tail_call_direct(struct bpf_prog *bpf_prog,
pop_r12(&prog); pop_r12(&prog);
} }
/* Pop tail_call_cnt_ptr. */
EMIT1(0x58); /* pop rax */
/* Pop tail_call_cnt, if it's main prog.
* Pop tail_call_cnt_ptr, if it's subprog.
*/
EMIT1(0x58); /* pop rax */ EMIT1(0x58); /* pop rax */
if (stack_depth) if (stack_depth)
EMIT3_off32(0x48, 0x81, 0xC4, round_up(stack_depth, 8)); EMIT3_off32(0x48, 0x81, 0xC4, round_up(stack_depth, 8));
...@@ -1311,9 +1356,11 @@ static void emit_shiftx(u8 **pprog, u32 dst_reg, u8 src_reg, bool is64, u8 op) ...@@ -1311,9 +1356,11 @@ static void emit_shiftx(u8 **pprog, u32 dst_reg, u8 src_reg, bool is64, u8 op)
#define INSN_SZ_DIFF (((addrs[i] - addrs[i - 1]) - (prog - temp))) #define INSN_SZ_DIFF (((addrs[i] - addrs[i - 1]) - (prog - temp)))
/* mov rax, qword ptr [rbp - rounded_stack_depth - 8] */ #define __LOAD_TCC_PTR(off) \
#define RESTORE_TAIL_CALL_CNT(stack) \ EMIT3_off32(0x48, 0x8B, 0x85, off)
EMIT3_off32(0x48, 0x8B, 0x85, -round_up(stack, 8) - 8) /* mov rax, qword ptr [rbp - rounded_stack_depth - 16] */
#define LOAD_TAIL_CALL_CNT_PTR(stack) \
__LOAD_TCC_PTR(BPF_TAIL_CALL_CNT_PTR_STACK_OFF(stack))
static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image, static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image,
int oldproglen, struct jit_context *ctx, bool jmp_padding) int oldproglen, struct jit_context *ctx, bool jmp_padding)
...@@ -2031,7 +2078,7 @@ st: if (is_imm8(insn->off)) ...@@ -2031,7 +2078,7 @@ st: if (is_imm8(insn->off))
func = (u8 *) __bpf_call_base + imm32; func = (u8 *) __bpf_call_base + imm32;
if (tail_call_reachable) { if (tail_call_reachable) {
RESTORE_TAIL_CALL_CNT(bpf_prog->aux->stack_depth); LOAD_TAIL_CALL_CNT_PTR(bpf_prog->aux->stack_depth);
ip += 7; ip += 7;
} }
if (!imm32) if (!imm32)
...@@ -2706,6 +2753,10 @@ static int invoke_bpf_mod_ret(const struct btf_func_model *m, u8 **pprog, ...@@ -2706,6 +2753,10 @@ static int invoke_bpf_mod_ret(const struct btf_func_model *m, u8 **pprog,
return 0; return 0;
} }
/* mov rax, qword ptr [rbp - rounded_stack_depth - 8] */
#define LOAD_TRAMP_TAIL_CALL_CNT_PTR(stack) \
__LOAD_TCC_PTR(-round_up(stack, 8) - 8)
/* Example: /* Example:
* __be16 eth_type_trans(struct sk_buff *skb, struct net_device *dev); * __be16 eth_type_trans(struct sk_buff *skb, struct net_device *dev);
* its 'struct btf_func_model' will be nr_args=2 * its 'struct btf_func_model' will be nr_args=2
...@@ -2826,7 +2877,7 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *rw_im ...@@ -2826,7 +2877,7 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *rw_im
* [ ... ] * [ ... ]
* [ stack_arg2 ] * [ stack_arg2 ]
* RBP - arg_stack_off [ stack_arg1 ] * RBP - arg_stack_off [ stack_arg1 ]
* RSP [ tail_call_cnt ] BPF_TRAMP_F_TAIL_CALL_CTX * RSP [ tail_call_cnt_ptr ] BPF_TRAMP_F_TAIL_CALL_CTX
*/ */
/* room for return value of orig_call or fentry prog */ /* room for return value of orig_call or fentry prog */
...@@ -2955,10 +3006,10 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *rw_im ...@@ -2955,10 +3006,10 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *rw_im
save_args(m, &prog, arg_stack_off, true); save_args(m, &prog, arg_stack_off, true);
if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) { if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) {
/* Before calling the original function, restore the /* Before calling the original function, load the
* tail_call_cnt from stack to rax. * tail_call_cnt_ptr from stack to rax.
*/ */
RESTORE_TAIL_CALL_CNT(stack_size); LOAD_TRAMP_TAIL_CALL_CNT_PTR(stack_size);
} }
if (flags & BPF_TRAMP_F_ORIG_STACK) { if (flags & BPF_TRAMP_F_ORIG_STACK) {
...@@ -3017,10 +3068,10 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *rw_im ...@@ -3017,10 +3068,10 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *rw_im
goto cleanup; goto cleanup;
} }
} else if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) { } else if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) {
/* Before running the original function, restore the /* Before running the original function, load the
* tail_call_cnt from stack to rax. * tail_call_cnt_ptr from stack to rax.
*/ */
RESTORE_TAIL_CALL_CNT(stack_size); LOAD_TRAMP_TAIL_CALL_CNT_PTR(stack_size);
} }
/* restore return value of orig_call or fentry prog back into RAX */ /* restore return value of orig_call or fentry prog back into RAX */
......
// SPDX-License-Identifier: GPL-2.0
#include <linux/bpf.h>
#include <bpf/bpf_helpers.h>
#include "bpf_legacy.h"
struct {
__uint(type, BPF_MAP_TYPE_PROG_ARRAY);
__uint(max_entries, 1);
__uint(key_size, sizeof(__u32));
__uint(value_size, sizeof(__u32));
} jmp_table SEC(".maps");
int count = 0;
static __noinline
int subprog_tail(struct __sk_buff *skb)
{
bpf_tail_call_static(skb, &jmp_table, 0);
return 0;
}
SEC("tc")
int entry(struct __sk_buff *skb)
{
int ret = 1;
count++;
subprog_tail(skb);
subprog_tail(skb);
return ret;
}
char __license[] SEC("license") = "GPL";
// SPDX-License-Identifier: GPL-2.0
#include <linux/bpf.h>
#include <bpf/bpf_helpers.h>
#include "bpf_misc.h"
int classifier_0(struct __sk_buff *skb);
int classifier_1(struct __sk_buff *skb);
struct {
__uint(type, BPF_MAP_TYPE_PROG_ARRAY);
__uint(max_entries, 2);
__uint(key_size, sizeof(__u32));
__array(values, void (void));
} jmp_table SEC(".maps") = {
.values = {
[0] = (void *) &classifier_0,
[1] = (void *) &classifier_1,
},
};
int count0 = 0;
int count1 = 0;
static __noinline
int subprog_tail0(struct __sk_buff *skb)
{
bpf_tail_call_static(skb, &jmp_table, 0);
return 0;
}
__auxiliary
SEC("tc")
int classifier_0(struct __sk_buff *skb)
{
count0++;
subprog_tail0(skb);
return 0;
}
static __noinline
int subprog_tail1(struct __sk_buff *skb)
{
bpf_tail_call_static(skb, &jmp_table, 1);
return 0;
}
__auxiliary
SEC("tc")
int classifier_1(struct __sk_buff *skb)
{
count1++;
subprog_tail1(skb);
return 0;
}
__success
__retval(33)
SEC("tc")
int tailcall_bpf2bpf_hierarchy_2(struct __sk_buff *skb)
{
volatile int ret = 0;
subprog_tail0(skb);
subprog_tail1(skb);
asm volatile (""::"r+"(ret));
return (count1 << 16) | count0;
}
char __license[] SEC("license") = "GPL";
// SPDX-License-Identifier: GPL-2.0
#include <linux/bpf.h>
#include <bpf/bpf_helpers.h>
#include "bpf_misc.h"
int classifier_0(struct __sk_buff *skb);
struct {
__uint(type, BPF_MAP_TYPE_PROG_ARRAY);
__uint(max_entries, 1);
__uint(key_size, sizeof(__u32));
__array(values, void (void));
} jmp_table0 SEC(".maps") = {
.values = {
[0] = (void *) &classifier_0,
},
};
struct {
__uint(type, BPF_MAP_TYPE_PROG_ARRAY);
__uint(max_entries, 1);
__uint(key_size, sizeof(__u32));
__array(values, void (void));
} jmp_table1 SEC(".maps") = {
.values = {
[0] = (void *) &classifier_0,
},
};
int count = 0;
static __noinline
int subprog_tail(struct __sk_buff *skb, void *jmp_table)
{
bpf_tail_call_static(skb, jmp_table, 0);
return 0;
}
__auxiliary
SEC("tc")
int classifier_0(struct __sk_buff *skb)
{
count++;
subprog_tail(skb, &jmp_table0);
subprog_tail(skb, &jmp_table1);
return count;
}
__success
__retval(33)
SEC("tc")
int tailcall_bpf2bpf_hierarchy_3(struct __sk_buff *skb)
{
volatile int ret = 0;
bpf_tail_call_static(skb, &jmp_table0, 0);
asm volatile (""::"r+"(ret));
return ret;
}
char __license[] SEC("license") = "GPL";
// SPDX-License-Identifier: GPL-2.0
/* Copyright Leon Hwang */
#include "vmlinux.h"
#include <bpf/bpf_helpers.h>
#include <bpf/bpf_tracing.h>
struct {
__uint(type, BPF_MAP_TYPE_PROG_ARRAY);
__uint(max_entries, 1);
__uint(key_size, sizeof(__u32));
__uint(value_size, sizeof(__u32));
} jmp_table SEC(".maps");
int count = 0;
static __noinline
int subprog_tail(void *ctx)
{
bpf_tail_call_static(ctx, &jmp_table, 0);
return 0;
}
SEC("fentry/dummy")
int BPF_PROG(fentry, struct sk_buff *skb)
{
count++;
subprog_tail(ctx);
subprog_tail(ctx);
return 0;
}
char _license[] SEC("license") = "GPL";
// SPDX-License-Identifier: GPL-2.0
#include <linux/bpf.h>
#include <bpf/bpf_helpers.h>
#include "bpf_legacy.h"
SEC("tc")
int entry(struct __sk_buff *skb)
{
return 1;
}
char __license[] SEC("license") = "GPL";
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment