Commit 8ed57a23 authored by Brenden Blanco's avatar Brenden Blanco

Add support for static helper functions

This adds support for static helper functions that can be reused. It is
not necessary to include pt_regs in the helper functions, even though
external pointers may be dereferenced. Arguments in the helpers can also
be reordered.
Signed-off-by: default avatarBrenden Blanco <bblanco@plumgrid.com>
parent 9ada11d1
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <clang/AST/ASTContext.h> #include <clang/AST/ASTContext.h>
#include <clang/AST/RecordLayout.h> #include <clang/AST/RecordLayout.h>
#include <clang/Frontend/CompilerInstance.h> #include <clang/Frontend/CompilerInstance.h>
#include <clang/Frontend/MultiplexConsumer.h>
#include <clang/Rewrite/Core/Rewriter.h> #include <clang/Rewrite/Core/Rewriter.h>
#include "b_frontend_action.h" #include "b_frontend_action.h"
...@@ -36,6 +37,7 @@ const char *calling_conv_regs_x86[] = { ...@@ -36,6 +37,7 @@ const char *calling_conv_regs_x86[] = {
const char **calling_conv_regs = calling_conv_regs_x86; const char **calling_conv_regs = calling_conv_regs_x86;
using std::map; using std::map;
using std::set;
using std::string; using std::string;
using std::to_string; using std::to_string;
using std::unique_ptr; using std::unique_ptr;
...@@ -90,27 +92,107 @@ bool BMapDeclVisitor::VisitBuiltinType(const BuiltinType *T) { ...@@ -90,27 +92,107 @@ bool BMapDeclVisitor::VisitBuiltinType(const BuiltinType *T) {
return true; return true;
} }
class BProbeChecker : public clang::RecursiveASTVisitor<BProbeChecker> { class ProbeChecker : public clang::RecursiveASTVisitor<ProbeChecker> {
public: public:
explicit ProbeChecker(Expr *arg, const set<Decl *> &ptregs)
: needs_probe_(false), ptregs_(ptregs) {
if (arg)
TraverseStmt(arg);
}
bool VisitDeclRefExpr(clang::DeclRefExpr *E) { bool VisitDeclRefExpr(clang::DeclRefExpr *E) {
if (E->getDecl()->hasAttr<UnavailableAttr>()) if (ptregs_.find(E->getDecl()) != ptregs_.end())
return false; needs_probe_ = true;
return true; return true;
} }
bool needs_probe() const { return needs_probe_; }
private:
bool needs_probe_;
const set<Decl *> &ptregs_;
}; };
// Visit a piece of the AST and mark it as needing probe reads // Visit a piece of the AST and mark it as needing probe reads
class BProbeSetter : public clang::RecursiveASTVisitor<BProbeSetter> { class ProbeSetter : public clang::RecursiveASTVisitor<ProbeSetter> {
public: public:
explicit BProbeSetter(ASTContext &C) : C(C) {} explicit ProbeSetter(set<Decl *> *ptregs) : ptregs_(ptregs) {}
bool VisitDeclRefExpr(clang::DeclRefExpr *E) { bool VisitDeclRefExpr(clang::DeclRefExpr *E) {
E->getDecl()->addAttr(UnavailableAttr::CreateImplicit(C, "ptregs")); ptregs_->insert(E->getDecl());
return true; return true;
} }
private: private:
ASTContext &C; set<Decl *> *ptregs_;
}; };
ProbeVisitor::ProbeVisitor(Rewriter &rewriter) : rewriter_(rewriter) {}
bool ProbeVisitor::VisitVarDecl(VarDecl *Decl) {
if (Expr *E = Decl->getInit()) {
if (ProbeChecker(E, ptregs_).needs_probe())
set_ptreg(Decl);
}
return true;
}
bool ProbeVisitor::VisitCallExpr(CallExpr *Call) {
if (FunctionDecl *F = dyn_cast<FunctionDecl>(Call->getCalleeDecl())) {
if (F->hasBody()) {
unsigned i = 0;
for (auto arg : Call->arguments()) {
if (ProbeChecker(arg, ptregs_).needs_probe())
ptregs_.insert(F->getParamDecl(i));
++i;
}
if (fn_visited_.find(F) == fn_visited_.end()) {
fn_visited_.insert(F);
TraverseDecl(F);
}
}
}
return true;
}
bool ProbeVisitor::VisitBinaryOperator(BinaryOperator *E) {
if (!E->isAssignmentOp())
return true;
// copy probe attribute from RHS to LHS if present
if (ProbeChecker(E->getRHS(), ptregs_).needs_probe()) {
ProbeSetter setter(&ptregs_);
setter.TraverseStmt(E->getLHS());
}
return true;
}
bool ProbeVisitor::VisitMemberExpr(MemberExpr *E) {
if (memb_visited_.find(E) != memb_visited_.end()) return true;
// Checks to see if the expression references something that needs to be run
// through bpf_probe_read.
if (!ProbeChecker(E, ptregs_).needs_probe())
return true;
Expr *base;
SourceLocation rhs_start, op;
bool found = false;
for (MemberExpr *M = E; M; M = dyn_cast<MemberExpr>(M->getBase())) {
memb_visited_.insert(M);
rhs_start = M->getLocEnd();
base = M->getBase();
op = M->getOperatorLoc();
if (M->isArrow()) {
found = true;
break;
}
}
if (!found)
return true;
string rhs = rewriter_.getRewrittenText(SourceRange(rhs_start, E->getLocEnd()));
string base_type = base->getType()->getPointeeType().getAsString();
string pre, post;
pre = "({ typeof(" + E->getType().getAsString() + ") _val; memset(&_val, 0, sizeof(_val));";
pre += " bpf_probe_read(&_val, sizeof(_val), (u64)";
post = " + offsetof(" + base_type + ", " + rhs + ")";
post += "); _val; })";
rewriter_.InsertText(E->getLocStart(), pre);
rewriter_.ReplaceText(SourceRange(op, E->getLocEnd()), post);
return true;
}
BTypeVisitor::BTypeVisitor(ASTContext &C, Rewriter &rewriter, vector<TableDesc> &tables) BTypeVisitor::BTypeVisitor(ASTContext &C, Rewriter &rewriter, vector<TableDesc> &tables)
: C(C), rewriter_(rewriter), out_(llvm::errs()), tables_(tables) { : C(C), rewriter_(rewriter), out_(llvm::errs()), tables_(tables) {
} }
...@@ -141,6 +223,11 @@ bool BTypeVisitor::VisitFunctionDecl(FunctionDecl *D) { ...@@ -141,6 +223,11 @@ bool BTypeVisitor::VisitFunctionDecl(FunctionDecl *D) {
// for each trace argument, convert the variable from ptregs to something on stack // for each trace argument, convert the variable from ptregs to something on stack
if (CompoundStmt *S = dyn_cast<CompoundStmt>(D->getBody())) if (CompoundStmt *S = dyn_cast<CompoundStmt>(D->getBody()))
rewriter_.ReplaceText(S->getLBracLoc(), 1, preamble); rewriter_.ReplaceText(S->getLBracLoc(), 1, preamble);
} else if (D->hasBody() &&
rewriter_.getSourceMgr().getFileID(D->getLocStart())
== rewriter_.getSourceMgr().getMainFileID()) {
// rewritable functions that are static should be always treated as helper
rewriter_.InsertText(D->getLocStart(), "__attribute__((always_inline))\n");
} }
return true; return true;
} }
...@@ -282,37 +369,6 @@ bool BTypeVisitor::VisitCallExpr(CallExpr *Call) { ...@@ -282,37 +369,6 @@ bool BTypeVisitor::VisitCallExpr(CallExpr *Call) {
return true; return true;
} }
bool BTypeVisitor::VisitMemberExpr(MemberExpr *E) {
if (visited_.find(E) != visited_.end()) return true;
// Checks to see if the expression references something that needs to be run
// through bpf_probe_read.
BProbeChecker checker;
if (checker.TraverseStmt(E))
return true;
Expr *base;
SourceLocation rhs_start, op;
for (MemberExpr *M = E; M; M = dyn_cast<MemberExpr>(M->getBase())) {
visited_.insert(M);
rhs_start = M->getLocEnd();
base = M->getBase();
op = M->getOperatorLoc();
if (M->isArrow())
break;
}
string rhs = rewriter_.getRewrittenText(SourceRange(rhs_start, E->getLocEnd()));
string base_type = base->getType()->getPointeeType().getAsString();
string pre, post;
pre = "({ typeof(" + E->getType().getAsString() + ") _val; memset(&_val, 0, sizeof(_val));";
pre += " bpf_probe_read(&_val, sizeof(_val), (u64)";
post = " + offsetof(" + base_type + ", " + rhs + ")";
post += "); _val; })";
rewriter_.InsertText(E->getLocStart(), pre);
rewriter_.ReplaceText(SourceRange(op, E->getLocEnd()), post);
return true;
}
bool BTypeVisitor::VisitBinaryOperator(BinaryOperator *E) { bool BTypeVisitor::VisitBinaryOperator(BinaryOperator *E) {
if (!E->isAssignmentOp()) if (!E->isAssignmentOp())
return true; return true;
...@@ -340,12 +396,6 @@ bool BTypeVisitor::VisitBinaryOperator(BinaryOperator *E) { ...@@ -340,12 +396,6 @@ bool BTypeVisitor::VisitBinaryOperator(BinaryOperator *E) {
} }
} }
} }
// copy probe attribute from RHS to LHS if present
BProbeChecker checker;
if (!checker.TraverseStmt(E->getRHS())) {
BProbeSetter setter(C);
setter.TraverseStmt(E->getLHS());
}
return true; return true;
} }
bool BTypeVisitor::VisitImplicitCastExpr(ImplicitCastExpr *E) { bool BTypeVisitor::VisitImplicitCastExpr(ImplicitCastExpr *E) {
...@@ -453,11 +503,6 @@ bool BTypeVisitor::VisitVarDecl(VarDecl *Decl) { ...@@ -453,11 +503,6 @@ bool BTypeVisitor::VisitVarDecl(VarDecl *Decl) {
} }
} }
} }
if (Expr *E = Decl->getInit()) {
BProbeChecker checker;
if (!checker.TraverseStmt(E))
Decl->addAttr(UnavailableAttr::CreateImplicit(C, "ptregs"));
}
return true; return true;
} }
...@@ -465,9 +510,27 @@ BTypeConsumer::BTypeConsumer(ASTContext &C, Rewriter &rewriter, vector<TableDesc ...@@ -465,9 +510,27 @@ BTypeConsumer::BTypeConsumer(ASTContext &C, Rewriter &rewriter, vector<TableDesc
: visitor_(C, rewriter, tables) { : visitor_(C, rewriter, tables) {
} }
bool BTypeConsumer::HandleTopLevelDecl(DeclGroupRef D) { bool BTypeConsumer::HandleTopLevelDecl(DeclGroupRef Group) {
for (auto it : D) for (auto D : Group)
visitor_.TraverseDecl(it); visitor_.TraverseDecl(D);
return true;
}
ProbeConsumer::ProbeConsumer(clang::ASTContext &C, Rewriter &rewriter)
: visitor_(rewriter) {}
bool ProbeConsumer::HandleTopLevelDecl(clang::DeclGroupRef Group) {
for (auto D : Group) {
if (FunctionDecl *F = dyn_cast<FunctionDecl>(D)) {
if (F->isExternallyVisible() && F->hasBody()) {
for (auto arg : F->parameters()) {
if (arg != F->getParamDecl(0))
visitor_.set_ptreg(arg);
}
visitor_.TraverseDecl(D);
}
}
}
return true; return true;
} }
...@@ -476,7 +539,6 @@ BFrontendAction::BFrontendAction(llvm::raw_ostream &os, unsigned flags) ...@@ -476,7 +539,6 @@ BFrontendAction::BFrontendAction(llvm::raw_ostream &os, unsigned flags)
} }
void BFrontendAction::EndSourceFileAction() { void BFrontendAction::EndSourceFileAction() {
// uncomment to see rewritten source
if (flags_ & 0x4) if (flags_ & 0x4)
rewriter_->getEditBuffer(rewriter_->getSourceMgr().getMainFileID()).write(llvm::errs()); rewriter_->getEditBuffer(rewriter_->getSourceMgr().getMainFileID()).write(llvm::errs());
rewriter_->getEditBuffer(rewriter_->getSourceMgr().getMainFileID()).write(os_); rewriter_->getEditBuffer(rewriter_->getSourceMgr().getMainFileID()).write(os_);
...@@ -485,7 +547,10 @@ void BFrontendAction::EndSourceFileAction() { ...@@ -485,7 +547,10 @@ void BFrontendAction::EndSourceFileAction() {
unique_ptr<ASTConsumer> BFrontendAction::CreateASTConsumer(CompilerInstance &Compiler, llvm::StringRef InFile) { unique_ptr<ASTConsumer> BFrontendAction::CreateASTConsumer(CompilerInstance &Compiler, llvm::StringRef InFile) {
rewriter_->setSourceMgr(Compiler.getSourceManager(), Compiler.getLangOpts()); rewriter_->setSourceMgr(Compiler.getSourceManager(), Compiler.getLangOpts());
return unique_ptr<ASTConsumer>(new BTypeConsumer(Compiler.getASTContext(), *rewriter_, *tables_)); vector<unique_ptr<ASTConsumer>> consumers;
consumers.push_back(unique_ptr<ASTConsumer>(new ProbeConsumer(Compiler.getASTContext(), *rewriter_)));
consumers.push_back(unique_ptr<ASTConsumer>(new BTypeConsumer(Compiler.getASTContext(), *rewriter_, *tables_)));
return unique_ptr<ASTConsumer>(new MultiplexConsumer(move(consumers)));
} }
} }
...@@ -66,7 +66,6 @@ class BTypeVisitor : public clang::RecursiveASTVisitor<BTypeVisitor> { ...@@ -66,7 +66,6 @@ class BTypeVisitor : public clang::RecursiveASTVisitor<BTypeVisitor> {
bool VisitFunctionDecl(clang::FunctionDecl *D); bool VisitFunctionDecl(clang::FunctionDecl *D);
bool VisitCallExpr(clang::CallExpr *Call); bool VisitCallExpr(clang::CallExpr *Call);
bool VisitVarDecl(clang::VarDecl *Decl); bool VisitVarDecl(clang::VarDecl *Decl);
bool VisitMemberExpr(clang::MemberExpr *E);
bool VisitBinaryOperator(clang::BinaryOperator *E); bool VisitBinaryOperator(clang::BinaryOperator *E);
bool VisitImplicitCastExpr(clang::ImplicitCastExpr *E); bool VisitImplicitCastExpr(clang::ImplicitCastExpr *E);
...@@ -79,16 +78,41 @@ class BTypeVisitor : public clang::RecursiveASTVisitor<BTypeVisitor> { ...@@ -79,16 +78,41 @@ class BTypeVisitor : public clang::RecursiveASTVisitor<BTypeVisitor> {
std::set<clang::Expr *> visited_; std::set<clang::Expr *> visited_;
}; };
// Do a depth-first search to rewrite all pointers that need to be probed
class ProbeVisitor : public clang::RecursiveASTVisitor<ProbeVisitor> {
public:
explicit ProbeVisitor(clang::Rewriter &rewriter);
bool VisitVarDecl(clang::VarDecl *Decl);
bool VisitCallExpr(clang::CallExpr *Call);
bool VisitBinaryOperator(clang::BinaryOperator *E);
bool VisitMemberExpr(clang::MemberExpr *E);
void set_ptreg(clang::Decl *D) { ptregs_.insert(D); }
private:
clang::Rewriter &rewriter_;
std::set<clang::Decl *> fn_visited_;
std::set<clang::Expr *> memb_visited_;
std::set<clang::Decl *> ptregs_;
};
// A helper class to the frontend action, walks the decls // A helper class to the frontend action, walks the decls
class BTypeConsumer : public clang::ASTConsumer { class BTypeConsumer : public clang::ASTConsumer {
public: public:
explicit BTypeConsumer(clang::ASTContext &C, clang::Rewriter &rewriter, explicit BTypeConsumer(clang::ASTContext &C, clang::Rewriter &rewriter,
std::vector<TableDesc> &tables); std::vector<TableDesc> &tables);
bool HandleTopLevelDecl(clang::DeclGroupRef D) override; bool HandleTopLevelDecl(clang::DeclGroupRef Group) override;
private: private:
BTypeVisitor visitor_; BTypeVisitor visitor_;
}; };
// A helper class to the frontend action, walks the decls
class ProbeConsumer : public clang::ASTConsumer {
public:
ProbeConsumer(clang::ASTContext &C, clang::Rewriter &rewriter);
bool HandleTopLevelDecl(clang::DeclGroupRef Group) override;
private:
ProbeVisitor visitor_;
};
// Create a B program in 2 phases (everything else is normal C frontend): // Create a B program in 2 phases (everything else is normal C frontend):
// 1. Catch the map declarations and open the fd's // 1. Catch the map declarations and open the fd's
// 2. Capture the IR // 2. Capture the IR
......
...@@ -104,7 +104,6 @@ int pem(struct __sk_buff *skb) { ...@@ -104,7 +104,6 @@ int pem(struct __sk_buff *skb) {
return 1; return 1;
} }
static int br_common(struct __sk_buff *skb, int which_br) __attribute__((always_inline));
static int br_common(struct __sk_buff *skb, int which_br) { static int br_common(struct __sk_buff *skb, int which_br) {
u8 *cursor = 0; u8 *cursor = 0;
u16 proto; u16 proto;
......
...@@ -172,5 +172,40 @@ int kprobe__blk_update_request(struct pt_regs *ctx, struct request *req) { ...@@ -172,5 +172,40 @@ int kprobe__blk_update_request(struct pt_regs *ctx, struct request *req) {
return 0; return 0;
}""") }""")
def test_probe_read_helper(self):
b = BPF(text="""
#include <linux/fs.h>
static void print_file_name(struct file *file) {
if (!file) return;
const char *name = file->f_path.dentry->d_name.name;
bpf_trace_printk("%s\\n", name);
}
int trace_entry(struct pt_regs *ctx, struct file *file) {
print_file_name(file);
return 0;
}
""")
fn = b.load_func("trace_entry", BPF.KPROBE)
def test_probe_struct_assign(self):
b = BPF(text = """
#include <uapi/linux/ptrace.h>
struct args_t {
const char *filename;
int flags;
int mode;
};
int kprobe__sys_open(struct pt_regs *ctx, const char *filename,
int flags, int mode) {
struct args_t args = {};
args.filename = filename;
args.flags = flags;
args.mode = mode;
bpf_trace_printk("%s\\n", args.filename);
return 0;
};
""")
if __name__ == "__main__": if __name__ == "__main__":
main() main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment