Commit a4803040 authored by 4ast's avatar 4ast

Merge pull request #457 from vmg/vmg/lua

Lua Tools for BCC
parents 39cc0ba6 ab324814
#include <uapi/linux/ptrace.h>
struct str_t {
u64 pid;
char str[80];
};
BPF_PERF_OUTPUT(events);
int printret(struct pt_regs *ctx)
{
struct str_t data = {};
u32 pid;
if (!ctx->ax)
return 0;
pid = bpf_get_current_pid_tgid();
data.pid = pid;
bpf_probe_read(&data.str, sizeof(data.str), (void *)ctx->ax);
events.perf_submit(ctx,&data,sizeof(data));
return 0;
};
--[[
Copyright 2016 GitHub, Inc
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
]]
local ffi = require("ffi")
return function(BPF)
local b = BPF:new{src_file="bashreadline.c", debug=0}
b:attach_uprobe{name="/bin/bash", sym="readline", fn_name="printret", retprobe=true}
local function print_readline(cpu, event)
print("%-9s %-6d %s" % {os.date("%H:%M:%S"), tonumber(event.pid), ffi.string(event.str)})
end
b:get_table("events"):open_perf_buffer(print_readline, "struct { uint64_t pid; char str[80]; }")
print("%-9s %-6s %s" % {"TIME", "PID", "COMMAND"})
b:kprobe_poll_loop()
end
--[[
Copyright 2016 GitHub, Inc
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
]]
local bpf_source = [[
#include <uapi/linux/ptrace.h>
struct alloc_info_t {
u64 size;
u64 timestamp_ns;
int stack_id;
};
BPF_HASH(sizes, u64);
BPF_HASH(allocs, u64, struct alloc_info_t);
BPF_STACK_TRACE(stack_traces, 10240)
int alloc_enter(struct pt_regs *ctx, size_t size)
{
SIZE_FILTER
if (SAMPLE_EVERY_N > 1) {
u64 ts = bpf_ktime_get_ns();
if (ts % SAMPLE_EVERY_N != 0)
return 0;
}
u64 pid = bpf_get_current_pid_tgid();
u64 size64 = size;
sizes.update(&pid, &size64);
if (SHOULD_PRINT)
bpf_trace_printk("alloc entered, size = %u\n", size);
return 0;
}
int alloc_exit(struct pt_regs *ctx)
{
u64 address = ctx->ax;
u64 pid = bpf_get_current_pid_tgid();
u64* size64 = sizes.lookup(&pid);
struct alloc_info_t info = {0};
if (size64 == 0)
return 0; // missed alloc entry
info.size = *size64;
sizes.delete(&pid);
info.timestamp_ns = bpf_ktime_get_ns();
info.stack_id = stack_traces.get_stackid(ctx, STACK_FLAGS);
allocs.update(&address, &info);
if (SHOULD_PRINT) {
bpf_trace_printk("alloc exited, size = %lu, result = %lx\n",
info.size, address);
}
return 0;
}
int free_enter(struct pt_regs *ctx, void *address)
{
u64 addr = (u64)address;
struct alloc_info_t *info = allocs.lookup(&addr);
if (info == 0)
return 0;
allocs.delete(&addr);
if (SHOULD_PRINT) {
bpf_trace_printk("free entered, address = %lx, size = %lu\n",
address, info->size);
}
return 0;
}
]]
return function(BPF, utils)
local parser = utils.argparse("memleak", "Catch memory leaks")
parser:flag("-t --trace")
parser:flag("-a --show-allocs")
parser:option("-p --pid"):convert(tonumber)
parser:option("-i --interval", "", 5):convert(tonumber)
parser:option("-o --older", "", 500):convert(tonumber)
parser:option("-s --sample-rate", "", 1):convert(tonumber)
parser:option("-z --min-size", ""):convert(tonumber)
parser:option("-Z --max-size", ""):convert(tonumber)
parser:option("-T --top", "", 10):convert(tonumber)
local args = parser:parse()
local size_filter = ""
if args.min_size and args.max_size then
size_filter = "if (size < %d || size > %d) return 0;" % {args.min_size, args.max_size}
elseif args.min_size then
size_filter = "if (size < %d) return 0;" % args.min_size
elseif args.max_size then
size_filter = "if (size > %d) return 0;" % args.max_size
end
local stack_flags = "BPF_F_REUSE_STACKID"
if args.pid then
stack_flags = stack_flags .. "|BPF_F_USER_STACK"
end
local text = bpf_source
text = text:gsub("SIZE_FILTER", size_filter)
text = text:gsub("STACK_FLAGS", stack_flags)
text = text:gsub("SHOULD_PRINT", args.trace and "1" or "0")
text = text:gsub("SAMPLE_EVERY_N", tostring(args.sample_rate))
local bpf = BPF:new{text=text, debug=0}
local syms = nil
local min_age_ns = args.older * 1e6
if args.pid then
print("Attaching to malloc and free in pid %d, Ctrl+C to quit." % args.pid)
bpf:attach_uprobe{name="c", sym="malloc", fn_name="alloc_enter", pid=args.pid}
bpf:attach_uprobe{name="c", sym="malloc", fn_name="alloc_exit", pid=args.pid, retprobe=true}
bpf:attach_uprobe{name="c", sym="free", fn_name="free_enter", pid=args.pid}
else
print("Attaching to kmalloc and kfree, Ctrl+C to quit.")
bpf:attach_kprobe{event="__kmalloc", fn_name="alloc_enter"}
bpf:attach_kprobe{event="__kmalloc", fn_name="alloc_exit", retprobe=true} -- TODO
bpf:attach_kprobe{event="kfree", fn_name="free_enter"}
end
local syms = args.pid and utils.sym.ProcSymbols:new(args.pid) or utils.sym.KSymbols:new()
local allocs = bpf:get_table("allocs")
local stack_traces = bpf:get_table("stack_traces")
local function resolve(addr)
local sym = syms:lookup(addr)
if args.pid == nil then
sym = sym .. " [kernel]"
end
return string.format("%s (%x)", sym, addr)
end
local function print_outstanding()
local alloc_info = {}
local now = utils.posix.time_ns()
print("[%s] Top %d stacks with outstanding allocations:" %
{os.date("%H:%M:%S"), args.top})
for address, info in allocs:items() do
if now - min_age_ns >= tonumber(info.timestamp_ns) then
local stack_id = tonumber(info.stack_id)
if stack_id >= 0 then
if alloc_info[stack_id] then
local s = alloc_info[stack_id]
s.count = s.count + 1
s.size = s.size + tonumber(info.size)
else
local stack = stack_traces:get(stack_id, resolve)
alloc_info[stack_id] = { stack=stack, count=1, size=tonumber(info.size) }
end
end
if args.show_allocs then
print("\taddr = %x size = %s" % {tonumber(address), tonumber(info.size)})
end
end
end
local top = table.values(alloc_info)
table.sort(top, function(a, b) return a.size > b.size end)
for n, alloc in ipairs(top) do
print("\t%d bytes in %d allocations from stack\n\t\t%s" %
{alloc.size, alloc.count, table.concat(alloc.stack, "\n\t\t")})
if n == args.top then break end
end
end
if args.trace then
local pipe = bpf:pipe()
while true do
print(pipe:trace_fields())
end
else
while true do
utils.posix.sleep(args.interval)
syms:refresh()
print_outstanding()
end
end
end
--[[
Copyright 2016 GitHub, Inc
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
]]
local program = [[
#include <uapi/linux/ptrace.h>
#include <linux/sched.h>
#define MINBLOCK_US 1
struct key_t {
char name[TASK_COMM_LEN];
int stack_id;
};
BPF_HASH(counts, struct key_t);
BPF_HASH(start, u32);
BPF_STACK_TRACE(stack_traces, 10240)
int oncpu(struct pt_regs *ctx, struct task_struct *prev) {
u32 pid;
u64 ts, *tsp;
// record previous thread sleep time
if (FILTER) {
pid = prev->pid;
ts = bpf_ktime_get_ns();
start.update(&pid, &ts);
}
// calculate current thread's delta time
pid = bpf_get_current_pid_tgid();
tsp = start.lookup(&pid);
if (tsp == 0)
return 0; // missed start or filtered
u64 delta = bpf_ktime_get_ns() - *tsp;
start.delete(&pid);
delta = delta / 1000;
if (delta < MINBLOCK_US)
return 0;
// create map key
u64 zero = 0, *val;
struct key_t key = {};
int stack_flags = BPF_F_REUSE_STACKID;
/*
if (!(prev->flags & PF_KTHREAD))
stack_flags |= BPF_F_USER_STACK;
*/
bpf_get_current_comm(&key.name, sizeof(key.name));
key.stack_id = stack_traces.get_stackid(ctx, stack_flags);
val = counts.lookup_or_init(&key, &zero);
(*val) += delta;
return 0;
}
]]
return function(BPF, utils)
local ffi = require("ffi")
local parser = utils.argparse("offcputime", "Summarize off-cpu time")
parser:flag("-u --user-only")
parser:option("-p --pid"):convert(tonumber)
parser:flag("-f --folded")
parser:option("-d --duration", "duration to trace for", 9999999):convert(tonumber)
local args = parser:parse()
local ksym = utils.sym.KSymbols:new()
local filter = "1"
local MAXDEPTH = 20
if args.pid then
filter = "pid == %d" % args.pid
elseif args.user_only then
filter = "!(prev->flags & PF_KTHREAD)"
end
local text = program:gsub("FILTER", filter)
local b = BPF:new{text=text}
b:attach_kprobe{event="finish_task_switch", fn_name="oncpu"}
if BPF.num_open_kprobes() == 0 then
print("no functions matched. quitting...")
return
end
print("Sleeping for %d seconds..." % args.duration)
pcall(utils.posix.sleep, args.duration)
print("Tracing...")
local counts = b:get_table("counts")
local stack_traces = b:get_table("stack_traces")
for k, v in counts:items() do
for addr in stack_traces:walk(tonumber(k.stack_id)) do
print(" %-16x %s" % {addr, ksym:lookup(addr)})
end
print(" %-16s %s" % {"-", ffi.string(k.name)})
print(" %d\n" % tonumber(v))
end
end
--[[
Copyright 2016 GitHub, Inc
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
]]
assert(arg[1], "usage: strlen_count PID")
local program = string.gsub([[
#include <uapi/linux/ptrace.h>
int printarg(struct pt_regs *ctx) {
if (!ctx->di)
return 0;
u32 pid = bpf_get_current_pid_tgid();
if (pid != PID)
return 0;
char str[128] = {};
bpf_probe_read(&str, sizeof(str), (void *)ctx->di);
bpf_trace_printk("strlen(\"%s\")\n", &str);
return 0;
};
]], "PID", arg[1])
return function(BPF)
local b = BPF:new{text=program, debug=0}
b:attach_uprobe{name="c", sym="strlen", fn_name="printarg"}
local pipe = b:pipe()
while true do
local task, pid, cpu, flags, ts, msg = pipe:trace_fields()
print("%-18.9f %-16s %-6d %s" % {ts, task, pid, msg})
end
end
--[[
Copyright 2016 GitHub, Inc
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
]]
local program = [[
#include <uapi/linux/ptrace.h>
#include <linux/sched.h>
struct key_t {
u32 prev_pid;
u32 curr_pid;
};
// map_type, key_type, leaf_type, table_name, num_entry
BPF_TABLE("hash", struct key_t, u64, stats, 1024);
int count_sched(struct pt_regs *ctx, struct task_struct *prev) {
struct key_t key = {};
u64 zero = 0, *val;
key.curr_pid = bpf_get_current_pid_tgid();
key.prev_pid = prev->pid;
val = stats.lookup_or_init(&key, &zero);
(*val)++;
return 0;
}
]]
return function(BPF)
local b = BPF:new{text=program, debug=0}
b:attach_kprobe{event="finish_task_switch", fn_name="count_sched"}
print("Press any key...")
io.read()
local t = b:get_table("stats")
for k, v in t:items() do
print("task_switch[%d -> %d] = %d" % {k.prev_pid, k.curr_pid, tonumber(v)})
end
end
Lua Tools for BCC
-----------------
This directory contains Lua tooling for [BCC](https://github.com/iovisor/bcc)
(the BPF Compiler Collection).
BCC is a toolkit for creating userspace and kernel tracing programs. By
default, it comes with a library `libbcc`, some example tooling and a Python
frontend for the library.
Here we present an alternate frontend for `libbcc` implemented in LuaJIT. This
lets you write the userspace part of your tracer in Lua instead of Python.
Since LuaJIT is a JIT compiled language, tracers implemented in `bcc-lua`
exhibit significantly reduced overhead compared to their Python equivalents.
This is particularly noticeable in tracers that actively use the table APIs to
get information from the kernel.
If your tracer makes extensive use of `BPF_MAP_TYPE_PERF_EVENT_ARRAY` or
`BPF_MAP_TYPE_HASH`, you may find the performance characteristics of this
implementation very appealing, as LuaJIT can compile to native code a lot of
the callchain to process the events, and this wrapper has been designed to
benefit from such JIT compilation.
## Quickstart Guide
The following instructions assume Ubuntu 14.04 LTS.
1. Install a **very new kernel**. It has to be new and shiny for this to work. 4.3+
```
VER=4.4.2-040402
PREFIX=http://kernel.ubuntu.com/~kernel-ppa/mainline/v4.4.2-wily/
REL=201602171633
wget ${PREFIX}/linux-headers-${VER}-generic_${VER}.${REL}_amd64.deb
wget ${PREFIX}/linux-headers-${VER}_${VER}.${REL}_all.deb
wget ${PREFIX}/linux-image-${VER}-generic_${VER}.${REL}_amd64.deb
sudo dpkg -i linux-*${VER}.${REL}*.deb
```
2. Install the `libbcc` binary packages and `luajit`
```
sudo apt-key adv --keyserver keyserver.ubuntu.com --recv-keys D4284CDD
echo "deb http://52.8.15.63/apt trusty main" | sudo tee /etc/apt/sources.list.d/iovisor.list
sudo apt-get update
sudo apt-get install libbcc luajit
```
3. Test one of the examples to ensure `libbcc` is properly installed
```
sudo ./bcc-probe examples/lua/task_switch.lua
```
#!/usr/bin/env luajit
--[[
Copyright 2016 GitHub, Inc
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
]]
function setup_path()
local str = require("debug").getinfo(2, "S").source:sub(2)
local script_path = str:match("(.*/)").."/?.lua;"
package.path = script_path..package.path
end
setup_path()
local BCC = require("bcc.init")
local BPF = BCC.BPF
BPF.script_root(arg[1])
local utils = {
argparse = require("bcc.vendor.argparse"),
posix = require("bcc.vendor.posix"),
sym = BCC.sym
}
local tracefile = table.remove(arg, 1)
local command = dofile(tracefile)
local res, err = pcall(command, BPF, utils)
if not res then
io.stderr:write("[ERROR] "..err.."\n")
end
BPF.cleanup_probes()
--[[
Copyright 2016 GitHub, Inc
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
]]
local ffi = require("ffi")
local libbcc = require("bcc.libbcc")
local TracerPipe = require("bcc.tracerpipe")
local Table = require("bcc.table")
local LD = require("bcc.ld")
local Bpf = class("BPF")
Bpf.static.open_kprobes = {}
Bpf.static.open_uprobes = {}
Bpf.static.process_symbols = {}
Bpf.static.KPROBE_LIMIT = 1000
Bpf.static.tracer_pipe = nil
Bpf.static.DEFAULT_CFLAGS = {
'-D__HAVE_BUILTIN_BSWAP16__',
'-D__HAVE_BUILTIN_BSWAP32__',
'-D__HAVE_BUILTIN_BSWAP64__',
}
function Bpf.static.check_probe_quota(n)
local cur = table.count(Bpf.static.open_kprobes) + table.count(Bpf.static.open_uprobes)
assert(cur + n <= Bpf.static.KPROBE_LIMIT, "number of open probes would exceed quota")
end
function Bpf.static.cleanup_probes()
local function detach_all(probe_type, all_probes)
for key, probe in pairs(all_probes) do
libbcc.perf_reader_free(probe)
if type(key) == "string" then
local desc = string.format("-:%s/%s", probe_type, key)
log.info("detaching %s", desc)
if probe_type == "kprobes" then
libbcc.bpf_detach_kprobe(desc)
elseif probe_type == "uprobes" then
libbcc.bpf_detach_uprobe(desc)
end
end
all_probes[key] = nil
end
end
detach_all("kprobes", Bpf.static.open_kprobes)
detach_all("uprobes", Bpf.static.open_uprobes)
if Bpf.static.tracer_pipe ~= nil then
Bpf.static.tracer_pipe:close()
end
end
function Bpf.static.num_open_uprobes()
return table.count(Bpf.static.open_uprobes)
end
function Bpf.static.num_open_kprobes()
return table.count(Bpf.static.open_kprobes)
end
function Bpf.static.usymaddr(pid, addr, refresh)
local proc_sym = Bpf.static.process_symbols[pid]
if proc_sym == nil then
proc_sym = ProcSymbols(pid)
Bpf.static.process_symbols[pid] = proc_sym
elseif refresh then
proc_sym.refresh()
end
return proc_sym.decode_addr(addr)
end
Bpf.static.SCRIPT_ROOT = "./"
function Bpf.static.script_root(root)
local dir, file = root:match'(.*/)(.*)'
Bpf.static.SCRIPT_ROOT = dir or "./"
return Bpf
end
local function _find_file(script_root, filename)
if filename == nil then
return nil
end
if os.exists(filename) then
return filename
end
if not filename:starts("/") then
filename = script_root .. filename
if os.exists(filename) then
return filename
end
end
assert(nil, "failed to find file "..filename.." (root="..script_root..")")
end
function Bpf:initialize(args)
self.do_debug = args.debug or false
self.funcs = {}
self.tables = {}
local cflags = table.join(Bpf.DEFAULT_CFLAGS, args.cflags)
local cflags_ary = ffi.new("const char *[?]", #cflags, cflags)
local llvm_debug = args.debug or 0
assert(type(llvm_debug) == "number")
if args.text then
log.info("\n%s\n", args.text)
self.module = libbcc.bpf_module_create_c_from_string(args.text, llvm_debug, cflags_ary, #cflags)
elseif args.src_file then
local src = _find_file(Bpf.SCRIPT_ROOT, args.src_file)
if src:ends(".b") then
local hdr = _find_file(Bpf.SCRIPT_ROOT, args.hdr_file)
self.module = libbcc.bpf_module_create_b(src, hdr, llvm_debug)
else
self.module = libbcc.bpf_module_create_c(src, llvm_debug, cflags_ary, #cflags)
end
end
assert(self.module ~= nil, "failed to compile BPF module")
end
function Bpf:load_funcs(prog_type)
prog_type = prog_type or "BPF_PROG_TYPE_KPROBE"
local result = {}
local fn_count = tonumber(libbcc.bpf_num_functions(self.module))
for i = 0,fn_count-1 do
local name = ffi.string(libbcc.bpf_function_name(self.module, i))
table.insert(result, self:load_func(name, prog_type))
end
return result
end
function Bpf:load_func(fn_name, prog_type)
if self.funcs[fn_name] ~= nil then
return self.funcs[fn_name]
end
assert(libbcc.bpf_function_start(self.module, fn_name) ~= nil,
"unknown program: "..fn_name)
local fd = libbcc.bpf_prog_load(prog_type,
libbcc.bpf_function_start(self.module, fn_name),
libbcc.bpf_function_size(self.module, fn_name),
libbcc.bpf_module_license(self.module),
libbcc.bpf_module_kern_version(self.module), nil, 0)
assert(fd >= 0, "failed to load BPF program "..fn_name)
log.info("loaded %s (%d)", fn_name, fd)
local fn = {bpf=self, name=fn_name, fd=fd}
self.funcs[fn_name] = fn
return fn
end
function Bpf:dump_func(fn_name)
local start = libbcc.bpf_function_start(self.module, fn_name)
assert(start ~= nil, "unknown program")
local len = libbcc.bpf_function_size(self.module, fn_name)
return ffi.string(start, tonumber(len))
end
function Bpf:attach_uprobe(args)
Bpf.check_probe_quota(1)
local path, addr = LD.check_path_symbol(args.name, args.sym, args.addr)
local fn = self:load_func(args.fn_name, 'BPF_PROG_TYPE_KPROBE')
local ptype = args.retprobe and "r" or "p"
local ev_name = string.format("%s_%s_0x%x", ptype, path:gsub("[^%a%d]", "_"), addr)
local desc = string.format("%s:uprobes/%s %s:0x%x", ptype, ev_name, path, addr)
log.info(desc)
local res = libbcc.bpf_attach_uprobe(fn.fd, ev_name, desc,
args.pid or -1,
args.cpu or 0,
args.group_fd or -1, nil, nil) -- TODO; reader callback
assert(res ~= nil, "failed to attach BPF to uprobe")
self:probe_store("uprobe", ev_name, res)
return self
end
function Bpf:attach_kprobe(args)
-- TODO: allow the caller to glob multiple functions together
Bpf.check_probe_quota(1)
local fn = self:load_func(args.fn_name, 'BPF_PROG_TYPE_KPROBE')
local event = args.event or ""
local ptype = args.retprobe and "r" or "p"
local ev_name = string.format("%s_%s", ptype, event:gsub("[%+%.]", "_"))
local desc = string.format("%s:kprobes/%s %s", ptype, ev_name, event)
log.info(desc)
local res = libbcc.bpf_attach_kprobe(fn.fd, ev_name, desc,
args.pid or -1,
args.cpu or 0,
args.group_fd or -1, nil, nil) -- TODO; reader callback
assert(res ~= nil, "failed to attach BPF to kprobe")
self:probe_store("kprobe", ev_name, res)
return self
end
function Bpf:pipe()
if Bpf.tracer_pipe == nil then
Bpf.tracer_pipe = TracerPipe:new()
end
return Bpf.tracer_pipe
end
function Bpf:get_table(name, key_type, leaf_type)
if self.tables[name] == nil then
self.tables[name] = Table(self, name, key_type, leaf_type)
end
return self.tables[name]
end
function Bpf:probe_store(t, id, reader)
if t == "kprobe" then
Bpf.open_kprobes[id] = reader
elseif t == "uprobe" then
Bpf.open_uprobes[id] = reader
else
error("unknown probe type '%s'" % t)
end
log.info("%s -> %s", id, reader)
end
function Bpf:probe_lookup(t, id)
if t == "kprobe" then
return Bpf.open_kprobes[id]
elseif t == "uprobe" then
return Bpf.open_uprobes[id]
else
return nil
end
end
function Bpf:_kprobe_array()
local kprobe_count = table.count(Bpf.open_kprobes)
local readers = ffi.new("struct perf_reader*[?]", kprobe_count)
local n = 0
for _, r in pairs(Bpf.open_kprobes) do
readers[n] = r
n = n + 1
end
assert(n == kprobe_count)
return readers, n
end
function Bpf:kprobe_poll_loop()
local probes, probe_count = self:_kprobe_array()
return pcall(function()
while true do
libbcc.perf_reader_poll(probe_count, probes, -1)
end
end)
end
function Bpf:kprobe_poll(timeout)
local probes, probe_count = self:_kprobe_array()
libbcc.perf_reader_poll(probe_count, probes, timeout or -1)
end
return Bpf
--[[
Copyright 2016 GitHub, Inc
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
]]
require("bcc.vendor.strict")
require("bcc.vendor.helpers")
class = require("bcc.vendor.middleclass")
return {
BPF = require("bcc.bpf"),
sym = require("bcc.sym"),
}
--[[
Copyright 2016 GitHub, Inc
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
]]
local ffi = require("ffi")
local _find_library_cache = {}
local function _find_library(name)
if _find_library_cache[name] ~= nil then
return _find_library_cache[name]
end
local arch = ffi.arch
local abi_type = "libc6"
if ffi.abi("64bit") then
if arch == "x64" then
abi_type = abi_type .. ",x86-64"
elseif arch == "ppc" or arch == "mips" then
abi_type = abi_type .. ",64bit"
end
end
local pattern = "%s+lib" .. name:escape() .. "%.%S+ %(" .. abi_type:escape() .. ".-%) => (%S+)"
local f = assert(io.popen("/sbin/ldconfig -p"))
local path = nil
for line in f:lines() do
path = line:match(pattern)
if path then break end
end
f:close()
if path then
_find_library_cache[name] = path
end
return path
end
local _find_load_address_cache = {}
local function _find_load_address(path)
if _find_load_address_cache[path] ~= nil then
return _find_load_address_cache[path]
end
local addr = os.spawn(
[[/usr/bin/objdump -x %s | awk '$1 == "LOAD" && $3 ~ /^[0x]*$/ { print $5 }']],
path)
if addr then
addr = tonumber(addr, 16)
_find_load_address_cache[path] = addr
end
return addr
end
local _find_symbol_cache = {}
local function _find_symbol(path, sym)
assert(path and sym)
if _find_symbol_cache[path] == nil then
_find_symbol_cache[path] = {}
end
local symbols = _find_symbol_cache[path]
if symbols[sym] ~= nil then
return symbols[sym]
end
local addr = os.spawn(
[[/usr/bin/objdump -tT %s | awk -v sym=%s '$NF == sym && $4 == ".text" { print $1; exit }']],
path, sym)
if addr then
addr = tonumber(addr, 16)
symbols[sym] = addr
end
return addr
end
local function _check_path_symbol(name, sym, addr)
assert(name)
local path = name:sub(1,1) == "/" and name or _find_library(name)
assert(path, "could not find library "..name)
-- TODO: realpath
local load_addr = _find_load_address(path)
assert(load_addr, "could not find load address for "..path)
if addr == nil and sym ~= nil then
addr = _find_symbol(path, sym)
end
assert(addr, "could not find address of symbol "..sym)
return path, (addr - load_addr)
end
return {
check_path_symbol=_check_path_symbol,
find_symbol=_find_symbol,
find_load_address=_find_load_address,
find_library=_find_library
}
--[[
Copyright 2016 GitHub, Inc
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
]]
local ffi = require("ffi")
ffi.cdef[[
enum bpf_prog_type {
BPF_PROG_TYPE_UNSPEC,
BPF_PROG_TYPE_SOCKET_FILTER,
BPF_PROG_TYPE_KPROBE,
BPF_PROG_TYPE_SCHED_CLS,
BPF_PROG_TYPE_SCHED_ACT,
};
int bpf_create_map(enum bpf_map_type map_type, int key_size, int value_size, int max_entries);
int bpf_update_elem(int fd, void *key, void *value, unsigned long long flags);
int bpf_lookup_elem(int fd, void *key, void *value);
int bpf_delete_elem(int fd, void *key);
int bpf_get_next_key(int fd, void *key, void *next_key);
int bpf_prog_load(enum bpf_prog_type prog_type, const struct bpf_insn *insns, int insn_len,
const char *license, unsigned kern_version, char *log_buf, unsigned log_buf_size);
int bpf_attach_socket(int sockfd, int progfd);
/* create RAW socket and bind to interface 'name' */
int bpf_open_raw_sock(const char *name);
typedef void (*perf_reader_cb)(void *cb_cookie, int pid, uint64_t callchain_num, void *callchain);
typedef void (*perf_reader_raw_cb)(void *cb_cookie, void *raw, int raw_size);
void * bpf_attach_kprobe(int progfd, const char *event, const char *event_desc,
int pid, int cpu, int group_fd, perf_reader_cb cb, void *cb_cookie);
int bpf_detach_kprobe(const char *event_desc);
void * bpf_attach_uprobe(int progfd, const char *event, const char *event_desc,
int pid, int cpu, int group_fd, perf_reader_cb cb, void *cb_cookie);
int bpf_detach_uprobe(const char *event_desc);
void * bpf_open_perf_buffer(perf_reader_raw_cb raw_cb, void *cb_cookie, int pid, int cpu);
]]
ffi.cdef[[
void * bpf_module_create_b(const char *filename, const char *proto_filename, unsigned flags);
void * bpf_module_create_c(const char *filename, unsigned flags, const char *cflags[], int ncflags);
void * bpf_module_create_c_from_string(const char *text, unsigned flags, const char *cflags[], int ncflags);
void bpf_module_destroy(void *program);
char * bpf_module_license(void *program);
unsigned bpf_module_kern_version(void *program);
size_t bpf_num_functions(void *program);
const char * bpf_function_name(void *program, size_t id);
void * bpf_function_start_id(void *program, size_t id);
void * bpf_function_start(void *program, const char *name);
size_t bpf_function_size_id(void *program, size_t id);
size_t bpf_function_size(void *program, const char *name);
size_t bpf_num_tables(void *program);
size_t bpf_table_id(void *program, const char *table_name);
int bpf_table_fd(void *program, const char *table_name);
int bpf_table_fd_id(void *program, size_t id);
int bpf_table_type(void *program, const char *table_name);
int bpf_table_type_id(void *program, size_t id);
size_t bpf_table_max_entries(void *program, const char *table_name);
size_t bpf_table_max_entries_id(void *program, size_t id);
const char * bpf_table_name(void *program, size_t id);
const char * bpf_table_key_desc(void *program, const char *table_name);
const char * bpf_table_key_desc_id(void *program, size_t id);
const char * bpf_table_leaf_desc(void *program, const char *table_name);
const char * bpf_table_leaf_desc_id(void *program, size_t id);
size_t bpf_table_key_size(void *program, const char *table_name);
size_t bpf_table_key_size_id(void *program, size_t id);
size_t bpf_table_leaf_size(void *program, const char *table_name);
size_t bpf_table_leaf_size_id(void *program, size_t id);
int bpf_table_key_snprintf(void *program, size_t id, char *buf, size_t buflen, const void *key);
int bpf_table_leaf_snprintf(void *program, size_t id, char *buf, size_t buflen, const void *leaf);
int bpf_table_key_sscanf(void *program, size_t id, const char *buf, void *key);
int bpf_table_leaf_sscanf(void *program, size_t id, const char *buf, void *leaf);
]]
ffi.cdef[[
struct perf_reader;
struct perf_reader * perf_reader_new(perf_reader_cb cb, perf_reader_raw_cb raw_cb, void *cb_cookie);
void perf_reader_free(void *ptr);
int perf_reader_mmap(struct perf_reader *reader, unsigned type, unsigned long sample_type);
int perf_reader_poll(int num_readers, struct perf_reader **readers, int timeout);
int perf_reader_fd(struct perf_reader *reader);
void perf_reader_set_fd(struct perf_reader *reader, int fd);
]]
local libbcc = ffi.load("bcc")
return libbcc
--[[
Copyright 2016 GitHub, Inc
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
]]
local ProcSymbols = class("ProcSymbols")
function ProcSymbols:initialize(pid)
self.pid = pid
self:refresh()
end
function ProcSymbols:_get_exe()
return os.spawn("readlink -f /proc/%d/exe", self.pid)
end
function ProcSymbols:_get_start_time()
return tonumber(os.spawn("cut -d' ' -f 22 /proc/%d/stat", self.pid))
end
function ProcSymbols:_get_code_ranges()
local function is_binary_segment(parts)
if #parts ~= 6 then return false end
if parts[6]:starts("[") then return false end
if parts[2]:find("x") == nil then return false end
return true
end
local ranges = {}
local cmd = string.format("/proc/%d/maps", self.pid)
for line in io.lines(cmd) do
local parts = line:split()
if is_binary_segment(parts) then
local binary = parts[6]
local range = parts[1]:split("-", true)
assert(#range == 2)
ranges[binary] = {tonumber(range[1], 16), tonumber(range[2], 16)}
end
end
return ranges
end
function ProcSymbols:refresh()
self.code_ranges = self:_get_code_ranges()
self.ranges_cache = {}
self.exe = self:_get_exe()
self.start_time = self:_get_start_time()
end
function ProcSymbols:_check_pid_wrap()
local new_exe = self:_get_exe()
local new_time = self:_get_start_time()
if self.exe ~= new_exe or self.start_time ~= new_time then
self:refresh()
end
end
function ProcSymbols:_get_sym_ranges(binary)
if self.ranges_cache[binary] ~= nil then
return self.ranges_cache[binary]
end
local function is_function_sym(parts)
return #parts == 6 and parts[4] == ".text" and parts[3] == "F"
end
local sym_ranges = {}
local proc = assert(io.popen("objdump -t "..binary))
for line in proc:lines() do
local parts = line:split()
if is_function_sym(parts) then
local sym_start = tonumber(parts[1], 16)
local sym_len = tonumber(parts[5], 16)
local sym_name = parts[6]
sym_ranges[sym_name] = {sym_start, sym_len}
end
end
proc:close()
self.ranges_cache[binary] = sym_ranges
return sym_ranges
end
function ProcSymbols:_decode_sym(binary, offset)
local sym_ranges = self:_get_sym_ranges(binary)
for name, range in pairs(sym_ranges) do
local start = range[1]
local length = range[2]
if offset >= start and offset <= (start + length) then
return string.format("%s+0x%x", name, offset - start)
end
end
return string.format("%x", offset)
end
function ProcSymbols:lookup(addr)
self:_check_pid_wrap()
for binary, range in pairs(self.code_ranges) do
local start = range[1]
local tend = range[2]
if addr >= start and addr <= tend then
local offset = binary:ends(".so") and (addr - start) or addr
return string.format("%s [%s]", self:_decode_sym(binary, offset), binary)
end
end
return string.format("%x", addr)
end
local KSymbols = class("KSymbols")
KSymbols.static.KALLSYMS = "/proc/kallsyms"
function KSymbols:initialize()
self.ksyms = {}
self.ksym_names = {}
self.loaded = false
end
function KSymbols:_load()
if self.loaded then return end
local first_line = true
for line in io.lines(KSymbols.KALLSYMS) do
if not first_line then
local cols = line:split()
local name = cols[3]
local addr = tonumber(cols[1], 16)
table.insert(self.ksyms, {name, addr})
self.ksym_names[name] = #self.ksyms
end
first_line = false
end
self.loaded = true
end
function KSymbols:_addr2index(addr)
self:_load()
return table.bsearch(self.ksyms, addr, function(v) return v[2] end)
end
function KSymbols:lookup(addr, with_offset)
local idx = self:_addr2index(addr)
if idx == nil then
return "[unknown]"
end
if with_offset then
local offset = addr - self.ksyms[idx][2]
return "%s %x" % {self.ksyms[idx][1], offset}
else
return self.ksyms[idx][1]
end
end
function KSymbols:refresh()
-- NOOP
end
return { ProcSymbols=ProcSymbols, KSymbols=KSymbols }
--[[
Copyright 2016 GitHub, Inc
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
]]
local ffi = require("ffi")
local libbcc = require("bcc.libbcc")
local Posix = require("bcc.vendor.posix")
local BaseTable = class("BaseTable")
BaseTable.static.BPF_MAP_TYPE_HASH = 1
BaseTable.static.BPF_MAP_TYPE_ARRAY = 2
BaseTable.static.BPF_MAP_TYPE_PROG_ARRAY = 3
BaseTable.static.BPF_MAP_TYPE_PERF_EVENT_ARRAY = 4
BaseTable.static.BPF_MAP_TYPE_PERCPU_HASH = 5
BaseTable.static.BPF_MAP_TYPE_PERCPU_ARRAY = 6
BaseTable.static.BPF_MAP_TYPE_STACK_TRACE = 7
function BaseTable:initialize(t_type, bpf, map_id, map_fd, key_type, leaf_type)
assert(t_type == libbcc.bpf_table_type_id(bpf.module, map_id))
self.t_type = t_type
self.bpf = bpf
self.map_id = map_id
self.map_fd = map_fd
self.c_key = ffi.typeof(key_type.."[1]")
self.c_leaf = ffi.typeof(leaf_type.."[1]")
end
function BaseTable:key_sprintf(key)
local pkey = self.c_key(key)
local buf_len = ffi.sizeof(self.c_key) * 8
local pbuf = ffi.new("char[?]", buf_len)
local res = libbcc.bpf_table_key_snprintf(
self.bpf.module, self.map_id, pbuf, buf_len, pkey)
assert(res == 0, "could not print key")
return ffi.string(pbuf)
end
function BaseTable:leaf_sprintf(leaf)
local pleaf = self.c_leaf(leaf)
local buf_len = ffi.sizeof(self.c_leaf) * 8
local pbuf = ffi.new("char[?]", buf_len)
local res = libbcc.bpf_table_leaf_snprintf(
self.bpf.module, self.map_id, pbuf, buf_len, pleaf)
assert(res == 0, "could not print leaf")
return ffi.string(pbuf)
end
function BaseTable:key_scanf(key_str)
local pkey = self.c_key()
local res = libbcc.bpf_table_key_sscanf(
self.bpf.module, self.map_id, key_str, pkey)
assert(res == 0, "could not scanf key")
return pkey[0]
end
function BaseTable:leaf_scanf(leaf_str)
local pleaf = self.c_leaf()
local res = libbcc.bpf_table_leaf_sscanf(
self.bpf.module, self.map_id, leaf_str, pleaf)
assert(res == 0, "could not scanf leaf")
return pleaf[0]
end
function BaseTable:get(key)
local pkey = self.c_key(key)
local pvalue = self.c_leaf()
if libbcc.bpf_lookup_elem(self.map_fd, pkey, pvalue) < 0 then
return nil
end
return pvalue[0]
end
function BaseTable:set(key, value)
local pkey = self.c_key(key)
local pvalue = self.c_leaf(value)
assert(libbcc.bpf_update_elem(self.map_fd, pkey, pvalue, 0) == 0, "could not update table")
end
function BaseTable:_empty_key()
local pkey = self.c_key()
local pvalue = self.c_leaf()
for _, v in ipairs({0x0, 0x55, 0xff}) do
ffi.fill(pkey, ffi.sizeof(pkey[0]), v)
if libbcc.bpf_lookup_elem(self.map_fd, pkey, pvalue) < 0 then
return pkey
end
end
error("failed to find an empty key for table iteration")
end
function BaseTable:keys()
local pkey = self:_empty_key()
return function()
local pkey_next = self.c_key()
if libbcc.bpf_get_next_key(self.map_fd, pkey, pkey_next) < 0 then
return nil
end
pkey = pkey_next
return pkey[0]
end
end
function BaseTable:items()
local pkey = self:_empty_key()
return function()
local pkey_next = self.c_key()
local pvalue = self.c_leaf()
if libbcc.bpf_get_next_key(self.map_fd, pkey, pkey_next) < 0 then
return nil
end
pkey = pkey_next
assert(libbcc.bpf_lookup_elem(self.map_fd, pkey, pvalue) == 0)
return pkey[0], pvalue[0]
end
end
local HashTable = class("HashTable", BaseTable)
function HashTable:initialize(bpf, map_id, map_fd, key_type, leaf_type)
BaseTable.initialize(self, BaseTable.BPF_MAP_TYPE_HASH, bpf, map_id, map_fd, key_type, leaf_type)
end
function HashTable:delete(key)
local pkey = self.c_key(key)
return libbcc.bpf_delete_elem(self.map_fd, pkey) == 0
end
function HashTable:size()
local n = 0
self:each(function() n = n + 1 end)
return n
end
local BaseArray = class("BaseArray", BaseTable)
function BaseArray:initialize(t_type, bpf, map_id, map_fd, key_type, leaf_type)
BaseTable.initialize(self, t_type, bpf, map_id, map_fd, key_type, leaf_type)
self.max_entries = tonumber(libbcc.bpf_table_max_entries_id(self.bpf.module, self.map_id))
end
function BaseArray:_normalize_key(key)
assert(type(key) == "number", "invalid key (expected a number")
if key < 0 then
key = self.max_entries + key
end
assert(key < self.max_entries, string.format("out of range (%d >= %d)", key, self.max_entries))
return key
end
function BaseArray:get(key)
return BaseTable.get(self, self:_normalize_key(key))
end
function BaseArray:set(key, value)
return BaseTable.set(self, self:_normalize_key(key), value)
end
function BaseArray:delete(key)
assert(nil, "unsupported")
end
function BaseArray:items(with_index)
local pkey = self.c_key()
local max = self.max_entries
local n = 0
-- TODO
return function()
local pvalue = self.c_leaf()
if n == max then
return nil
end
pkey[0] = n
n = n + 1
if libbcc.bpf_lookup_elem(self.map_fd, pkey, pvalue) ~= 0 then
return nil
end
if with_index then
return n, pvalue[0] -- return 1-based index
else
return pvalue[0]
end
end
end
local Array = class("Array", BaseArray)
function Array:initialize(bpf, map_id, map_fd, key_type, leaf_type)
BaseArray.initialize(self, BaseTable.BPF_MAP_TYPE_ARRAY, bpf, map_id, map_fd, key_type, leaf_type)
end
local PerfEventArray = class("PerfEventArray", BaseArray)
function PerfEventArray:initialize(bpf, map_id, map_fd, key_type, leaf_type)
BaseArray.initialize(self, BaseTable.BPF_MAP_TYPE_PERF_EVENT_ARRAY, bpf, map_id, map_fd, key_type, leaf_type)
self._callbacks = {}
end
local function _perf_id(id, cpu)
return string.format("perf_event_array:%d:%d", tonumber(id), cpu or 0)
end
function PerfEventArray:_open_perf_buffer(cpu, callback, ctype)
local _cb = ffi.cast("perf_reader_raw_cb",
function (cookie, data, size)
callback(cpu, ctype(data)[0])
end)
local reader = libbcc.bpf_open_perf_buffer(_cb, nil, -1, cpu)
assert(reader, "failed to open perf buffer")
local fd = libbcc.perf_reader_fd(reader)
self:set(cpu, fd)
self.bpf:probe_store("kprobe", _perf_id(self.map_id, cpu), reader)
self._callbacks[cpu] = _cb
end
function PerfEventArray:open_perf_buffer(callback, data_type)
assert(data_type, "a data type is needed for callback conversion")
local ctype = ffi.typeof(data_type.."*")
for i = 0, Posix.cpu_count() - 1 do
self:_open_perf_buffer(i, callback, ctype)
end
end
local StackTrace = class("StackTrace", BaseTable)
StackTrace.static.MAX_STACK = 127
function StackTrace:initialize(bpf, map_id, map_fd, key_type, leaf_type)
BaseTable.initialize(self, BaseTable.BPF_MAP_TYPE_STACK_TRACE, bpf, map_id, map_fd, key_type, leaf_type)
self._stackp = self.c_leaf() -- FIXME: not threadsafe
end
function StackTrace:walk(id)
local pkey = self.c_key(id)
local pstack = self._stackp
local i = 0
if libbcc.bpf_lookup_elem(self.map_fd, pkey, pstack) < 0 then
return nil
end
return function()
if i >= StackTrace.MAX_STACK then
return nil
end
local addr = tonumber(pstack[0].ip[i])
if addr == 0 then
return nil
end
i = i + 1
return addr
end
end
function StackTrace:get(id, resolver)
local stack = {}
for addr in self:walk(id) do
table.insert(stack, resolver and resolver(addr) or addr)
end
return stack
end
local function _decode_table_type(desc)
local json = require("bcc.vendor.json")
local json_desc = ffi.string(desc)
local function _dec(t)
if type(t) == "string" then
return t
end
local fields = {}
local struct = t[3] or "struct"
for _, value in ipairs(t[2]) do
local f = nil
if #value == 2 then
f = string.format("%s %s;", _dec(value[2]), value[1])
elseif #value == 3 then
if type(value[3]) == "table" then
f = string.format("%s %s[%d];", _dec(value[2]), value[1], value[3][1])
elseif type(value[3]) == "number" then
local t = _dec(value[2])
assert(t == "int" or t == "unsigned int",
"bitfields can only appear in [unsigned] int types")
f = string.format("%s %s:%d;", t, value[1], value[3])
end
end
assert(f ~= nil, "failed to decode type "..json_desc)
table.insert(fields, f)
end
assert(struct == "struct" or struct == "union", "unknown complex type: "..struct)
return string.format("%s { %s }", struct, table.concat(fields, " "))
end
return _dec(json.parse(json_desc))
end
local function NewTable(bpf, name, key_type, leaf_type)
local id = libbcc.bpf_table_id(bpf.module, name)
local fd = libbcc.bpf_table_fd(bpf.module, name)
if fd < 0 then
return nil
end
local t_type = libbcc.bpf_table_type_id(bpf.module, id)
local table = nil
if t_type == BaseTable.BPF_MAP_TYPE_HASH then
table = HashTable
elseif t_type == BaseTable.BPF_MAP_TYPE_ARRAY then
table = Array
elseif t_type == BaseTable.BPF_MAP_TYPE_PERF_EVENT_ARRAY then
table = PerfEventArray
elseif t_type == BaseTable.BPF_MAP_TYPE_STACK_TRACE then
table = StackTrace
end
assert(table, "unsupported table type %d" % t_type)
if key_type == nil then
local desc = libbcc.bpf_table_key_desc(bpf.module, name)
assert(desc, "Failed to load BPF table description for "..name)
key_type = _decode_table_type(desc)
end
if leaf_type == nil then
local desc = libbcc.bpf_table_leaf_desc(bpf.module, name)
assert(desc, "Failed to load BPF table description for "..name)
leaf_type = _decode_table_type(desc)
end
log.info("key = %s value = %s", key_type, leaf_type)
return table:new(bpf, id, fd, key_type, leaf_type)
end
return NewTable
--[[
Copyright 2016 GitHub, Inc
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
]]
local TracerPipe = class("TracerPipe")
TracerPipe.static.TRACEFS = "/sys/kernel/debug/tracing"
TracerPipe.static.fields = "%s+(.-)%-(%d+)%s+%[(%d+)%]%s+(....)%s+([%d%.]+):.-:%s+(.+)"
function TracerPipe:close()
if self.pipe ~= nil then
self.pipe:close()
end
end
function TracerPipe:open()
if self.pipe == nil then
self.pipe = assert(io.open(TracerPipe.TRACEFS .. "/trace_pipe"))
end
return self.pipe
end
function TracerPipe:readline()
return self:open():read()
end
function TracerPipe:trace_fields()
while true do
local line = self:readline()
if not line and self.nonblocking then
return nil
end
if not line:starts("CPU:") then
local task, pid, cpu, flags, ts, msg = line:match(TracerPipe.fields)
if task ~= nil then
return task, tonumber(pid), tonumber(cpu), flags, tonumber(ts), msg
end
end
end
end
function TracerPipe:initialize(nonblocking)
self.nonblocking = nonblocking
end
return TracerPipe
-- The MIT License (MIT)
-- Copyright (c) 2013 - 2015 Peter Melnichenko
-- Permission is hereby granted, free of charge, to any person obtaining a copy of
-- this software and associated documentation files (the "Software"), to deal in
-- the Software without restriction, including without limitation the rights to
-- use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
-- the Software, and to permit persons to whom the Software is furnished to do so,
-- subject to the following conditions:
-- The above copyright notice and this permission notice shall be included in all
-- copies or substantial portions of the Software.
-- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
-- FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
-- COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
-- IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
-- CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
local function deep_update(t1, t2)
for k, v in pairs(t2) do
if type(v) == "table" then
v = deep_update({}, v)
end
t1[k] = v
end
return t1
end
-- A property is a tuple {name, callback}.
-- properties.args is number of properties that can be set as arguments
-- when calling an object.
local function class(prototype, properties, parent)
-- Class is the metatable of its instances.
local cl = {}
cl.__index = cl
if parent then
cl.__prototype = deep_update(deep_update({}, parent.__prototype), prototype)
else
cl.__prototype = prototype
end
if properties then
local names = {}
-- Create setter methods and fill set of property names.
for _, property in ipairs(properties) do
local name, callback = property[1], property[2]
cl[name] = function(self, value)
if not callback(self, value) then
self["_" .. name] = value
end
return self
end
names[name] = true
end
function cl.__call(self, ...)
-- When calling an object, if the first argument is a table,
-- interpret keys as property names, else delegate arguments
-- to corresponding setters in order.
if type((...)) == "table" then
for name, value in pairs((...)) do
if names[name] then
self[name](self, value)
end
end
else
local nargs = select("#", ...)
for i, property in ipairs(properties) do
if i > nargs or i > properties.args then
break
end
local arg = select(i, ...)
if arg ~= nil then
self[property[1]](self, arg)
end
end
end
return self
end
end
-- If indexing class fails, fallback to its parent.
local class_metatable = {}
class_metatable.__index = parent
function class_metatable.__call(self, ...)
-- Calling a class returns its instance.
-- Arguments are delegated to the instance.
local object = deep_update({}, self.__prototype)
setmetatable(object, self)
return object(...)
end
return setmetatable(cl, class_metatable)
end
local function typecheck(name, types, value)
for _, type_ in ipairs(types) do
if type(value) == type_ then
return true
end
end
error(("bad property '%s' (%s expected, got %s)"):format(name, table.concat(types, " or "), type(value)))
end
local function typechecked(name, ...)
local types = {...}
return {name, function(_, value) typecheck(name, types, value) end}
end
local multiname = {"name", function(self, value)
typecheck("name", {"string"}, value)
for alias in value:gmatch("%S+") do
self._name = self._name or alias
table.insert(self._aliases, alias)
end
-- Do not set _name as with other properties.
return true
end}
local function parse_boundaries(str)
if tonumber(str) then
return tonumber(str), tonumber(str)
end
if str == "*" then
return 0, math.huge
end
if str == "+" then
return 1, math.huge
end
if str == "?" then
return 0, 1
end
if str:match "^%d+%-%d+$" then
local min, max = str:match "^(%d+)%-(%d+)$"
return tonumber(min), tonumber(max)
end
if str:match "^%d+%+$" then
local min = str:match "^(%d+)%+$"
return tonumber(min), math.huge
end
end
local function boundaries(name)
return {name, function(self, value)
typecheck(name, {"number", "string"}, value)
local min, max = parse_boundaries(value)
if not min then
error(("bad property '%s'"):format(name))
end
self["_min" .. name], self["_max" .. name] = min, max
end}
end
local actions = {}
local option_action = {"action", function(_, value)
typecheck("action", {"function", "string"}, value)
if type(value) == "string" and not actions[value] then
error(("unknown action '%s'"):format(value))
end
end}
local option_init = {"init", function(self)
self._has_init = true
end}
local option_default = {"default", function(self, value)
if type(value) ~= "string" then
self._init = value
self._has_init = true
return true
end
end}
local add_help = {"add_help", function(self, value)
typecheck("add_help", {"boolean", "string", "table"}, value)
if self._has_help then
table.remove(self._options)
self._has_help = false
end
if value then
local help = self:flag()
:description "Show this help message and exit."
:action(function()
print(self:get_help())
os.exit(0)
end)
if value ~= true then
help = help(value)
end
if not help._name then
help "-h" "--help"
end
self._has_help = true
end
end}
local Parser = class({
_arguments = {},
_options = {},
_commands = {},
_mutexes = {},
_require_command = true,
_handle_options = true
}, {
args = 3,
typechecked("name", "string"),
typechecked("description", "string"),
typechecked("epilog", "string"),
typechecked("usage", "string"),
typechecked("help", "string"),
typechecked("require_command", "boolean"),
typechecked("handle_options", "boolean"),
typechecked("action", "function"),
typechecked("command_target", "string"),
add_help
})
local Command = class({
_aliases = {}
}, {
args = 3,
multiname,
typechecked("description", "string"),
typechecked("epilog", "string"),
typechecked("target", "string"),
typechecked("usage", "string"),
typechecked("help", "string"),
typechecked("require_command", "boolean"),
typechecked("handle_options", "boolean"),
typechecked("action", "function"),
typechecked("command_target", "string"),
add_help
}, Parser)
local Argument = class({
_minargs = 1,
_maxargs = 1,
_mincount = 1,
_maxcount = 1,
_defmode = "unused",
_show_default = true
}, {
args = 5,
typechecked("name", "string"),
typechecked("description", "string"),
option_default,
typechecked("convert", "function", "table"),
boundaries("args"),
typechecked("target", "string"),
typechecked("defmode", "string"),
typechecked("show_default", "boolean"),
typechecked("argname", "string", "table"),
option_action,
option_init
})
local Option = class({
_aliases = {},
_mincount = 0,
_overwrite = true
}, {
args = 6,
multiname,
typechecked("description", "string"),
option_default,
typechecked("convert", "function", "table"),
boundaries("args"),
boundaries("count"),
typechecked("target", "string"),
typechecked("defmode", "string"),
typechecked("show_default", "boolean"),
typechecked("overwrite", "boolean"),
typechecked("argname", "string", "table"),
option_action,
option_init
}, Argument)
function Argument:_get_argument_list()
local buf = {}
local i = 1
while i <= math.min(self._minargs, 3) do
local argname = self:_get_argname(i)
if self._default and self._defmode:find "a" then
argname = "[" .. argname .. "]"
end
table.insert(buf, argname)
i = i+1
end
while i <= math.min(self._maxargs, 3) do
table.insert(buf, "[" .. self:_get_argname(i) .. "]")
i = i+1
if self._maxargs == math.huge then
break
end
end
if i < self._maxargs then
table.insert(buf, "...")
end
return buf
end
function Argument:_get_usage()
local usage = table.concat(self:_get_argument_list(), " ")
if self._default and self._defmode:find "u" then
if self._maxargs > 1 or (self._minargs == 1 and not self._defmode:find "a") then
usage = "[" .. usage .. "]"
end
end
return usage
end
function actions.store_true(result, target)
result[target] = true
end
function actions.store_false(result, target)
result[target] = false
end
function actions.store(result, target, argument)
result[target] = argument
end
function actions.count(result, target, _, overwrite)
if not overwrite then
result[target] = result[target] + 1
end
end
function actions.append(result, target, argument, overwrite)
result[target] = result[target] or {}
table.insert(result[target], argument)
if overwrite then
table.remove(result[target], 1)
end
end
function actions.concat(result, target, arguments, overwrite)
if overwrite then
error("'concat' action can't handle too many invocations")
end
result[target] = result[target] or {}
for _, argument in ipairs(arguments) do
table.insert(result[target], argument)
end
end
function Argument:_get_action()
local action, init
if self._maxcount == 1 then
if self._maxargs == 0 then
action, init = "store_true", nil
else
action, init = "store", nil
end
else
if self._maxargs == 0 then
action, init = "count", 0
else
action, init = "append", {}
end
end
if self._action then
action = self._action
end
if self._has_init then
init = self._init
end
if type(action) == "string" then
action = actions[action]
end
return action, init
end
-- Returns placeholder for `narg`-th argument.
function Argument:_get_argname(narg)
local argname = self._argname or self:_get_default_argname()
if type(argname) == "table" then
return argname[narg]
else
return argname
end
end
function Argument:_get_default_argname()
return "<" .. self._name .. ">"
end
function Option:_get_default_argname()
return "<" .. self:_get_default_target() .. ">"
end
-- Returns label to be shown in the help message.
function Argument:_get_label()
return self._name
end
function Option:_get_label()
local variants = {}
local argument_list = self:_get_argument_list()
table.insert(argument_list, 1, nil)
for _, alias in ipairs(self._aliases) do
argument_list[1] = alias
table.insert(variants, table.concat(argument_list, " "))
end
return table.concat(variants, ", ")
end
function Command:_get_label()
return table.concat(self._aliases, ", ")
end
function Argument:_get_description()
if self._default and self._show_default then
if self._description then
return ("%s (default: %s)"):format(self._description, self._default)
else
return ("default: %s"):format(self._default)
end
else
return self._description or ""
end
end
function Command:_get_description()
return self._description or ""
end
function Option:_get_usage()
local usage = self:_get_argument_list()
table.insert(usage, 1, self._name)
usage = table.concat(usage, " ")
if self._mincount == 0 or self._default then
usage = "[" .. usage .. "]"
end
return usage
end
function Argument:_get_default_target()
return self._name
end
function Option:_get_default_target()
local res
for _, alias in ipairs(self._aliases) do
if alias:sub(1, 1) == alias:sub(2, 2) then
res = alias:sub(3)
break
end
end
res = res or self._name:sub(2)
return (res:gsub("-", "_"))
end
function Option:_is_vararg()
return self._maxargs ~= self._minargs
end
function Parser:_get_fullname()
local parent = self._parent
local buf = {self._name}
while parent do
table.insert(buf, 1, parent._name)
parent = parent._parent
end
return table.concat(buf, " ")
end
function Parser:_update_charset(charset)
charset = charset or {}
for _, command in ipairs(self._commands) do
command:_update_charset(charset)
end
for _, option in ipairs(self._options) do
for _, alias in ipairs(option._aliases) do
charset[alias:sub(1, 1)] = true
end
end
return charset
end
function Parser:argument(...)
local argument = Argument(...)
table.insert(self._arguments, argument)
return argument
end
function Parser:option(...)
local option = Option(...)
if self._has_help then
table.insert(self._options, #self._options, option)
else
table.insert(self._options, option)
end
return option
end
function Parser:flag(...)
return self:option():args(0)(...)
end
function Parser:command(...)
local command = Command():add_help(true)(...)
command._parent = self
table.insert(self._commands, command)
return command
end
function Parser:mutex(...)
local options = {...}
for i, option in ipairs(options) do
assert(getmetatable(option) == Option, ("bad argument #%d to 'mutex' (Option expected)"):format(i))
end
table.insert(self._mutexes, options)
return self
end
local max_usage_width = 70
local usage_welcome = "Usage: "
function Parser:get_usage()
if self._usage then
return self._usage
end
local lines = {usage_welcome .. self:_get_fullname()}
local function add(s)
if #lines[#lines]+1+#s <= max_usage_width then
lines[#lines] = lines[#lines] .. " " .. s
else
lines[#lines+1] = (" "):rep(#usage_welcome) .. s
end
end
-- This can definitely be refactored into something cleaner
local mutex_options = {}
local vararg_mutexes = {}
-- First, put mutexes which do not contain vararg options and remember those which do
for _, mutex in ipairs(self._mutexes) do
local buf = {}
local is_vararg = false
for _, option in ipairs(mutex) do
if option:_is_vararg() then
is_vararg = true
end
table.insert(buf, option:_get_usage())
mutex_options[option] = true
end
local repr = "(" .. table.concat(buf, " | ") .. ")"
if is_vararg then
table.insert(vararg_mutexes, repr)
else
add(repr)
end
end
-- Second, put regular options
for _, option in ipairs(self._options) do
if not mutex_options[option] and not option:_is_vararg() then
add(option:_get_usage())
end
end
-- Put positional arguments
for _, argument in ipairs(self._arguments) do
add(argument:_get_usage())
end
-- Put mutexes containing vararg options
for _, mutex_repr in ipairs(vararg_mutexes) do
add(mutex_repr)
end
for _, option in ipairs(self._options) do
if not mutex_options[option] and option:_is_vararg() then
add(option:_get_usage())
end
end
if #self._commands > 0 then
if self._require_command then
add("<command>")
else
add("[<command>]")
end
add("...")
end
return table.concat(lines, "\n")
end
local margin_len = 3
local margin_len2 = 25
local margin = (" "):rep(margin_len)
local margin2 = (" "):rep(margin_len2)
local function make_two_columns(s1, s2)
if s2 == "" then
return margin .. s1
end
s2 = s2:gsub("\n", "\n" .. margin2)
if #s1 < (margin_len2-margin_len) then
return margin .. s1 .. (" "):rep(margin_len2-margin_len-#s1) .. s2
else
return margin .. s1 .. "\n" .. margin2 .. s2
end
end
function Parser:get_help()
if self._help then
return self._help
end
local blocks = {self:get_usage()}
if self._description then
table.insert(blocks, self._description)
end
local labels = {"Arguments:", "Options:", "Commands:"}
for i, elements in ipairs{self._arguments, self._options, self._commands} do
if #elements > 0 then
local buf = {labels[i]}
for _, element in ipairs(elements) do
table.insert(buf, make_two_columns(element:_get_label(), element:_get_description()))
end
table.insert(blocks, table.concat(buf, "\n"))
end
end
if self._epilog then
table.insert(blocks, self._epilog)
end
return table.concat(blocks, "\n\n")
end
local function get_tip(context, wrong_name)
local context_pool = {}
local possible_name
local possible_names = {}
for name in pairs(context) do
if type(name) == "string" then
for i = 1, #name do
possible_name = name:sub(1, i - 1) .. name:sub(i + 1)
if not context_pool[possible_name] then
context_pool[possible_name] = {}
end
table.insert(context_pool[possible_name], name)
end
end
end
for i = 1, #wrong_name + 1 do
possible_name = wrong_name:sub(1, i - 1) .. wrong_name:sub(i + 1)
if context[possible_name] then
possible_names[possible_name] = true
elseif context_pool[possible_name] then
for _, name in ipairs(context_pool[possible_name]) do
possible_names[name] = true
end
end
end
local first = next(possible_names)
if first then
if next(possible_names, first) then
local possible_names_arr = {}
for name in pairs(possible_names) do
table.insert(possible_names_arr, "'" .. name .. "'")
end
table.sort(possible_names_arr)
return "\nDid you mean one of these: " .. table.concat(possible_names_arr, " ") .. "?"
else
return "\nDid you mean '" .. first .. "'?"
end
else
return ""
end
end
local ElementState = class({
invocations = 0
})
function ElementState:__call(state, element)
self.state = state
self.result = state.result
self.element = element
self.target = element._target or element:_get_default_target()
self.action, self.result[self.target] = element:_get_action()
return self
end
function ElementState:error(fmt, ...)
self.state:error(fmt, ...)
end
function ElementState:convert(argument)
local converter = self.element._convert
if converter then
local ok, err
if type(converter) == "function" then
ok, err = converter(argument)
else
ok = converter[argument]
end
if ok == nil then
self:error(err and "%s" or "malformed argument '%s'", err or argument)
end
argument = ok
end
return argument
end
function ElementState:default(mode)
return self.element._defmode:find(mode) and self.element._default
end
local function bound(noun, min, max, is_max)
local res = ""
if min ~= max then
res = "at " .. (is_max and "most" or "least") .. " "
end
local number = is_max and max or min
return res .. tostring(number) .. " " .. noun .. (number == 1 and "" or "s")
end
function ElementState:invoke(alias)
self.open = true
self.name = ("%s '%s'"):format(alias and "option" or "argument", alias or self.element._name)
self.overwrite = false
if self.invocations >= self.element._maxcount then
if self.element._overwrite then
self.overwrite = true
else
self:error("%s must be used %s", self.name, bound("time", self.element._mincount, self.element._maxcount, true))
end
else
self.invocations = self.invocations + 1
end
self.args = {}
if self.element._maxargs <= 0 then
self:close()
end
return self.open
end
function ElementState:pass(argument)
argument = self:convert(argument)
table.insert(self.args, argument)
if #self.args >= self.element._maxargs then
self:close()
end
return self.open
end
function ElementState:complete_invocation()
while #self.args < self.element._minargs do
self:pass(self.element._default)
end
end
function ElementState:close()
if self.open then
self.open = false
if #self.args < self.element._minargs then
if self:default("a") then
self:complete_invocation()
else
if #self.args == 0 then
if getmetatable(self.element) == Argument then
self:error("missing %s", self.name)
elseif self.element._maxargs == 1 then
self:error("%s requires an argument", self.name)
end
end
self:error("%s requires %s", self.name, bound("argument", self.element._minargs, self.element._maxargs))
end
end
local args = self.args
if self.element._maxargs <= 1 then
args = args[1]
end
if self.element._maxargs == 1 and self.element._minargs == 0 and self.element._mincount ~= self.element._maxcount then
args = self.args
end
self.action(self.result, self.target, args, self.overwrite)
end
end
local ParseState = class({
result = {},
options = {},
arguments = {},
argument_i = 1,
element_to_mutexes = {},
mutex_to_used_option = {},
command_actions = {}
})
function ParseState:__call(parser, error_handler)
self.parser = parser
self.error_handler = error_handler
self.charset = parser:_update_charset()
self:switch(parser)
return self
end
function ParseState:error(fmt, ...)
self.error_handler(self.parser, fmt:format(...))
end
function ParseState:switch(parser)
self.parser = parser
if parser._action then
table.insert(self.command_actions, {action = parser._action, name = parser._name})
end
for _, option in ipairs(parser._options) do
option = ElementState(self, option)
table.insert(self.options, option)
for _, alias in ipairs(option.element._aliases) do
self.options[alias] = option
end
end
for _, mutex in ipairs(parser._mutexes) do
for _, option in ipairs(mutex) do
if not self.element_to_mutexes[option] then
self.element_to_mutexes[option] = {}
end
table.insert(self.element_to_mutexes[option], mutex)
end
end
for _, argument in ipairs(parser._arguments) do
argument = ElementState(self, argument)
table.insert(self.arguments, argument)
argument:invoke()
end
self.handle_options = parser._handle_options
self.argument = self.arguments[self.argument_i]
self.commands = parser._commands
for _, command in ipairs(self.commands) do
for _, alias in ipairs(command._aliases) do
self.commands[alias] = command
end
end
end
function ParseState:get_option(name)
local option = self.options[name]
if not option then
self:error("unknown option '%s'%s", name, get_tip(self.options, name))
else
return option
end
end
function ParseState:get_command(name)
local command = self.commands[name]
if not command then
if #self.commands > 0 then
self:error("unknown command '%s'%s", name, get_tip(self.commands, name))
else
self:error("too many arguments")
end
else
return command
end
end
function ParseState:invoke(option, name)
self:close()
if self.element_to_mutexes[option.element] then
for _, mutex in ipairs(self.element_to_mutexes[option.element]) do
local used_option = self.mutex_to_used_option[mutex]
if used_option and used_option ~= option then
self:error("option '%s' can not be used together with %s", name, used_option.name)
else
self.mutex_to_used_option[mutex] = option
end
end
end
if option:invoke(name) then
self.option = option
end
end
function ParseState:pass(arg)
if self.option then
if not self.option:pass(arg) then
self.option = nil
end
elseif self.argument then
if not self.argument:pass(arg) then
self.argument_i = self.argument_i + 1
self.argument = self.arguments[self.argument_i]
end
else
local command = self:get_command(arg)
self.result[command._target or command._name] = true
if self.parser._command_target then
self.result[self.parser._command_target] = command._name
end
self:switch(command)
end
end
function ParseState:close()
if self.option then
self.option:close()
self.option = nil
end
end
function ParseState:finalize()
self:close()
for i = self.argument_i, #self.arguments do
local argument = self.arguments[i]
if #argument.args == 0 and argument:default("u") then
argument:complete_invocation()
else
argument:close()
end
end
if self.parser._require_command and #self.commands > 0 then
self:error("a command is required")
end
for _, option in ipairs(self.options) do
local name = option.name or ("option '%s'"):format(option.element._name)
if option.invocations == 0 then
if option:default("u") then
option:invoke(name)
option:complete_invocation()
option:close()
end
end
local mincount = option.element._mincount
if option.invocations < mincount then
if option:default("a") then
while option.invocations < mincount do
option:invoke(name)
option:close()
end
elseif option.invocations == 0 then
self:error("missing %s", name)
else
self:error("%s must be used %s", name, bound("time", mincount, option.element._maxcount))
end
end
end
for i = #self.command_actions, 1, -1 do
self.command_actions[i].action(self.result, self.command_actions[i].name)
end
end
function ParseState:parse(args)
for _, arg in ipairs(args) do
local plain = true
if self.handle_options then
local first = arg:sub(1, 1)
if self.charset[first] then
if #arg > 1 then
plain = false
if arg:sub(2, 2) == first then
if #arg == 2 then
self:close()
self.handle_options = false
else
local equals = arg:find "="
if equals then
local name = arg:sub(1, equals - 1)
local option = self:get_option(name)
if option.element._maxargs <= 0 then
self:error("option '%s' does not take arguments", name)
end
self:invoke(option, name)
self:pass(arg:sub(equals + 1))
else
local option = self:get_option(arg)
self:invoke(option, arg)
end
end
else
for i = 2, #arg do
local name = first .. arg:sub(i, i)
local option = self:get_option(name)
self:invoke(option, name)
if i ~= #arg and option.element._maxargs > 0 then
self:pass(arg:sub(i + 1))
break
end
end
end
end
end
end
if plain then
self:pass(arg)
end
end
self:finalize()
return self.result
end
function Parser:error(msg)
io.stderr:write(("%s\n\nError: %s\n"):format(self:get_usage(), msg))
os.exit(1)
end
-- Compatibility with strict.lua and other checkers:
local default_cmdline = rawget(_G, "arg") or {}
function Parser:_parse(args, error_handler)
return ParseState(self, error_handler):parse(args or default_cmdline)
end
function Parser:parse(args)
return self:_parse(args, self.error)
end
local function xpcall_error_handler(err)
return tostring(err) .. "\noriginal " .. debug.traceback("", 2):sub(2)
end
function Parser:pparse(args)
local parse_error
local ok, result = xpcall(function()
return self:_parse(args, function(_, err)
parse_error = err
error(err, 0)
end)
end, xpcall_error_handler)
if ok then
return true, result
elseif not parse_error then
error(result, 0)
else
return false, parse_error
end
end
return function(...)
return Parser(default_cmdline[0]):add_help(true)(...)
end
function string.starts(String,Start)
return string.sub(String,1,string.len(Start))==Start
end
function string.ends(String,End)
return End=='' or string.sub(String,-string.len(End))==End
end
function string.escape(s)
return s:gsub('[%-%.%+%[%]%(%)%$%^%%%?%*]','%%%1')
end
--- split a string into a list of strings separated by a delimiter.
-- @param s The input string
-- @param re A Lua string pattern; defaults to '%s+'
-- @param plain don't use Lua patterns
-- @param n optional maximum number of splits
-- @return a list-like table
-- @raise error if s is not a string
function string.split(s,re,plain,n)
local find,sub,append = string.find, string.sub, table.insert
local i1,ls = 1,{}
if not re then re = '%s+' end
if re == '' then return {s} end
while true do
local i2,i3 = find(s,re,i1,plain)
if not i2 then
local last = sub(s,i1)
if last ~= '' then append(ls,last) end
if #ls == 1 and ls[1] == '' then
return {}
else
return ls
end
end
append(ls,sub(s,i1,i2-1))
if n and #ls == n then
ls[#ls] = sub(s,i1)
return ls
end
i1 = i3+1
end
end
function table.count(T)
local count = 0
for _ in pairs(T) do count = count + 1 end
return count
end
function table.bsearch(list, value, mkval)
local low = 1
local high = #list
while low <= high do
local mid = math.floor((low+high)/2)
local this = mkval and mkval(list[mid]) or list[mid]
if this > value then
high = mid - 1
elseif this < value then
low = mid + 1
else
return mid
end
end
return nil
end
function table.join(a, b)
assert(a)
if b == nil or #b == 0 then
return a
end
local res = {}
for _, v in ipairs(a) do
table.insert(res, v)
end
for _, v in ipairs(b) do
table.insert(res, v)
end
return res
end
function table.build(iterator_fn, build_fn)
build_fn = (build_fn or function(arg) return arg end)
local res = {}
while true do
local vars = {iterator_fn()}
if vars[1] == nil then break end
table.insert(res, build_fn(vars))
end
return res
end
function table.values(T)
local V = {}
for k, v in pairs(T) do
table.insert(V, v)
end
return V
end
function table.tuples(T)
local i = 0
local n = table.getn(t)
return function ()
i = i + 1
if i <= n then return t[i][1], t[i][2] end
end
end
getmetatable("").__mod = function(a, b)
if not b then
return a
elseif type(b) == "table" then
return string.format(a, unpack(b))
else
return string.format(a, b)
end
end
function os.exists(path)
local f=io.open(path,"r")
if f~=nil then
io.close(f)
return true
else
return false
end
end
function os.spawn(...)
local cmd = string.format(...)
local proc = assert(io.popen(cmd))
local out = proc:read("*a")
proc:close()
return out
end
local function logline(...)
if not log.enabled then
return
end
local c_green = "\27[32m"
local c_grey = "\27[1;30m"
local c_clear = "\27[0m"
local msg = string.format(...)
local info = debug.getinfo(2, "Sln")
local line = string.format("%s[%s:%s]%s %s", c_grey,
info.short_src:match("^.+/(.+)$"), info.currentline, c_clear, info.name)
io.stderr:write(
string.format("%s[%s]%s %s: %s\n", c_green,
os.date("%H:%M:%S"), c_clear, line, msg))
end
log = { info = logline, enabled = true }
--[[ json.lua
A compact pure-Lua JSON library.
This code is in the public domain:
https://gist.github.com/tylerneylon/59f4bcf316be525b30ab
The main functions are: json.stringify, json.parse.
## json.stringify:
This expects the following to be true of any tables being encoded:
* They only have string or number keys. Number keys must be represented as
strings in json; this is part of the json spec.
* They are not recursive. Such a structure cannot be specified in json.
A Lua table is considered to be an array if and only if its set of keys is a
consecutive sequence of positive integers starting at 1. Arrays are encoded like
so: `[2, 3, false, "hi"]`. Any other type of Lua table is encoded as a json
object, encoded like so: `{"key1": 2, "key2": false}`.
Because the Lua nil value cannot be a key, and as a table value is considerd
equivalent to a missing key, there is no way to express the json "null" value in
a Lua table. The only way this will output "null" is if your entire input obj is
nil itself.
An empty Lua table, {}, could be considered either a json object or array -
it's an ambiguous edge case. We choose to treat this as an object as it is the
more general type.
To be clear, none of the above considerations is a limitation of this code.
Rather, it is what we get when we completely observe the json specification for
as arbitrary a Lua object as json is capable of expressing.
## json.parse:
This function parses json, with the exception that it does not pay attention to
\u-escaped unicode code points in strings.
It is difficult for Lua to return null as a value. In order to prevent the loss
of keys with a null value in a json string, this function uses the one-off
table value json.null (which is just an empty table) to indicate null values.
This way you can check if a value is null with the conditional
`val == json.null`.
If you have control over the data and are using Lua, I would recommend just
avoiding null values in your data to begin with.
--]]
local json = {}
-- Internal functions.
local function kind_of(obj)
if type(obj) ~= 'table' then return type(obj) end
local i = 1
for _ in pairs(obj) do
if obj[i] ~= nil then i = i + 1 else return 'table' end
end
if i == 1 then return 'table' else return 'array' end
end
local function escape_str(s)
local in_char = {'\\', '"', '/', '\b', '\f', '\n', '\r', '\t'}
local out_char = {'\\', '"', '/', 'b', 'f', 'n', 'r', 't'}
for i, c in ipairs(in_char) do
s = s:gsub(c, '\\' .. out_char[i])
end
return s
end
-- Returns pos, did_find; there are two cases:
-- 1. Delimiter found: pos = pos after leading space + delim; did_find = true.
-- 2. Delimiter not found: pos = pos after leading space; did_find = false.
-- This throws an error if err_if_missing is true and the delim is not found.
local function skip_delim(str, pos, delim, err_if_missing)
pos = pos + #str:match('^%s*', pos)
if str:sub(pos, pos) ~= delim then
if err_if_missing then
error('Expected ' .. delim .. ' near position ' .. pos)
end
return pos, false
end
return pos + 1, true
end
-- Expects the given pos to be the first character after the opening quote.
-- Returns val, pos; the returned pos is after the closing quote character.
local function parse_str_val(str, pos, val)
val = val or ''
local early_end_error = 'End of input found while parsing string.'
if pos > #str then error(early_end_error) end
local c = str:sub(pos, pos)
if c == '"' then return val, pos + 1 end
if c ~= '\\' then return parse_str_val(str, pos + 1, val .. c) end
-- We must have a \ character.
local esc_map = {b = '\b', f = '\f', n = '\n', r = '\r', t = '\t'}
local nextc = str:sub(pos + 1, pos + 1)
if not nextc then error(early_end_error) end
return parse_str_val(str, pos + 2, val .. (esc_map[nextc] or nextc))
end
-- Returns val, pos; the returned pos is after the number's final character.
local function parse_num_val(str, pos)
local num_str = str:match('^-?%d+%.?%d*[eE]?[+-]?%d*', pos)
local val = tonumber(num_str)
if not val then error('Error parsing number at position ' .. pos .. '.') end
return val, pos + #num_str
end
-- Public values and functions.
function json.stringify(obj, as_key)
local s = {} -- We'll build the string as an array of strings to be concatenated.
local kind = kind_of(obj) -- This is 'array' if it's an array or type(obj) otherwise.
if kind == 'array' then
if as_key then error('Can\'t encode array as key.') end
s[#s + 1] = '['
for i, val in ipairs(obj) do
if i > 1 then s[#s + 1] = ', ' end
s[#s + 1] = json.stringify(val)
end
s[#s + 1] = ']'
elseif kind == 'table' then
if as_key then error('Can\'t encode table as key.') end
s[#s + 1] = '{'
for k, v in pairs(obj) do
if #s > 1 then s[#s + 1] = ', ' end
s[#s + 1] = json.stringify(k, true)
s[#s + 1] = ':'
s[#s + 1] = json.stringify(v)
end
s[#s + 1] = '}'
elseif kind == 'string' then
return '"' .. escape_str(obj) .. '"'
elseif kind == 'number' then
if as_key then return '"' .. tostring(obj) .. '"' end
return tostring(obj)
elseif kind == 'boolean' then
return tostring(obj)
elseif kind == 'nil' then
return 'null'
else
error('Unjsonifiable type: ' .. kind .. '.')
end
return table.concat(s)
end
json.null = {} -- This is a one-off table to represent the null value.
function json.parse(str, pos, end_delim)
pos = pos or 1
if pos > #str then error('Reached unexpected end of input.') end
local pos = pos + #str:match('^%s*', pos) -- Skip whitespace.
local first = str:sub(pos, pos)
if first == '{' then -- Parse an object.
local obj, key, delim_found = {}, true, true
pos = pos + 1
while true do
key, pos = json.parse(str, pos, '}')
if key == nil then return obj, pos end
if not delim_found then error('Comma missing between object items.') end
pos = skip_delim(str, pos, ':', true) -- true -> error if missing.
obj[key], pos = json.parse(str, pos)
pos, delim_found = skip_delim(str, pos, ',')
end
elseif first == '[' then -- Parse an array.
local arr, val, delim_found = {}, true, true
pos = pos + 1
while true do
val, pos = json.parse(str, pos, ']')
if val == nil then return arr, pos end
if not delim_found then error('Comma missing between array items.') end
arr[#arr + 1] = val
pos, delim_found = skip_delim(str, pos, ',')
end
elseif first == '"' then -- Parse a string.
return parse_str_val(str, pos + 1)
elseif first == '-' or first:match('%d') then -- Parse a number.
return parse_num_val(str, pos)
elseif first == end_delim then -- End of an object or array.
return nil, pos + 1
else -- Parse true, false, or null.
local literals = {['true'] = true, ['false'] = false, ['null'] = json.null}
for lit_str, lit_val in pairs(literals) do
local lit_end = pos + #lit_str - 1
if str:sub(pos, lit_end) == lit_str then return lit_val, lit_end + 1 end
end
local pos_info_str = 'position ' .. pos .. ': ' .. str:sub(pos, pos + 10)
error('Invalid json syntax starting at ' .. pos_info_str)
end
end
return json
local middleclass = {
_VERSION = 'middleclass v4.0.0',
_DESCRIPTION = 'Object Orientation for Lua',
_URL = 'https://github.com/kikito/middleclass',
_LICENSE = [[
MIT LICENSE
Copyright (c) 2011 Enrique García Cota
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be included
in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
]]
}
local function _createIndexWrapper(aClass, f)
if f == nil then
return aClass.__instanceDict
else
return function(self, name)
local value = aClass.__instanceDict[name]
if value ~= nil then
return value
elseif type(f) == "function" then
return (f(self, name))
else
return f[name]
end
end
end
end
local function _propagateInstanceMethod(aClass, name, f)
f = name == "__index" and _createIndexWrapper(aClass, f) or f
aClass.__instanceDict[name] = f
for subclass in pairs(aClass.subclasses) do
if rawget(subclass.__declaredMethods, name) == nil then
_propagateInstanceMethod(subclass, name, f)
end
end
end
local function _declareInstanceMethod(aClass, name, f)
aClass.__declaredMethods[name] = f
if f == nil and aClass.super then
f = aClass.super.__instanceDict[name]
end
_propagateInstanceMethod(aClass, name, f)
end
local function _tostring(self) return "class " .. self.name end
local function _call(self, ...) return self:new(...) end
local function _createClass(name, super)
local dict = {}
dict.__index = dict
local aClass = { name = name, super = super, static = {},
__instanceDict = dict, __declaredMethods = {},
subclasses = setmetatable({}, {__mode='k'}) }
if super then
setmetatable(aClass.static, { __index = function(_,k) return rawget(dict,k) or super.static[k] end })
else
setmetatable(aClass.static, { __index = function(_,k) return rawget(dict,k) end })
end
setmetatable(aClass, { __index = aClass.static, __tostring = _tostring,
__call = _call, __newindex = _declareInstanceMethod })
return aClass
end
local function _includeMixin(aClass, mixin)
assert(type(mixin) == 'table', "mixin must be a table")
for name,method in pairs(mixin) do
if name ~= "included" and name ~= "static" then aClass[name] = method end
end
for name,method in pairs(mixin.static or {}) do
aClass.static[name] = method
end
if type(mixin.included)=="function" then mixin:included(aClass) end
return aClass
end
local DefaultMixin = {
__tostring = function(self) return "instance of " .. tostring(self.class) end,
initialize = function(self, ...) end,
isInstanceOf = function(self, aClass)
return type(self) == 'table' and
type(self.class) == 'table' and
type(aClass) == 'table' and
( aClass == self.class or
type(aClass.isSubclassOf) == 'function' and
self.class:isSubclassOf(aClass) )
end,
static = {
allocate = function(self)
assert(type(self) == 'table', "Make sure that you are using 'Class:allocate' instead of 'Class.allocate'")
return setmetatable({ class = self }, self.__instanceDict)
end,
new = function(self, ...)
assert(type(self) == 'table', "Make sure that you are using 'Class:new' instead of 'Class.new'")
local instance = self:allocate()
instance:initialize(...)
return instance
end,
subclass = function(self, name)
assert(type(self) == 'table', "Make sure that you are using 'Class:subclass' instead of 'Class.subclass'")
assert(type(name) == "string", "You must provide a name(string) for your class")
local subclass = _createClass(name, self)
for methodName, f in pairs(self.__instanceDict) do
_propagateInstanceMethod(subclass, methodName, f)
end
subclass.initialize = function(instance, ...) return self.initialize(instance, ...) end
self.subclasses[subclass] = true
self:subclassed(subclass)
return subclass
end,
subclassed = function(self, other) end,
isSubclassOf = function(self, other)
return type(other) == 'table' and
type(self) == 'table' and
type(self.super) == 'table' and
( self.super == other or
type(self.super.isSubclassOf) == 'function' and
self.super:isSubclassOf(other) )
end,
include = function(self, ...)
assert(type(self) == 'table', "Make sure you that you are using 'Class:include' instead of 'Class.include'")
for _,mixin in ipairs({...}) do _includeMixin(self, mixin) end
return self
end
}
}
function middleclass.class(name, super)
assert(type(name) == 'string', "A name (string) is needed for the new class")
return super and super:subclass(name) or _includeMixin(_createClass(name), DefaultMixin)
end
setmetatable(middleclass, { __call = function(_, ...) return middleclass.class(...) end })
return middleclass
--[[
Copyright 2016 GitHub, Inc
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
]]
local ffi = require("ffi")
ffi.cdef[[
typedef int clockid_t;
typedef long time_t;
struct timespec {
time_t tv_sec;
long tv_nsec;
};
int clock_gettime(clockid_t clk_id, struct timespec *tp);
int clock_nanosleep(clockid_t clock_id, int flags,
const struct timespec *request, struct timespec *remain);
int get_nprocs(void);
]]
local CLOCK = {
REALTIME = 0,
MONOTONIC = 1,
PROCESS_CPUTIME_ID = 2,
THREAD_CPUTIME_ID = 3,
MONOTONIC_RAW = 4,
REALTIME_COARSE = 5,
MONOTONIC_COARSE = 6,
}
local function time_ns(clock)
local ts = ffi.new("struct timespec[1]")
assert(ffi.C.clock_gettime(clock or CLOCK.MONOTONIC_RAW, ts) == 0,
"clock_gettime() failed: "..ffi.errno())
return tonumber(ts[0].tv_sec * 1e9 + ts[0].tv_nsec)
end
local function sleep(seconds, clock)
local s, ns = math.modf(seconds)
local ts = ffi.new("struct timespec[1]")
ts[0].tv_sec = s
ts[0].tv_nsec = ns / 1e9
ffi.C.clock_nanosleep(clock or CLOCK.MONOTONIC, 0, ts, nil)
end
local function cpu_count()
return tonumber(ffi.C.get_nprocs())
end
return {
time_ns=time_ns,
sleep=sleep,
CLOCK=CLOCK,
cpu_count=cpu_count,
}
--[[
Copyright (C) 2009 Steve Donovan, David Manura.
This code is licensed under the MIT License
https://github.com/stevedonovan/Penlight
--]]
--- Checks uses of undeclared global variables.
-- All global variables must be 'declared' through a regular assignment
-- (even assigning `nil` will do) in a main chunk before being used
-- anywhere or assigned to inside a function. Existing metatables `__newindex` and `__index`
-- metamethods are respected.
--
-- You can set any table to have strict behaviour using `strict.module`. Creating a new
-- module with `strict.closed_module` makes the module immune to monkey-patching, if
-- you don't wish to encourage monkey business.
--
-- If the global `PENLIGHT_NO_GLOBAL_STRICT` is defined, then this module won't make the
-- global environment strict - if you just want to explicitly set table strictness.
--
-- @module pl.strict
require 'debug' -- for Lua 5.2
local getinfo, error, rawset, rawget = debug.getinfo, error, rawset, rawget
local strict = {}
local function what ()
local d = getinfo(3, "S")
return d and d.what or "C"
end
--- make an existing table strict.
-- @string name name of table (optional)
-- @tab[opt] mod table - if `nil` then we'll return a new table
-- @tab[opt] predeclared - table of variables that are to be considered predeclared.
-- @return the given table, or a new table
function strict.module (name,mod,predeclared)
local mt, old_newindex, old_index, old_index_type, global, closed
if predeclared then
global = predeclared.__global
closed = predeclared.__closed
end
if type(mod) == 'table' then
mt = getmetatable(mod)
if mt and rawget(mt,'__declared') then return end -- already patched...
else
mod = {}
end
if mt == nil then
mt = {}
setmetatable(mod, mt)
else
old_newindex = mt.__newindex
old_index = mt.__index
old_index_type = type(old_index)
end
mt.__declared = predeclared or {}
mt.__newindex = function(t, n, v)
if old_newindex then
old_newindex(t, n, v)
if rawget(t,n)~=nil then return end
end
if not mt.__declared[n] then
if global then
local w = what()
if w ~= "main" and w ~= "C" then
error("assign to undeclared global '"..n.."'", 2)
end
end
mt.__declared[n] = true
end
rawset(t, n, v)
end
mt.__index = function(t,n)
if not mt.__declared[n] and what() ~= "C" then
if old_index then
if old_index_type == "table" then
local fallback = old_index[n]
if fallback ~= nil then
return fallback
end
else
local res = old_index(t, n)
if res then return res end
end
end
local msg = "variable '"..n.."' is not declared"
if name then
msg = msg .. " in '"..name.."'"
end
error(msg, 2)
end
return rawget(t, n)
end
return mod
end
--- make all tables in a table strict.
-- So `strict.make_all_strict(_G)` prevents monkey-patching
-- of any global table
-- @tab T
function strict.make_all_strict (T)
for k,v in pairs(T) do
if type(v) == 'table' and v ~= T then
strict.module(k,v)
end
end
end
--- make a new module table which is closed to further changes.
function strict.closed_module (mod,name)
local M = {}
mod = mod or {}
local mt = getmetatable(mod)
if not mt then
mt = {}
setmetatable(mod,mt)
end
mt.__newindex = function(t,k,v)
M[k] = v
end
return strict.module(name,M)
end
if not rawget(_G,'PENLIGHT_NO_GLOBAL_STRICT') then
strict.module(nil,_G,{_PROMPT=true,__global=true})
end
return strict
......@@ -6,3 +6,4 @@ set(TEST_WRAPPER ${CMAKE_CURRENT_BINARY_DIR}/wrapper.sh)
add_subdirectory(cc)
add_subdirectory(python)
add_subdirectory(lua)
find_program(LUAJIT luajit)
if(LUAJIT)
add_test(NAME lua_test_clang WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
COMMAND ${TEST_WRAPPER} lua_test_clang sudo ${LUAJIT} test_clang.lua)
add_test(NAME lua_test_uprobes WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
COMMAND ${TEST_WRAPPER} lua_test_uprobes sudo ${LUAJIT} test_uprobes.lua)
add_test(NAME lua_test_dump WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
COMMAND ${TEST_WRAPPER} lua_test_dump sudo ${LUAJIT} test_dump.lua)
endif()
--[[
luaunit.lua
Description: A unit testing framework
Homepage: https://github.com/bluebird75/luaunit
Development by Philippe Fremy <phil@freehackers.org>
Based on initial work of Ryu, Gwang (http://www.gpgstudy.com/gpgiki/LuaUnit)
License: BSD License, see LICENSE.txt
Version: 3.2
]]--
require("math")
local M={}
-- private exported functions (for testing)
M.private = {}
M.VERSION='3.2'
--[[ Some people like assertEquals( actual, expected ) and some people prefer
assertEquals( expected, actual ).
]]--
M.ORDER_ACTUAL_EXPECTED = true
M.PRINT_TABLE_REF_IN_ERROR_MSG = false
M.TABLE_EQUALS_KEYBYCONTENT = true
M.LINE_LENGTH=80
-- set this to false to debug luaunit
local STRIP_LUAUNIT_FROM_STACKTRACE=true
M.VERBOSITY_DEFAULT = 10
M.VERBOSITY_LOW = 1
M.VERBOSITY_QUIET = 0
M.VERBOSITY_VERBOSE = 20
-- set EXPORT_ASSERT_TO_GLOBALS to have all asserts visible as global values
-- EXPORT_ASSERT_TO_GLOBALS = true
-- we need to keep a copy of the script args before it is overriden
local cmdline_argv = rawget(_G, "arg")
M.FAILURE_PREFIX = 'LuaUnit test FAILURE: ' -- prefix string for failed tests
M.USAGE=[[Usage: lua <your_test_suite.lua> [options] [testname1 [testname2] ... ]
Options:
-h, --help: Print this help
--version: Print version information
-v, --verbose: Increase verbosity
-q, --quiet: Set verbosity to minimum
-e, --error: Stop on first error
-f, --failure: Stop on first failure or error
-o, --output OUTPUT: Set output type to OUTPUT
Possible values: text, tap, junit, nil
-n, --name NAME: For junit only, mandatory name of xml file
-p, --pattern PATTERN: Execute all test names matching the Lua PATTERN
May be repeated to include severals patterns
Make sure you escape magic chars like +? with %
testname1, testname2, ... : tests to run in the form of testFunction,
TestClass or TestClass.testMethod
]]
----------------------------------------------------------------
--
-- general utility functions
--
----------------------------------------------------------------
local crossTypeOrdering = {
number = 1,
boolean = 2,
string = 3,
table = 4,
other = 5
}
local crossTypeComparison = {
number = function(a, b) return a < b end,
string = function(a, b) return a < b end,
other = function(a, b) return tostring(a) < tostring(b) end,
}
local function crossTypeSort(a, b)
local type_a, type_b = type(a), type(b)
if type_a == type_b then
local func = crossTypeComparison[type_a] or crossTypeComparison.other
return func(a, b)
end
type_a = crossTypeOrdering[type_a] or crossTypeOrdering.other
type_b = crossTypeOrdering[type_b] or crossTypeOrdering.other
return type_a < type_b
end
local function __genSortedIndex( t )
-- Returns a sequence consisting of t's keys, sorted.
local sortedIndex = {}
for key,_ in pairs(t) do
table.insert(sortedIndex, key)
end
table.sort(sortedIndex, crossTypeSort)
return sortedIndex
end
M.private.__genSortedIndex = __genSortedIndex
local function sortedNext(state, control)
-- Equivalent of the next() function of table iteration, but returns the
-- keys in sorted order (see __genSortedIndex and crossTypeSort).
-- The state is a temporary variable during iteration and contains the
-- sorted key table (state.sortedIdx). It also stores the last index (into
-- the keys) used by the iteration, to find the next one quickly.
local key
--print("sortedNext: control = "..tostring(control) )
if control == nil then
-- start of iteration
state.lastIdx = 1
key = state.sortedIdx[1]
return key, state.t[key]
end
-- normally, we expect the control variable to match the last key used
if control ~= state.sortedIdx[state.lastIdx] then
-- strange, we have to find the next value by ourselves
-- the key table is sorted in crossTypeSort() order! -> use bisection
local count = #state.sortedIdx
local lower, upper = 1, count
repeat
state.lastIdx = math.modf((lower + upper) / 2)
key = state.sortedIdx[state.lastIdx]
if key == control then break; end -- key found (and thus prev index)
if crossTypeSort(key, control) then
-- key < control, continue search "right" (towards upper bound)
lower = state.lastIdx + 1
else
-- key > control, continue search "left" (towards lower bound)
upper = state.lastIdx - 1
end
until lower > upper
if lower > upper then -- only true if the key wasn't found, ...
state.lastIdx = count -- ... so ensure no match for the code below
end
end
-- proceed by retrieving the next value (or nil) from the sorted keys
state.lastIdx = state.lastIdx + 1
key = state.sortedIdx[state.lastIdx]
if key then
return key, state.t[key]
end
-- getting here means returning `nil`, which will end the iteration
end
local function sortedPairs(tbl)
-- Equivalent of the pairs() function on tables. Allows to iterate in
-- sorted order. As required by "generic for" loops, this will return the
-- iterator (function), an "invariant state", and the initial control value.
-- (see http://www.lua.org/pil/7.2.html)
return sortedNext, {t = tbl, sortedIdx = __genSortedIndex(tbl)}, nil
end
M.private.sortedPairs = sortedPairs
local function strsplit(delimiter, text)
-- Split text into a list consisting of the strings in text,
-- separated by strings matching delimiter (which may be a pattern).
-- example: strsplit(",%s*", "Anna, Bob, Charlie,Dolores")
if string.find("", delimiter, 1, true) then -- this would result in endless loops
error("delimiter matches empty string!")
end
local list, pos, first, last = {}, 1
while true do
first, last = text:find(delimiter, pos, true)
if first then -- found?
table.insert(list, text:sub(pos, first - 1))
pos = last + 1
else
table.insert(list, text:sub(pos))
break
end
end
return list
end
M.private.strsplit = strsplit
local function hasNewLine( s )
-- return true if s has a newline
return (string.find(s, '\n', 1, true) ~= nil)
end
M.private.hasNewLine = hasNewLine
local function prefixString( prefix, s )
-- Prefix all the lines of s with prefix
return prefix .. table.concat(strsplit('\n', s), '\n' .. prefix)
end
M.private.prefixString = prefixString
local function strMatch(s, pattern, start, final )
-- return true if s matches completely the pattern from index start to index end
-- return false in every other cases
-- if start is nil, matches from the beginning of the string
-- if final is nil, matches to the end of the string
start = start or 1
final = final or string.len(s)
local foundStart, foundEnd = string.find(s, pattern, start, false)
return foundStart == start and foundEnd == final
end
M.private.strMatch = strMatch
local function xmlEscape( s )
-- Return s escaped for XML attributes
-- escapes table:
-- " &quot;
-- ' &apos;
-- < &lt;
-- > &gt;
-- & &amp;
return string.gsub( s, '.', {
['&'] = "&amp;",
['"'] = "&quot;",
["'"] = "&apos;",
['<'] = "&lt;",
['>'] = "&gt;",
} )
end
M.private.xmlEscape = xmlEscape
local function xmlCDataEscape( s )
-- Return s escaped for CData section, escapes: "]]>"
return string.gsub( s, ']]>', ']]&gt;' )
end
M.private.xmlCDataEscape = xmlCDataEscape
local function stripLuaunitTrace( stackTrace )
--[[
-- Example of a traceback:
<<stack traceback:
example_with_luaunit.lua:130: in function 'test2_withFailure'
./luaunit.lua:1449: in function <./luaunit.lua:1449>
[C]: in function 'xpcall'
./luaunit.lua:1449: in function 'protectedCall'
./luaunit.lua:1508: in function 'execOneFunction'
./luaunit.lua:1596: in function 'runSuiteByInstances'
./luaunit.lua:1660: in function 'runSuiteByNames'
./luaunit.lua:1736: in function 'runSuite'
example_with_luaunit.lua:140: in main chunk
[C]: in ?>>
Other example:
<<stack traceback:
./luaunit.lua:545: in function 'assertEquals'
example_with_luaunit.lua:58: in function 'TestToto.test7'
./luaunit.lua:1517: in function <./luaunit.lua:1517>
[C]: in function 'xpcall'
./luaunit.lua:1517: in function 'protectedCall'
./luaunit.lua:1578: in function 'execOneFunction'
./luaunit.lua:1677: in function 'runSuiteByInstances'
./luaunit.lua:1730: in function 'runSuiteByNames'
./luaunit.lua:1806: in function 'runSuite'
example_with_luaunit.lua:140: in main chunk
[C]: in ?>>
<<stack traceback:
luaunit2/example_with_luaunit.lua:124: in function 'test1_withFailure'
luaunit2/luaunit.lua:1532: in function <luaunit2/luaunit.lua:1532>
[C]: in function 'xpcall'
luaunit2/luaunit.lua:1532: in function 'protectedCall'
luaunit2/luaunit.lua:1591: in function 'execOneFunction'
luaunit2/luaunit.lua:1679: in function 'runSuiteByInstances'
luaunit2/luaunit.lua:1743: in function 'runSuiteByNames'
luaunit2/luaunit.lua:1819: in function 'runSuite'
luaunit2/example_with_luaunit.lua:140: in main chunk
[C]: in ?>>
-- first line is "stack traceback": KEEP
-- next line may be luaunit line: REMOVE
-- next lines are call in the program under testOk: REMOVE
-- next lines are calls from luaunit to call the program under test: KEEP
-- Strategy:
-- keep first line
-- remove lines that are part of luaunit
-- kepp lines until we hit a luaunit line
]]
local function isLuaunitInternalLine( s )
-- return true if line of stack trace comes from inside luaunit
return s:find('[/\\]luaunit%.lua:%d+: ') ~= nil
end
-- print( '<<'..stackTrace..'>>' )
local t = strsplit( '\n', stackTrace )
-- print( prettystr(t) )
local idx = 2
-- remove lines that are still part of luaunit
while t[idx] and isLuaunitInternalLine( t[idx] ) do
-- print('Removing : '..t[idx] )
table.remove(t, idx)
end
-- keep lines until we hit luaunit again
while t[idx] and (not isLuaunitInternalLine(t[idx])) do
-- print('Keeping : '..t[idx] )
idx = idx + 1
end
-- remove remaining luaunit lines
while t[idx] do
-- print('Removing : '..t[idx] )
table.remove(t, idx)
end
-- print( prettystr(t) )
return table.concat( t, '\n')
end
M.private.stripLuaunitTrace = stripLuaunitTrace
local function prettystr_sub(v, indentLevel, keeponeline, printTableRefs, recursionTable )
local type_v = type(v)
if "string" == type_v then
if keeponeline then v = v:gsub("\n", "\\n") end
-- use clever delimiters according to content:
-- enclose with single quotes if string contains ", but no '
if v:find('"', 1, true) and not v:find("'", 1, true) then
return "'" .. v .. "'"
end
-- use double quotes otherwise, escape embedded "
return '"' .. v:gsub('"', '\\"') .. '"'
elseif "table" == type_v then
--if v.__class__ then
-- return string.gsub( tostring(v), 'table', v.__class__ )
--end
return M.private._table_tostring(v, indentLevel, printTableRefs, recursionTable)
end
return tostring(v)
end
local function prettystr( v, keeponeline )
--[[ Better string conversion, to display nice variable content:
For strings, if keeponeline is set to true, string is displayed on one line, with visible \n
* string are enclosed with " by default, or with ' if string contains a "
* if table is a class, display class name
* tables are expanded
]]--
local recursionTable = {}
local s = prettystr_sub(v, 1, keeponeline, M.PRINT_TABLE_REF_IN_ERROR_MSG, recursionTable)
if recursionTable.recursionDetected and not M.PRINT_TABLE_REF_IN_ERROR_MSG then
-- some table contain recursive references,
-- so we must recompute the value by including all table references
-- else the result looks like crap
recursionTable = {}
s = prettystr_sub(v, 1, keeponeline, true, recursionTable)
end
return s
end
M.prettystr = prettystr
local function prettystrPadded(value1, value2, suffix_a, suffix_b)
--[[
This function helps with the recurring task of constructing the "expected
vs. actual" error messages. It takes two arbitrary values and formats
corresponding strings with prettystr().
To keep the (possibly complex) output more readable in case the resulting
strings contain line breaks, they get automatically prefixed with additional
newlines. Both suffixes are optional (default to empty strings), and get
appended to the "value1" string. "suffix_a" is used if line breaks were
encountered, "suffix_b" otherwise.
Returns the two formatted strings (including padding/newlines).
]]
local str1, str2 = prettystr(value1), prettystr(value2)
if hasNewLine(str1) or hasNewLine(str2) then
-- line break(s) detected, add padding
return "\n" .. str1 .. (suffix_a or ""), "\n" .. str2
end
return str1 .. (suffix_b or ""), str2
end
M.private.prettystrPadded = prettystrPadded
local function _table_keytostring(k)
-- like prettystr but do not enclose with "" if the string is just alphanumerical
-- this is better for displaying table keys who are often simple strings
if "string" == type(k) and k:match("^[_%a][_%w]*$") then
return k
end
return prettystr(k)
end
M.private._table_keytostring = _table_keytostring
local TABLE_TOSTRING_SEP = ", "
local TABLE_TOSTRING_SEP_LEN = string.len(TABLE_TOSTRING_SEP)
local function _table_tostring( tbl, indentLevel, printTableRefs, recursionTable )
printTableRefs = printTableRefs or M.PRINT_TABLE_REF_IN_ERROR_MSG
recursionTable = recursionTable or {}
recursionTable[tbl] = true
local result, dispOnMultLines = {}, false
local entry, count, seq_index = nil, 0, 1
for k, v in sortedPairs( tbl ) do
if k == seq_index then
-- for the sequential part of tables, we'll skip the "<key>=" output
entry = ''
seq_index = seq_index + 1
else
entry = _table_keytostring( k ) .. "="
end
if recursionTable[v] then -- recursion detected!
recursionTable.recursionDetected = true
entry = entry .. "<"..tostring(v)..">"
else
entry = entry ..
prettystr_sub( v, indentLevel+1, true, printTableRefs, recursionTable )
end
count = count + 1
result[count] = entry
end
-- set dispOnMultLines if the maximum LINE_LENGTH would be exceeded
local totalLength = 0
for k, v in ipairs( result ) do
totalLength = totalLength + string.len( v )
if totalLength >= M.LINE_LENGTH then
dispOnMultLines = true
break
end
end
if not dispOnMultLines then
-- adjust with length of separator(s):
-- two items need 1 sep, three items two seps, ... plus len of '{}'
if count > 0 then
totalLength = totalLength + TABLE_TOSTRING_SEP_LEN * (count - 1)
end
dispOnMultLines = totalLength + 2 >= M.LINE_LENGTH
end
-- now reformat the result table (currently holding element strings)
if dispOnMultLines then
local indentString = string.rep(" ", indentLevel - 1)
result = {"{\n ", indentString,
table.concat(result, ",\n " .. indentString), "\n",
indentString, "}"}
else
result = {"{", table.concat(result, TABLE_TOSTRING_SEP), "}"}
end
if printTableRefs then
table.insert(result, 1, "<"..tostring(tbl).."> ") -- prepend table ref
end
return table.concat(result)
end
M.private._table_tostring = _table_tostring -- prettystr_sub() needs it
local function _table_contains(t, element)
if t then
for _, value in pairs(t) do
if type(value) == type(element) then
if type(element) == 'table' then
-- if we wanted recursive items content comparison, we could use
-- _is_table_items_equals(v, expected) but one level of just comparing
-- items is sufficient
if M.private._is_table_equals( value, element ) then
return true
end
else
if value == element then
return true
end
end
end
end
end
return false
end
local function _is_table_items_equals(actual, expected )
if (type(actual) == 'table') and (type(expected) == 'table') then
for k,v in pairs(actual) do
if not _table_contains(expected, v) then
return false
end
end
for k,v in pairs(expected) do
if not _table_contains(actual, v) then
return false
end
end
return true
elseif type(actual) ~= type(expected) then
return false
elseif actual == expected then
return true
end
return false
end
local function _is_table_equals(actual, expected)
if (type(actual) == 'table') and (type(expected) == 'table') then
if (#actual ~= #expected) then
return false
end
local actualTableKeys = {}
for k,v in pairs(actual) do
if M.TABLE_EQUALS_KEYBYCONTENT and type(k) == "table" then
-- If the keys are tables, things get a bit tricky here as we
-- can have _is_table_equals(k1, k2) and t[k1] ~= t[k2]. So we
-- collect actual's table keys, group them by length for
-- performance, and then for each table key in expected we look
-- it up in actualTableKeys.
if not actualTableKeys[#k] then actualTableKeys[#k] = {} end
table.insert(actualTableKeys[#k], k)
else
if not _is_table_equals(v, expected[k]) then
return false
end
end
end
for k,v in pairs(expected) do
if M.TABLE_EQUALS_KEYBYCONTENT and type(k) == "table" then
local candidates = actualTableKeys[#k]
if not candidates then return false end
local found
for i, candidate in pairs(candidates) do
if _is_table_equals(candidate, k) then
found = candidate
-- Remove the candidate we matched against from the list
-- of candidates, so each key in actual can only match
-- one key in expected.
candidates[i] = nil
break
end
end
if not(found and _is_table_equals(actual[found], v)) then return false end
else
if not _is_table_equals(v, actual[k]) then
return false
end
end
end
if M.TABLE_EQUALS_KEYBYCONTENT then
for _, keys in pairs(actualTableKeys) do
-- if there are any keys left in any actualTableKeys[i] then
-- that is a key in actual with no matching key in expected,
-- and so the tables aren't equal.
if next(keys) then return false end
end
end
return true
elseif type(actual) ~= type(expected) then
return false
elseif actual == expected then
return true
end
return false
end
M.private._is_table_equals = _is_table_equals
local function failure(msg, level)
-- raise an error indicating a test failure
-- for error() compatibility we adjust "level" here (by +1), to report the
-- calling context
error(M.FAILURE_PREFIX .. msg, (level or 1) + 1)
end
local function fail_fmt(level, ...)
-- failure with printf-style formatted message and given error level
failure(string.format(...), (level or 1) + 1)
end
M.private.fail_fmt = fail_fmt
local function error_fmt(level, ...)
-- printf-style error()
error(string.format(...), (level or 1) + 1)
end
----------------------------------------------------------------
--
-- assertions
--
----------------------------------------------------------------
local function errorMsgEquality(actual, expected)
if not M.ORDER_ACTUAL_EXPECTED then
expected, actual = actual, expected
end
if type(expected) == 'string' or type(expected) == 'table' then
expected, actual = prettystrPadded(expected, actual)
return string.format("expected: %s\nactual: %s", expected, actual)
end
return string.format("expected: %s, actual: %s",
prettystr(expected), prettystr(actual))
end
function M.assertError(f, ...)
-- assert that calling f with the arguments will raise an error
-- example: assertError( f, 1, 2 ) => f(1,2) should generate an error
if pcall( f, ... ) then
failure( "Expected an error when calling function but no error generated", 2 )
end
end
function M.assertTrue(value)
if not value then
failure("expected: true, actual: " ..prettystr(value), 2)
end
end
function M.assertFalse(value)
if value then
failure("expected: false, actual: " ..prettystr(value), 2)
end
end
function M.assertIsNil(value)
if value ~= nil then
failure("expected: nil, actual: " ..prettystr(value), 2)
end
end
function M.assertNotIsNil(value)
if value == nil then
failure("expected non nil value, received nil", 2)
end
end
function M.assertEquals(actual, expected)
if type(actual) == 'table' and type(expected) == 'table' then
if not _is_table_equals(actual, expected) then
failure( errorMsgEquality(actual, expected), 2 )
end
elseif type(actual) ~= type(expected) then
failure( errorMsgEquality(actual, expected), 2 )
elseif actual ~= expected then
failure( errorMsgEquality(actual, expected), 2 )
end
end
-- Help Lua in corner cases like almostEquals(1.1, 1.0, 0.1), which by default
-- may not work. We need to give margin a small boost; EPSILON defines the
-- default value to use for this:
local EPSILON = 0.00000000001
function M.almostEquals( actual, expected, margin, margin_boost )
if type(actual) ~= 'number' or type(expected) ~= 'number' or type(margin) ~= 'number' then
error_fmt(3, 'almostEquals: must supply only number arguments.\nArguments supplied: %s, %s, %s',
prettystr(actual), prettystr(expected), prettystr(margin))
end
if margin <= 0 then
error('almostEquals: margin must be positive, current value is ' .. margin, 3)
end
local realmargin = margin + (margin_boost or EPSILON)
return math.abs(expected - actual) <= realmargin
end
function M.assertAlmostEquals( actual, expected, margin )
-- check that two floats are close by margin
if not M.almostEquals(actual, expected, margin) then
if not M.ORDER_ACTUAL_EXPECTED then
expected, actual = actual, expected
end
fail_fmt(2, 'Values are not almost equal\nExpected: %s with margin of %s, received: %s',
expected, margin, actual)
end
end
function M.assertNotEquals(actual, expected)
if type(actual) ~= type(expected) then
return
end
if type(actual) == 'table' and type(expected) == 'table' then
if not _is_table_equals(actual, expected) then
return
end
elseif actual ~= expected then
return
end
fail_fmt(2, 'Received the not expected value: %s', prettystr(actual))
end
function M.assertNotAlmostEquals( actual, expected, margin )
-- check that two floats are not close by margin
if M.almostEquals(actual, expected, margin) then
if not M.ORDER_ACTUAL_EXPECTED then
expected, actual = actual, expected
end
fail_fmt(2, 'Values are almost equal\nExpected: %s with a difference above margin of %s, received: %s',
expected, margin, actual)
end
end
function M.assertStrContains( str, sub, useRe )
-- this relies on lua string.find function
-- a string always contains the empty string
if not string.find(str, sub, 1, not useRe) then
sub, str = prettystrPadded(sub, str, '\n')
fail_fmt(2, 'Error, %s %s was not found in string %s',
useRe and 'regexp' or 'substring', sub, str)
end
end
function M.assertStrIContains( str, sub )
-- this relies on lua string.find function
-- a string always contains the empty string
if not string.find(str:lower(), sub:lower(), 1, true) then
sub, str = prettystrPadded(sub, str, '\n')
fail_fmt(2, 'Error, substring %s was not found (case insensitively) in string %s',
sub, str)
end
end
function M.assertNotStrContains( str, sub, useRe )
-- this relies on lua string.find function
-- a string always contains the empty string
if string.find(str, sub, 1, not useRe) then
sub, str = prettystrPadded(sub, str, '\n')
fail_fmt(2, 'Error, %s %s was found in string %s',
useRe and 'regexp' or 'substring', sub, str)
end
end
function M.assertNotStrIContains( str, sub )
-- this relies on lua string.find function
-- a string always contains the empty string
if string.find(str:lower(), sub:lower(), 1, true) then
sub, str = prettystrPadded(sub, str, '\n')
fail_fmt(2, 'Error, substring %s was found (case insensitively) in string %s',
sub, str)
end
end
function M.assertStrMatches( str, pattern, start, final )
-- Verify a full match for the string
-- for a partial match, simply use assertStrContains with useRe set to true
if not strMatch( str, pattern, start, final ) then
pattern, str = prettystrPadded(pattern, str, '\n')
fail_fmt(2, 'Error, pattern %s was not matched by string %s',
pattern, str)
end
end
function M.assertErrorMsgEquals( expectedMsg, func, ... )
-- assert that calling f with the arguments will raise an error
-- example: assertError( f, 1, 2 ) => f(1,2) should generate an error
local no_error, error_msg = pcall( func, ... )
if no_error then
failure( 'No error generated when calling function but expected error: "'..expectedMsg..'"', 2 )
end
if error_msg ~= expectedMsg then
error_msg, expectedMsg = prettystrPadded(error_msg, expectedMsg)
fail_fmt(2, 'Exact error message expected: %s\nError message received: %s\n',
expectedMsg, error_msg)
end
end
function M.assertErrorMsgContains( partialMsg, func, ... )
-- assert that calling f with the arguments will raise an error
-- example: assertError( f, 1, 2 ) => f(1,2) should generate an error
local no_error, error_msg = pcall( func, ... )
if no_error then
failure( 'No error generated when calling function but expected error containing: '..prettystr(partialMsg), 2 )
end
if not string.find( error_msg, partialMsg, nil, true ) then
error_msg, partialMsg = prettystrPadded(error_msg, partialMsg)
fail_fmt(2, 'Error message does not contain: %s\nError message received: %s\n',
partialMsg, error_msg)
end
end
function M.assertErrorMsgMatches( expectedMsg, func, ... )
-- assert that calling f with the arguments will raise an error
-- example: assertError( f, 1, 2 ) => f(1,2) should generate an error
local no_error, error_msg = pcall( func, ... )
if no_error then
failure( 'No error generated when calling function but expected error matching: "'..expectedMsg..'"', 2 )
end
if not strMatch( error_msg, expectedMsg ) then
expectedMsg, error_msg = prettystrPadded(expectedMsg, error_msg)
fail_fmt(2, 'Error message does not match: %s\nError message received: %s\n',
expectedMsg, error_msg)
end
end
--[[
Add type assertion functions to the module table M. Each of these functions
takes a single parameter "value", and checks that its Lua type matches the
expected string (derived from the function name):
M.assertIsXxx(value) -> ensure that type(value) conforms to "xxx"
]]
for _, funcName in ipairs(
{'assertIsNumber', 'assertIsString', 'assertIsTable', 'assertIsBoolean',
'assertIsFunction', 'assertIsUserdata', 'assertIsThread'}
) do
local typeExpected = funcName:match("^assertIs([A-Z]%a*)$")
-- Lua type() always returns lowercase, also make sure the match() succeeded
typeExpected = typeExpected and typeExpected:lower()
or error("bad function name '"..funcName.."' for type assertion")
M[funcName] = function(value)
if type(value) ~= typeExpected then
fail_fmt(2, 'Expected: a %s value, actual: type %s, value %s',
typeExpected, type(value), prettystrPadded(value))
end
end
end
--[[
Add non-type assertion functions to the module table M. Each of these functions
takes a single parameter "value", and checks that its Lua type differs from the
expected string (derived from the function name):
M.assertNotIsXxx(value) -> ensure that type(value) is not "xxx"
]]
for _, funcName in ipairs(
{'assertNotIsNumber', 'assertNotIsString', 'assertNotIsTable', 'assertNotIsBoolean',
'assertNotIsFunction', 'assertNotIsUserdata', 'assertNotIsThread'}
) do
local typeUnexpected = funcName:match("^assertNotIs([A-Z]%a*)$")
-- Lua type() always returns lowercase, also make sure the match() succeeded
typeUnexpected = typeUnexpected and typeUnexpected:lower()
or error("bad function name '"..funcName.."' for type assertion")
M[funcName] = function(value)
if type(value) == typeUnexpected then
fail_fmt(2, 'Not expected: a %s type, actual: value %s',
typeUnexpected, prettystrPadded(value))
end
end
end
function M.assertIs(actual, expected)
if actual ~= expected then
if not M.ORDER_ACTUAL_EXPECTED then
actual, expected = expected, actual
end
expected, actual = prettystrPadded(expected, actual, '\n', ', ')
fail_fmt(2, 'Expected object and actual object are not the same\nExpected: %sactual: %s',
expected, actual)
end
end
function M.assertNotIs(actual, expected)
if actual == expected then
if not M.ORDER_ACTUAL_EXPECTED then
expected = actual
end
fail_fmt(2, 'Expected object and actual object are the same object: %s',
prettystrPadded(expected))
end
end
function M.assertItemsEquals(actual, expected)
-- checks that the items of table expected
-- are contained in table actual. Warning, this function
-- is at least O(n^2)
if not _is_table_items_equals(actual, expected ) then
expected, actual = prettystrPadded(expected, actual)
fail_fmt(2, 'Contents of the tables are not identical:\nExpected: %s\nActual: %s',
expected, actual)
end
end
----------------------------------------------------------------
-- Compatibility layer
----------------------------------------------------------------
-- for compatibility with LuaUnit v2.x
function M.wrapFunctions(...)
io.stderr:write( [[Use of WrapFunction() is no longer needed.
Just prefix your test function names with "test" or "Test" and they
will be picked up and run by LuaUnit.]] )
-- In LuaUnit version <= 2.1 , this function was necessary to include
-- a test function inside the global test suite. Nowadays, the functions
-- are simply run directly as part of the test discovery process.
-- so just do nothing !
--[[
local testClass, testFunction
testClass = {}
local function storeAsMethod(idx, testName)
testFunction = _G[testName]
testClass[testName] = testFunction
end
for i,v in ipairs({...}) do
storeAsMethod( i, v )
end
return testClass
]]
end
local list_of_funcs = {
-- { official function name , alias }
-- general assertions
{ 'assertEquals' , 'assert_equals' },
{ 'assertItemsEquals' , 'assert_items_equals' },
{ 'assertNotEquals' , 'assert_not_equals' },
{ 'assertAlmostEquals' , 'assert_almost_equals' },
{ 'assertNotAlmostEquals' , 'assert_not_almost_equals' },
{ 'assertTrue' , 'assert_true' },
{ 'assertFalse' , 'assert_false' },
{ 'assertStrContains' , 'assert_str_contains' },
{ 'assertStrIContains' , 'assert_str_icontains' },
{ 'assertNotStrContains' , 'assert_not_str_contains' },
{ 'assertNotStrIContains' , 'assert_not_str_icontains' },
{ 'assertStrMatches' , 'assert_str_matches' },
{ 'assertError' , 'assert_error' },
{ 'assertErrorMsgEquals' , 'assert_error_msg_equals' },
{ 'assertErrorMsgContains' , 'assert_error_msg_contains' },
{ 'assertErrorMsgMatches' , 'assert_error_msg_matches' },
{ 'assertIs' , 'assert_is' },
{ 'assertNotIs' , 'assert_not_is' },
{ 'wrapFunctions' , 'WrapFunctions' },
{ 'wrapFunctions' , 'wrap_functions' },
-- type assertions: assertIsXXX -> assert_is_xxx
{ 'assertIsNumber' , 'assert_is_number' },
{ 'assertIsString' , 'assert_is_string' },
{ 'assertIsTable' , 'assert_is_table' },
{ 'assertIsBoolean' , 'assert_is_boolean' },
{ 'assertIsNil' , 'assert_is_nil' },
{ 'assertIsFunction' , 'assert_is_function' },
{ 'assertIsThread' , 'assert_is_thread' },
{ 'assertIsUserdata' , 'assert_is_userdata' },
-- type assertions: assertIsXXX -> assertXxx
{ 'assertIsNumber' , 'assertNumber' },
{ 'assertIsString' , 'assertString' },
{ 'assertIsTable' , 'assertTable' },
{ 'assertIsBoolean' , 'assertBoolean' },
{ 'assertIsNil' , 'assertNil' },
{ 'assertIsFunction' , 'assertFunction' },
{ 'assertIsThread' , 'assertThread' },
{ 'assertIsUserdata' , 'assertUserdata' },
-- type assertions: assertIsXXX -> assert_xxx (luaunit v2 compat)
{ 'assertIsNumber' , 'assert_number' },
{ 'assertIsString' , 'assert_string' },
{ 'assertIsTable' , 'assert_table' },
{ 'assertIsBoolean' , 'assert_boolean' },
{ 'assertIsNil' , 'assert_nil' },
{ 'assertIsFunction' , 'assert_function' },
{ 'assertIsThread' , 'assert_thread' },
{ 'assertIsUserdata' , 'assert_userdata' },
-- type assertions: assertNotIsXXX -> assert_not_is_xxx
{ 'assertNotIsNumber' , 'assert_not_is_number' },
{ 'assertNotIsString' , 'assert_not_is_string' },
{ 'assertNotIsTable' , 'assert_not_is_table' },
{ 'assertNotIsBoolean' , 'assert_not_is_boolean' },
{ 'assertNotIsNil' , 'assert_not_is_nil' },
{ 'assertNotIsFunction' , 'assert_not_is_function' },
{ 'assertNotIsThread' , 'assert_not_is_thread' },
{ 'assertNotIsUserdata' , 'assert_not_is_userdata' },
-- type assertions: assertNotIsXXX -> assertNotXxx (luaunit v2 compat)
{ 'assertNotIsNumber' , 'assertNotNumber' },
{ 'assertNotIsString' , 'assertNotString' },
{ 'assertNotIsTable' , 'assertNotTable' },
{ 'assertNotIsBoolean' , 'assertNotBoolean' },
{ 'assertNotIsNil' , 'assertNotNil' },
{ 'assertNotIsFunction' , 'assertNotFunction' },
{ 'assertNotIsThread' , 'assertNotThread' },
{ 'assertNotIsUserdata' , 'assertNotUserdata' },
-- type assertions: assertNotIsXXX -> assert_not_xxx
{ 'assertNotIsNumber' , 'assert_not_number' },
{ 'assertNotIsString' , 'assert_not_string' },
{ 'assertNotIsTable' , 'assert_not_table' },
{ 'assertNotIsBoolean' , 'assert_not_boolean' },
{ 'assertNotIsNil' , 'assert_not_nil' },
{ 'assertNotIsFunction' , 'assert_not_function' },
{ 'assertNotIsThread' , 'assert_not_thread' },
{ 'assertNotIsUserdata' , 'assert_not_userdata' },
-- all assertions with Coroutine duplicate Thread assertions
{ 'assertIsThread' , 'assertIsCoroutine' },
{ 'assertIsThread' , 'assertCoroutine' },
{ 'assertIsThread' , 'assert_is_coroutine' },
{ 'assertIsThread' , 'assert_coroutine' },
{ 'assertNotIsThread' , 'assertNotIsCoroutine' },
{ 'assertNotIsThread' , 'assertNotCoroutine' },
{ 'assertNotIsThread' , 'assert_not_is_coroutine' },
{ 'assertNotIsThread' , 'assert_not_coroutine' },
}
-- Create all aliases in M
for _,v in ipairs( list_of_funcs ) do
funcname, alias = v[1], v[2]
M[alias] = M[funcname]
if EXPORT_ASSERT_TO_GLOBALS then
_G[funcname] = M[funcname]
_G[alias] = M[funcname]
end
end
----------------------------------------------------------------
--
-- Outputters
--
----------------------------------------------------------------
----------------------------------------------------------------
-- class TapOutput
----------------------------------------------------------------
local TapOutput = { __class__ = 'TapOutput' } -- class
local TapOutput_MT = { __index = TapOutput } -- metatable
-- For a good reference for TAP format, check: http://testanything.org/tap-specification.html
function TapOutput:new()
return setmetatable( { verbosity = M.VERBOSITY_LOW }, TapOutput_MT)
end
function TapOutput:startSuite()
print("1.."..self.result.testCount)
print('# Started on '..self.result.startDate)
end
function TapOutput:startClass(className)
if className ~= '[TestFunctions]' then
print('# Starting class: '..className)
end
end
function TapOutput:startTest(testName) end
function TapOutput:addFailure( node )
io.stdout:write("not ok ", self.result.currentTestNumber, "\t", node.testName, "\n")
if self.verbosity > M.VERBOSITY_LOW then
print( prefixString( ' ', node.msg ) )
end
if self.verbosity > M.VERBOSITY_DEFAULT then
print( prefixString( ' ', node.stackTrace ) )
end
end
TapOutput.addError = TapOutput.addFailure
function TapOutput:endTest( node )
if node:isPassed() then
io.stdout:write("ok ", self.result.currentTestNumber, "\t", node.testName, "\n")
end
end
function TapOutput:endClass() end
function TapOutput:endSuite()
print( '# '..M.LuaUnit.statusLine( self.result ) )
return self.result.notPassedCount
end
-- class TapOutput end
----------------------------------------------------------------
-- class JUnitOutput
----------------------------------------------------------------
-- See directory junitxml for more information about the junit format
local JUnitOutput = { __class__ = 'JUnitOutput' } -- class
local JUnitOutput_MT = { __index = JUnitOutput } -- metatable
function JUnitOutput:new()
return setmetatable(
{ testList = {}, verbosity = M.VERBOSITY_LOW }, JUnitOutput_MT)
end
function JUnitOutput:startSuite()
-- open xml file early to deal with errors
if self.fname == nil then
error('With Junit, an output filename must be supplied with --name!')
end
if string.sub(self.fname,-4) ~= '.xml' then
self.fname = self.fname..'.xml'
end
self.fd = io.open(self.fname, "w")
if self.fd == nil then
error("Could not open file for writing: "..self.fname)
end
print('# XML output to '..self.fname)
print('# Started on '..self.result.startDate)
end
function JUnitOutput:startClass(className)
if className ~= '[TestFunctions]' then
print('# Starting class: '..className)
end
end
function JUnitOutput:startTest(testName)
print('# Starting test: '..testName)
end
function JUnitOutput:addFailure( node )
print('# Failure: ' .. node.msg)
-- print('# ' .. node.stackTrace)
end
function JUnitOutput:addError( node )
print('# Error: ' .. node.msg)
-- print('# ' .. node.stackTrace)
end
function JUnitOutput:endTest( node )
end
function JUnitOutput:endClass()
end
function JUnitOutput:endSuite()
print( '# '..M.LuaUnit.statusLine(self.result))
-- XML file writing
self.fd:write('<?xml version="1.0" encoding="UTF-8" ?>\n')
self.fd:write('<testsuites>\n')
self.fd:write(string.format(
' <testsuite name="LuaUnit" id="00001" package="" hostname="localhost" tests="%d" timestamp="%s" time="%0.3f" errors="%d" failures="%d">\n',
self.result.runCount, self.result.startIsodate, self.result.duration, self.result.errorCount, self.result.failureCount ))
self.fd:write(" <properties>\n")
self.fd:write(string.format(' <property name="Lua Version" value="%s"/>\n', _VERSION ) )
self.fd:write(string.format(' <property name="LuaUnit Version" value="%s"/>\n', M.VERSION) )
-- XXX please include system name and version if possible
self.fd:write(" </properties>\n")
for i,node in ipairs(self.result.tests) do
self.fd:write(string.format(' <testcase classname="%s" name="%s" time="%0.3f">\n',
node.className, node.testName, node.duration ) )
if node:isNotPassed() then
self.fd:write(node:statusXML())
end
self.fd:write(' </testcase>\n')
end
-- Next two lines are needed to validate junit ANT xsd, but really not useful in general:
self.fd:write(' <system-out/>\n')
self.fd:write(' <system-err/>\n')
self.fd:write(' </testsuite>\n')
self.fd:write('</testsuites>\n')
self.fd:close()
return self.result.notPassedCount
end
-- class TapOutput end
----------------------------------------------------------------
-- class TextOutput
----------------------------------------------------------------
--[[
-- Python Non verbose:
For each test: . or F or E
If some failed tests:
==============
ERROR / FAILURE: TestName (testfile.testclass)
---------
Stack trace
then --------------
then "Ran x tests in 0.000s"
then OK or FAILED (failures=1, error=1)
-- Python Verbose:
testname (filename.classname) ... ok
testname (filename.classname) ... FAIL
testname (filename.classname) ... ERROR
then --------------
then "Ran x tests in 0.000s"
then OK or FAILED (failures=1, error=1)
-- Ruby:
Started
.
Finished in 0.002695 seconds.
1 tests, 2 assertions, 0 failures, 0 errors
-- Ruby:
>> ruby tc_simple_number2.rb
Loaded suite tc_simple_number2
Started
F..
Finished in 0.038617 seconds.
1) Failure:
test_failure(TestSimpleNumber) [tc_simple_number2.rb:16]:
Adding doesn't work.
<3> expected but was
<4>.
3 tests, 4 assertions, 1 failures, 0 errors
-- Java Junit
.......F.
Time: 0,003
There was 1 failure:
1) testCapacity(junit.samples.VectorTest)junit.framework.AssertionFailedError
at junit.samples.VectorTest.testCapacity(VectorTest.java:87)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
FAILURES!!!
Tests run: 8, Failures: 1, Errors: 0
-- Maven
# mvn test
-------------------------------------------------------
T E S T S
-------------------------------------------------------
Running math.AdditionTest
Tests run: 2, Failures: 1, Errors: 0, Skipped: 0, Time elapsed:
0.03 sec <<< FAILURE!
Results :
Failed tests:
testLireSymbole(math.AdditionTest)
Tests run: 2, Failures: 1, Errors: 0, Skipped: 0
-- LuaUnit
---- non verbose
* display . or F or E when running tests
---- verbose
* display test name + ok/fail
----
* blank line
* number) ERROR or FAILURE: TestName
Stack trace
* blank line
* number) ERROR or FAILURE: TestName
Stack trace
then --------------
then "Ran x tests in 0.000s (%d not selected, %d skipped)"
then OK or FAILED (failures=1, error=1)
]]
local TextOutput = { __class__ = 'TextOutput' } -- class
local TextOutput_MT = { __index = TextOutput } -- metatable
function TextOutput:new()
return setmetatable(
{ errorList = {}, verbosity = M.VERBOSITY_DEFAULT }, TextOutput_MT )
end
function TextOutput:startSuite()
if self.verbosity > M.VERBOSITY_DEFAULT then
print( 'Started on '.. self.result.startDate )
end
end
function TextOutput:startClass(className)
-- display nothing when starting a new class
end
function TextOutput:startTest(testName)
if self.verbosity > M.VERBOSITY_DEFAULT then
io.stdout:write( " ", self.result.currentNode.testName, " ... " )
end
end
function TextOutput:addFailure( node )
-- nothing
end
function TextOutput:addError( node )
-- nothing
end
function TextOutput:endTest( node )
if node:isPassed() then
if self.verbosity > M.VERBOSITY_DEFAULT then
io.stdout:write("Ok\n")
else
io.stdout:write(".")
end
else
if self.verbosity > M.VERBOSITY_DEFAULT then
print( node.status )
print( node.msg )
--[[
-- find out when to do this:
if self.verbosity > M.VERBOSITY_DEFAULT then
print( node.stackTrace )
end
]]
else
-- write only the first character of status
io.stdout:write(string.sub(node.status, 1, 1))
end
end
end
function TextOutput:endClass()
-- nothing
end
function TextOutput:displayOneFailedTest( index, failure )
print(index..") "..failure.testName )
print( failure.msg )
print( failure.stackTrace )
print()
end
function TextOutput:displayFailedTests()
if self.result.notPassedCount == 0 then return end
print("Failed tests:")
print("-------------")
for i,v in ipairs(self.result.notPassed) do
self:displayOneFailedTest( i, v )
end
end
function TextOutput:endSuite()
if self.verbosity > M.VERBOSITY_DEFAULT then
print("=========================================================")
else
print()
end
self:displayFailedTests()
print( M.LuaUnit.statusLine( self.result ) )
local ignoredString = ""
if self.result.notPassedCount == 0 then
print('OK')
end
end
-- class TextOutput end
----------------------------------------------------------------
-- class NilOutput
----------------------------------------------------------------
local function nopCallable()
--print(42)
return nopCallable
end
local NilOutput = { __class__ = 'NilOuptut' } -- class
local NilOutput_MT = { __index = nopCallable } -- metatable
function NilOutput:new()
return setmetatable( { __class__ = 'NilOutput' }, NilOutput_MT )
end
----------------------------------------------------------------
--
-- class LuaUnit
--
----------------------------------------------------------------
M.LuaUnit = {
outputType = TextOutput,
verbosity = M.VERBOSITY_DEFAULT,
__class__ = 'LuaUnit'
}
local LuaUnit_MT = { __index = M.LuaUnit }
if EXPORT_ASSERT_TO_GLOBALS then
LuaUnit = M.LuaUnit
end
function M.LuaUnit:new()
return setmetatable( {}, LuaUnit_MT )
end
-----------------[[ Utility methods ]]---------------------
function M.LuaUnit.asFunction(aObject)
-- return "aObject" if it is a function, and nil otherwise
if 'function' == type(aObject) then return aObject end
end
function M.LuaUnit.isClassMethod(aName)
-- return true if aName contains a class + a method name in the form class:method
return string.find(aName, '.', nil, true) ~= nil
end
function M.LuaUnit.splitClassMethod(someName)
-- return a pair className, methodName for a name in the form class:method
-- return nil if not a class + method name
-- name is class + method
local hasMethod, methodName, className
hasMethod = string.find(someName, '.', nil, true )
if not hasMethod then return nil end
methodName = string.sub(someName, hasMethod+1)
className = string.sub(someName,1,hasMethod-1)
return className, methodName
end
function M.LuaUnit.isMethodTestName( s )
-- return true is the name matches the name of a test method
-- default rule is that is starts with 'Test' or with 'test'
return string.sub(s, 1, 4):lower() == 'test'
end
function M.LuaUnit.isTestName( s )
-- return true is the name matches the name of a test
-- default rule is that is starts with 'Test' or with 'test'
return string.sub(s, 1, 4):lower() == 'test'
end
function M.LuaUnit.collectTests()
-- return a list of all test names in the global namespace
-- that match LuaUnit.isTestName
local testNames = {}
for k, v in pairs(_G) do
if M.LuaUnit.isTestName( k ) then
table.insert( testNames , k )
end
end
table.sort( testNames )
return testNames
end
function M.LuaUnit.parseCmdLine( cmdLine )
-- parse the command line
-- Supported command line parameters:
-- --verbose, -v: increase verbosity
-- --quiet, -q: silence output
-- --error, -e: treat errors as fatal (quit program)
-- --output, -o, + name: select output type
-- --pattern, -p, + pattern: run test matching pattern, may be repeated
-- --name, -n, + fname: name of output file for junit, default to stdout
-- [testnames, ...]: run selected test names
--
-- Returns a table with the following fields:
-- verbosity: nil, M.VERBOSITY_DEFAULT, M.VERBOSITY_QUIET, M.VERBOSITY_VERBOSE
-- output: nil, 'tap', 'junit', 'text', 'nil'
-- testNames: nil or a list of test names to run
-- pattern: nil or a list of patterns
local result = {}
local state = nil
local SET_OUTPUT = 1
local SET_PATTERN = 2
local SET_FNAME = 3
if cmdLine == nil then
return result
end
local function parseOption( option )
if option == '--help' or option == '-h' then
result['help'] = true
return
elseif option == '--version' then
result['version'] = true
return
elseif option == '--verbose' or option == '-v' then
result['verbosity'] = M.VERBOSITY_VERBOSE
return
elseif option == '--quiet' or option == '-q' then
result['verbosity'] = M.VERBOSITY_QUIET
return
elseif option == '--error' or option == '-e' then
result['quitOnError'] = true
return
elseif option == '--failure' or option == '-f' then
result['quitOnFailure'] = true
return
elseif option == '--output' or option == '-o' then
state = SET_OUTPUT
return state
elseif option == '--name' or option == '-n' then
state = SET_FNAME
return state
elseif option == '--pattern' or option == '-p' then
state = SET_PATTERN
return state
end
error('Unknown option: '..option,3)
end
local function setArg( cmdArg, state )
if state == SET_OUTPUT then
result['output'] = cmdArg
return
elseif state == SET_FNAME then
result['fname'] = cmdArg
return
elseif state == SET_PATTERN then
if result['pattern'] then
table.insert( result['pattern'], cmdArg )
else
result['pattern'] = { cmdArg }
end
return
end
error('Unknown parse state: '.. state)
end
for i, cmdArg in ipairs(cmdLine) do
if state ~= nil then
setArg( cmdArg, state, result )
state = nil
else
if cmdArg:sub(1,1) == '-' then
state = parseOption( cmdArg )
else
if result['testNames'] then
table.insert( result['testNames'], cmdArg )
else
result['testNames'] = { cmdArg }
end
end
end
end
if result['help'] then
M.LuaUnit.help()
end
if result['version'] then
M.LuaUnit.version()
end
if state ~= nil then
error('Missing argument after '..cmdLine[ #cmdLine ],2 )
end
return result
end
function M.LuaUnit.help()
print(M.USAGE)
os.exit(0)
end
function M.LuaUnit.version()
print('LuaUnit v'..M.VERSION..' by Philippe Fremy <phil@freehackers.org>')
os.exit(0)
end
function M.LuaUnit.patternInclude( patternFilter, expr )
-- check if any of patternFilter is contained in expr. If so, return true.
-- return false if None of the patterns are contained in expr
-- if patternFilter is nil, return true (no filtering)
if patternFilter == nil then
return true
end
for i,pattern in ipairs(patternFilter) do
if string.find(expr, pattern) then
return true
end
end
return false
end
----------------------------------------------------------------
-- class NodeStatus
----------------------------------------------------------------
local NodeStatus = { __class__ = 'NodeStatus' } -- class
local NodeStatus_MT = { __index = NodeStatus } -- metatable
M.NodeStatus = NodeStatus
-- values of status
NodeStatus.PASS = 'PASS'
NodeStatus.FAIL = 'FAIL'
NodeStatus.ERROR = 'ERROR'
function NodeStatus:new( number, testName, className )
local t = { number = number, testName = testName, className = className }
setmetatable( t, NodeStatus_MT )
t:pass()
return t
end
function NodeStatus:pass()
self.status = self.PASS
-- useless but we know it's the field we want to use
self.msg = nil
self.stackTrace = nil
end
function NodeStatus:fail(msg, stackTrace)
self.status = self.FAIL
self.msg = msg
self.stackTrace = stackTrace
end
function NodeStatus:error(msg, stackTrace)
self.status = self.ERROR
self.msg = msg
self.stackTrace = stackTrace
end
function NodeStatus:isPassed()
return self.status == NodeStatus.PASS
end
function NodeStatus:isNotPassed()
-- print('hasFailure: '..prettystr(self))
return self.status ~= NodeStatus.PASS
end
function NodeStatus:isFailure()
return self.status == NodeStatus.FAIL
end
function NodeStatus:isError()
return self.status == NodeStatus.ERROR
end
function NodeStatus:statusXML()
if self:isError() then
return table.concat(
{' <error type="', xmlEscape(self.msg), '">\n',
' <![CDATA[', xmlCDataEscape(self.stackTrace),
']]></error>\n'})
elseif self:isFailure() then
return table.concat(
{' <failure type="', xmlEscape(self.msg), '">\n',
' <![CDATA[', xmlCDataEscape(self.stackTrace),
']]></failure>\n'})
end
return ' <passed/>\n' -- (not XSD-compliant! normally shouldn't get here)
end
--------------[[ Output methods ]]-------------------------
function M.LuaUnit.statusLine(result)
-- return status line string according to results
local s = string.format('Ran %d tests in %0.3f seconds, %d successes',
result.runCount, result.duration, result.passedCount )
if result.notPassedCount > 0 then
if result.failureCount > 0 then
s = s..string.format(', %d failures', result.failureCount )
end
if result.errorCount > 0 then
s = s..string.format(', %d errors', result.errorCount )
end
else
s = s..', 0 failures'
end
if result.nonSelectedCount > 0 then
s = s..string.format(", %d non-selected", result.nonSelectedCount )
end
return s
end
function M.LuaUnit:startSuite(testCount, nonSelectedCount)
self.result = {}
self.result.testCount = testCount
self.result.nonSelectedCount = nonSelectedCount
self.result.passedCount = 0
self.result.runCount = 0
self.result.currentTestNumber = 0
self.result.currentClassName = ""
self.result.currentNode = nil
self.result.suiteStarted = true
self.result.startTime = os.clock()
self.result.startDate = os.date(os.getenv('LUAUNIT_DATEFMT'))
self.result.startIsodate = os.date('%Y-%m-%dT%H:%M:%S')
self.result.patternFilter = self.patternFilter
self.result.tests = {}
self.result.failures = {}
self.result.errors = {}
self.result.notPassed = {}
self.outputType = self.outputType or TextOutput
self.output = self.outputType:new()
self.output.runner = self
self.output.result = self.result
self.output.verbosity = self.verbosity
self.output.fname = self.fname
self.output:startSuite()
end
function M.LuaUnit:startClass( className )
self.result.currentClassName = className
self.output:startClass( className )
end
function M.LuaUnit:startTest( testName )
self.result.currentTestNumber = self.result.currentTestNumber + 1
self.result.runCount = self.result.runCount + 1
self.result.currentNode = NodeStatus:new(
self.result.currentTestNumber,
testName,
self.result.currentClassName
)
self.result.currentNode.startTime = os.clock()
table.insert( self.result.tests, self.result.currentNode )
self.output:startTest( testName )
end
function M.LuaUnit:addStatus( err )
-- "err" is expected to be a table / result from protectedCall()
if err.status == NodeStatus.PASS then return end
local node = self.result.currentNode
--[[ As a first approach, we will report only one error or one failure for one test.
However, we can have the case where the test is in failure, and the teardown is in error.
In such case, it's a good idea to report both a failure and an error in the test suite. This is
what Python unittest does for example. However, it mixes up counts so need to be handled carefully: for
example, there could be more (failures + errors) count that tests. What happens to the current node ?
We will do this more intelligent version later.
]]
-- if the node is already in failure/error, just don't report the new error (see above)
if node.status ~= NodeStatus.PASS then return end
table.insert( self.result.notPassed, node )
if err.status == NodeStatus.FAIL then
node:fail( err.msg, err.trace )
table.insert( self.result.failures, node )
self.output:addFailure( node )
elseif err.status == NodeStatus.ERROR then
node:error( err.msg, err.trace )
table.insert( self.result.errors, node )
self.output:addError( node )
end
end
function M.LuaUnit:endTest()
local node = self.result.currentNode
-- print( 'endTest() '..prettystr(node))
-- print( 'endTest() '..prettystr(node:isNotPassed()))
node.duration = os.clock() - node.startTime
node.startTime = nil
self.output:endTest( node )
if node:isPassed() then
self.result.passedCount = self.result.passedCount + 1
elseif node:isError() then
if self.quitOnError or self.quitOnFailure then
-- Runtime error - abort test execution as requested by
-- "--error" option. This is done by setting a special
-- flag that gets handled in runSuiteByInstances().
print("\nERROR during LuaUnit test execution:\n" .. node.msg)
self.result.aborted = true
end
elseif node:isFailure() then
if self.quitOnFailure then
-- Failure - abort test execution as requested by
-- "--failure" option. This is done by setting a special
-- flag that gets handled in runSuiteByInstances().
print("\nFailure during LuaUnit test execution:\n" .. node.msg)
self.result.aborted = true
end
end
self.result.currentNode = nil
end
function M.LuaUnit:endClass()
self.output:endClass()
end
function M.LuaUnit:endSuite()
if self.result.suiteStarted == false then
error('LuaUnit:endSuite() -- suite was already ended' )
end
self.result.duration = os.clock()-self.result.startTime
self.result.suiteStarted = false
-- Expose test counts for outputter's endSuite(). This could be managed
-- internally instead, but unit tests (and existing use cases) might
-- rely on these fields being present.
self.result.notPassedCount = #self.result.notPassed
self.result.failureCount = #self.result.failures
self.result.errorCount = #self.result.errors
self.output:endSuite()
end
function M.LuaUnit:setOutputType(outputType)
-- default to text
-- tap produces results according to TAP format
if outputType:upper() == "NIL" then
self.outputType = NilOutput
return
end
if outputType:upper() == "TAP" then
self.outputType = TapOutput
return
end
if outputType:upper() == "JUNIT" then
self.outputType = JUnitOutput
return
end
if outputType:upper() == "TEXT" then
self.outputType = TextOutput
return
end
error( 'No such format: '..outputType,2)
end
--------------[[ Runner ]]-----------------
function M.LuaUnit:protectedCall(classInstance, methodInstance, prettyFuncName)
-- if classInstance is nil, this is just a function call
-- else, it's method of a class being called.
local function err_handler(e)
-- transform error into a table, adding the traceback information
return {
status = NodeStatus.ERROR,
msg = e,
trace = string.sub(debug.traceback("", 3), 2)
}
end
local ok, err
if classInstance then
-- stupid Lua < 5.2 does not allow xpcall with arguments so let's use a workaround
ok, err = xpcall( function () methodInstance(classInstance) end, err_handler )
else
ok, err = xpcall( function () methodInstance() end, err_handler )
end
if ok then
return {status = NodeStatus.PASS}
end
-- determine if the error was a failed test:
-- We do this by stripping the failure prefix from the error message,
-- while keeping track of the gsub() count. A non-zero value -> failure
local failed
err.msg, failed = err.msg:gsub(M.FAILURE_PREFIX, "", 1)
if failed > 0 then
err.status = NodeStatus.FAIL
end
-- reformat / improve the stack trace
if prettyFuncName then -- we do have the real method name
err.trace = err.trace:gsub("in (%a+) 'methodInstance'", "in %1 '"..prettyFuncName.."'")
end
if STRIP_LUAUNIT_FROM_STACKTRACE then
err.trace = stripLuaunitTrace(err.trace)
end
return err -- return the error "object" (table)
end
function M.LuaUnit:execOneFunction(className, methodName, classInstance, methodInstance)
-- When executing a test function, className and classInstance must be nil
-- When executing a class method, all parameters must be set
if type(methodInstance) ~= 'function' then
error( tostring(methodName)..' must be a function, not '..type(methodInstance))
end
local prettyFuncName
if className == nil then
className = '[TestFunctions]'
prettyFuncName = methodName
else
prettyFuncName = className..'.'..methodName
end
if self.lastClassName ~= className then
if self.lastClassName ~= nil then
self:endClass()
end
self:startClass( className )
self.lastClassName = className
end
self:startTest(prettyFuncName)
-- run setUp first (if any)
if classInstance then
local func = self.asFunction( classInstance.setUp )
or self.asFunction( classInstance.Setup )
or self.asFunction( classInstance.setup )
or self.asFunction( classInstance.SetUp )
if func then
self:addStatus(self:protectedCall(classInstance, func, className..'.setUp'))
end
end
-- run testMethod()
if self.result.currentNode:isPassed() then
self:addStatus(self:protectedCall(classInstance, methodInstance, prettyFuncName))
end
-- lastly, run tearDown (if any)
if classInstance then
local func = self.asFunction( classInstance.tearDown )
or self.asFunction( classInstance.TearDown )
or self.asFunction( classInstance.teardown )
or self.asFunction( classInstance.Teardown )
if func then
self:addStatus(self:protectedCall(classInstance, func, className..'.tearDown'))
end
end
self:endTest()
end
function M.LuaUnit.expandOneClass( result, className, classInstance )
-- add all test methods of classInstance to result
for methodName, methodInstance in sortedPairs(classInstance) do
if M.LuaUnit.asFunction(methodInstance) and M.LuaUnit.isMethodTestName( methodName ) then
table.insert( result, { className..'.'..methodName, classInstance } )
end
end
end
function M.LuaUnit.expandClasses( listOfNameAndInst )
-- expand all classes (provided as {className, classInstance}) to a list of {className.methodName, classInstance}
-- functions and methods remain untouched
local result = {}
for i,v in ipairs( listOfNameAndInst ) do
local name, instance = v[1], v[2]
if M.LuaUnit.asFunction(instance) then
table.insert( result, { name, instance } )
else
if type(instance) ~= 'table' then
error( 'Instance must be a table or a function, not a '..type(instance)..', value '..prettystr(instance))
end
if M.LuaUnit.isClassMethod( name ) then
local className, methodName = M.LuaUnit.splitClassMethod( name )
local methodInstance = instance[methodName]
if methodInstance == nil then
error( "Could not find method in class "..tostring(className).." for method "..tostring(methodName) )
end
table.insert( result, { name, instance } )
else
M.LuaUnit.expandOneClass( result, name, instance )
end
end
end
return result
end
function M.LuaUnit.applyPatternFilter( patternFilter, listOfNameAndInst )
local included, excluded = {}, {}
for i, v in ipairs( listOfNameAndInst ) do
-- local name, instance = v[1], v[2]
if M.LuaUnit.patternInclude( patternFilter, v[1] ) then
table.insert( included, v )
else
table.insert( excluded, v )
end
end
return included, excluded
end
function M.LuaUnit:runSuiteByInstances( listOfNameAndInst )
-- Run an explicit list of tests. All test instances and names must be supplied.
-- each test must be one of:
-- * { function name, function instance }
-- * { class name, class instance }
-- * { class.method name, class instance }
local expandedList, filteredList, filteredOutList, className, methodName, methodInstance
expandedList = self.expandClasses( listOfNameAndInst )
filteredList, filteredOutList = self.applyPatternFilter( self.patternFilter, expandedList )
self:startSuite( #filteredList, #filteredOutList )
for i,v in ipairs( filteredList ) do
local name, instance = v[1], v[2]
if M.LuaUnit.asFunction(instance) then
self:execOneFunction( nil, name, nil, instance )
else
if type(instance) ~= 'table' then
error( 'Instance must be a table or a function, not a '..type(instance)..', value '..prettystr(instance))
else
assert( M.LuaUnit.isClassMethod( name ) )
className, methodName = M.LuaUnit.splitClassMethod( name )
methodInstance = instance[methodName]
if methodInstance == nil then
error( "Could not find method in class "..tostring(className).." for method "..tostring(methodName) )
end
self:execOneFunction( className, methodName, instance, methodInstance )
end
end
if self.result.aborted then break end -- "--error" or "--failure" option triggered
end
if self.lastClassName ~= nil then
self:endClass()
end
self:endSuite()
if self.result.aborted then
print("LuaUnit ABORTED (as requested by --error or --failure option)")
os.exit(-2)
end
end
function M.LuaUnit:runSuiteByNames( listOfName )
-- Run an explicit list of test names
local className, methodName, instanceName, instance, methodInstance
local listOfNameAndInst = {}
for i,name in ipairs( listOfName ) do
if M.LuaUnit.isClassMethod( name ) then
className, methodName = M.LuaUnit.splitClassMethod( name )
instanceName = className
instance = _G[instanceName]
if instance == nil then
error( "No such name in global space: "..instanceName )
end
if type(instance) ~= 'table' then
error( 'Instance of '..instanceName..' must be a table, not '..type(instance))
end
methodInstance = instance[methodName]
if methodInstance == nil then
error( "Could not find method in class "..tostring(className).." for method "..tostring(methodName) )
end
else
-- for functions and classes
instanceName = name
instance = _G[instanceName]
end
if instance == nil then
error( "No such name in global space: "..instanceName )
end
if (type(instance) ~= 'table' and type(instance) ~= 'function') then
error( 'Name must match a function or a table: '..instanceName )
end
table.insert( listOfNameAndInst, { name, instance } )
end
self:runSuiteByInstances( listOfNameAndInst )
end
function M.LuaUnit.run(...)
-- Run some specific test classes.
-- If no arguments are passed, run the class names specified on the
-- command line. If no class name is specified on the command line
-- run all classes whose name starts with 'Test'
--
-- If arguments are passed, they must be strings of the class names
-- that you want to run or generic command line arguments (-o, -p, -v, ...)
local runner = M.LuaUnit.new()
return runner:runSuite(...)
end
function M.LuaUnit:runSuite( ... )
local args = {...}
if type(args[1]) == 'table' and args[1].__class__ == 'LuaUnit' then
-- run was called with the syntax M.LuaUnit:runSuite()
-- we support both M.LuaUnit.run() and M.LuaUnit:run()
-- strip out the first argument
table.remove(args,1)
end
if #args == 0 then
args = cmdline_argv
end
local no_error, val = pcall( M.LuaUnit.parseCmdLine, args )
if not no_error then
print(val) -- error message
print()
print(M.USAGE)
os.exit(-1)
end
local options = val
-- We expect these option fields to be either `nil` or contain
-- valid values, so it's safe to always copy them directly.
self.verbosity = options.verbosity
self.quitOnError = options.quitOnError
self.quitOnFailure = options.quitOnFailure
self.fname = options.fname
self.patternFilter = options.pattern
if options.output and options.output:lower() == 'junit' and options.fname == nil then
print('With junit output, a filename must be supplied with -n or --name')
os.exit(-1)
end
if options.output then
no_error, val = pcall(self.setOutputType, self, options.output)
if not no_error then
print(val) -- error message
print()
print(M.USAGE)
os.exit(-1)
end
end
self:runSuiteByNames( options.testNames or M.LuaUnit.collectTests() )
return self.result.notPassedCount
end
-- class LuaUnit
-- For compatbility with LuaUnit v2
M.run = M.LuaUnit.run
M.Run = M.LuaUnit.run
function M:setVerbosity( verbosity )
M.LuaUnit.verbosity = verbosity
end
M.set_verbosity = M.setVerbosity
M.SetVerbosity = M.setVerbosity
return M
require("test_helper")
TestClang = {}
function TestClang:test_probe_read1()
local text = [[
#include <linux/sched.h>
#include <uapi/linux/ptrace.h>
int count_sched(struct pt_regs *ctx, struct task_struct *prev) {
pid_t p = prev->pid;
return (p != -1);
}
]]
local b = BPF:new{text=text, debug=0}
local fn = b:load_func("count_sched", 'BPF_PROG_TYPE_KPROBE')
end
function TestClang:test_probe_read2()
local text = [[
#include <linux/sched.h>
#include <uapi/linux/ptrace.h>
int count_foo(struct pt_regs *ctx, unsigned long a, unsigned long b) {
return (a != b);
}
]]
local b = BPF:new{text=text, debug=0}
local fn = b:load_func("count_foo", 'BPF_PROG_TYPE_KPROBE')
end
function TestClang:test_probe_read_keys()
local text = [[
#include <uapi/linux/ptrace.h>
#include <linux/blkdev.h>
BPF_HASH(start, struct request *);
int do_request(struct pt_regs *ctx, struct request *req) {
u64 ts = bpf_ktime_get_ns();
start.update(&req, &ts);
return 0;
}
int do_completion(struct pt_regs *ctx, struct request *req) {
u64 *tsp = start.lookup(&req);
if (tsp != 0) {
start.delete(&req);
}
return 0;
}
]]
local b = BPF:new{text=text, debug=0}
local fns = b:load_funcs('BPF_PROG_TYPE_KPROBE')
end
function TestClang:test_sscanf()
local text = [[
BPF_TABLE("hash", int, struct { u64 a; u64 b; u32 c:18; u32 d:14; struct { u32 a; u32 b; } s; }, stats, 10);
int foo(void *ctx) {
return 0;
}
]]
local b = BPF:new{text=text, debug=0}
local fn = b:load_func("foo", 'BPF_PROG_TYPE_KPROBE')
local t = b:get_table("stats")
local s1 = t:key_sprintf(2)
assert_equals(s1, "0x2")
local s2 = t:leaf_sprintf({{2, 3, 4, 1, {5, 6}}})
local l = t:leaf_scanf(s2)
assert_equals(tonumber(l.a), 2)
assert_equals(tonumber(l.b), 3)
assert_equals(tonumber(l.c), 4)
assert_equals(tonumber(l.d), 1)
assert_equals(tonumber(l.s.a), 5)
assert_equals(tonumber(l.s.b), 6)
end
function TestClang:test_sscanf_array()
local text = [[ BPF_TABLE("hash", int, struct { u32 a[3]; u32 b; }, stats, 10); ]]
local b = BPF:new{text=text, debug=0}
local t = b:get_table("stats")
local s1 = t:key_sprintf(2)
assert_equals(s1, "0x2")
local s2 = t:leaf_sprintf({{{1, 2, 3}, 4}})
assert_equals(s2, "{ [ 0x1 0x2 0x3 ] 0x4 }")
local l = t:leaf_scanf(s2)
assert_equals(l.a[0], 1)
assert_equals(l.a[1], 2)
assert_equals(l.a[2], 3)
assert_equals(l.b, 4)
end
function TestClang:test_iosnoop()
local text = [[
#include <linux/blkdev.h>
#include <uapi/linux/ptrace.h>
struct key_t {
struct request *req;
};
BPF_TABLE("hash", struct key_t, u64, start, 1024);
int do_request(struct pt_regs *ctx, struct request *req) {
struct key_t key = {};
bpf_trace_printk("traced start %d\\n", req->__data_len);
return 0;
}
]]
local b = BPF:new{text=text, debug=0}
local fn = b:load_func("do_request", 'BPF_PROG_TYPE_KPROBE')
end
function TestClang:test_blk_start_request()
local text = [[
#include <linux/blkdev.h>
#include <uapi/linux/ptrace.h>
int do_request(struct pt_regs *ctx, int req) {
bpf_trace_printk("req ptr: 0x%x\n", req);
return 0;
}
]]
local b = BPF:new{text=text, debug=0}
local fn = b:load_func("do_request", 'BPF_PROG_TYPE_KPROBE')
end
function TestClang:test_bpf_hash()
local text = [[
BPF_HASH(table1);
BPF_HASH(table2, u32);
BPF_HASH(table3, u32, int);
]]
local b = BPF:new{text=text, debug=0}
end
function TestClang:test_consecutive_probe_read()
local text = [[
#include <linux/fs.h>
#include <linux/mount.h>
BPF_HASH(table1, struct super_block *);
int trace_entry(struct pt_regs *ctx, struct file *file) {
if (!file) return 0;
struct vfsmount *mnt = file->f_path.mnt;
if (mnt) {
struct super_block *k = mnt->mnt_sb;
u64 zero = 0;
table1.update(&k, &zero);
k = mnt->mnt_sb;
table1.update(&k, &zero);
}
return 0;
}
]]
local b = BPF:new{text=text, debug=0}
local fn = b:load_func("trace_entry", 'BPF_PROG_TYPE_KPROBE')
end
function TestClang:test_nested_probe_read()
local text = [[
#include <linux/fs.h>
int trace_entry(struct pt_regs *ctx, struct file *file) {
if (!file) return 0;
const char *name = file->f_path.dentry->d_name.name;
bpf_trace_printk("%s\\n", name);
return 0;
}
]]
local b = BPF:new{text=text, debug=0}
local fn = b:load_func("trace_entry", 'BPF_PROG_TYPE_KPROBE')
end
function TestClang:test_char_array_probe()
local b = BPF:new{text=[[#include <linux/blkdev.h>
int kprobe__blk_update_request(struct pt_regs *ctx, struct request *req) {
bpf_trace_printk("%s\\n", req->rq_disk->disk_name);
return 0;
}]]}
end
function TestClang:test_probe_read_helper()
local b = BPF:new{text=[[
#include <linux/fs.h>
static void print_file_name(struct file *file) {
if (!file) return;
const char *name = file->f_path.dentry->d_name.name;
bpf_trace_printk("%s\\n", name);
}
static void print_file_name2(int unused, struct file *file) {
print_file_name(file);
}
int trace_entry1(struct pt_regs *ctx, struct file *file) {
print_file_name(file);
return 0;
}
int trace_entry2(struct pt_regs *ctx, int unused, struct file *file) {
print_file_name2(unused, file);
return 0;
}
]]}
local fn1 = b:load_func("trace_entry1", 'BPF_PROG_TYPE_KPROBE')
local fn2 = b:load_func("trace_entry2", 'BPF_PROG_TYPE_KPROBE')
end
function TestClang:test_probe_struct_assign()
local b = BPF:new{text = [[
#include <uapi/linux/ptrace.h>
struct args_t {
const char *filename;
int flags;
int mode;
};
int kprobe__sys_open(struct pt_regs *ctx, const char *filename,
int flags, int mode) {
struct args_t args = {};
args.filename = filename;
args.flags = flags;
args.mode = mode;
bpf_trace_printk("%s\\n", args.filename);
return 0;
};
]]}
end
function TestClang:test_task_switch()
local b = BPF:new{text=[[
#include <uapi/linux/ptrace.h>
#include <linux/sched.h>
struct key_t {
u32 prev_pid;
u32 curr_pid;
};
BPF_TABLE("hash", struct key_t, u64, stats, 1024);
int kprobe__finish_task_switch(struct pt_regs *ctx, struct task_struct *prev) {
struct key_t key = {};
u64 zero = 0, *val;
key.curr_pid = bpf_get_current_pid_tgid();
key.prev_pid = prev->pid;
val = stats.lookup_or_init(&key, &zero);
(*val)++;
return 0;
}
]]}
end
function TestClang:test_probe_simple_assign()
local b = BPF:new{text=[[
#include <uapi/linux/ptrace.h>
#include <linux/gfp.h>
struct leaf { size_t size; };
BPF_HASH(simple_map, u32, struct leaf);
int kprobe____kmalloc(struct pt_regs *ctx, size_t size) {
u32 pid = bpf_get_current_pid_tgid();
struct leaf* leaf = simple_map.lookup(&pid);
if (leaf)
leaf->size += size;
return 0;
}]]}
end
function TestClang:test_unop_probe_read()
local text = [[
#include <linux/blkdev.h>
int trace_entry(struct pt_regs *ctx, struct request *req) {
if (!(req->bio->bi_rw & 1))
return 1;
if (((req->bio->bi_rw)))
return 1;
return 0;
}
]]
local b = BPF:new{text=text}
local fn = b:load_func("trace_entry", 'BPF_PROG_TYPE_KPROBE')
end
function TestClang:test_complex_leaf_types()
local text = [[
struct list;
struct list {
struct list *selfp;
struct list *another_selfp;
struct list *selfp_array[2];
};
struct empty {
};
union emptyu {
struct empty *em1;
struct empty em2;
struct empty em3;
struct empty em4;
};
BPF_TABLE("array", int, struct list, t1, 1);
BPF_TABLE("array", int, struct list *, t2, 1);
BPF_TABLE("array", int, union emptyu, t3, 1);
]]
local b = BPF:new{text=text}
local ffi = require("ffi")
-- TODO: ptrs?
assert_equals(ffi.sizeof(b:get_table("t3").c_leaf), 8)
end
function TestClang:test_cflags()
local text = [[
#ifndef MYFLAG
#error "MYFLAG not set as expected"
#endif
]]
local b = BPF:new{text=text, cflags={"-DMYFLAG"}}
end
function TestClang:test_exported_maps()
local b1 = BPF{text=[[BPF_TABLE_PUBLIC("hash", int, int, table1, 10);]]}
local b2 = BPF{text=[[BPF_TABLE("extern", int, int, table1, 10);]]}
end
function TestClang:test_syntax_error()
assert_error_msg_contains(
"failed to compile BPF module",
BPF.new,
BPF, {text=[[int failure(void *ctx) { if (); return 0; }]]})
end
os.exit(LuaUnit.run())
require("test_helper")
function test_dump_func()
local raw = "\xb7\x00\x00\x00\x01\x00\x00\x00\x95\x00\x00\x00\x00\x00\x00\x00"
local b = BPF:new{text=[[int entry(void) { return 1; }]]}
assert_equals(b:dump_func("entry"), raw)
end
os.exit(LuaUnit.run())
function setup_path()
local str = require("debug").getinfo(2, "S").source:sub(2)
local cwd = str:match("(.*/)")
local bpf_path = cwd.."/../../src/lua/?.lua;"
local test_path = cwd.."/?.lua;"
package.path = bpf_path..test_path..package.path
end
setup_path()
USE_EXPECTED_ACTUAL_IN_ASSERT_EQUALS = false
EXPORT_ASSERT_TO_GLOBALS = true
require("luaunit")
BCC = require("bcc.init")
BPF = BCC.BPF
log.enabled = false
require("test_helper")
local ffi = require("ffi")
ffi.cdef[[
int getpid(void);
void malloc_stats(void);
]]
TestUprobes = {}
function TestUprobes:test_simple_library()
local text = [[
#include <uapi/linux/ptrace.h>
BPF_TABLE("array", int, u64, stats, 1);
static void incr(int idx) {
u64 *ptr = stats.lookup(&idx);
if (ptr)
++(*ptr);
}
int count(struct pt_regs *ctx) {
u32 pid = bpf_get_current_pid_tgid();
if (pid == PID)
incr(0);
return 0;
}]]
local pid = tonumber(ffi.C.getpid())
local text = text:gsub("PID", tostring(pid))
local b = BPF:new{text=text}
b:attach_uprobe{name="c", sym="malloc_stats", fn_name="count"}
b:attach_uprobe{name="c", sym="malloc_stats", fn_name="count", retprobe=true}
assert_equals(BPF.num_open_uprobes(), 2)
ffi.C.malloc_stats()
local stats = b:get_table("stats")
assert_equals(tonumber(stats:get(0)), 2)
end
function TestUprobes:test_simple_binary()
local text = [[
#include <uapi/linux/ptrace.h>
BPF_TABLE("array", int, u64, stats, 1);
static void incr(int idx) {
u64 *ptr = stats.lookup(&idx);
if (ptr)
++(*ptr);
}
int count(struct pt_regs *ctx) {
u32 pid = bpf_get_current_pid_tgid();
incr(0);
return 0;
}]]
local b = BPF:new{text=text}
b:attach_uprobe{name="/usr/bin/python", sym="main", fn_name="count"}
b:attach_uprobe{name="/usr/bin/python", sym="main", fn_name="count", retprobe=true}
os.spawn("/usr/bin/python -V")
local stats = b:get_table("stats")
assert_true(tonumber(stats:get(0)) >= 2)
end
function TestUprobes:teardown()
BPF.cleanup_probes()
end
os.exit(LuaUnit.run())
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment