diff --git a/src/cc/frontends/clang/b_frontend_action.cc b/src/cc/frontends/clang/b_frontend_action.cc index d5c14d1535ced61c55e005da280ab1a55da460a6..cec7e46946dd7097c36af1ac4852bbf21ca2c890 100644 --- a/src/cc/frontends/clang/b_frontend_action.cc +++ b/src/cc/frontends/clang/b_frontend_action.cc @@ -214,11 +214,56 @@ ProbeVisitor::ProbeVisitor(ASTContext &C, Rewriter &rewriter, set<Decl *> &m, bool track_helpers) : C(C), rewriter_(rewriter), m_(m), track_helpers_(track_helpers) {} +bool ProbeVisitor::assignsExtPtr(Expr *E, int *nbAddrOf) { + if (IsContextMemberExpr(E)) { + *nbAddrOf = 0; + return true; + } + + ProbeChecker checker = ProbeChecker(E, ptregs_, track_helpers_, + true); + if (checker.is_transitive()) { + // The negative of the number of dereferences is the number of addrof. In + // an assignment, if we went through n addrof before getting the external + // pointer, then we'll need n dereferences on the left-hand side variable + // to get to the external pointer. + *nbAddrOf = -checker.get_nb_derefs(); + return true; + } + + if (E->getStmtClass() == Stmt::CallExprClass) { + CallExpr *Call = dyn_cast<CallExpr>(E); + if (MemberExpr *Memb = dyn_cast<MemberExpr>(Call->getCallee()->IgnoreImplicit())) { + StringRef memb_name = Memb->getMemberDecl()->getName(); + if (DeclRefExpr *Ref = dyn_cast<DeclRefExpr>(Memb->getBase())) { + if (SectionAttr *A = Ref->getDecl()->getAttr<SectionAttr>()) { + if (!A->getName().startswith("maps")) + return false; + + if (memb_name == "lookup" || memb_name == "lookup_or_init") { + if (m_.find(Ref->getDecl()) != m_.end()) { + // Retrieved an ext. pointer from a map, mark LHS as ext. pointer. + // Pointers from maps always need a single dereference to get the + // actual value. The value may be an external pointer but cannot + // be a pointer to an external pointer as the verifier prohibits + // storing known pointers (to map values, context, the stack, or + // the packet) in maps. + *nbAddrOf = 1; + return true; + } + } + } + } + } + } + return false; +} bool ProbeVisitor::VisitVarDecl(VarDecl *D) { if (Expr *E = D->getInit()) { - ProbeChecker checker = ProbeChecker(E, ptregs_, track_helpers_, true); - if (checker.is_transitive() || IsContextMemberExpr(E)) { - tuple<Decl *, int> pt = make_tuple(D, checker.get_nb_derefs()); + int nbAddrOf; + if (assignsExtPtr(E, &nbAddrOf)) { + // The negative of the number of addrof is the number of dereferences. + tuple<Decl *, int> pt = make_tuple(D, -nbAddrOf); set_ptreg(pt); } } @@ -249,42 +294,11 @@ bool ProbeVisitor::VisitCallExpr(CallExpr *Call) { bool ProbeVisitor::VisitBinaryOperator(BinaryOperator *E) { if (!E->isAssignmentOp()) return true; - // copy probe attribute from RHS to LHS if present - ProbeChecker checker = ProbeChecker(E->getRHS(), ptregs_, track_helpers_, - true); - if (checker.is_transitive()) { - // The negative of the number of dereferences is the number of addrof. In - // an assignment, if we went through n addrof before getting the external - // pointer, then we'll need n dereferences on the left-hand side variable - // to get to the external pointer. - ProbeSetter setter(&ptregs_, -checker.get_nb_derefs()); - setter.TraverseStmt(E->getLHS()); - } else if (E->getRHS()->getStmtClass() == Stmt::CallExprClass) { - CallExpr *Call = dyn_cast<CallExpr>(E->getRHS()); - if (MemberExpr *Memb = dyn_cast<MemberExpr>(Call->getCallee()->IgnoreImplicit())) { - StringRef memb_name = Memb->getMemberDecl()->getName(); - if (DeclRefExpr *Ref = dyn_cast<DeclRefExpr>(Memb->getBase())) { - if (SectionAttr *A = Ref->getDecl()->getAttr<SectionAttr>()) { - if (!A->getName().startswith("maps")) - return true; - if (memb_name == "lookup" || memb_name == "lookup_or_init") { - if (m_.find(Ref->getDecl()) != m_.end()) { - // Retrieved an ext. pointer from a map, mark LHS as ext. pointer. - // Pointers from maps always need a single dereference to get the - // actual value. The value may be an external pointer but cannot - // be a pointer to an external pointer as the verifier prohibits - // storing known pointers (to map values, context, the stack, or - // the packet) in maps. - ProbeSetter setter(&ptregs_, 1); - setter.TraverseStmt(E->getLHS()); - } - } - } - } - } - } else if (IsContextMemberExpr(E->getRHS())) { - ProbeSetter setter(&ptregs_); + // copy probe attribute from RHS to LHS if present + int nbAddrOf; + if (assignsExtPtr(E->getRHS(), &nbAddrOf)) { + ProbeSetter setter(&ptregs_, nbAddrOf); setter.TraverseStmt(E->getLHS()); } return true; @@ -355,7 +369,6 @@ bool ProbeVisitor::IsContextMemberExpr(Expr *E) { bool found = false; MemberExpr *M; for (M = Memb; M; M = dyn_cast<MemberExpr>(M->getBase())) { - memb_visited_.insert(M); rhs_start = M->getLocEnd(); base = M->getBase(); member = M->getMemberLoc(); diff --git a/src/cc/frontends/clang/b_frontend_action.h b/src/cc/frontends/clang/b_frontend_action.h index 490381d7ecd17648606eda3c9ec39cccf1c09628..7dc373c6b00bbedf446749bc458665cda9710563 100644 --- a/src/cc/frontends/clang/b_frontend_action.h +++ b/src/cc/frontends/clang/b_frontend_action.h @@ -99,6 +99,7 @@ class ProbeVisitor : public clang::RecursiveASTVisitor<ProbeVisitor> { void set_ctx(clang::Decl *D) { ctx_ = D; } std::set<std::tuple<clang::Decl *, int>> get_ptregs() { return ptregs_; } private: + bool assignsExtPtr(clang::Expr *E, int *nbAddrOf); bool IsContextMemberExpr(clang::Expr *E); clang::SourceRange expansionRange(clang::SourceRange range); template <unsigned N> diff --git a/tests/python/test_clang.py b/tests/python/test_clang.py index 191a65df83e1e11df73fbb290fe4c215b372477c..797db9555ca98793c5cdc92383eafcd9377c96ca 100755 --- a/tests/python/test_clang.py +++ b/tests/python/test_clang.py @@ -538,7 +538,7 @@ int process(struct xdp_md *ctx) { t = b["act"] self.assertEqual(len(t), 32); - def test_ext_ptr_maps(self): + def test_ext_ptr_maps1(self): bpf_text = """ #include <uapi/linux/ptrace.h> #include <net/sock.h> @@ -568,6 +568,35 @@ int trace_exit(struct pt_regs *ctx) { b.load_func("trace_entry", BPF.KPROBE) b.load_func("trace_exit", BPF.KPROBE) + def test_ext_ptr_maps2(self): + bpf_text = """ +#include <uapi/linux/ptrace.h> +#include <net/sock.h> +#include <bcc/proto.h> + +BPF_HASH(currsock, u32, struct sock *); + +int trace_entry(struct pt_regs *ctx, struct sock *sk, + struct sockaddr *uaddr, int addr_len) { + u32 pid = bpf_get_current_pid_tgid(); + currsock.update(&pid, &sk); + return 0; +}; + +int trace_exit(struct pt_regs *ctx) { + u32 pid = bpf_get_current_pid_tgid(); + struct sock **skpp = currsock.lookup(&pid); + if (skpp) { + struct sock *skp = *skpp; + return skp->__sk_common.skc_dport; + } + return 0; +} + """ + b = BPF(text=bpf_text) + b.load_func("trace_entry", BPF.KPROBE) + b.load_func("trace_exit", BPF.KPROBE) + def test_ext_ptr_maps_reverse(self): bpf_text = """ #include <uapi/linux/ptrace.h>