• Leon Hwang's avatar
    bpf, x64: Fix tailcall hierarchy · 116e04ba
    Leon Hwang authored
    This patch fixes a tailcall issue caused by abusing the tailcall in
    bpf2bpf feature.
    
    As we know, tail_call_cnt propagates by rax from caller to callee when
    to call subprog in tailcall context. But, like the following example,
    MAX_TAIL_CALL_CNT won't work because of missing tail_call_cnt
    back-propagation from callee to caller.
    
    \#include <linux/bpf.h>
    \#include <bpf/bpf_helpers.h>
    \#include "bpf_legacy.h"
    
    struct {
    	__uint(type, BPF_MAP_TYPE_PROG_ARRAY);
    	__uint(max_entries, 1);
    	__uint(key_size, sizeof(__u32));
    	__uint(value_size, sizeof(__u32));
    } jmp_table SEC(".maps");
    
    int count = 0;
    
    static __noinline
    int subprog_tail1(struct __sk_buff *skb)
    {
    	bpf_tail_call_static(skb, &jmp_table, 0);
    	return 0;
    }
    
    static __noinline
    int subprog_tail2(struct __sk_buff *skb)
    {
    	bpf_tail_call_static(skb, &jmp_table, 0);
    	return 0;
    }
    
    SEC("tc")
    int entry(struct __sk_buff *skb)
    {
    	volatile int ret = 1;
    
    	count++;
    	subprog_tail1(skb);
    	subprog_tail2(skb);
    
    	return ret;
    }
    
    char __license[] SEC("license") = "GPL";
    
    At run time, the tail_call_cnt in entry() will be propagated to
    subprog_tail1() and subprog_tail2(). But, when the tail_call_cnt in
    subprog_tail1() updates when bpf_tail_call_static(), the tail_call_cnt
    in entry() won't be updated at the same time. As a result, in entry(),
    when tail_call_cnt in entry() is less than MAX_TAIL_CALL_CNT and
    subprog_tail1() returns because of MAX_TAIL_CALL_CNT limit,
    bpf_tail_call_static() in suprog_tail2() is able to run because the
    tail_call_cnt in subprog_tail2() propagated from entry() is less than
    MAX_TAIL_CALL_CNT.
    
    So, how many tailcalls are there for this case if no error happens?
    
    From top-down view, does it look like hierarchy layer and layer?
    
    With this view, there will be 2+4+8+...+2^33 = 2^34 - 2 = 17,179,869,182
    tailcalls for this case.
    
    How about there are N subprog_tail() in entry()? There will be almost
    N^34 tailcalls.
    
    Then, in this patch, it resolves this case on x86_64.
    
    In stead of propagating tail_call_cnt from caller to callee, it
    propagates its pointer, tail_call_cnt_ptr, tcc_ptr for short.
    
    However, where does it store tail_call_cnt?
    
    It stores tail_call_cnt on the stack of main prog. When tail call
    happens in subprog, it increments tail_call_cnt by tcc_ptr.
    
    Meanwhile, it stores tail_call_cnt_ptr on the stack of main prog, too.
    
    And, before jump to tail callee, it has to pop tail_call_cnt and
    tail_call_cnt_ptr.
    
    Then, at the prologue of subprog, it must not make rax as
    tail_call_cnt_ptr again. It has to reuse tail_call_cnt_ptr from caller.
    
    As a result, at run time, it has to recognize rax is tail_call_cnt or
    tail_call_cnt_ptr at prologue by:
    
    1. rax is tail_call_cnt if rax is <= MAX_TAIL_CALL_CNT.
    2. rax is tail_call_cnt_ptr if rax is > MAX_TAIL_CALL_CNT, because a
       pointer won't be <= MAX_TAIL_CALL_CNT.
    
    Here's an example to dump JITed.
    
    struct {
    	__uint(type, BPF_MAP_TYPE_PROG_ARRAY);
    	__uint(max_entries, 1);
    	__uint(key_size, sizeof(__u32));
    	__uint(value_size, sizeof(__u32));
    } jmp_table SEC(".maps");
    
    int count = 0;
    
    static __noinline
    int subprog_tail(struct __sk_buff *skb)
    {
    	bpf_tail_call_static(skb, &jmp_table, 0);
    	return 0;
    }
    
    SEC("tc")
    int entry(struct __sk_buff *skb)
    {
    	int ret = 1;
    
    	count++;
    	subprog_tail(skb);
    	subprog_tail(skb);
    
    	return ret;
    }
    
    When bpftool p d j id 42:
    
    int entry(struct __sk_buff * skb):
    bpf_prog_0c0f4c2413ef19b1_entry:
    ; int entry(struct __sk_buff *skb)
       0:	endbr64
       4:	nopl	(%rax,%rax)
       9:	xorq	%rax, %rax		;; rax = 0 (tail_call_cnt)
       c:	pushq	%rbp
       d:	movq	%rsp, %rbp
      10:	endbr64
      14:	cmpq	$33, %rax		;; if rax > 33, rax = tcc_ptr
      18:	ja	0x20			;; if rax > 33 goto 0x20 ---+
      1a:	pushq	%rax			;; [rbp - 8] = rax = 0      |
      1b:	movq	%rsp, %rax		;; rax = rbp - 8            |
      1e:	jmp	0x21			;; ---------+               |
      20:	pushq	%rax			;; <--------|---------------+
      21:	pushq	%rax			;; <--------+ [rbp - 16] = rax
      22:	pushq	%rbx			;; callee saved
      23:	movq	%rdi, %rbx		;; rbx = skb (callee saved)
    ; count++;
      26:	movabsq	$-82417199407104, %rdi
      30:	movl	(%rdi), %esi
      33:	addl	$1, %esi
      36:	movl	%esi, (%rdi)
    ; subprog_tail(skb);
      39:	movq	%rbx, %rdi		;; rdi = skb
      3c:	movq	-16(%rbp), %rax		;; rax = tcc_ptr
      43:	callq	0x80			;; call subprog_tail()
    ; subprog_tail(skb);
      48:	movq	%rbx, %rdi		;; rdi = skb
      4b:	movq	-16(%rbp), %rax		;; rax = tcc_ptr
      52:	callq	0x80			;; call subprog_tail()
    ; return ret;
      57:	movl	$1, %eax
      5c:	popq	%rbx
      5d:	leave
      5e:	retq
    
    int subprog_tail(struct __sk_buff * skb):
    bpf_prog_3a140cef239a4b4f_subprog_tail:
    ; int subprog_tail(struct __sk_buff *skb)
       0:	endbr64
       4:	nopl	(%rax,%rax)
       9:	nopl	(%rax)			;; do not touch tail_call_cnt
       c:	pushq	%rbp
       d:	movq	%rsp, %rbp
      10:	endbr64
      14:	pushq	%rax			;; [rbp - 8]  = rax (tcc_ptr)
      15:	pushq	%rax			;; [rbp - 16] = rax (tcc_ptr)
      16:	pushq	%rbx			;; callee saved
      17:	pushq	%r13			;; callee saved
      19:	movq	%rdi, %rbx		;; rbx = skb
    ; asm volatile("r1 = %[ctx]\n\t"
      1c:	movabsq	$-105487587488768, %r13	;; r13 = jmp_table
      26:	movq	%rbx, %rdi		;; 1st arg, skb
      29:	movq	%r13, %rsi		;; 2nd arg, jmp_table
      2c:	xorl	%edx, %edx		;; 3rd arg, index = 0
      2e:	movq	-16(%rbp), %rax		;; rax = [rbp - 16] (tcc_ptr)
      35:	cmpq	$33, (%rax)
      39:	jae	0x4e			;; if *tcc_ptr >= 33 goto 0x4e --------+
      3b:	jmp	0x4e			;; jmp bypass, toggled by poking       |
      40:	addq	$1, (%rax)		;; (*tcc_ptr)++                        |
      44:	popq	%r13			;; callee saved                        |
      46:	popq	%rbx			;; callee saved                        |
      47:	popq	%rax			;; undo rbp-16 push                    |
      48:	popq	%rax			;; undo rbp-8  push                    |
      49:	nopl	(%rax,%rax)		;; tail call target, toggled by poking |
    ; return 0;				;;                                     |
      4e:	popq	%r13			;; restore callee saved <--------------+
      50:	popq	%rbx			;; restore callee saved
      51:	leave
      52:	retq
    
    Furthermore, when trampoline is the caller of bpf prog, which is
    tail_call_reachable, it is required to propagate rax through trampoline.
    
    Fixes: ebf7d1f5 ("bpf, x64: rework pro/epilogue and tailcall handling in JIT")
    Fixes: e411901c ("bpf: allow for tailcalls in BPF subprograms for x64 JIT")
    Reviewed-by: default avatarEduard Zingerman <eddyz87@gmail.com>
    Signed-off-by: default avatarLeon Hwang <hffilwlqm@gmail.com>
    Link: https://lore.kernel.org/r/20240714123902.32305-2-hffilwlqm@gmail.comSigned-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
    Signed-off-by: default avatarAndrii Nakryiko <andrii@kernel.org>
    116e04ba
bpf_jit_comp.c 96.9 KB