Commit 45788284 authored by Marius Wachtler's avatar Marius Wachtler

Merge pull request #1121 from undingen/bjit_exc_fix2

bjit: don't directly do a OSR from the bjit
parents 31aba47e 82c676a1
......@@ -642,6 +642,8 @@ void Assembler::incl(Indirect mem) {
assert(src_idx >= 0 && src_idx < 8);
bool needssib = (src_idx == 0b100);
if (rex)
emitRex(rex);
emitByte(0xff);
......@@ -649,8 +651,12 @@ void Assembler::incl(Indirect mem) {
assert(-0x80 <= mem.offset && mem.offset < 0x80);
if (mem.offset == 0) {
emitModRM(0b00, 0, src_idx);
if (needssib)
emitSIB(0b00, 0b100, src_idx);
} else {
emitModRM(0b01, 0, src_idx);
if (needssib)
emitSIB(0b00, 0b100, src_idx);
emitByte(mem.offset);
}
}
......@@ -737,13 +743,13 @@ void Assembler::cmp(Register reg, Immediate imm) {
emitArith(imm, reg, OPCODE_CMP);
}
void Assembler::cmp(Indirect mem, Immediate imm) {
void Assembler::cmp(Indirect mem, Immediate imm, MovType type) {
int64_t val = imm.val;
assert(fitsInto<int32_t>(val));
int src_idx = mem.base.regnum;
int rex = REX_W;
assert(type == MovType::Q || type == MovType::L);
int rex = type == MovType::Q ? REX_W : 0;
if (src_idx >= 8) {
rex |= REX_B;
src_idx -= 8;
......@@ -751,17 +757,26 @@ void Assembler::cmp(Indirect mem, Immediate imm) {
assert(src_idx >= 0 && src_idx < 8);
emitRex(rex);
bool needssib = (src_idx == 0b100);
if (rex)
emitRex(rex);
emitByte(0x81);
if (mem.offset == 0) {
emitModRM(0b00, 7, src_idx);
if (needssib)
emitSIB(0b00, 0b100, src_idx);
} else if (-0x80 <= mem.offset && mem.offset < 0x80) {
emitModRM(0b01, 7, src_idx);
if (needssib)
emitSIB(0b00, 0b100, src_idx);
emitByte(mem.offset);
} else {
assert(fitsInto<int32_t>(mem.offset));
emitModRM(0b10, 7, src_idx);
if (needssib)
emitSIB(0b00, 0b100, src_idx);
emitInt(mem.offset, 4);
}
......
......@@ -168,9 +168,11 @@ public:
void cmp(Register reg1, Register reg2);
void cmp(Register reg, Immediate imm);
void cmp(Indirect mem, Immediate imm);
void cmp(Indirect mem, Immediate imm, MovType type = MovType::Q);
void cmpl(Indirect mem, Immediate imm) { return cmp(mem, imm, MovType::L); }
void cmp(Indirect mem, Register reg);
void lea(Indirect mem, Register reg);
void test(Register reg1, Register reg2);
......
......@@ -137,6 +137,7 @@ private:
// instructions
CFGBlock* next_block, *current_block;
FrameInfo frame_info;
unsigned edgecount;
SourceInfo* source_info;
ScopeInfo* scope_info;
......@@ -145,7 +146,6 @@ private:
ExcInfo last_exception;
BoxedClosure* created_closure;
BoxedGenerator* generator;
unsigned edgecount;
BoxedModule* parent_module;
std::unique_ptr<JitFragmentWriter> jit;
......@@ -229,6 +229,7 @@ void ASTInterpreter::setGlobals(Box* globals) {
ASTInterpreter::ASTInterpreter(FunctionMetadata* md, Box** vregs)
: current_block(0),
frame_info(ExcInfo(NULL, NULL, NULL)),
edgecount(0),
source_info(md->source.get()),
scope_info(0),
phis(NULL),
......@@ -236,7 +237,6 @@ ASTInterpreter::ASTInterpreter(FunctionMetadata* md, Box** vregs)
last_exception(NULL, NULL, NULL),
created_closure(0),
generator(0),
edgecount(0),
parent_module(source_info->parent_module),
should_jit(false) {
......@@ -320,8 +320,7 @@ Box* ASTInterpreter::execJITedBlock(CFGBlock* b) {
UNAVOIDABLE_STAT_TIMER(t0, "us_timer_in_baseline_jitted_code");
std::pair<CFGBlock*, Box*> rtn = b->entry_code(this, b, vregs);
next_block = rtn.first;
if (!next_block)
return rtn.second;
return rtn.second;
} catch (ExcInfo e) {
AST_stmt* stmt = getCurrentStatement();
if (stmt->type != AST_TYPE::Invoke)
......@@ -384,6 +383,15 @@ Box* ASTInterpreter::executeInner(ASTInterpreter& interpreter, CFGBlock* start_b
Box* rtn = interpreter.execJITedBlock(b);
if (interpreter.next_block)
continue;
// check if we returned from the baseline JIT because we should do a OSR.
if (unlikely(rtn == (Box*)ASTInterpreterJitInterface::osr_dummy_value)) {
AST_Jump* cur_stmt = (AST_Jump*)interpreter.getCurrentStatement();
RELEASE_ASSERT(cur_stmt->type == AST_TYPE::Jump, "");
// WARNING: do not put a try catch + rethrow block around this code here.
// it will confuse our unwinder!
rtn = interpreter.doOSR(cur_stmt);
}
return rtn;
}
}
......@@ -607,7 +615,7 @@ Value ASTInterpreter::visit_jump(AST_Jump* node) {
}
if (jit) {
if (backedge)
if (backedge && ENABLE_OSR && !FORCE_INTERPRETER)
jit->emitOSRPoint(node);
jit->emitJump(node->target);
finishJITing(node->target);
......@@ -1614,6 +1622,11 @@ int ASTInterpreterJitInterface::getCurrentInstOffset() {
return offsetof(ASTInterpreter, frame_info.stmt);
}
int ASTInterpreterJitInterface::getEdgeCountOffset() {
static_assert(sizeof(ASTInterpreter::edgecount) == 4, "caller assumes that");
return offsetof(ASTInterpreter, edgecount);
}
int ASTInterpreterJitInterface::getGeneratorOffset() {
return offsetof(ASTInterpreter, generator);
}
......@@ -1655,14 +1668,6 @@ Box* ASTInterpreterJitInterface::derefHelper(void* _interpreter, InternedString
return val;
}
Box* ASTInterpreterJitInterface::doOSRHelper(void* _interpreter, AST_Jump* node) {
ASTInterpreter* interpreter = (ASTInterpreter*)_interpreter;
++interpreter->edgecount;
if (interpreter->edgecount >= OSR_THRESHOLD_BASELINE)
return interpreter->doOSR(node);
return NULL;
}
Box* ASTInterpreterJitInterface::landingpadHelper(void* _interpreter) {
ASTInterpreter* interpreter = (ASTInterpreter*)_interpreter;
ExcInfo& last_exception = interpreter->last_exception;
......
......@@ -35,15 +35,18 @@ struct LineInfo;
extern const void* interpreter_instr_addr;
struct ASTInterpreterJitInterface {
// Special value which when returned from the bjit will trigger a OSR.
static constexpr uint64_t osr_dummy_value = -1;
static int getBoxedLocalsOffset();
static int getCurrentBlockOffset();
static int getCurrentInstOffset();
static int getEdgeCountOffset();
static int getGeneratorOffset();
static int getGlobalsOffset();
static void delNameHelper(void* _interpreter, InternedString name);
static Box* derefHelper(void* interp, InternedString s);
static Box* doOSRHelper(void* interp, AST_Jump* node);
static Box* landingpadHelper(void* interp);
static void pendingCallsCheckHelper();
static Box* setExcInfoHelper(void* interp, Box* type, Box* value, Box* traceback);
......
......@@ -469,9 +469,7 @@ void JitFragmentWriter::emitJump(CFGBlock* b) {
}
void JitFragmentWriter::emitOSRPoint(AST_Jump* node) {
RewriterVar* node_var = imm(node);
RewriterVar* result = createNewVar();
addAction([=]() { _emitOSRPoint(result, node_var); }, { result, node_var, getInterp() }, ActionType::NORMAL);
addAction([=]() { _emitOSRPoint(); }, { getInterp() }, ActionType::NORMAL);
}
void JitFragmentWriter::emitPendingCallsCheck() {
......@@ -822,24 +820,30 @@ void JitFragmentWriter::_emitJump(CFGBlock* b, RewriterVar* block_next, int& siz
block_next->bumpUse();
}
void JitFragmentWriter::_emitOSRPoint(RewriterVar* result, RewriterVar* node_var) {
RewriterVar::SmallVector args;
args.push_back(getInterp());
args.push_back(node_var);
_call(result, false, (void*)ASTInterpreterJitInterface::doOSRHelper, args, RewriterVar::SmallVector());
auto result_reg = result->getInReg(assembler::RDX);
result->bumpUse();
assembler->test(result_reg, result_reg);
void JitFragmentWriter::_emitOSRPoint() {
// We can't directly do OSR from the bjit frame because it will cause issues with exception handling.
// Reason is that the bjit and the OSRed code share the same python frame and the way invokes are implemented in the
// bjit. During unwinding we will see the OSR frame and will remove it and continue to unwind but the try catch
// block inside ASTInterpreter::execJITedBlock will rethrow the exception which causes another frame deinit,
// which is wrong because it already got removed.
// Instead we return back to the interpreter loop with special value (osr_dummy_value) which will trigger the OSR.
// this generates code for:
// if (++interpreter.edgecount < OSR_THRESHOLD_BASELINE)
// return std::make_pair((CFGBlock*)0, ASTInterpreterJitInterface::osr_dummy_value);
assembler::Register interp_reg = getInterp()->getInReg(); // will always be R12
assembler::Indirect edgecount = assembler::Indirect(interp_reg, ASTInterpreterJitInterface::getEdgeCountOffset());
assembler->incl(edgecount); // 32bit inc
assembler->cmpl(edgecount, assembler::Immediate(OSR_THRESHOLD_BASELINE)); // 32bit cmp
{
assembler::ForwardJump je(*assembler, assembler::COND_EQUAL);
assembler->clear_reg(assembler::RAX);
assembler::ForwardJump jl(*assembler, assembler::COND_BELOW);
assembler->clear_reg(assembler::RAX); // = next block to execute
assembler->mov(assembler::Immediate(ASTInterpreterJitInterface::osr_dummy_value), assembler::RDX);
assembler->add(assembler::Immediate(JitCodeBlock::sp_adjustment), assembler::RSP);
assembler->pop(assembler::R12);
assembler->pop(assembler::R14);
assembler->retq();
}
interp->bumpUse();
assertConsistent();
}
......
......@@ -294,7 +294,7 @@ private:
void _emitGetLocal(RewriterVar* val_var, const char* name);
void _emitJump(CFGBlock* b, RewriterVar* block_next, int& size_of_exit_to_interp);
void _emitOSRPoint(RewriterVar* result, RewriterVar* node_var);
void _emitOSRPoint();
void _emitPPCall(RewriterVar* result, void* func_addr, llvm::ArrayRef<RewriterVar*> args, int num_slots,
int slot_size, AST* ast_node);
void _emitRecordType(RewriterVar* type_recorder_var, RewriterVar* obj_cls_var);
......
# this specific test used to crash
def f(x):
if x:
raise Exception
def osr_f():
for i in range(10000):
f(False)
f(True)
try:
osr_f()
except Exception:
print "exc"
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment