Commit 4961d8f4 authored by Alexei Starovoitov's avatar Alexei Starovoitov

Merge branch 'bpf-arm64-simplify-jited-prologue-epilogue'

Xu Kuohai says:

====================
bpf, arm64: Simplify jited prologue/epilogue

From: Xu Kuohai <xukuohai@huawei.com>

The arm64 jit blindly saves/restores all callee-saved registers, making
the jited result looks a bit too compliated. For example, for an empty
prog, the jited result is:

   0:   bti jc
   4:   mov     x9, lr
   8:   nop
   c:   paciasp
  10:   stp     fp, lr, [sp, #-16]!
  14:   mov     fp, sp
  18:   stp     x19, x20, [sp, #-16]!
  1c:   stp     x21, x22, [sp, #-16]!
  20:   stp     x26, x25, [sp, #-16]!
  24:   mov     x26, #0
  28:   stp     x26, x25, [sp, #-16]!
  2c:   mov     x26, sp
  30:   stp     x27, x28, [sp, #-16]!
  34:   mov     x25, sp
  38:   bti j 		// tailcall target
  3c:   sub     sp, sp, #0
  40:   mov     x7, #0
  44:   add     sp, sp, #0
  48:   ldp     x27, x28, [sp], #16
  4c:   ldp     x26, x25, [sp], #16
  50:   ldp     x26, x25, [sp], #16
  54:   ldp     x21, x22, [sp], #16
  58:   ldp     x19, x20, [sp], #16
  5c:   ldp     fp, lr, [sp], #16
  60:   mov     x0, x7
  64:   autiasp
  68:   ret

Clearly, there is no need to save/restore unused callee-saved registers.
This patch does this change, making the jited image to only save/restore
the callee-saved registers it uses.

Now the jited result of empty prog is:

   0:   bti jc
   4:   mov     x9, lr
   8:   nop
   c:   paciasp
  10:   stp     fp, lr, [sp, #-16]!
  14:   mov     fp, sp
  18:   stp     xzr, x26, [sp, #-16]!
  1c:   mov     x26, sp
  20:   bti j		// tailcall target
  24:   mov     x7, #0
  28:   ldp     xzr, x26, [sp], #16
  2c:   ldp     fp, lr, [sp], #16
  30:   mov     x0, x7
  34:   autiasp
  38:   ret
====================
Acked-by: default avatarPuranjay Mohan <puranjay@kernel.org>
Link: https://lore.kernel.org/r/20240826071624.350108-1-xukuohai@huaweicloud.comSigned-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
parents d205d4af 5d4fa9ec
...@@ -28,7 +28,6 @@ ...@@ -28,7 +28,6 @@
#define TMP_REG_2 (MAX_BPF_JIT_REG + 1) #define TMP_REG_2 (MAX_BPF_JIT_REG + 1)
#define TCCNT_PTR (MAX_BPF_JIT_REG + 2) #define TCCNT_PTR (MAX_BPF_JIT_REG + 2)
#define TMP_REG_3 (MAX_BPF_JIT_REG + 3) #define TMP_REG_3 (MAX_BPF_JIT_REG + 3)
#define FP_BOTTOM (MAX_BPF_JIT_REG + 4)
#define ARENA_VM_START (MAX_BPF_JIT_REG + 5) #define ARENA_VM_START (MAX_BPF_JIT_REG + 5)
#define check_imm(bits, imm) do { \ #define check_imm(bits, imm) do { \
...@@ -67,7 +66,6 @@ static const int bpf2a64[] = { ...@@ -67,7 +66,6 @@ static const int bpf2a64[] = {
[TCCNT_PTR] = A64_R(26), [TCCNT_PTR] = A64_R(26),
/* temporary register for blinding constants */ /* temporary register for blinding constants */
[BPF_REG_AX] = A64_R(9), [BPF_REG_AX] = A64_R(9),
[FP_BOTTOM] = A64_R(27),
/* callee saved register for kern_vm_start address */ /* callee saved register for kern_vm_start address */
[ARENA_VM_START] = A64_R(28), [ARENA_VM_START] = A64_R(28),
}; };
...@@ -78,11 +76,14 @@ struct jit_ctx { ...@@ -78,11 +76,14 @@ struct jit_ctx {
int epilogue_offset; int epilogue_offset;
int *offset; int *offset;
int exentry_idx; int exentry_idx;
int nr_used_callee_reg;
u8 used_callee_reg[8]; /* r6~r9, fp, arena_vm_start */
__le32 *image; __le32 *image;
__le32 *ro_image; __le32 *ro_image;
u32 stack_size; u32 stack_size;
int fpb_offset;
u64 user_vm_start; u64 user_vm_start;
u64 arena_vm_start;
bool fp_used;
}; };
struct bpf_plt { struct bpf_plt {
...@@ -273,41 +274,141 @@ static bool is_lsi_offset(int offset, int scale) ...@@ -273,41 +274,141 @@ static bool is_lsi_offset(int offset, int scale)
return true; return true;
} }
/* generated prologue: /* generated main prog prologue:
* bti c // if CONFIG_ARM64_BTI_KERNEL * bti c // if CONFIG_ARM64_BTI_KERNEL
* mov x9, lr * mov x9, lr
* nop // POKE_OFFSET * nop // POKE_OFFSET
* paciasp // if CONFIG_ARM64_PTR_AUTH_KERNEL * paciasp // if CONFIG_ARM64_PTR_AUTH_KERNEL
* stp x29, lr, [sp, #-16]! * stp x29, lr, [sp, #-16]!
* mov x29, sp * mov x29, sp
* stp x19, x20, [sp, #-16]! * stp xzr, x26, [sp, #-16]!
* stp x21, x22, [sp, #-16]! * mov x26, sp
* stp x26, x25, [sp, #-16]!
* stp x26, x25, [sp, #-16]!
* stp x27, x28, [sp, #-16]!
* mov x25, sp
* mov tcc, #0
* // PROLOGUE_OFFSET * // PROLOGUE_OFFSET
* // save callee-saved registers
*/ */
static void prepare_bpf_tail_call_cnt(struct jit_ctx *ctx) static void prepare_bpf_tail_call_cnt(struct jit_ctx *ctx)
{ {
const struct bpf_prog *prog = ctx->prog; const bool is_main_prog = !bpf_is_subprog(ctx->prog);
const bool is_main_prog = !bpf_is_subprog(prog);
const u8 ptr = bpf2a64[TCCNT_PTR]; const u8 ptr = bpf2a64[TCCNT_PTR];
const u8 fp = bpf2a64[BPF_REG_FP];
const u8 tcc = ptr;
emit(A64_PUSH(ptr, fp, A64_SP), ctx);
if (is_main_prog) { if (is_main_prog) {
/* Initialize tail_call_cnt. */ /* Initialize tail_call_cnt. */
emit(A64_MOVZ(1, tcc, 0, 0), ctx); emit(A64_PUSH(A64_ZR, ptr, A64_SP), ctx);
emit(A64_PUSH(tcc, fp, A64_SP), ctx);
emit(A64_MOV(1, ptr, A64_SP), ctx); emit(A64_MOV(1, ptr, A64_SP), ctx);
} else
emit(A64_PUSH(ptr, ptr, A64_SP), ctx);
}
static void find_used_callee_regs(struct jit_ctx *ctx)
{
int i;
const struct bpf_prog *prog = ctx->prog;
const struct bpf_insn *insn = &prog->insnsi[0];
int reg_used = 0;
for (i = 0; i < prog->len; i++, insn++) {
if (insn->dst_reg == BPF_REG_6 || insn->src_reg == BPF_REG_6)
reg_used |= 1;
if (insn->dst_reg == BPF_REG_7 || insn->src_reg == BPF_REG_7)
reg_used |= 2;
if (insn->dst_reg == BPF_REG_8 || insn->src_reg == BPF_REG_8)
reg_used |= 4;
if (insn->dst_reg == BPF_REG_9 || insn->src_reg == BPF_REG_9)
reg_used |= 8;
if (insn->dst_reg == BPF_REG_FP || insn->src_reg == BPF_REG_FP) {
ctx->fp_used = true;
reg_used |= 16;
}
}
i = 0;
if (reg_used & 1)
ctx->used_callee_reg[i++] = bpf2a64[BPF_REG_6];
if (reg_used & 2)
ctx->used_callee_reg[i++] = bpf2a64[BPF_REG_7];
if (reg_used & 4)
ctx->used_callee_reg[i++] = bpf2a64[BPF_REG_8];
if (reg_used & 8)
ctx->used_callee_reg[i++] = bpf2a64[BPF_REG_9];
if (reg_used & 16)
ctx->used_callee_reg[i++] = bpf2a64[BPF_REG_FP];
if (ctx->arena_vm_start)
ctx->used_callee_reg[i++] = bpf2a64[ARENA_VM_START];
ctx->nr_used_callee_reg = i;
}
/* Save callee-saved registers */
static void push_callee_regs(struct jit_ctx *ctx)
{
int reg1, reg2, i;
/*
* Program acting as exception boundary should save all ARM64
* Callee-saved registers as the exception callback needs to recover
* all ARM64 Callee-saved registers in its epilogue.
*/
if (ctx->prog->aux->exception_boundary) {
emit(A64_PUSH(A64_R(19), A64_R(20), A64_SP), ctx);
emit(A64_PUSH(A64_R(21), A64_R(22), A64_SP), ctx);
emit(A64_PUSH(A64_R(23), A64_R(24), A64_SP), ctx);
emit(A64_PUSH(A64_R(25), A64_R(26), A64_SP), ctx);
emit(A64_PUSH(A64_R(27), A64_R(28), A64_SP), ctx);
} else { } else {
emit(A64_PUSH(ptr, fp, A64_SP), ctx); find_used_callee_regs(ctx);
emit(A64_NOP, ctx); for (i = 0; i + 1 < ctx->nr_used_callee_reg; i += 2) {
emit(A64_NOP, ctx); reg1 = ctx->used_callee_reg[i];
reg2 = ctx->used_callee_reg[i + 1];
emit(A64_PUSH(reg1, reg2, A64_SP), ctx);
}
if (i < ctx->nr_used_callee_reg) {
reg1 = ctx->used_callee_reg[i];
/* keep SP 16-byte aligned */
emit(A64_PUSH(reg1, A64_ZR, A64_SP), ctx);
}
}
}
/* Restore callee-saved registers */
static void pop_callee_regs(struct jit_ctx *ctx)
{
struct bpf_prog_aux *aux = ctx->prog->aux;
int reg1, reg2, i;
/*
* Program acting as exception boundary pushes R23 and R24 in addition
* to BPF callee-saved registers. Exception callback uses the boundary
* program's stack frame, so recover these extra registers in the above
* two cases.
*/
if (aux->exception_boundary || aux->exception_cb) {
emit(A64_POP(A64_R(27), A64_R(28), A64_SP), ctx);
emit(A64_POP(A64_R(25), A64_R(26), A64_SP), ctx);
emit(A64_POP(A64_R(23), A64_R(24), A64_SP), ctx);
emit(A64_POP(A64_R(21), A64_R(22), A64_SP), ctx);
emit(A64_POP(A64_R(19), A64_R(20), A64_SP), ctx);
} else {
i = ctx->nr_used_callee_reg - 1;
if (ctx->nr_used_callee_reg % 2 != 0) {
reg1 = ctx->used_callee_reg[i];
emit(A64_POP(reg1, A64_ZR, A64_SP), ctx);
i--;
}
while (i > 0) {
reg1 = ctx->used_callee_reg[i - 1];
reg2 = ctx->used_callee_reg[i];
emit(A64_POP(reg1, reg2, A64_SP), ctx);
i -= 2;
}
} }
} }
...@@ -318,19 +419,13 @@ static void prepare_bpf_tail_call_cnt(struct jit_ctx *ctx) ...@@ -318,19 +419,13 @@ static void prepare_bpf_tail_call_cnt(struct jit_ctx *ctx)
#define POKE_OFFSET (BTI_INSNS + 1) #define POKE_OFFSET (BTI_INSNS + 1)
/* Tail call offset to jump into */ /* Tail call offset to jump into */
#define PROLOGUE_OFFSET (BTI_INSNS + 2 + PAC_INSNS + 10) #define PROLOGUE_OFFSET (BTI_INSNS + 2 + PAC_INSNS + 4)
static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf, static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf)
bool is_exception_cb, u64 arena_vm_start)
{ {
const struct bpf_prog *prog = ctx->prog; const struct bpf_prog *prog = ctx->prog;
const bool is_main_prog = !bpf_is_subprog(prog); const bool is_main_prog = !bpf_is_subprog(prog);
const u8 r6 = bpf2a64[BPF_REG_6];
const u8 r7 = bpf2a64[BPF_REG_7];
const u8 r8 = bpf2a64[BPF_REG_8];
const u8 r9 = bpf2a64[BPF_REG_9];
const u8 fp = bpf2a64[BPF_REG_FP]; const u8 fp = bpf2a64[BPF_REG_FP];
const u8 fpb = bpf2a64[FP_BOTTOM];
const u8 arena_vm_base = bpf2a64[ARENA_VM_START]; const u8 arena_vm_base = bpf2a64[ARENA_VM_START];
const int idx0 = ctx->idx; const int idx0 = ctx->idx;
int cur_offset; int cur_offset;
...@@ -369,19 +464,28 @@ static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf, ...@@ -369,19 +464,28 @@ static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf,
emit(A64_MOV(1, A64_R(9), A64_LR), ctx); emit(A64_MOV(1, A64_R(9), A64_LR), ctx);
emit(A64_NOP, ctx); emit(A64_NOP, ctx);
if (!is_exception_cb) { if (!prog->aux->exception_cb) {
/* Sign lr */ /* Sign lr */
if (IS_ENABLED(CONFIG_ARM64_PTR_AUTH_KERNEL)) if (IS_ENABLED(CONFIG_ARM64_PTR_AUTH_KERNEL))
emit(A64_PACIASP, ctx); emit(A64_PACIASP, ctx);
/* Save FP and LR registers to stay align with ARM64 AAPCS */ /* Save FP and LR registers to stay align with ARM64 AAPCS */
emit(A64_PUSH(A64_FP, A64_LR, A64_SP), ctx); emit(A64_PUSH(A64_FP, A64_LR, A64_SP), ctx);
emit(A64_MOV(1, A64_FP, A64_SP), ctx); emit(A64_MOV(1, A64_FP, A64_SP), ctx);
/* Save callee-saved registers */
emit(A64_PUSH(r6, r7, A64_SP), ctx);
emit(A64_PUSH(r8, r9, A64_SP), ctx);
prepare_bpf_tail_call_cnt(ctx); prepare_bpf_tail_call_cnt(ctx);
emit(A64_PUSH(fpb, A64_R(28), A64_SP), ctx);
if (!ebpf_from_cbpf && is_main_prog) {
cur_offset = ctx->idx - idx0;
if (cur_offset != PROLOGUE_OFFSET) {
pr_err_once("PROLOGUE_OFFSET = %d, expected %d!\n",
cur_offset, PROLOGUE_OFFSET);
return -1;
}
/* BTI landing pad for the tail call, done with a BR */
emit_bti(A64_BTI_J, ctx);
}
push_callee_regs(ctx);
} else { } else {
/* /*
* Exception callback receives FP of Main Program as third * Exception callback receives FP of Main Program as third
...@@ -398,50 +502,23 @@ static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf, ...@@ -398,50 +502,23 @@ static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf,
emit(A64_SUB_I(1, A64_SP, A64_FP, 96), ctx); emit(A64_SUB_I(1, A64_SP, A64_FP, 96), ctx);
} }
/* Set up BPF prog stack base register */ if (ctx->fp_used)
emit(A64_MOV(1, fp, A64_SP), ctx); /* Set up BPF prog stack base register */
emit(A64_MOV(1, fp, A64_SP), ctx);
if (!ebpf_from_cbpf && is_main_prog) {
cur_offset = ctx->idx - idx0;
if (cur_offset != PROLOGUE_OFFSET) {
pr_err_once("PROLOGUE_OFFSET = %d, expected %d!\n",
cur_offset, PROLOGUE_OFFSET);
return -1;
}
/* BTI landing pad for the tail call, done with a BR */
emit_bti(A64_BTI_J, ctx);
}
/*
* Program acting as exception boundary should save all ARM64
* Callee-saved registers as the exception callback needs to recover
* all ARM64 Callee-saved registers in its epilogue.
*/
if (prog->aux->exception_boundary) {
/*
* As we are pushing two more registers, BPF_FP should be moved
* 16 bytes
*/
emit(A64_SUB_I(1, fp, fp, 16), ctx);
emit(A64_PUSH(A64_R(23), A64_R(24), A64_SP), ctx);
}
emit(A64_SUB_I(1, fpb, fp, ctx->fpb_offset), ctx);
/* Stack must be multiples of 16B */ /* Stack must be multiples of 16B */
ctx->stack_size = round_up(prog->aux->stack_depth, 16); ctx->stack_size = round_up(prog->aux->stack_depth, 16);
/* Set up function call stack */ /* Set up function call stack */
emit(A64_SUB_I(1, A64_SP, A64_SP, ctx->stack_size), ctx); if (ctx->stack_size)
emit(A64_SUB_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
if (arena_vm_start) if (ctx->arena_vm_start)
emit_a64_mov_i64(arena_vm_base, arena_vm_start, ctx); emit_a64_mov_i64(arena_vm_base, ctx->arena_vm_start, ctx);
return 0; return 0;
} }
static int out_offset = -1; /* initialized on the first pass of build_body() */
static int emit_bpf_tail_call(struct jit_ctx *ctx) static int emit_bpf_tail_call(struct jit_ctx *ctx)
{ {
/* bpf_tail_call(void *prog_ctx, struct bpf_array *array, u64 index) */ /* bpf_tail_call(void *prog_ctx, struct bpf_array *array, u64 index) */
...@@ -452,10 +529,10 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx) ...@@ -452,10 +529,10 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx)
const u8 prg = bpf2a64[TMP_REG_2]; const u8 prg = bpf2a64[TMP_REG_2];
const u8 tcc = bpf2a64[TMP_REG_3]; const u8 tcc = bpf2a64[TMP_REG_3];
const u8 ptr = bpf2a64[TCCNT_PTR]; const u8 ptr = bpf2a64[TCCNT_PTR];
const int idx0 = ctx->idx;
#define cur_offset (ctx->idx - idx0)
#define jmp_offset (out_offset - (cur_offset))
size_t off; size_t off;
__le32 *branch1 = NULL;
__le32 *branch2 = NULL;
__le32 *branch3 = NULL;
/* if (index >= array->map.max_entries) /* if (index >= array->map.max_entries)
* goto out; * goto out;
...@@ -465,17 +542,20 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx) ...@@ -465,17 +542,20 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx)
emit(A64_LDR32(tmp, r2, tmp), ctx); emit(A64_LDR32(tmp, r2, tmp), ctx);
emit(A64_MOV(0, r3, r3), ctx); emit(A64_MOV(0, r3, r3), ctx);
emit(A64_CMP(0, r3, tmp), ctx); emit(A64_CMP(0, r3, tmp), ctx);
emit(A64_B_(A64_COND_CS, jmp_offset), ctx); branch1 = ctx->image + ctx->idx;
emit(A64_NOP, ctx);
/* /*
* if ((*tail_call_cnt_ptr) >= MAX_TAIL_CALL_CNT) * if ((*tail_call_cnt_ptr) >= MAX_TAIL_CALL_CNT)
* goto out; * goto out;
* (*tail_call_cnt_ptr)++;
*/ */
emit_a64_mov_i64(tmp, MAX_TAIL_CALL_CNT, ctx); emit_a64_mov_i64(tmp, MAX_TAIL_CALL_CNT, ctx);
emit(A64_LDR64I(tcc, ptr, 0), ctx); emit(A64_LDR64I(tcc, ptr, 0), ctx);
emit(A64_CMP(1, tcc, tmp), ctx); emit(A64_CMP(1, tcc, tmp), ctx);
emit(A64_B_(A64_COND_CS, jmp_offset), ctx); branch2 = ctx->image + ctx->idx;
emit(A64_NOP, ctx);
/* (*tail_call_cnt_ptr)++; */
emit(A64_ADD_I(1, tcc, tcc, 1), ctx); emit(A64_ADD_I(1, tcc, tcc, 1), ctx);
/* prog = array->ptrs[index]; /* prog = array->ptrs[index];
...@@ -487,30 +567,37 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx) ...@@ -487,30 +567,37 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx)
emit(A64_ADD(1, tmp, r2, tmp), ctx); emit(A64_ADD(1, tmp, r2, tmp), ctx);
emit(A64_LSL(1, prg, r3, 3), ctx); emit(A64_LSL(1, prg, r3, 3), ctx);
emit(A64_LDR64(prg, tmp, prg), ctx); emit(A64_LDR64(prg, tmp, prg), ctx);
emit(A64_CBZ(1, prg, jmp_offset), ctx); branch3 = ctx->image + ctx->idx;
emit(A64_NOP, ctx);
/* Update tail_call_cnt if the slot is populated. */ /* Update tail_call_cnt if the slot is populated. */
emit(A64_STR64I(tcc, ptr, 0), ctx); emit(A64_STR64I(tcc, ptr, 0), ctx);
/* restore SP */
if (ctx->stack_size)
emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
pop_callee_regs(ctx);
/* goto *(prog->bpf_func + prologue_offset); */ /* goto *(prog->bpf_func + prologue_offset); */
off = offsetof(struct bpf_prog, bpf_func); off = offsetof(struct bpf_prog, bpf_func);
emit_a64_mov_i64(tmp, off, ctx); emit_a64_mov_i64(tmp, off, ctx);
emit(A64_LDR64(tmp, prg, tmp), ctx); emit(A64_LDR64(tmp, prg, tmp), ctx);
emit(A64_ADD_I(1, tmp, tmp, sizeof(u32) * PROLOGUE_OFFSET), ctx); emit(A64_ADD_I(1, tmp, tmp, sizeof(u32) * PROLOGUE_OFFSET), ctx);
emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
emit(A64_BR(tmp), ctx); emit(A64_BR(tmp), ctx);
/* out: */ if (ctx->image) {
if (out_offset == -1) off = &ctx->image[ctx->idx] - branch1;
out_offset = cur_offset; *branch1 = cpu_to_le32(A64_B_(A64_COND_CS, off));
if (cur_offset != out_offset) {
pr_err_once("tail_call out_offset = %d, expected %d!\n", off = &ctx->image[ctx->idx] - branch2;
cur_offset, out_offset); *branch2 = cpu_to_le32(A64_B_(A64_COND_CS, off));
return -1;
off = &ctx->image[ctx->idx] - branch3;
*branch3 = cpu_to_le32(A64_CBZ(1, prg, off));
} }
return 0; return 0;
#undef cur_offset
#undef jmp_offset
} }
#ifdef CONFIG_ARM64_LSE_ATOMICS #ifdef CONFIG_ARM64_LSE_ATOMICS
...@@ -736,38 +823,18 @@ static void build_plt(struct jit_ctx *ctx) ...@@ -736,38 +823,18 @@ static void build_plt(struct jit_ctx *ctx)
plt->target = (u64)&dummy_tramp; plt->target = (u64)&dummy_tramp;
} }
static void build_epilogue(struct jit_ctx *ctx, bool is_exception_cb) static void build_epilogue(struct jit_ctx *ctx)
{ {
const u8 r0 = bpf2a64[BPF_REG_0]; const u8 r0 = bpf2a64[BPF_REG_0];
const u8 r6 = bpf2a64[BPF_REG_6];
const u8 r7 = bpf2a64[BPF_REG_7];
const u8 r8 = bpf2a64[BPF_REG_8];
const u8 r9 = bpf2a64[BPF_REG_9];
const u8 fp = bpf2a64[BPF_REG_FP];
const u8 ptr = bpf2a64[TCCNT_PTR]; const u8 ptr = bpf2a64[TCCNT_PTR];
const u8 fpb = bpf2a64[FP_BOTTOM];
/* We're done with BPF stack */ /* We're done with BPF stack */
emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx); if (ctx->stack_size)
emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
/* pop_callee_regs(ctx);
* Program acting as exception boundary pushes R23 and R24 in addition
* to BPF callee-saved registers. Exception callback uses the boundary
* program's stack frame, so recover these extra registers in the above
* two cases.
*/
if (ctx->prog->aux->exception_boundary || is_exception_cb)
emit(A64_POP(A64_R(23), A64_R(24), A64_SP), ctx);
/* Restore x27 and x28 */
emit(A64_POP(fpb, A64_R(28), A64_SP), ctx);
/* Restore fs (x25) and x26 */
emit(A64_POP(ptr, fp, A64_SP), ctx);
emit(A64_POP(ptr, fp, A64_SP), ctx);
/* Restore callee-saved register */ emit(A64_POP(A64_ZR, ptr, A64_SP), ctx);
emit(A64_POP(r8, r9, A64_SP), ctx);
emit(A64_POP(r6, r7, A64_SP), ctx);
/* Restore FP/LR registers */ /* Restore FP/LR registers */
emit(A64_POP(A64_FP, A64_LR, A64_SP), ctx); emit(A64_POP(A64_FP, A64_LR, A64_SP), ctx);
...@@ -887,7 +954,6 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, ...@@ -887,7 +954,6 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx,
const u8 tmp = bpf2a64[TMP_REG_1]; const u8 tmp = bpf2a64[TMP_REG_1];
const u8 tmp2 = bpf2a64[TMP_REG_2]; const u8 tmp2 = bpf2a64[TMP_REG_2];
const u8 fp = bpf2a64[BPF_REG_FP]; const u8 fp = bpf2a64[BPF_REG_FP];
const u8 fpb = bpf2a64[FP_BOTTOM];
const u8 arena_vm_base = bpf2a64[ARENA_VM_START]; const u8 arena_vm_base = bpf2a64[ARENA_VM_START];
const s16 off = insn->off; const s16 off = insn->off;
const s32 imm = insn->imm; const s32 imm = insn->imm;
...@@ -1339,9 +1405,9 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, ...@@ -1339,9 +1405,9 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx,
emit(A64_ADD(1, tmp2, src, arena_vm_base), ctx); emit(A64_ADD(1, tmp2, src, arena_vm_base), ctx);
src = tmp2; src = tmp2;
} }
if (ctx->fpb_offset > 0 && src == fp && BPF_MODE(insn->code) != BPF_PROBE_MEM32) { if (src == fp) {
src_adj = fpb; src_adj = A64_SP;
off_adj = off + ctx->fpb_offset; off_adj = off + ctx->stack_size;
} else { } else {
src_adj = src; src_adj = src;
off_adj = off; off_adj = off;
...@@ -1432,9 +1498,9 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, ...@@ -1432,9 +1498,9 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx,
emit(A64_ADD(1, tmp2, dst, arena_vm_base), ctx); emit(A64_ADD(1, tmp2, dst, arena_vm_base), ctx);
dst = tmp2; dst = tmp2;
} }
if (ctx->fpb_offset > 0 && dst == fp && BPF_MODE(insn->code) != BPF_PROBE_MEM32) { if (dst == fp) {
dst_adj = fpb; dst_adj = A64_SP;
off_adj = off + ctx->fpb_offset; off_adj = off + ctx->stack_size;
} else { } else {
dst_adj = dst; dst_adj = dst;
off_adj = off; off_adj = off;
...@@ -1494,9 +1560,9 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, ...@@ -1494,9 +1560,9 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx,
emit(A64_ADD(1, tmp2, dst, arena_vm_base), ctx); emit(A64_ADD(1, tmp2, dst, arena_vm_base), ctx);
dst = tmp2; dst = tmp2;
} }
if (ctx->fpb_offset > 0 && dst == fp && BPF_MODE(insn->code) != BPF_PROBE_MEM32) { if (dst == fp) {
dst_adj = fpb; dst_adj = A64_SP;
off_adj = off + ctx->fpb_offset; off_adj = off + ctx->stack_size;
} else { } else {
dst_adj = dst; dst_adj = dst;
off_adj = off; off_adj = off;
...@@ -1565,79 +1631,6 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, ...@@ -1565,79 +1631,6 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx,
return 0; return 0;
} }
/*
* Return 0 if FP may change at runtime, otherwise find the minimum negative
* offset to FP, converts it to positive number, and align down to 8 bytes.
*/
static int find_fpb_offset(struct bpf_prog *prog)
{
int i;
int offset = 0;
for (i = 0; i < prog->len; i++) {
const struct bpf_insn *insn = &prog->insnsi[i];
const u8 class = BPF_CLASS(insn->code);
const u8 mode = BPF_MODE(insn->code);
const u8 src = insn->src_reg;
const u8 dst = insn->dst_reg;
const s32 imm = insn->imm;
const s16 off = insn->off;
switch (class) {
case BPF_STX:
case BPF_ST:
/* fp holds atomic operation result */
if (class == BPF_STX && mode == BPF_ATOMIC &&
((imm == BPF_XCHG ||
imm == (BPF_FETCH | BPF_ADD) ||
imm == (BPF_FETCH | BPF_AND) ||
imm == (BPF_FETCH | BPF_XOR) ||
imm == (BPF_FETCH | BPF_OR)) &&
src == BPF_REG_FP))
return 0;
if (mode == BPF_MEM && dst == BPF_REG_FP &&
off < offset)
offset = insn->off;
break;
case BPF_JMP32:
case BPF_JMP:
break;
case BPF_LDX:
case BPF_LD:
/* fp holds load result */
if (dst == BPF_REG_FP)
return 0;
if (class == BPF_LDX && mode == BPF_MEM &&
src == BPF_REG_FP && off < offset)
offset = off;
break;
case BPF_ALU:
case BPF_ALU64:
default:
/* fp holds ALU result */
if (dst == BPF_REG_FP)
return 0;
}
}
if (offset < 0) {
/*
* safely be converted to a positive 'int', since insn->off
* is 's16'
*/
offset = -offset;
/* align down to 8 bytes */
offset = ALIGN_DOWN(offset, 8);
}
return offset;
}
static int build_body(struct jit_ctx *ctx, bool extra_pass) static int build_body(struct jit_ctx *ctx, bool extra_pass)
{ {
const struct bpf_prog *prog = ctx->prog; const struct bpf_prog *prog = ctx->prog;
...@@ -1726,7 +1719,6 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog) ...@@ -1726,7 +1719,6 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
bool tmp_blinded = false; bool tmp_blinded = false;
bool extra_pass = false; bool extra_pass = false;
struct jit_ctx ctx; struct jit_ctx ctx;
u64 arena_vm_start;
u8 *image_ptr; u8 *image_ptr;
u8 *ro_image_ptr; u8 *ro_image_ptr;
...@@ -1744,7 +1736,6 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog) ...@@ -1744,7 +1736,6 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
prog = tmp; prog = tmp;
} }
arena_vm_start = bpf_arena_get_kern_vm_start(prog->aux->arena);
jit_data = prog->aux->jit_data; jit_data = prog->aux->jit_data;
if (!jit_data) { if (!jit_data) {
jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL); jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
...@@ -1774,8 +1765,8 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog) ...@@ -1774,8 +1765,8 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
goto out_off; goto out_off;
} }
ctx.fpb_offset = find_fpb_offset(prog);
ctx.user_vm_start = bpf_arena_get_user_vm_start(prog->aux->arena); ctx.user_vm_start = bpf_arena_get_user_vm_start(prog->aux->arena);
ctx.arena_vm_start = bpf_arena_get_kern_vm_start(prog->aux->arena);
/* /*
* 1. Initial fake pass to compute ctx->idx and ctx->offset. * 1. Initial fake pass to compute ctx->idx and ctx->offset.
...@@ -1783,8 +1774,7 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog) ...@@ -1783,8 +1774,7 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
* BPF line info needs ctx->offset[i] to be the offset of * BPF line info needs ctx->offset[i] to be the offset of
* instruction[i] in jited image, so build prologue first. * instruction[i] in jited image, so build prologue first.
*/ */
if (build_prologue(&ctx, was_classic, prog->aux->exception_cb, if (build_prologue(&ctx, was_classic)) {
arena_vm_start)) {
prog = orig_prog; prog = orig_prog;
goto out_off; goto out_off;
} }
...@@ -1795,7 +1785,7 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog) ...@@ -1795,7 +1785,7 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
} }
ctx.epilogue_offset = ctx.idx; ctx.epilogue_offset = ctx.idx;
build_epilogue(&ctx, prog->aux->exception_cb); build_epilogue(&ctx);
build_plt(&ctx); build_plt(&ctx);
extable_align = __alignof__(struct exception_table_entry); extable_align = __alignof__(struct exception_table_entry);
...@@ -1832,14 +1822,14 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog) ...@@ -1832,14 +1822,14 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
ctx.idx = 0; ctx.idx = 0;
ctx.exentry_idx = 0; ctx.exentry_idx = 0;
build_prologue(&ctx, was_classic, prog->aux->exception_cb, arena_vm_start); build_prologue(&ctx, was_classic);
if (build_body(&ctx, extra_pass)) { if (build_body(&ctx, extra_pass)) {
prog = orig_prog; prog = orig_prog;
goto out_free_hdr; goto out_free_hdr;
} }
build_epilogue(&ctx, prog->aux->exception_cb); build_epilogue(&ctx);
build_plt(&ctx); build_plt(&ctx);
/* 3. Extra pass to validate JITed code. */ /* 3. Extra pass to validate JITed code. */
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment