bpf: Fix incorrect verifier pruning due to missing register precision taints
Juan Jose et al reported an issue found via fuzzing where the verifier's pruning logic prematurely marks a program path as safe. Consider the following program: 0: (b7) r6 = 1024 1: (b7) r7 = 0 2: (b7) r8 = 0 3: (b7) r9 = -2147483648 4: (97) r6 %= 1025 5: (05) goto pc+0 6: (bd) if r6 <= r9 goto pc+2 7: (97) r6 %= 1 8: (b7) r9 = 0 9: (bd) if r6 <= r9 goto pc+1 10: (b7) r6 = 0 11: (b7) r0 = 0 12: (63) *(u32 *)(r10 -4) = r0 13: (18) r4 = 0xffff888103693400 // map_ptr(ks=4,vs=48) 15: (bf) r1 = r4 16: (bf) r2 = r10 17: (07) r2 += -4 18: (85) call bpf_map_lookup_elem#1 19: (55) if r0 != 0x0 goto pc+1 20: (95) exit 21: (77) r6 >>= 10 22: (27) r6 *= 8192 23: (bf) r1 = r0 24: (0f) r0 += r6 25: (79) r3 = *(u64 *)(r0 +0) 26: (7b) *(u64 *)(r1 +0) = r3 27: (95) exit The verifier treats this as safe, leading to oob read/write access due to an incorrect verifier conclusion: func#0 @0 0: R1=ctx(off=0,imm=0) R10=fp0 0: (b7) r6 = 1024 ; R6_w=1024 1: (b7) r7 = 0 ; R7_w=0 2: (b7) r8 = 0 ; R8_w=0 3: (b7) r9 = -2147483648 ; R9_w=-2147483648 4: (97) r6 %= 1025 ; R6_w=scalar() 5: (05) goto pc+0 6: (bd) if r6 <= r9 goto pc+2 ; R6_w=scalar(umin=18446744071562067969,var_off=(0xffffffff00000000; 0xffffffff)) R9_w=-2147483648 7: (97) r6 %= 1 ; R6_w=scalar() 8: (b7) r9 = 0 ; R9=0 9: (bd) if r6 <= r9 goto pc+1 ; R6=scalar(umin=1) R9=0 10: (b7) r6 = 0 ; R6_w=0 11: (b7) r0 = 0 ; R0_w=0 12: (63) *(u32 *)(r10 -4) = r0 last_idx 12 first_idx 9 regs=1 stack=0 before 11: (b7) r0 = 0 13: R0_w=0 R10=fp0 fp-8=0000???? 13: (18) r4 = 0xffff8ad3886c2a00 ; R4_w=map_ptr(off=0,ks=4,vs=48,imm=0) 15: (bf) r1 = r4 ; R1_w=map_ptr(off=0,ks=4,vs=48,imm=0) R4_w=map_ptr(off=0,ks=4,vs=48,imm=0) 16: (bf) r2 = r10 ; R2_w=fp0 R10=fp0 17: (07) r2 += -4 ; R2_w=fp-4 18: (85) call bpf_map_lookup_elem#1 ; R0=map_value_or_null(id=1,off=0,ks=4,vs=48,imm=0) 19: (55) if r0 != 0x0 goto pc+1 ; R0=0 20: (95) exit from 19 to 21: R0=map_value(off=0,ks=4,vs=48,imm=0) R6=0 R7=0 R8=0 R9=0 R10=fp0 fp-8=mmmm???? 21: (77) r6 >>= 10 ; R6_w=0 22: (27) r6 *= 8192 ; R6_w=0 23: (bf) r1 = r0 ; R0=map_value(off=0,ks=4,vs=48,imm=0) R1_w=map_value(off=0,ks=4,vs=48,imm=0) 24: (0f) r0 += r6 last_idx 24 first_idx 19 regs=40 stack=0 before 23: (bf) r1 = r0 regs=40 stack=0 before 22: (27) r6 *= 8192 regs=40 stack=0 before 21: (77) r6 >>= 10 regs=40 stack=0 before 19: (55) if r0 != 0x0 goto pc+1 parent didn't have regs=40 stack=0 marks: R0_rw=map_value_or_null(id=1,off=0,ks=4,vs=48,imm=0) R6_rw=P0 R7=0 R8=0 R9=0 R10=fp0 fp-8=mmmm???? last_idx 18 first_idx 9 regs=40 stack=0 before 18: (85) call bpf_map_lookup_elem#1 regs=40 stack=0 before 17: (07) r2 += -4 regs=40 stack=0 before 16: (bf) r2 = r10 regs=40 stack=0 before 15: (bf) r1 = r4 regs=40 stack=0 before 13: (18) r4 = 0xffff8ad3886c2a00 regs=40 stack=0 before 12: (63) *(u32 *)(r10 -4) = r0 regs=40 stack=0 before 11: (b7) r0 = 0 regs=40 stack=0 before 10: (b7) r6 = 0 25: (79) r3 = *(u64 *)(r0 +0) ; R0_w=map_value(off=0,ks=4,vs=48,imm=0) R3_w=scalar() 26: (7b) *(u64 *)(r1 +0) = r3 ; R1_w=map_value(off=0,ks=4,vs=48,imm=0) R3_w=scalar() 27: (95) exit from 9 to 11: R1=ctx(off=0,imm=0) R6=0 R7=0 R8=0 R9=0 R10=fp0 11: (b7) r0 = 0 ; R0_w=0 12: (63) *(u32 *)(r10 -4) = r0 last_idx 12 first_idx 11 regs=1 stack=0 before 11: (b7) r0 = 0 13: R0_w=0 R10=fp0 fp-8=0000???? 13: (18) r4 = 0xffff8ad3886c2a00 ; R4_w=map_ptr(off=0,ks=4,vs=48,imm=0) 15: (bf) r1 = r4 ; R1_w=map_ptr(off=0,ks=4,vs=48,imm=0) R4_w=map_ptr(off=0,ks=4,vs=48,imm=0) 16: (bf) r2 = r10 ; R2_w=fp0 R10=fp0 17: (07) r2 += -4 ; R2_w=fp-4 18: (85) call bpf_map_lookup_elem#1 frame 0: propagating r6 last_idx 19 first_idx 11 regs=40 stack=0 before 18: (85) call bpf_map_lookup_elem#1 regs=40 stack=0 before 17: (07) r2 += -4 regs=40 stack=0 before 16: (bf) r2 = r10 regs=40 stack=0 before 15: (bf) r1 = r4 regs=40 stack=0 before 13: (18) r4 = 0xffff8ad3886c2a00 regs=40 stack=0 before 12: (63) *(u32 *)(r10 -4) = r0 regs=40 stack=0 before 11: (b7) r0 = 0 parent didn't have regs=40 stack=0 marks: R1=ctx(off=0,imm=0) R6_r=P0 R7=0 R8=0 R9=0 R10=fp0 last_idx 9 first_idx 9 regs=40 stack=0 before 9: (bd) if r6 <= r9 goto pc+1 parent didn't have regs=40 stack=0 marks: R1=ctx(off=0,imm=0) R6_rw=Pscalar() R7_w=0 R8_w=0 R9_rw=0 R10=fp0 last_idx 8 first_idx 0 regs=40 stack=0 before 8: (b7) r9 = 0 regs=40 stack=0 before 7: (97) r6 %= 1 regs=40 stack=0 before 6: (bd) if r6 <= r9 goto pc+2 regs=40 stack=0 before 5: (05) goto pc+0 regs=40 stack=0 before 4: (97) r6 %= 1025 regs=40 stack=0 before 3: (b7) r9 = -2147483648 regs=40 stack=0 before 2: (b7) r8 = 0 regs=40 stack=0 before 1: (b7) r7 = 0 regs=40 stack=0 before 0: (b7) r6 = 1024 19: safe frame 0: propagating r6 last_idx 9 first_idx 0 regs=40 stack=0 before 6: (bd) if r6 <= r9 goto pc+2 regs=40 stack=0 before 5: (05) goto pc+0 regs=40 stack=0 before 4: (97) r6 %= 1025 regs=40 stack=0 before 3: (b7) r9 = -2147483648 regs=40 stack=0 before 2: (b7) r8 = 0 regs=40 stack=0 before 1: (b7) r7 = 0 regs=40 stack=0 before 0: (b7) r6 = 1024 from 6 to 9: safe verification time 110 usec stack depth 4 processed 36 insns (limit 1000000) max_states_per_insn 0 total_states 3 peak_states 3 mark_read 2 The verifier considers this program as safe by mistakenly pruning unsafe code paths. In the above func#0, code lines 0-10 are of interest. In line 0-3 registers r6 to r9 are initialized with known scalar values. In line 4 the register r6 is reset to an unknown scalar given the verifier does not track modulo operations. Due to this, the verifier can also not determine precisely which branches in line 6 and 9 are taken, therefore it needs to explore them both. As can be seen, the verifier starts with exploring the false/fall-through paths first. The 'from 19 to 21' path has both r6=0 and r9=0 and the pointer arithmetic on r0 += r6 is therefore considered safe. Given the arithmetic, r6 is correctly marked for precision tracking where backtracking kicks in where it walks back the current path all the way where r6 was set to 0 in the fall-through branch. Next, the pruning logics pops the path 'from 9 to 11' from the stack. Also here, the state of the registers is the same, that is, r6=0 and r9=0, so that at line 19 the path can be pruned as it is considered safe. It is interesting to note that the conditional in line 9 turned r6 into a more precise state, that is, in the fall-through path at the beginning of line 10, it is R6=scalar(umin=1), and in the branch-taken path (which is analyzed here) at the beginning of line 11, r6 turned into a known const r6=0 as r9=0 prior to that and therefore (unsigned) r6 <= 0 concludes that r6 must be 0 (**): [...] ; R6_w=scalar() 9: (bd) if r6 <= r9 goto pc+1 ; R6=scalar(umin=1) R9=0 [...] from 9 to 11: R1=ctx(off=0,imm=0) R6=0 R7=0 R8=0 R9=0 R10=fp0 [...] The next path is 'from 6 to 9'. The verifier considers the old and current state equivalent, and therefore prunes the search incorrectly. Looking into the two states which are being compared by the pruning logic at line 9, the old state consists of R6_rwD=Pscalar() R9_rwD=0 R10=fp0 and the new state consists of R1=ctx(off=0,imm=0) R6_w=scalar(umax=18446744071562067968) R7_w=0 R8_w=0 R9_w=-2147483648 R10=fp0. While r6 had the reg->precise flag correctly set in the old state, r9 did not. Both r6'es are considered as equivalent given the old one is a superset of the current, more precise one, however, r9's actual values (0 vs 0x80000000) mismatch. Given the old r9 did not have reg->precise flag set, the verifier does not consider the register as contributing to the precision state of r6, and therefore it considered both r9 states as equivalent. However, for this specific pruned path (which is also the actual path taken at runtime), register r6 will be 0x400 and r9 0x80000000 when reaching line 21, thus oob-accessing the map. The purpose of precision tracking is to initially mark registers (including spilled ones) as imprecise to help verifier's pruning logic finding equivalent states it can then prune if they don't contribute to the program's safety aspects. For example, if registers are used for pointer arithmetic or to pass constant length to a helper, then the verifier sets reg->precise flag and backtracks the BPF program instruction sequence and chain of verifier states to ensure that the given register or stack slot including their dependencies are marked as precisely tracked scalar. This also includes any other registers and slots that contribute to a tracked state of given registers/stack slot. This backtracking relies on recorded jmp_history and is able to traverse entire chain of parent states. This process ends only when all the necessary registers/slots and their transitive dependencies are marked as precise. The backtrack_insn() is called from the current instruction up to the first instruction, and its purpose is to compute a bitmask of registers and stack slots that need precision tracking in the parent's verifier state. For example, if a current instruction is r6 = r7, then r6 needs precision after this instruction and r7 needs precision before this instruction, that is, in the parent state. Hence for the latter r7 is marked and r6 unmarked. For the class of jmp/jmp32 instructions, backtrack_insn() today only looks at call and exit instructions and for all other conditionals the masks remain as-is. However, in the given situation register r6 has a dependency on r9 (as described above in **), so also that one needs to be marked for precision tracking. In other words, if an imprecise register influences a precise one, then the imprecise register should also be marked precise. Meaning, in the parent state both dest and src register need to be tracked for precision and therefore the marking must be more conservative by setting reg->precise flag for both. The precision propagation needs to cover both for the conditional: if the src reg was marked but not the dst reg and vice versa. After the fix the program is correctly rejected: func#0 @0 0: R1=ctx(off=0,imm=0) R10=fp0 0: (b7) r6 = 1024 ; R6_w=1024 1: (b7) r7 = 0 ; R7_w=0 2: (b7) r8 = 0 ; R8_w=0 3: (b7) r9 = -2147483648 ; R9_w=-2147483648 4: (97) r6 %= 1025 ; R6_w=scalar() 5: (05) goto pc+0 6: (bd) if r6 <= r9 goto pc+2 ; R6_w=scalar(umin=18446744071562067969,var_off=(0xffffffff80000000; 0x7fffffff),u32_min=-2147483648) R9_w=-2147483648 7: (97) r6 %= 1 ; R6_w=scalar() 8: (b7) r9 = 0 ; R9=0 9: (bd) if r6 <= r9 goto pc+1 ; R6=scalar(umin=1) R9=0 10: (b7) r6 = 0 ; R6_w=0 11: (b7) r0 = 0 ; R0_w=0 12: (63) *(u32 *)(r10 -4) = r0 last_idx 12 first_idx 9 regs=1 stack=0 before 11: (b7) r0 = 0 13: R0_w=0 R10=fp0 fp-8=0000???? 13: (18) r4 = 0xffff9290dc5bfe00 ; R4_w=map_ptr(off=0,ks=4,vs=48,imm=0) 15: (bf) r1 = r4 ; R1_w=map_ptr(off=0,ks=4,vs=48,imm=0) R4_w=map_ptr(off=0,ks=4,vs=48,imm=0) 16: (bf) r2 = r10 ; R2_w=fp0 R10=fp0 17: (07) r2 += -4 ; R2_w=fp-4 18: (85) call bpf_map_lookup_elem#1 ; R0=map_value_or_null(id=1,off=0,ks=4,vs=48,imm=0) 19: (55) if r0 != 0x0 goto pc+1 ; R0=0 20: (95) exit from 19 to 21: R0=map_value(off=0,ks=4,vs=48,imm=0) R6=0 R7=0 R8=0 R9=0 R10=fp0 fp-8=mmmm???? 21: (77) r6 >>= 10 ; R6_w=0 22: (27) r6 *= 8192 ; R6_w=0 23: (bf) r1 = r0 ; R0=map_value(off=0,ks=4,vs=48,imm=0) R1_w=map_value(off=0,ks=4,vs=48,imm=0) 24: (0f) r0 += r6 last_idx 24 first_idx 19 regs=40 stack=0 before 23: (bf) r1 = r0 regs=40 stack=0 before 22: (27) r6 *= 8192 regs=40 stack=0 before 21: (77) r6 >>= 10 regs=40 stack=0 before 19: (55) if r0 != 0x0 goto pc+1 parent didn't have regs=40 stack=0 marks: R0_rw=map_value_or_null(id=1,off=0,ks=4,vs=48,imm=0) R6_rw=P0 R7=0 R8=0 R9=0 R10=fp0 fp-8=mmmm???? last_idx 18 first_idx 9 regs=40 stack=0 before 18: (85) call bpf_map_lookup_elem#1 regs=40 stack=0 before 17: (07) r2 += -4 regs=40 stack=0 before 16: (bf) r2 = r10 regs=40 stack=0 before 15: (bf) r1 = r4 regs=40 stack=0 before 13: (18) r4 = 0xffff9290dc5bfe00 regs=40 stack=0 before 12: (63) *(u32 *)(r10 -4) = r0 regs=40 stack=0 before 11: (b7) r0 = 0 regs=40 stack=0 before 10: (b7) r6 = 0 25: (79) r3 = *(u64 *)(r0 +0) ; R0_w=map_value(off=0,ks=4,vs=48,imm=0) R3_w=scalar() 26: (7b) *(u64 *)(r1 +0) = r3 ; R1_w=map_value(off=0,ks=4,vs=48,imm=0) R3_w=scalar() 27: (95) exit from 9 to 11: R1=ctx(off=0,imm=0) R6=0 R7=0 R8=0 R9=0 R10=fp0 11: (b7) r0 = 0 ; R0_w=0 12: (63) *(u32 *)(r10 -4) = r0 last_idx 12 first_idx 11 regs=1 stack=0 before 11: (b7) r0 = 0 13: R0_w=0 R10=fp0 fp-8=0000???? 13: (18) r4 = 0xffff9290dc5bfe00 ; R4_w=map_ptr(off=0,ks=4,vs=48,imm=0) 15: (bf) r1 = r4 ; R1_w=map_ptr(off=0,ks=4,vs=48,imm=0) R4_w=map_ptr(off=0,ks=4,vs=48,imm=0) 16: (bf) r2 = r10 ; R2_w=fp0 R10=fp0 17: (07) r2 += -4 ; R2_w=fp-4 18: (85) call bpf_map_lookup_elem#1 frame 0: propagating r6 last_idx 19 first_idx 11 regs=40 stack=0 before 18: (85) call bpf_map_lookup_elem#1 regs=40 stack=0 before 17: (07) r2 += -4 regs=40 stack=0 before 16: (bf) r2 = r10 regs=40 stack=0 before 15: (bf) r1 = r4 regs=40 stack=0 before 13: (18) r4 = 0xffff9290dc5bfe00 regs=40 stack=0 before 12: (63) *(u32 *)(r10 -4) = r0 regs=40 stack=0 before 11: (b7) r0 = 0 parent didn't have regs=40 stack=0 marks: R1=ctx(off=0,imm=0) R6_r=P0 R7=0 R8=0 R9=0 R10=fp0 last_idx 9 first_idx 9 regs=40 stack=0 before 9: (bd) if r6 <= r9 goto pc+1 parent didn't have regs=240 stack=0 marks: R1=ctx(off=0,imm=0) R6_rw=Pscalar() R7_w=0 R8_w=0 R9_rw=P0 R10=fp0 last_idx 8 first_idx 0 regs=240 stack=0 before 8: (b7) r9 = 0 regs=40 stack=0 before 7: (97) r6 %= 1 regs=40 stack=0 before 6: (bd) if r6 <= r9 goto pc+2 regs=240 stack=0 before 5: (05) goto pc+0 regs=240 stack=0 before 4: (97) r6 %= 1025 regs=240 stack=0 before 3: (b7) r9 = -2147483648 regs=40 stack=0 before 2: (b7) r8 = 0 regs=40 stack=0 before 1: (b7) r7 = 0 regs=40 stack=0 before 0: (b7) r6 = 1024 19: safe from 6 to 9: R1=ctx(off=0,imm=0) R6_w=scalar(umax=18446744071562067968) R7_w=0 R8_w=0 R9_w=-2147483648 R10=fp0 9: (bd) if r6 <= r9 goto pc+1 last_idx 9 first_idx 0 regs=40 stack=0 before 6: (bd) if r6 <= r9 goto pc+2 regs=240 stack=0 before 5: (05) goto pc+0 regs=240 stack=0 before 4: (97) r6 %= 1025 regs=240 stack=0 before 3: (b7) r9 = -2147483648 regs=40 stack=0 before 2: (b7) r8 = 0 regs=40 stack=0 before 1: (b7) r7 = 0 regs=40 stack=0 before 0: (b7) r6 = 1024 last_idx 9 first_idx 0 regs=200 stack=0 before 6: (bd) if r6 <= r9 goto pc+2 regs=240 stack=0 before 5: (05) goto pc+0 regs=240 stack=0 before 4: (97) r6 %= 1025 regs=240 stack=0 before 3: (b7) r9 = -2147483648 regs=40 stack=0 before 2: (b7) r8 = 0 regs=40 stack=0 before 1: (b7) r7 = 0 regs=40 stack=0 before 0: (b7) r6 = 1024 11: R6=scalar(umax=18446744071562067968) R9=-2147483648 11: (b7) r0 = 0 ; R0_w=0 12: (63) *(u32 *)(r10 -4) = r0 last_idx 12 first_idx 11 regs=1 stack=0 before 11: (b7) r0 = 0 13: R0_w=0 R10=fp0 fp-8=0000???? 13: (18) r4 = 0xffff9290dc5bfe00 ; R4_w=map_ptr(off=0,ks=4,vs=48,imm=0) 15: (bf) r1 = r4 ; R1_w=map_ptr(off=0,ks=4,vs=48,imm=0) R4_w=map_ptr(off=0,ks=4,vs=48,imm=0) 16: (bf) r2 = r10 ; R2_w=fp0 R10=fp0 17: (07) r2 += -4 ; R2_w=fp-4 18: (85) call bpf_map_lookup_elem#1 ; R0_w=map_value_or_null(id=3,off=0,ks=4,vs=48,imm=0) 19: (55) if r0 != 0x0 goto pc+1 ; R0_w=0 20: (95) exit from 19 to 21: R0=map_value(off=0,ks=4,vs=48,imm=0) R6=scalar(umax=18446744071562067968) R7=0 R8=0 R9=-2147483648 R10=fp0 fp-8=mmmm???? 21: (77) r6 >>= 10 ; R6_w=scalar(umax=18014398507384832,var_off=(0x0; 0x3fffffffffffff)) 22: (27) r6 *= 8192 ; R6_w=scalar(smax=9223372036854767616,umax=18446744073709543424,var_off=(0x0; 0xffffffffffffe000),s32_max=2147475456,u32_max=-8192) 23: (bf) r1 = r0 ; R0=map_value(off=0,ks=4,vs=48,imm=0) R1_w=map_value(off=0,ks=4,vs=48,imm=0) 24: (0f) r0 += r6 last_idx 24 first_idx 21 regs=40 stack=0 before 23: (bf) r1 = r0 regs=40 stack=0 before 22: (27) r6 *= 8192 regs=40 stack=0 before 21: (77) r6 >>= 10 parent didn't have regs=40 stack=0 marks: R0_rw=map_value(off=0,ks=4,vs=48,imm=0) R6_r=Pscalar(umax=18446744071562067968) R7=0 R8=0 R9=-2147483648 R10=fp0 fp-8=mmmm???? last_idx 19 first_idx 11 regs=40 stack=0 before 19: (55) if r0 != 0x0 goto pc+1 regs=40 stack=0 before 18: (85) call bpf_map_lookup_elem#1 regs=40 stack=0 before 17: (07) r2 += -4 regs=40 stack=0 before 16: (bf) r2 = r10 regs=40 stack=0 before 15: (bf) r1 = r4 regs=40 stack=0 before 13: (18) r4 = 0xffff9290dc5bfe00 regs=40 stack=0 before 12: (63) *(u32 *)(r10 -4) = r0 regs=40 stack=0 before 11: (b7) r0 = 0 parent didn't have regs=40 stack=0 marks: R1=ctx(off=0,imm=0) R6_rw=Pscalar(umax=18446744071562067968) R7_w=0 R8_w=0 R9_w=-2147483648 R10=fp0 last_idx 9 first_idx 0 regs=40 stack=0 before 9: (bd) if r6 <= r9 goto pc+1 regs=240 stack=0 before 6: (bd) if r6 <= r9 goto pc+2 regs=240 stack=0 before 5: (05) goto pc+0 regs=240 stack=0 before 4: (97) r6 %= 1025 regs=240 stack=0 before 3: (b7) r9 = -2147483648 regs=40 stack=0 before 2: (b7) r8 = 0 regs=40 stack=0 before 1: (b7) r7 = 0 regs=40 stack=0 before 0: (b7) r6 = 1024 math between map_value pointer and register with unbounded min value is not allowed verification time 886 usec stack depth 4 processed 49 insns (limit 1000000) max_states_per_insn 1 total_states 5 peak_states 5 mark_read 2 Fixes: b5dc0163 ("bpf: precise scalar_value tracking") Reported-by: Juan Jose Lopez Jaimez <jjlopezjaimez@google.com> Reported-by: Meador Inge <meadori@google.com> Reported-by: Simon Scannell <simonscannell@google.com> Reported-by: Nenad Stojanovski <thenenadx@google.com> Signed-off-by: Daniel Borkmann <daniel@iogearbox.net> Co-developed-by: Andrii Nakryiko <andrii@kernel.org> Signed-off-by: Andrii Nakryiko <andrii@kernel.org> Reviewed-by: John Fastabend <john.fastabend@gmail.com> Reviewed-by: Juan Jose Lopez Jaimez <jjlopezjaimez@google.com> Reviewed-by: Meador Inge <meadori@google.com> Reviewed-by: Simon Scannell <simonscannell@google.com>
Showing
Please register or sign in to comment