Commit 17f1f320 authored by Kevin Modzelewski's avatar Kevin Modzelewski

More bjit work. could use some cleanup / identification of common patterns

parent aa4caa4f
...@@ -464,6 +464,8 @@ void ASTInterpreter::doStore(AST_Name* node, STOLEN(Value) value) { ...@@ -464,6 +464,8 @@ void ASTInterpreter::doStore(AST_Name* node, STOLEN(Value) value) {
jit->emitSetGlobal(globals, name.getBox(), value); jit->emitSetGlobal(globals, name.getBox(), value);
setGlobal(globals, name.getBox(), value.o); setGlobal(globals, name.getBox(), value.o);
Py_DECREF(value.o); Py_DECREF(value.o);
if (jit)
value.var->decref();
} else if (vst == ScopeInfo::VarScopeType::NAME) { } else if (vst == ScopeInfo::VarScopeType::NAME) {
assert(0 && "check refcounting"); assert(0 && "check refcounting");
if (jit) if (jit)
...@@ -570,6 +572,10 @@ Value ASTInterpreter::visit_binop(AST_BinOp* node) { ...@@ -570,6 +572,10 @@ Value ASTInterpreter::visit_binop(AST_BinOp* node) {
Value r = doBinOp(node, left, right, node->op_type, BinExpType::BinOp); Value r = doBinOp(node, left, right, node->op_type, BinExpType::BinOp);
Py_DECREF(left.o); Py_DECREF(left.o);
Py_DECREF(right.o); Py_DECREF(right.o);
if (jit) {
left.var->xdecref();
right.var->xdecref();
}
return r; return r;
} }
...@@ -624,8 +630,10 @@ Value ASTInterpreter::visit_branch(AST_Branch* node) { ...@@ -624,8 +630,10 @@ Value ASTInterpreter::visit_branch(AST_Branch* node) {
Value v = visit_expr(node->test); Value v = visit_expr(node->test);
ASSERT(v.o == True || v.o == False, "Should have called NONZERO before this branch"); ASSERT(v.o == True || v.o == False, "Should have called NONZERO before this branch");
if (jit) if (jit) {
jit->emitSideExit(v, v.o, v.o == True ? node->iffalse : node->iftrue); jit->emitSideExit(v, v.o, v.o == True ? node->iffalse : node->iftrue);
v.var->decref();
}
if (v.o == True) if (v.o == True)
next_block = node->iftrue; next_block = node->iftrue;
...@@ -849,6 +857,11 @@ Value ASTInterpreter::visit_augBinOp(AST_AugBinOp* node) { ...@@ -849,6 +857,11 @@ Value ASTInterpreter::visit_augBinOp(AST_AugBinOp* node) {
Value r = doBinOp(node, left, right, node->op_type, BinExpType::AugBinOp); Value r = doBinOp(node, left, right, node->op_type, BinExpType::AugBinOp);
Py_DECREF(left.o); Py_DECREF(left.o);
Py_DECREF(right.o); Py_DECREF(right.o);
if (jit) {
left.var->decref();
right.var->decref();
}
return r; return r;
} }
...@@ -917,6 +930,8 @@ Value ASTInterpreter::visit_langPrimitive(AST_LangPrimitive* node) { ...@@ -917,6 +930,8 @@ Value ASTInterpreter::visit_langPrimitive(AST_LangPrimitive* node) {
Value obj = visit_expr(node->args[0]); Value obj = visit_expr(node->args[0]);
v = Value(boxBool(nonzero(obj.o)), jit ? jit->emitNonzero(obj) : NULL); v = Value(boxBool(nonzero(obj.o)), jit ? jit->emitNonzero(obj) : NULL);
Py_DECREF(obj.o); Py_DECREF(obj.o);
if (jit)
obj.var->decref();
} else if (node->opcode == AST_LangPrimitive::SET_EXC_INFO) { } else if (node->opcode == AST_LangPrimitive::SET_EXC_INFO) {
assert(node->args.size() == 3); assert(node->args.size() == 3);
...@@ -1114,9 +1129,14 @@ Value ASTInterpreter::visit_makeFunction(AST_MakeFunction* mkfn) { ...@@ -1114,9 +1129,14 @@ Value ASTInterpreter::visit_makeFunction(AST_MakeFunction* mkfn) {
Value func = createFunction(node, args, node->body); Value func = createFunction(node, args, node->body);
for (int i = decorators.size() - 1; i >= 0; i--) { for (int i = decorators.size() - 1; i >= 0; i--) {
if (jit)
func.var = jit->emitRuntimeCall(NULL, decorators[i], ArgPassSpec(1), { func }, NULL);
func.o = runtimeCall(autoDecref(decorators[i].o), ArgPassSpec(1), autoDecref(func.o), 0, 0, 0, 0); func.o = runtimeCall(autoDecref(decorators[i].o), ArgPassSpec(1), autoDecref(func.o), 0, 0, 0, 0);
if (jit) {
auto prev_func_var = func.var;
func.var = jit->emitRuntimeCall(NULL, decorators[i], ArgPassSpec(1), { func }, NULL);
decorators[i].var->decref();
prev_func_var->decref();
}
} }
return func; return func;
} }
...@@ -1290,6 +1310,12 @@ Value ASTInterpreter::visit_print(AST_Print* node) { ...@@ -1290,6 +1310,12 @@ Value ASTInterpreter::visit_print(AST_Print* node) {
else else
printHelper(getSysStdout(), autoXDecref(var.o), node->nl); printHelper(getSysStdout(), autoXDecref(var.o), node->nl);
if (jit) {
if (node->dest)
dest.var->decref();
var.var->decref();
}
return Value(); return Value();
} }
...@@ -1313,6 +1339,10 @@ Value ASTInterpreter::visit_compare(AST_Compare* node) { ...@@ -1313,6 +1339,10 @@ Value ASTInterpreter::visit_compare(AST_Compare* node) {
Value r = doBinOp(node, left, right, node->ops[0], BinExpType::Compare); Value r = doBinOp(node, left, right, node->ops[0], BinExpType::Compare);
Py_DECREF(left.o); Py_DECREF(left.o);
Py_DECREF(right.o); Py_DECREF(right.o);
if (jit) {
left.var->decref();
right.var->decref();
}
return r; return r;
} }
...@@ -1446,6 +1476,12 @@ Value ASTInterpreter::visit_call(AST_Call* node) { ...@@ -1446,6 +1476,12 @@ Value ASTInterpreter::visit_call(AST_Call* node) {
for (auto e : args) for (auto e : args)
Py_DECREF(e); Py_DECREF(e);
if (jit) {
func.var->decref();
for (auto e : args_vars)
e->decref();
}
return v; return v;
} }
...@@ -1617,7 +1653,6 @@ Value ASTInterpreter::visit_list(AST_List* node) { ...@@ -1617,7 +1653,6 @@ Value ASTInterpreter::visit_list(AST_List* node) {
} }
Value ASTInterpreter::visit_tuple(AST_Tuple* node) { Value ASTInterpreter::visit_tuple(AST_Tuple* node) {
return getNone();
llvm::SmallVector<RewriterVar*, 8> items; llvm::SmallVector<RewriterVar*, 8> items;
BoxedTuple* rtn = BoxedTuple::create(node->elts.size()); BoxedTuple* rtn = BoxedTuple::create(node->elts.size());
......
...@@ -263,16 +263,23 @@ RewriterVar* JitFragmentWriter::emitCreateSlice(RewriterVar* start, RewriterVar* ...@@ -263,16 +263,23 @@ RewriterVar* JitFragmentWriter::emitCreateSlice(RewriterVar* start, RewriterVar*
RewriterVar* JitFragmentWriter::emitCreateTuple(const llvm::ArrayRef<RewriterVar*> values) { RewriterVar* JitFragmentWriter::emitCreateTuple(const llvm::ArrayRef<RewriterVar*> values) {
auto num = values.size(); auto num = values.size();
if (num == 0) RewriterVar* r;
return imm(EmptyTuple); if (num == 0) {
else if (num == 1) r = imm(EmptyTuple);
return call(false, (void*)BoxedTuple::create1, values[0]); r->incref();
} else if (num == 1)
r = call(false, (void*)BoxedTuple::create1, values[0]);
else if (num == 2) else if (num == 2)
return call(false, (void*)BoxedTuple::create2, values[0], values[1]); r = call(false, (void*)BoxedTuple::create2, values[0], values[1]);
else if (num == 3) else if (num == 3)
return call(false, (void*)BoxedTuple::create3, values[0], values[1], values[2]); r = call(false, (void*)BoxedTuple::create3, values[0], values[1], values[2]);
else else
return call(false, (void*)createTupleHelper, imm(num), allocArgs(values)); r = call(false, (void*)createTupleHelper, imm(num), allocArgs(values));
for (auto v : values)
v->decref();
return r;
} }
RewriterVar* JitFragmentWriter::emitDeref(InternedString s) { RewriterVar* JitFragmentWriter::emitDeref(InternedString s) {
...@@ -316,8 +323,11 @@ RewriterVar* JitFragmentWriter::emitGetClsAttr(RewriterVar* obj, BoxedString* s) ...@@ -316,8 +323,11 @@ RewriterVar* JitFragmentWriter::emitGetClsAttr(RewriterVar* obj, BoxedString* s)
} }
RewriterVar* JitFragmentWriter::emitGetGlobal(Box* global, BoxedString* s) { RewriterVar* JitFragmentWriter::emitGetGlobal(Box* global, BoxedString* s) {
if (s->s() == "None") if (s->s() == "None") {
return imm(None); RewriterVar* r = imm(None);
r->incref();
return r;
}
return emitPPCall((void*)getGlobal, { imm(global), imm(s) }, 2, 512); return emitPPCall((void*)getGlobal, { imm(global), imm(s) }, 2, 512);
} }
...@@ -490,6 +500,11 @@ void JitFragmentWriter::emitRaise3(RewriterVar* arg0, RewriterVar* arg1, Rewrite ...@@ -490,6 +500,11 @@ void JitFragmentWriter::emitRaise3(RewriterVar* arg0, RewriterVar* arg1, Rewrite
} }
void JitFragmentWriter::emitReturn(RewriterVar* v) { void JitFragmentWriter::emitReturn(RewriterVar* v) {
for (auto v : local_syms) {
if (v.second)
v.second->decref();
}
addAction([=]() { _emitReturn(v); }, { v }, ActionType::NORMAL); addAction([=]() { _emitReturn(v); }, { v }, ActionType::NORMAL);
} }
...@@ -545,6 +560,7 @@ void JitFragmentWriter::emitSetLocal(InternedString s, int vreg, bool set_closur ...@@ -545,6 +560,7 @@ void JitFragmentWriter::emitSetLocal(InternedString s, int vreg, bool set_closur
} }
void JitFragmentWriter::emitSideExit(RewriterVar* v, Box* cmp_value, CFGBlock* next_block) { void JitFragmentWriter::emitSideExit(RewriterVar* v, Box* cmp_value, CFGBlock* next_block) {
assert(0 && "need to decref any local syms");
RewriterVar* var = imm(cmp_value); RewriterVar* var = imm(cmp_value);
RewriterVar* next_block_var = imm(next_block); RewriterVar* next_block_var = imm(next_block);
addAction([=]() { _emitSideExit(v, var, next_block, next_block_var); }, { v, var, next_block_var }, addAction([=]() { _emitSideExit(v, var, next_block, next_block_var); }, { v, var, next_block_var },
...@@ -796,6 +812,11 @@ void JitFragmentWriter::_emitGetLocal(RewriterVar* val_var, const char* name) { ...@@ -796,6 +812,11 @@ void JitFragmentWriter::_emitGetLocal(RewriterVar* val_var, const char* name) {
} }
void JitFragmentWriter::_emitJump(CFGBlock* b, RewriterVar* block_next, ExitInfo& exit_info) { void JitFragmentWriter::_emitJump(CFGBlock* b, RewriterVar* block_next, ExitInfo& exit_info) {
for (auto v : local_syms) {
if (v.second)
v.second->decref(); // xdecref?
}
assert(exit_info.num_bytes == 0); assert(exit_info.num_bytes == 0);
assert(exit_info.exit_start == NULL); assert(exit_info.exit_start == NULL);
if (b->code) { if (b->code) {
......
...@@ -4653,7 +4653,7 @@ Box* binopInternal(Box* lhs, Box* rhs, int op_type, bool inplace, BinopRewriteAr ...@@ -4653,7 +4653,7 @@ Box* binopInternal(Box* lhs, Box* rhs, int op_type, bool inplace, BinopRewriteAr
Box* irtn = NULL; Box* irtn = NULL;
if (inplace) { if (inplace) {
BoxedString* iop_name = getInplaceOpName(op_type); DecrefHandle<BoxedString> iop_name(getInplaceOpName(op_type));
if (rewrite_args) { if (rewrite_args) {
CallattrRewriteArgs srewrite_args(rewrite_args->rewriter, rewrite_args->lhs, rewrite_args->destination); CallattrRewriteArgs srewrite_args(rewrite_args->rewriter, rewrite_args->lhs, rewrite_args->destination);
srewrite_args.arg1 = rewrite_args->rhs; srewrite_args.arg1 = rewrite_args->rhs;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment