Commit 82c676a1 authored by Marius Wachtler's avatar Marius Wachtler

bjit: don't directly do a OSR from the bjit

We can't directly do OSR from the bjit frame because it will cause issues with exception handling.
Reason is that the bjit and the OSRed code share the same python frame and the way invokes are implemented in the
bjit. During unwinding we will see the OSR frame and will remove it and continue to unwind but the try catch
block inside ASTInterpreter::execJITedBlock will rethrow the exception which causes another frame deinit,
which is wrong because it already got removed.
Instead we return back to the interpreter loop with special value (osr_dummy_value) which will trigger the OSR from there.
parent 31aba47e
......@@ -642,6 +642,8 @@ void Assembler::incl(Indirect mem) {
assert(src_idx >= 0 && src_idx < 8);
bool needssib = (src_idx == 0b100);
if (rex)
emitRex(rex);
emitByte(0xff);
......@@ -649,8 +651,12 @@ void Assembler::incl(Indirect mem) {
assert(-0x80 <= mem.offset && mem.offset < 0x80);
if (mem.offset == 0) {
emitModRM(0b00, 0, src_idx);
if (needssib)
emitSIB(0b00, 0b100, src_idx);
} else {
emitModRM(0b01, 0, src_idx);
if (needssib)
emitSIB(0b00, 0b100, src_idx);
emitByte(mem.offset);
}
}
......@@ -737,13 +743,13 @@ void Assembler::cmp(Register reg, Immediate imm) {
emitArith(imm, reg, OPCODE_CMP);
}
void Assembler::cmp(Indirect mem, Immediate imm) {
void Assembler::cmp(Indirect mem, Immediate imm, MovType type) {
int64_t val = imm.val;
assert(fitsInto<int32_t>(val));
int src_idx = mem.base.regnum;
int rex = REX_W;
assert(type == MovType::Q || type == MovType::L);
int rex = type == MovType::Q ? REX_W : 0;
if (src_idx >= 8) {
rex |= REX_B;
src_idx -= 8;
......@@ -751,17 +757,26 @@ void Assembler::cmp(Indirect mem, Immediate imm) {
assert(src_idx >= 0 && src_idx < 8);
emitRex(rex);
bool needssib = (src_idx == 0b100);
if (rex)
emitRex(rex);
emitByte(0x81);
if (mem.offset == 0) {
emitModRM(0b00, 7, src_idx);
if (needssib)
emitSIB(0b00, 0b100, src_idx);
} else if (-0x80 <= mem.offset && mem.offset < 0x80) {
emitModRM(0b01, 7, src_idx);
if (needssib)
emitSIB(0b00, 0b100, src_idx);
emitByte(mem.offset);
} else {
assert(fitsInto<int32_t>(mem.offset));
emitModRM(0b10, 7, src_idx);
if (needssib)
emitSIB(0b00, 0b100, src_idx);
emitInt(mem.offset, 4);
}
......
......@@ -168,9 +168,11 @@ public:
void cmp(Register reg1, Register reg2);
void cmp(Register reg, Immediate imm);
void cmp(Indirect mem, Immediate imm);
void cmp(Indirect mem, Immediate imm, MovType type = MovType::Q);
void cmpl(Indirect mem, Immediate imm) { return cmp(mem, imm, MovType::L); }
void cmp(Indirect mem, Register reg);
void lea(Indirect mem, Register reg);
void test(Register reg1, Register reg2);
......
......@@ -137,6 +137,7 @@ private:
// instructions
CFGBlock* next_block, *current_block;
FrameInfo frame_info;
unsigned edgecount;
SourceInfo* source_info;
ScopeInfo* scope_info;
......@@ -145,7 +146,6 @@ private:
ExcInfo last_exception;
BoxedClosure* created_closure;
BoxedGenerator* generator;
unsigned edgecount;
BoxedModule* parent_module;
std::unique_ptr<JitFragmentWriter> jit;
......@@ -229,6 +229,7 @@ void ASTInterpreter::setGlobals(Box* globals) {
ASTInterpreter::ASTInterpreter(FunctionMetadata* md, Box** vregs)
: current_block(0),
frame_info(ExcInfo(NULL, NULL, NULL)),
edgecount(0),
source_info(md->source.get()),
scope_info(0),
phis(NULL),
......@@ -236,7 +237,6 @@ ASTInterpreter::ASTInterpreter(FunctionMetadata* md, Box** vregs)
last_exception(NULL, NULL, NULL),
created_closure(0),
generator(0),
edgecount(0),
parent_module(source_info->parent_module),
should_jit(false) {
......@@ -320,8 +320,7 @@ Box* ASTInterpreter::execJITedBlock(CFGBlock* b) {
UNAVOIDABLE_STAT_TIMER(t0, "us_timer_in_baseline_jitted_code");
std::pair<CFGBlock*, Box*> rtn = b->entry_code(this, b, vregs);
next_block = rtn.first;
if (!next_block)
return rtn.second;
return rtn.second;
} catch (ExcInfo e) {
AST_stmt* stmt = getCurrentStatement();
if (stmt->type != AST_TYPE::Invoke)
......@@ -384,6 +383,15 @@ Box* ASTInterpreter::executeInner(ASTInterpreter& interpreter, CFGBlock* start_b
Box* rtn = interpreter.execJITedBlock(b);
if (interpreter.next_block)
continue;
// check if we returned from the baseline JIT because we should do a OSR.
if (unlikely(rtn == (Box*)ASTInterpreterJitInterface::osr_dummy_value)) {
AST_Jump* cur_stmt = (AST_Jump*)interpreter.getCurrentStatement();
RELEASE_ASSERT(cur_stmt->type == AST_TYPE::Jump, "");
// WARNING: do not put a try catch + rethrow block around this code here.
// it will confuse our unwinder!
rtn = interpreter.doOSR(cur_stmt);
}
return rtn;
}
}
......@@ -607,7 +615,7 @@ Value ASTInterpreter::visit_jump(AST_Jump* node) {
}
if (jit) {
if (backedge)
if (backedge && ENABLE_OSR && !FORCE_INTERPRETER)
jit->emitOSRPoint(node);
jit->emitJump(node->target);
finishJITing(node->target);
......@@ -1614,6 +1622,11 @@ int ASTInterpreterJitInterface::getCurrentInstOffset() {
return offsetof(ASTInterpreter, frame_info.stmt);
}
int ASTInterpreterJitInterface::getEdgeCountOffset() {
static_assert(sizeof(ASTInterpreter::edgecount) == 4, "caller assumes that");
return offsetof(ASTInterpreter, edgecount);
}
int ASTInterpreterJitInterface::getGeneratorOffset() {
return offsetof(ASTInterpreter, generator);
}
......@@ -1655,14 +1668,6 @@ Box* ASTInterpreterJitInterface::derefHelper(void* _interpreter, InternedString
return val;
}
Box* ASTInterpreterJitInterface::doOSRHelper(void* _interpreter, AST_Jump* node) {
ASTInterpreter* interpreter = (ASTInterpreter*)_interpreter;
++interpreter->edgecount;
if (interpreter->edgecount >= OSR_THRESHOLD_BASELINE)
return interpreter->doOSR(node);
return NULL;
}
Box* ASTInterpreterJitInterface::landingpadHelper(void* _interpreter) {
ASTInterpreter* interpreter = (ASTInterpreter*)_interpreter;
ExcInfo& last_exception = interpreter->last_exception;
......
......@@ -35,15 +35,18 @@ struct LineInfo;
extern const void* interpreter_instr_addr;
struct ASTInterpreterJitInterface {
// Special value which when returned from the bjit will trigger a OSR.
static constexpr uint64_t osr_dummy_value = -1;
static int getBoxedLocalsOffset();
static int getCurrentBlockOffset();
static int getCurrentInstOffset();
static int getEdgeCountOffset();
static int getGeneratorOffset();
static int getGlobalsOffset();
static void delNameHelper(void* _interpreter, InternedString name);
static Box* derefHelper(void* interp, InternedString s);
static Box* doOSRHelper(void* interp, AST_Jump* node);
static Box* landingpadHelper(void* interp);
static void pendingCallsCheckHelper();
static Box* setExcInfoHelper(void* interp, Box* type, Box* value, Box* traceback);
......
......@@ -469,9 +469,7 @@ void JitFragmentWriter::emitJump(CFGBlock* b) {
}
void JitFragmentWriter::emitOSRPoint(AST_Jump* node) {
RewriterVar* node_var = imm(node);
RewriterVar* result = createNewVar();
addAction([=]() { _emitOSRPoint(result, node_var); }, { result, node_var, getInterp() }, ActionType::NORMAL);
addAction([=]() { _emitOSRPoint(); }, { getInterp() }, ActionType::NORMAL);
}
void JitFragmentWriter::emitPendingCallsCheck() {
......@@ -822,24 +820,30 @@ void JitFragmentWriter::_emitJump(CFGBlock* b, RewriterVar* block_next, int& siz
block_next->bumpUse();
}
void JitFragmentWriter::_emitOSRPoint(RewriterVar* result, RewriterVar* node_var) {
RewriterVar::SmallVector args;
args.push_back(getInterp());
args.push_back(node_var);
_call(result, false, (void*)ASTInterpreterJitInterface::doOSRHelper, args, RewriterVar::SmallVector());
auto result_reg = result->getInReg(assembler::RDX);
result->bumpUse();
assembler->test(result_reg, result_reg);
void JitFragmentWriter::_emitOSRPoint() {
// We can't directly do OSR from the bjit frame because it will cause issues with exception handling.
// Reason is that the bjit and the OSRed code share the same python frame and the way invokes are implemented in the
// bjit. During unwinding we will see the OSR frame and will remove it and continue to unwind but the try catch
// block inside ASTInterpreter::execJITedBlock will rethrow the exception which causes another frame deinit,
// which is wrong because it already got removed.
// Instead we return back to the interpreter loop with special value (osr_dummy_value) which will trigger the OSR.
// this generates code for:
// if (++interpreter.edgecount < OSR_THRESHOLD_BASELINE)
// return std::make_pair((CFGBlock*)0, ASTInterpreterJitInterface::osr_dummy_value);
assembler::Register interp_reg = getInterp()->getInReg(); // will always be R12
assembler::Indirect edgecount = assembler::Indirect(interp_reg, ASTInterpreterJitInterface::getEdgeCountOffset());
assembler->incl(edgecount); // 32bit inc
assembler->cmpl(edgecount, assembler::Immediate(OSR_THRESHOLD_BASELINE)); // 32bit cmp
{
assembler::ForwardJump je(*assembler, assembler::COND_EQUAL);
assembler->clear_reg(assembler::RAX);
assembler::ForwardJump jl(*assembler, assembler::COND_BELOW);
assembler->clear_reg(assembler::RAX); // = next block to execute
assembler->mov(assembler::Immediate(ASTInterpreterJitInterface::osr_dummy_value), assembler::RDX);
assembler->add(assembler::Immediate(JitCodeBlock::sp_adjustment), assembler::RSP);
assembler->pop(assembler::R12);
assembler->pop(assembler::R14);
assembler->retq();
}
interp->bumpUse();
assertConsistent();
}
......
......@@ -294,7 +294,7 @@ private:
void _emitGetLocal(RewriterVar* val_var, const char* name);
void _emitJump(CFGBlock* b, RewriterVar* block_next, int& size_of_exit_to_interp);
void _emitOSRPoint(RewriterVar* result, RewriterVar* node_var);
void _emitOSRPoint();
void _emitPPCall(RewriterVar* result, void* func_addr, llvm::ArrayRef<RewriterVar*> args, int num_slots,
int slot_size, AST* ast_node);
void _emitRecordType(RewriterVar* type_recorder_var, RewriterVar* obj_cls_var);
......
# this specific test used to crash
def f(x):
if x:
raise Exception
def osr_f():
for i in range(10000):
f(False)
f(True)
try:
osr_f()
except Exception:
print "exc"
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment