Commit 68a8357e authored by Luke Nelson's avatar Luke Nelson Committed by Daniel Borkmann

bpf, x32: Fix bug with ALU64 {LSH, RSH, ARSH} BPF_X shift by 0

The current x32 BPF JIT for shift operations is not correct when the
shift amount in a register is 0. The expected behavior is a no-op, whereas
the current implementation changes bits in the destination register.

The following example demonstrates the bug. The expected result of this
program is 1, but the current JITed code returns 2.

  r0 = 1
  r1 = 1
  r2 = 0
  r1 <<= r2
  if r1 == 1 goto end
  r0 = 2
end:
  exit

The bug is caused by an incorrect assumption by the JIT that a shift by
32 clear the register. On x32 however, shifts use the lower 5 bits of
the source, making a shift by 32 equivalent to a shift by 0.

This patch fixes the bug using double-precision shifts, which also
simplifies the code.

Fixes: 03f5781b ("bpf, x86_32: add eBPF JIT compiler for ia32")
Co-developed-by: default avatarXi Wang <xi.wang@gmail.com>
Signed-off-by: default avatarXi Wang <xi.wang@gmail.com>
Signed-off-by: default avatarLuke Nelson <luke.r.nels@gmail.com>
Signed-off-by: default avatarDaniel Borkmann <daniel@iogearbox.net>
parent 0472301a
...@@ -724,9 +724,6 @@ static inline void emit_ia32_lsh_r64(const u8 dst[], const u8 src[], ...@@ -724,9 +724,6 @@ static inline void emit_ia32_lsh_r64(const u8 dst[], const u8 src[],
{ {
u8 *prog = *pprog; u8 *prog = *pprog;
int cnt = 0; int cnt = 0;
static int jmp_label1 = -1;
static int jmp_label2 = -1;
static int jmp_label3 = -1;
u8 dreg_lo = dstk ? IA32_EAX : dst_lo; u8 dreg_lo = dstk ? IA32_EAX : dst_lo;
u8 dreg_hi = dstk ? IA32_EDX : dst_hi; u8 dreg_hi = dstk ? IA32_EDX : dst_hi;
...@@ -745,79 +742,23 @@ static inline void emit_ia32_lsh_r64(const u8 dst[], const u8 src[], ...@@ -745,79 +742,23 @@ static inline void emit_ia32_lsh_r64(const u8 dst[], const u8 src[],
/* mov ecx,src_lo */ /* mov ecx,src_lo */
EMIT2(0x8B, add_2reg(0xC0, src_lo, IA32_ECX)); EMIT2(0x8B, add_2reg(0xC0, src_lo, IA32_ECX));
/* cmp ecx,32 */ /* shld dreg_hi,dreg_lo,cl */
EMIT3(0x83, add_1reg(0xF8, IA32_ECX), 32); EMIT3(0x0F, 0xA5, add_2reg(0xC0, dreg_hi, dreg_lo));
/* Jumps when >= 32 */
if (is_imm8(jmp_label(jmp_label1, 2)))
EMIT2(IA32_JAE, jmp_label(jmp_label1, 2));
else
EMIT2_off32(0x0F, IA32_JAE + 0x10, jmp_label(jmp_label1, 6));
/* < 32 */
/* shl dreg_hi,cl */
EMIT2(0xD3, add_1reg(0xE0, dreg_hi));
/* mov ebx,dreg_lo */
EMIT2(0x8B, add_2reg(0xC0, dreg_lo, IA32_EBX));
/* shl dreg_lo,cl */ /* shl dreg_lo,cl */
EMIT2(0xD3, add_1reg(0xE0, dreg_lo)); EMIT2(0xD3, add_1reg(0xE0, dreg_lo));
/* IA32_ECX = -IA32_ECX + 32 */ /* if ecx >= 32, mov dreg_lo into dreg_hi and clear dreg_lo */
/* neg ecx */
EMIT2(0xF7, add_1reg(0xD8, IA32_ECX));
/* add ecx,32 */
EMIT3(0x83, add_1reg(0xC0, IA32_ECX), 32);
/* shr ebx,cl */ /* cmp ecx,32 */
EMIT2(0xD3, add_1reg(0xE8, IA32_EBX)); EMIT3(0x83, add_1reg(0xF8, IA32_ECX), 32);
/* or dreg_hi,ebx */ /* skip the next two instructions (4 bytes) when < 32 */
EMIT2(0x09, add_2reg(0xC0, dreg_hi, IA32_EBX)); EMIT2(IA32_JB, 4);
/* goto out; */
if (is_imm8(jmp_label(jmp_label3, 2)))
EMIT2(0xEB, jmp_label(jmp_label3, 2));
else
EMIT1_off32(0xE9, jmp_label(jmp_label3, 5));
/* >= 32 */
if (jmp_label1 == -1)
jmp_label1 = cnt;
/* cmp ecx,64 */
EMIT3(0x83, add_1reg(0xF8, IA32_ECX), 64);
/* Jumps when >= 64 */
if (is_imm8(jmp_label(jmp_label2, 2)))
EMIT2(IA32_JAE, jmp_label(jmp_label2, 2));
else
EMIT2_off32(0x0F, IA32_JAE + 0x10, jmp_label(jmp_label2, 6));
/* >= 32 && < 64 */
/* sub ecx,32 */
EMIT3(0x83, add_1reg(0xE8, IA32_ECX), 32);
/* shl dreg_lo,cl */
EMIT2(0xD3, add_1reg(0xE0, dreg_lo));
/* mov dreg_hi,dreg_lo */ /* mov dreg_hi,dreg_lo */
EMIT2(0x89, add_2reg(0xC0, dreg_hi, dreg_lo)); EMIT2(0x89, add_2reg(0xC0, dreg_hi, dreg_lo));
/* xor dreg_lo,dreg_lo */ /* xor dreg_lo,dreg_lo */
EMIT2(0x33, add_2reg(0xC0, dreg_lo, dreg_lo)); EMIT2(0x33, add_2reg(0xC0, dreg_lo, dreg_lo));
/* goto out; */
if (is_imm8(jmp_label(jmp_label3, 2)))
EMIT2(0xEB, jmp_label(jmp_label3, 2));
else
EMIT1_off32(0xE9, jmp_label(jmp_label3, 5));
/* >= 64 */
if (jmp_label2 == -1)
jmp_label2 = cnt;
/* xor dreg_lo,dreg_lo */
EMIT2(0x33, add_2reg(0xC0, dreg_lo, dreg_lo));
/* xor dreg_hi,dreg_hi */
EMIT2(0x33, add_2reg(0xC0, dreg_hi, dreg_hi));
if (jmp_label3 == -1)
jmp_label3 = cnt;
if (dstk) { if (dstk) {
/* mov dword ptr [ebp+off],dreg_lo */ /* mov dword ptr [ebp+off],dreg_lo */
EMIT3(0x89, add_2reg(0x40, IA32_EBP, dreg_lo), EMIT3(0x89, add_2reg(0x40, IA32_EBP, dreg_lo),
...@@ -836,9 +777,6 @@ static inline void emit_ia32_arsh_r64(const u8 dst[], const u8 src[], ...@@ -836,9 +777,6 @@ static inline void emit_ia32_arsh_r64(const u8 dst[], const u8 src[],
{ {
u8 *prog = *pprog; u8 *prog = *pprog;
int cnt = 0; int cnt = 0;
static int jmp_label1 = -1;
static int jmp_label2 = -1;
static int jmp_label3 = -1;
u8 dreg_lo = dstk ? IA32_EAX : dst_lo; u8 dreg_lo = dstk ? IA32_EAX : dst_lo;
u8 dreg_hi = dstk ? IA32_EDX : dst_hi; u8 dreg_hi = dstk ? IA32_EDX : dst_hi;
...@@ -857,79 +795,23 @@ static inline void emit_ia32_arsh_r64(const u8 dst[], const u8 src[], ...@@ -857,79 +795,23 @@ static inline void emit_ia32_arsh_r64(const u8 dst[], const u8 src[],
/* mov ecx,src_lo */ /* mov ecx,src_lo */
EMIT2(0x8B, add_2reg(0xC0, src_lo, IA32_ECX)); EMIT2(0x8B, add_2reg(0xC0, src_lo, IA32_ECX));
/* cmp ecx,32 */ /* shrd dreg_lo,dreg_hi,cl */
EMIT3(0x83, add_1reg(0xF8, IA32_ECX), 32); EMIT3(0x0F, 0xAD, add_2reg(0xC0, dreg_lo, dreg_hi));
/* Jumps when >= 32 */ /* sar dreg_hi,cl */
if (is_imm8(jmp_label(jmp_label1, 2)))
EMIT2(IA32_JAE, jmp_label(jmp_label1, 2));
else
EMIT2_off32(0x0F, IA32_JAE + 0x10, jmp_label(jmp_label1, 6));
/* < 32 */
/* lshr dreg_lo,cl */
EMIT2(0xD3, add_1reg(0xE8, dreg_lo));
/* mov ebx,dreg_hi */
EMIT2(0x8B, add_2reg(0xC0, dreg_hi, IA32_EBX));
/* ashr dreg_hi,cl */
EMIT2(0xD3, add_1reg(0xF8, dreg_hi)); EMIT2(0xD3, add_1reg(0xF8, dreg_hi));
/* IA32_ECX = -IA32_ECX + 32 */ /* if ecx >= 32, mov dreg_hi to dreg_lo and set/clear dreg_hi depending on sign */
/* neg ecx */
EMIT2(0xF7, add_1reg(0xD8, IA32_ECX));
/* add ecx,32 */
EMIT3(0x83, add_1reg(0xC0, IA32_ECX), 32);
/* shl ebx,cl */
EMIT2(0xD3, add_1reg(0xE0, IA32_EBX));
/* or dreg_lo,ebx */
EMIT2(0x09, add_2reg(0xC0, dreg_lo, IA32_EBX));
/* goto out; */
if (is_imm8(jmp_label(jmp_label3, 2)))
EMIT2(0xEB, jmp_label(jmp_label3, 2));
else
EMIT1_off32(0xE9, jmp_label(jmp_label3, 5));
/* >= 32 */ /* cmp ecx,32 */
if (jmp_label1 == -1) EMIT3(0x83, add_1reg(0xF8, IA32_ECX), 32);
jmp_label1 = cnt; /* skip the next two instructions (5 bytes) when < 32 */
EMIT2(IA32_JB, 5);
/* cmp ecx,64 */
EMIT3(0x83, add_1reg(0xF8, IA32_ECX), 64);
/* Jumps when >= 64 */
if (is_imm8(jmp_label(jmp_label2, 2)))
EMIT2(IA32_JAE, jmp_label(jmp_label2, 2));
else
EMIT2_off32(0x0F, IA32_JAE + 0x10, jmp_label(jmp_label2, 6));
/* >= 32 && < 64 */
/* sub ecx,32 */
EMIT3(0x83, add_1reg(0xE8, IA32_ECX), 32);
/* ashr dreg_hi,cl */
EMIT2(0xD3, add_1reg(0xF8, dreg_hi));
/* mov dreg_lo,dreg_hi */ /* mov dreg_lo,dreg_hi */
EMIT2(0x89, add_2reg(0xC0, dreg_lo, dreg_hi)); EMIT2(0x89, add_2reg(0xC0, dreg_lo, dreg_hi));
/* sar dreg_hi,31 */
/* ashr dreg_hi,imm8 */
EMIT3(0xC1, add_1reg(0xF8, dreg_hi), 31); EMIT3(0xC1, add_1reg(0xF8, dreg_hi), 31);
/* goto out; */
if (is_imm8(jmp_label(jmp_label3, 2)))
EMIT2(0xEB, jmp_label(jmp_label3, 2));
else
EMIT1_off32(0xE9, jmp_label(jmp_label3, 5));
/* >= 64 */
if (jmp_label2 == -1)
jmp_label2 = cnt;
/* ashr dreg_hi,imm8 */
EMIT3(0xC1, add_1reg(0xF8, dreg_hi), 31);
/* mov dreg_lo,dreg_hi */
EMIT2(0x89, add_2reg(0xC0, dreg_lo, dreg_hi));
if (jmp_label3 == -1)
jmp_label3 = cnt;
if (dstk) { if (dstk) {
/* mov dword ptr [ebp+off],dreg_lo */ /* mov dword ptr [ebp+off],dreg_lo */
EMIT3(0x89, add_2reg(0x40, IA32_EBP, dreg_lo), EMIT3(0x89, add_2reg(0x40, IA32_EBP, dreg_lo),
...@@ -948,9 +830,6 @@ static inline void emit_ia32_rsh_r64(const u8 dst[], const u8 src[], bool dstk, ...@@ -948,9 +830,6 @@ static inline void emit_ia32_rsh_r64(const u8 dst[], const u8 src[], bool dstk,
{ {
u8 *prog = *pprog; u8 *prog = *pprog;
int cnt = 0; int cnt = 0;
static int jmp_label1 = -1;
static int jmp_label2 = -1;
static int jmp_label3 = -1;
u8 dreg_lo = dstk ? IA32_EAX : dst_lo; u8 dreg_lo = dstk ? IA32_EAX : dst_lo;
u8 dreg_hi = dstk ? IA32_EDX : dst_hi; u8 dreg_hi = dstk ? IA32_EDX : dst_hi;
...@@ -969,77 +848,23 @@ static inline void emit_ia32_rsh_r64(const u8 dst[], const u8 src[], bool dstk, ...@@ -969,77 +848,23 @@ static inline void emit_ia32_rsh_r64(const u8 dst[], const u8 src[], bool dstk,
/* mov ecx,src_lo */ /* mov ecx,src_lo */
EMIT2(0x8B, add_2reg(0xC0, src_lo, IA32_ECX)); EMIT2(0x8B, add_2reg(0xC0, src_lo, IA32_ECX));
/* cmp ecx,32 */ /* shrd dreg_lo,dreg_hi,cl */
EMIT3(0x83, add_1reg(0xF8, IA32_ECX), 32); EMIT3(0x0F, 0xAD, add_2reg(0xC0, dreg_lo, dreg_hi));
/* Jumps when >= 32 */
if (is_imm8(jmp_label(jmp_label1, 2)))
EMIT2(IA32_JAE, jmp_label(jmp_label1, 2));
else
EMIT2_off32(0x0F, IA32_JAE + 0x10, jmp_label(jmp_label1, 6));
/* < 32 */
/* lshr dreg_lo,cl */
EMIT2(0xD3, add_1reg(0xE8, dreg_lo));
/* mov ebx,dreg_hi */
EMIT2(0x8B, add_2reg(0xC0, dreg_hi, IA32_EBX));
/* shr dreg_hi,cl */ /* shr dreg_hi,cl */
EMIT2(0xD3, add_1reg(0xE8, dreg_hi)); EMIT2(0xD3, add_1reg(0xE8, dreg_hi));
/* IA32_ECX = -IA32_ECX + 32 */ /* if ecx >= 32, mov dreg_hi to dreg_lo and clear dreg_hi */
/* neg ecx */
EMIT2(0xF7, add_1reg(0xD8, IA32_ECX));
/* add ecx,32 */
EMIT3(0x83, add_1reg(0xC0, IA32_ECX), 32);
/* shl ebx,cl */ /* cmp ecx,32 */
EMIT2(0xD3, add_1reg(0xE0, IA32_EBX)); EMIT3(0x83, add_1reg(0xF8, IA32_ECX), 32);
/* or dreg_lo,ebx */ /* skip the next two instructions (4 bytes) when < 32 */
EMIT2(0x09, add_2reg(0xC0, dreg_lo, IA32_EBX)); EMIT2(IA32_JB, 4);
/* goto out; */
if (is_imm8(jmp_label(jmp_label3, 2)))
EMIT2(0xEB, jmp_label(jmp_label3, 2));
else
EMIT1_off32(0xE9, jmp_label(jmp_label3, 5));
/* >= 32 */
if (jmp_label1 == -1)
jmp_label1 = cnt;
/* cmp ecx,64 */
EMIT3(0x83, add_1reg(0xF8, IA32_ECX), 64);
/* Jumps when >= 64 */
if (is_imm8(jmp_label(jmp_label2, 2)))
EMIT2(IA32_JAE, jmp_label(jmp_label2, 2));
else
EMIT2_off32(0x0F, IA32_JAE + 0x10, jmp_label(jmp_label2, 6));
/* >= 32 && < 64 */
/* sub ecx,32 */
EMIT3(0x83, add_1reg(0xE8, IA32_ECX), 32);
/* shr dreg_hi,cl */
EMIT2(0xD3, add_1reg(0xE8, dreg_hi));
/* mov dreg_lo,dreg_hi */ /* mov dreg_lo,dreg_hi */
EMIT2(0x89, add_2reg(0xC0, dreg_lo, dreg_hi)); EMIT2(0x89, add_2reg(0xC0, dreg_lo, dreg_hi));
/* xor dreg_hi,dreg_hi */ /* xor dreg_hi,dreg_hi */
EMIT2(0x33, add_2reg(0xC0, dreg_hi, dreg_hi)); EMIT2(0x33, add_2reg(0xC0, dreg_hi, dreg_hi));
/* goto out; */
if (is_imm8(jmp_label(jmp_label3, 2)))
EMIT2(0xEB, jmp_label(jmp_label3, 2));
else
EMIT1_off32(0xE9, jmp_label(jmp_label3, 5));
/* >= 64 */
if (jmp_label2 == -1)
jmp_label2 = cnt;
/* xor dreg_lo,dreg_lo */
EMIT2(0x33, add_2reg(0xC0, dreg_lo, dreg_lo));
/* xor dreg_hi,dreg_hi */
EMIT2(0x33, add_2reg(0xC0, dreg_hi, dreg_hi));
if (jmp_label3 == -1)
jmp_label3 = cnt;
if (dstk) { if (dstk) {
/* mov dword ptr [ebp+off],dreg_lo */ /* mov dword ptr [ebp+off],dreg_lo */
EMIT3(0x89, add_2reg(0x40, IA32_EBP, dreg_lo), EMIT3(0x89, add_2reg(0x40, IA32_EBP, dreg_lo),
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment