Commit 80679473 authored by 4ast's avatar 4ast Committed by GitHub

Merge pull request #1097 from vavrusa/master

Updates to Lua/BPF code generator
parents f762df56 67bb1502
......@@ -137,6 +137,7 @@ Below is a list of BPF-specific helpers:
* `comm(var)` - write current process name (uses `bpf_get_current_comm`)
* `perf_submit(map, var)` - submit variable to perf event array BPF map
* `stack_id(map, flags)` - return stack trace identifier from stack trace BPF map
* `load_bytes(off, var)` - helper for direct packet access with `skb_load_bytes()`
### Current state
......
......@@ -296,53 +296,6 @@ local function bb_end(Vcomp)
end
end
local function LD_ABS(dst, off, w)
local dst_reg = vreg(dst, 0, true, builtins.width_type(w)) -- Reserve R0
-- assert(w < 8, 'NYI: LD_ABS64 is not supported') -- IMM64 has two IMM32 insns fused together
emit(BPF.LD + BPF.ABS + const_width[w], dst_reg, 0, 0, off)
end
local function LD_IND(dst, src, w, off)
local src_reg = vreg(src) -- Must materialize first in case dst == src
local dst_reg = vreg(dst, 0, true, builtins.width_type(w)) -- Reserve R0
emit(BPF.LD + BPF.IND + const_width[w], dst_reg, src_reg, 0, off or 0)
end
local function LD_FIELD(a, d, w, imm)
if imm then
LD_ABS(a, imm, w)
else
LD_IND(a, d, w)
end
end
-- @note: This is specific now as it expects registers reserved
local function LD_IMM_X(dst_reg, src_type, imm, w)
if w == 8 then -- IMM64 must be done in two instructions with imm64 = (lo(imm32), hi(imm32))
emit(BPF.LD + const_width[w], dst_reg, src_type, 0, ffi.cast('uint32_t', imm))
-- Must shift in two steps as bit.lshift supports [0..31]
emit(0, 0, 0, 0, ffi.cast('uint32_t', bit.lshift(bit.lshift(imm, 16), 16)))
else
emit(BPF.LD + const_width[w], dst_reg, src_type, 0, imm)
end
end
local function LOAD(dst, src, off, vtype)
local base = V[src].const
assert(base.__dissector, 'NYI: load() on variable that doesnt have dissector')
-- Cast to different type if requested
vtype = vtype or base.__dissector
local w = ffi.sizeof(vtype)
assert(w <= 4, 'NYI: load() supports 1/2/4 bytes at a time only')
if base.off then -- Absolute address to payload
LD_ABS(dst, off + base.off, w)
else -- Indirect address to payload
LD_IND(dst, src, w, off)
end
V[dst].type = vtype
V[dst].const = nil -- Dissected value is not constant anymore
end
local function CMP_STR(a, b, op)
assert(op == 'JEQ' or op == 'JNE', 'NYI: only equivallence stack/string only supports == or ~=')
-- I have no better idea how to implement it than unrolled XOR loop, as we can fixup only one JMP
......@@ -463,7 +416,6 @@ local function ALU_REG(dst, a, b, op)
end
end
local function ALU_IMM_NV(dst, a, b, op)
-- Do DST = IMM(a) op VAR(b) where we can't invert because
-- the registers are u64 but immediates are u32, so complement
......@@ -472,6 +424,71 @@ local function ALU_IMM_NV(dst, a, b, op)
ALU_REG(dst, stackslots+1, b, op)
end
local function LD_ABS(dst, off, w)
if w < 8 then
local dst_reg = vreg(dst, 0, true, builtins.width_type(w)) -- Reserve R0
emit(BPF.LD + BPF.ABS + const_width[w], dst_reg, 0, 0, off)
elseif w == 8 then
-- LD_ABS|IND prohibits DW, we need to do two W loads and combine them
local tmp_reg = vreg(stackslots, 0, true, builtins.width_type(w)) -- Reserve R0
emit(BPF.LD + BPF.ABS + const_width[4], tmp_reg, 0, 0, off + 4)
if ffi.abi('le') then -- LD_ABS has htonl() semantics, reverse
emit(BPF.ALU + BPF.END + BPF.TO_BE, tmp_reg, 0, 0, 32)
end
ALU_IMM(stackslots, stackslots, 32, 'LSH')
local dst_reg = vreg(dst, 0, true, builtins.width_type(w)) -- Reserve R0, spill tmp variable
emit(BPF.LD + BPF.ABS + const_width[4], dst_reg, 0, 0, off)
if ffi.abi('le') then -- LD_ABS has htonl() semantics, reverse
emit(BPF.ALU + BPF.END + BPF.TO_BE, dst_reg, 0, 0, 32)
end
ALU_REG(dst, dst, stackslots, 'OR')
V[stackslots].reg = nil -- Free temporary registers
else
assert(w < 8, 'NYI: only LD_ABS of 1/2/4/8 is supported')
end
end
local function LD_IND(dst, src, w, off)
local src_reg = vreg(src) -- Must materialize first in case dst == src
local dst_reg = vreg(dst, 0, true, builtins.width_type(w)) -- Reserve R0
emit(BPF.LD + BPF.IND + const_width[w], dst_reg, src_reg, 0, off or 0)
end
local function LD_FIELD(a, d, w, imm)
if imm then
LD_ABS(a, imm, w)
else
LD_IND(a, d, w)
end
end
-- @note: This is specific now as it expects registers reserved
local function LD_IMM_X(dst_reg, src_type, imm, w)
if w == 8 then -- IMM64 must be done in two instructions with imm64 = (lo(imm32), hi(imm32))
emit(BPF.LD + const_width[w], dst_reg, src_type, 0, ffi.cast('uint32_t', imm))
-- Must shift in two steps as bit.lshift supports [0..31]
emit(0, 0, 0, 0, ffi.cast('uint32_t', bit.lshift(bit.lshift(imm, 16), 16)))
else
emit(BPF.LD + const_width[w], dst_reg, src_type, 0, imm)
end
end
local function LOAD(dst, src, off, vtype)
local base = V[src].const
assert(base.__dissector, 'NYI: load() on variable that doesnt have dissector')
-- Cast to different type if requested
vtype = vtype or base.__dissector
local w = ffi.sizeof(vtype)
assert(w <= 4, 'NYI: load() supports 1/2/4 bytes at a time only')
if base.off then -- Absolute address to payload
LD_ABS(dst, off + base.off, w)
else -- Indirect address to payload
LD_IND(dst, src, w, off)
end
V[dst].type = vtype
V[dst].const = nil -- Dissected value is not constant anymore
end
local function BUILTIN(func, ...)
local builtin_export = {
-- Compiler primitives (work with variable slots, emit instructions)
......@@ -598,6 +615,7 @@ local function MAP_SET(map_var, key, key_imm, src)
reg_alloc(stackslots, 4) -- Spill anything in R4 (unnamed tmp variable)
emit(BPF.ALU64 + BPF.MOV + BPF.K, 4, 0, 0, 0) -- BPF_ANY, create new element or update existing
-- Reserve R3 for value pointer
reg_alloc(stackslots, 3) -- Spill anything in R3 (unnamed tmp variable)
local val_size = ffi.sizeof(map.val_type)
local w = const_width[val_size] or BPF.DW
local pod_type = const_width[val_size]
......@@ -610,14 +628,25 @@ local function MAP_SET(map_var, key, key_imm, src)
emit(BPF.MEM + BPF.ST + w, 10, 0, -sp, base)
-- Value is in register, spill it
elseif V[src].reg and pod_type then
-- Value is a pointer, derefernce it and spill it
if cdef.isptr(V[src].type) then
vderef(3, V[src].reg, V[src].const.__dissector)
emit(BPF.MEM + BPF.STX + w, 10, 3, -sp, 0)
else
emit(BPF.MEM + BPF.STX + w, 10, V[src].reg, -sp, 0)
end
-- We get a pointer to spilled register on stack
elseif V[src].spill then
-- If variable is a pointer, we can load it to R3 directly (save "LEA")
if cdef.isptr(V[src].type) then
reg_fill(src, 3)
-- If variable is a stack pointer, we don't have to check it
if base.__base then
emit(BPF.JMP + BPF.CALL, 0, 0, 0, HELPER.map_update_elem)
return
end
vderef(3, V[src].reg, V[src].const.__dissector)
emit(BPF.MEM + BPF.STX + w, 10, 3, -sp, 0)
else
sp = V[src].spill
end
......@@ -627,9 +656,8 @@ local function MAP_SET(map_var, key, key_imm, src)
sp = base.__base
-- Value is constant, materialize it on stack
else
error('VAR '..key..' is neither const-expr/register/stack/spilled')
error('VAR '.. key or key_imm ..' is neither const-expr/register/stack/spilled')
end
reg_alloc(stackslots, 3) -- Spill anything in R3 (unnamed tmp variable)
emit(BPF.ALU64 + BPF.MOV + BPF.X, 3, 10, 0, 0)
emit(BPF.ALU64 + BPF.ADD + BPF.K, 3, 0, 0, -sp)
emit(BPF.JMP + BPF.CALL, 0, 0, 0, HELPER.map_update_elem)
......@@ -714,6 +742,12 @@ local BC = {
local base = V[b].const
if base.__map then -- BPF map read (constant)
MAP_GET(a, b, nil, d)
-- Specialise PTR[0] as dereference operator
elseif cdef.isptr(V[b].type) and d == 0 then
vcopy(a, b)
local dst_reg = vreg(a)
vderef(dst_reg, dst_reg, V[a].const.__dissector)
V[a].type = V[a].const.__dissector
else
LOAD(a, b, d, ffi.typeof('uint8_t'))
end
......@@ -845,7 +879,7 @@ local BC = {
CALL = function (a, b, _, d) -- A = A(A+1, ..., A+D-1)
CALL(a, b, d)
end,
JMP = function (a, _, c, d) -- JMP
JMP = function (a, _, c, _) -- JMP
-- Discard unused slots after jump
for i, _ in pairs(V) do
if i >= a then V[i] = {} end
......@@ -895,6 +929,10 @@ local BC = {
end,
RET1 = function (a, _, _, _) -- RET1
if V[a].reg ~= 0 then vreg(a, 0) end
-- Dereference pointer variables
if cdef.isptr(V[a].type) then
vderef(0, 0, V[a].const.__dissector)
end
emit(BPF.JMP + BPF.EXIT, 0, 0, 0, 0)
-- Free optimisation: spilled variable will not be filled again
for _,v in pairs(V) do if v.reg == 0 then v.reg = nil end end
......@@ -920,7 +958,7 @@ end
vset(stackslots)
vset(stackslots+1)
return setmetatable(BC, {
__index = function (t, k, v)
__index = function (_, k, _)
if type(k) == 'number' then
local op_str = string.sub(require('jit.vmdef').bcnames, 6*k+1, 6*k+6)
error(string.format("NYI: opcode '0x%02x' (%-04s)", k, op_str))
......@@ -954,7 +992,9 @@ return setmetatable(BC, {
end
-- Emitted code dump
local function dump_mem(cls, ins)
local function dump_mem(cls, ins, _, fuse)
-- This is a very dense MEM instruction decoder without much explanation
-- Refer to https://www.kernel.org/doc/Documentation/networking/filter.txt for instruction format
local mode = bit.band(ins.code, 0xe0)
if mode == BPF.XADD then cls = 5 end -- The only mode
local op_1 = {'LD', 'LDX', 'ST', 'STX', '', 'XADD'}
......@@ -966,7 +1006,7 @@ local function dump_mem(cls, ins)
if cls == BPF.LDX then src = string.format('[R%d%+d]', ins.src_reg, off) end
if mode == BPF.ABS then src = string.format('[%d]', ins.imm) end
if mode == BPF.IND then src = string.format('[R%d%+d]', ins.src_reg, ins.imm) end
return string.format('%s\t%s\t%s', name, dst, src)
return string.format('%s\t%s\t%s', fuse and '' or name, fuse and '' or dst, src)
end
local function dump_alu(cls, ins, pc)
......@@ -979,6 +1019,8 @@ local function dump_alu(cls, ins, pc)
'skb_get_tunnel_key', 'skb_set_tunnel_key', 'perf_event_read', 'redirect', 'get_route_realm',
'perf_event_output', 'skb_load_bytes'}
local op = 0
-- This is a very dense ALU instruction decoder without much explanation
-- Refer to https://www.kernel.org/doc/Documentation/networking/filter.txt for instruction format
for i = 0,13 do if 0x10 * i == bit.band(ins.code, 0xf0) then op = i + 1 break end end
local name = (cls == 5) and jmp[op] or alu[op]
local src = (bit.band(ins.code, 0x08) == BPF.X) and 'R'..ins.src_reg or '#'..ins.imm
......@@ -994,16 +1036,19 @@ local function dump(code)
[0] = dump_mem, [1] = dump_mem, [2] = dump_mem, [3] = dump_mem,
[4] = dump_alu, [5] = dump_alu, [7] = dump_alu,
}
local fused = false
for i = 0, code.pc - 1 do
local ins = code.insn[i]
local cls = bit.band(ins.code, 0x07)
print(string.format('%04u\t%s', i, cls_map[cls](cls, ins, i)))
local line = cls_map[cls](cls, ins, i, fused)
print(string.format('%04u\t%s', i, line))
fused = string.find(line, 'LDDW', 1)
end
end
local function compile(prog, params)
-- Create code emitter sandbox, include caller locals
local env = { pkt=proto.pkt, BPF=BPF }
local env = { pkt=proto.pkt, BPF=BPF, ffi=ffi }
-- Include upvalues up to 4 nested scopes back
-- the narrower scope overrides broader scope
for k = 5, 2, -1 do
......@@ -1064,7 +1109,7 @@ local bpf_map_mt = {
cur_key = ffi.new(ffi.typeof(t.key))
end
local ok, err = S.bpf_map_op(S.c.BPF_CMD.MAP_GET_NEXT_KEY, map.fd, cur_key, next_key)
if not ok then return nil end
if not ok then return nil, err end
-- Get next value
assert(S.bpf_map_op(S.c.BPF_CMD.MAP_LOOKUP_ELEM, map.fd, next_key, map.val))
return next_key[0], map.val[0]
......@@ -1278,5 +1323,5 @@ return setmetatable({
end,
ntoh = builtins.ntoh, hton = builtins.hton,
}, {
__call = function (t, prog) return compile(prog) end,
__call = function (_, prog) return compile(prog) end,
})
......@@ -80,7 +80,7 @@ if ffi.abi('be') then
builtins[hton] = function(a, b, w) return end
end
-- Other built-ins
local function xadd(a, b) error('NYI') end
local function xadd() error('NYI') end
builtins.xadd = xadd
builtins[xadd] = function (e, dst, a, b, off)
assert(e.V[a].const.__dissector, 'xadd(a, b) called on non-pointer')
......@@ -120,7 +120,7 @@ builtins[probe_read] = function (e, ret, dst, src, vtype, ofs)
e.reg_alloc(e.tmpvar, 3) -- Copy from original register
e.emit(BPF.ALU64 + BPF.MOV + BPF.X, 3, e.V[src].reg, 0, 0)
else
local src_reg = e.vreg(src, 3)
e.vreg(src, 3)
e.reg_spill(src) -- Spill to avoid overwriting
end
if ofs and ofs > 0 then
......@@ -133,6 +133,7 @@ builtins[probe_read] = function (e, ret, dst, src, vtype, ofs)
e.emit(BPF.JMP + BPF.CALL, 0, 0, 0, HELPER.probe_read)
e.V[e.tmpvar].reg = nil -- Free temporary registers
end
builtins[ffi.cast] = function (e, dst, ct, x)
assert(e.V[ct].const, 'ffi.cast(ctype, x) called with bad ctype')
e.vcopy(dst, x)
......@@ -151,6 +152,7 @@ builtins[ffi.cast] = function (e, dst, ct, x)
e.V[dst].source = 'probe'
end
end
builtins[ffi.new] = function (e, dst, ct, x)
if type(ct) == 'number' then
ct = ffi.typeof(e.V[ct].const) -- Get ctype from variable
......@@ -160,7 +162,8 @@ builtins[ffi.new] = function (e, dst, ct, x)
e.vset(dst, nil, ct)
e.V[dst].const = {__base = e.valloc(ffi.sizeof(ct), true), __dissector = ct}
end
builtins[ffi.copy] = function (e,ret, dst, src)
builtins[ffi.copy] = function (e, ret, dst, src)
assert(cdef.isptr(e.V[dst].type), 'ffi.copy(dst, src) - dst MUST be a pointer type')
assert(cdef.isptr(e.V[src].type), 'ffi.copy(dst, src) - src MUST be a pointer type')
-- Specific types also encode source of the data
......@@ -186,7 +189,7 @@ builtins[ffi.copy] = function (e,ret, dst, src)
e.reg_alloc(e.tmpvar, 3) -- Copy from original register
e.emit(BPF.ALU64 + BPF.MOV + BPF.X, 3, e.V[src].reg, 0, 0)
else
local src_reg = e.vreg(src, 3)
e.vreg(src, 3)
e.reg_spill(src) -- Spill to avoid overwriting
end
-- Call probe read helper
......@@ -265,6 +268,28 @@ local function perf_submit(e, dst, map_var, src)
e.V[e.tmpvar].reg = nil -- Free temporary registers
end
-- Implements bpf_skb_load_bytes(ctx, off, var, vlen) on skb->data
local function load_bytes(e, dst, off, var)
print(e.V[off].const, e.V[var].const)
-- Set R2 = offset
e.vset(e.tmpvar, nil, off)
e.vreg(e.tmpvar, 2, false, ffi.typeof('uint64_t'))
-- Set R1 = ctx
e.reg_alloc(e.tmpvar, 1) -- Spill anything in R1 (unnamed tmp variable)
e.emit(BPF.ALU64 + BPF.MOV + BPF.X, 1, 6, 0, 0) -- CTX is always in R6, copy
-- Set R3 = pointer to var on stack
assert(e.V[var].const.__base, 'NYI: load_bytes(off, var, len) - variable is not on stack')
e.emit(BPF.ALU64 + BPF.MOV + BPF.X, 3, 10, 0, 0)
e.emit(BPF.ALU64 + BPF.ADD + BPF.K, 3, 0, 0, -e.V[var].const.__base)
-- Set R4 = var length
e.emit(BPF.ALU64 + BPF.MOV + BPF.K, 4, 0, 0, ffi.sizeof(e.V[var].type))
-- Set R0 = ret and call
e.vset(dst)
e.vreg(dst, 0, true, ffi.typeof('int32_t')) -- Return is integer
e.emit(BPF.JMP + BPF.CALL, 0, 0, 0, HELPER.skb_load_bytes)
e.V[e.tmpvar].reg = nil -- Free temporary registers
end
-- Implements bpf_get_stack_id()
local function stack_id(e, ret, map_var, key)
-- Set R2 = map fd (indirect load)
......@@ -313,7 +338,7 @@ builtins[comm] = function (e, ret, dst)
end
-- Math library built-ins
math.log2 = function (x) error('NYI') end
math.log2 = function () error('NYI') end
builtins[math.log2] = function (e, dst, x)
-- Classic integer bits subdivison algorithm to find the position
-- of the highest bit set, adapted for BPF bytecode-friendly operations.
......@@ -363,7 +388,7 @@ end
-- Call-type helpers
local function call_helper(e, dst, h)
e.vset(dst)
local dst_reg = e.vreg(dst, 0, true)
e.vreg(dst, 0, true)
e.emit(BPF.JMP + BPF.CALL, 0, 0, 0, h)
e.V[dst].const = nil -- Target is not a function anymore
end
......@@ -381,6 +406,7 @@ builtins.uid_gid = uid_gid
builtins.comm = comm
builtins.perf_submit = perf_submit
builtins.stack_id = stack_id
builtins.load_bytes = load_bytes
builtins[cpu] = function (e, dst) return call_helper(e, dst, HELPER.get_smp_processor_id) end
builtins[rand] = function (e, dst) return call_helper(e, dst, HELPER.get_prandom_u32) end
builtins[time] = function (e, dst) return call_helper(e, dst, HELPER.ktime_get_ns) end
......@@ -388,5 +414,6 @@ builtins[pid_tgid] = function (e, dst) return call_helper(e, dst, HELPER.get_cur
builtins[uid_gid] = function (e, dst) return call_helper(e, dst, HELPER.get_current_uid_gid) end
builtins[perf_submit] = function (e, dst, map, value) return perf_submit(e, dst, map, value) end
builtins[stack_id] = function (e, dst, map, key) return stack_id(e, dst, map, key) end
builtins[load_bytes] = function (e, dst, off, var, len) return load_bytes(e, dst, off, var, len) end
return builtins
......@@ -78,6 +78,9 @@ struct bpf {
static const int F_USER_STACK = 1 << 8;
static const int F_FAST_STACK_CMP = 1 << 9;
static const int F_REUSE_STACKID = 1 << 10;
/* special offsets for ancillary data */
static const int NET_OFF = -0x100000;
static const int LL_OFF = -0x200000;
};
/* eBPF commands */
struct bpf_cmd {
......
......@@ -35,6 +35,10 @@ struct sk_buff {
uint32_t tc_classid;
};
struct net_off_t {
uint8_t ver:4;
} __attribute__((packed));
struct eth_t {
uint8_t dst[6];
uint8_t src[6];
......@@ -275,12 +279,23 @@ M.udp = function (...) return dissector(ffi.typeof('struct udp_t'), ...) end
M.tcp = function (...) return dissector(ffi.typeof('struct tcp_t'), ...) end
M.vxlan = function (...) return dissector(ffi.typeof('struct vxlan_t'), ...) end
M.data = function (...) return dissector(ffi.typeof('uint8_t'), ...) end
M.net_off = function (...) return dissector(ffi.typeof('struct net_off_t'), ...) end
-- Metatables
ffi.metatype(ffi.typeof('struct eth_t'), {
__index = {
ip = skip_eth,
ip6 = skip_eth,
net_off = function (e, dst)
next_skip(e, dst, BPF.NET_OFF)
end,
}
})
ffi.metatype(ffi.typeof('struct net_off_t'), {
__index = {
ip = function () end,
ip6 = function () end,
}
})
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment