Commit 57c7cf6d authored by Kevin Modzelewski's avatar Kevin Modzelewski

Working on cross-BB ref handling

parent f1d4b01f
...@@ -661,6 +661,8 @@ static void emitBBs(IRGenState* irstate, TypeAnalysis* types, const OSREntryDesc ...@@ -661,6 +661,8 @@ static void emitBBs(IRGenState* irstate, TypeAnalysis* types, const OSREntryDesc
ConcreteCompilerType* type = getTypeAtBlockStart(types, s, block); ConcreteCompilerType* type = getTypeAtBlockStart(types, s, block);
llvm::PHINode* phi llvm::PHINode* phi
= emitter->getBuilder()->CreatePHI(type->llvmType(), block->predecessors.size(), s.s()); = emitter->getBuilder()->CreatePHI(type->llvmType(), block->predecessors.size(), s.s());
if (phi->getType() == g.llvm_value_type_ptr)
irstate->getRefcounts()->setType(phi, RefType::OWNED);
ConcreteCompilerVariable* var = new ConcreteCompilerVariable(type, phi); ConcreteCompilerVariable* var = new ConcreteCompilerVariable(type, phi);
generator->giveLocalSymbol(s, var); generator->giveLocalSymbol(s, var);
...@@ -760,6 +762,8 @@ static void emitBBs(IRGenState* irstate, TypeAnalysis* types, const OSREntryDesc ...@@ -760,6 +762,8 @@ static void emitBBs(IRGenState* irstate, TypeAnalysis* types, const OSREntryDesc
// printf("block %d: adding phi for %s from pred %d\n", block->idx, name.c_str(), pred->idx); // printf("block %d: adding phi for %s from pred %d\n", block->idx, name.c_str(), pred->idx);
llvm::PHINode* phi = emitter->getBuilder()->CreatePHI(cv->getType()->llvmType(), llvm::PHINode* phi = emitter->getBuilder()->CreatePHI(cv->getType()->llvmType(),
block->predecessors.size(), name.s()); block->predecessors.size(), name.s());
if (phi->getType() == g.llvm_value_type_ptr)
irstate->getRefcounts()->setType(phi, RefType::OWNED);
// emitter->getBuilder()->CreateCall(g.funcs.dump, phi); // emitter->getBuilder()->CreateCall(g.funcs.dump, phi);
ConcreteCompilerVariable* var = new ConcreteCompilerVariable(cv->getType(), phi); ConcreteCompilerVariable* var = new ConcreteCompilerVariable(cv->getType(), phi);
generator->giveLocalSymbol(name, var); generator->giveLocalSymbol(name, var);
...@@ -838,7 +842,7 @@ static void emitBBs(IRGenState* irstate, TypeAnalysis* types, const OSREntryDesc ...@@ -838,7 +842,7 @@ static void emitBBs(IRGenState* irstate, TypeAnalysis* types, const OSREntryDesc
// Can't always add the phi incoming value right away, since we may have to create more // Can't always add the phi incoming value right away, since we may have to create more
// basic blocks as part of type coercion. // basic blocks as part of type coercion.
// Intsead, just make a record of the phi node, value, and the location of the from-BB, // Instead, just make a record of the phi node, value, and the location of the from-BB,
// which we won't read until after all new BBs have been added. // which we won't read until after all new BBs have been added.
std::vector<std::tuple<llvm::PHINode*, llvm::Value*, llvm::BasicBlock*&>> phi_args; std::vector<std::tuple<llvm::PHINode*, llvm::Value*, llvm::BasicBlock*&>> phi_args;
...@@ -858,9 +862,14 @@ static void emitBBs(IRGenState* irstate, TypeAnalysis* types, const OSREntryDesc ...@@ -858,9 +862,14 @@ static void emitBBs(IRGenState* irstate, TypeAnalysis* types, const OSREntryDesc
llvm::Value* val = v->getValue(); llvm::Value* val = v->getValue();
llvm_phi->addIncoming(v->getValue(), llvm_exit_blocks[b->predecessors[j]]); llvm_phi->addIncoming(v->getValue(), llvm_exit_blocks[b->predecessors[j]]);
llvm::outs() << *v->getValue() << " is getting consumed by phi " << *llvm_phi << '\n';
irstate->getRefcounts()->setType(llvm_phi, RefType::OWNED);
irstate->getRefcounts()->refConsumed(v->getValue(), llvm_exit_blocks[b->predecessors[j]]->getTerminator());
} }
if (this_is_osr_entry) { if (this_is_osr_entry) {
assert(0 && "check refcounting");
ConcreteCompilerVariable* v = (*osr_syms)[it->first]; ConcreteCompilerVariable* v = (*osr_syms)[it->first];
assert(v); assert(v);
...@@ -869,6 +878,7 @@ static void emitBBs(IRGenState* irstate, TypeAnalysis* types, const OSREntryDesc ...@@ -869,6 +878,7 @@ static void emitBBs(IRGenState* irstate, TypeAnalysis* types, const OSREntryDesc
} }
} }
for (auto t : phi_args) { for (auto t : phi_args) {
assert(0 && "check refcounting");
std::get<0>(t)->addIncoming(std::get<1>(t), std::get<2>(t)); std::get<0>(t)->addIncoming(std::get<1>(t), std::get<2>(t));
} }
} }
......
...@@ -197,8 +197,9 @@ private: ...@@ -197,8 +197,9 @@ private:
struct RefcountState { struct RefcountState {
RefType reftype; RefType reftype;
llvm::SmallVector<llvm::Instruction*, 2> ref_consumers; //llvm::SmallVector<llvm::Instruction*, 2> ref_consumers;
}; };
llvm::DenseMap<llvm::Instruction*, llvm::SmallVector<llvm::Value*, 4>> refs_consumed;
llvm::ValueMap<llvm::Value*, RefcountState> vars; llvm::ValueMap<llvm::Value*, RefcountState> vars;
public: public:
......
...@@ -50,16 +50,16 @@ llvm::Value* RefcountTracker::setType(llvm::Value* v, RefType reftype) { ...@@ -50,16 +50,16 @@ llvm::Value* RefcountTracker::setType(llvm::Value* v, RefType reftype) {
} }
void RefcountTracker::refConsumed(llvm::Value* v, llvm::Instruction* inst) { void RefcountTracker::refConsumed(llvm::Value* v, llvm::Instruction* inst) {
auto& var = this->vars[v]; assert(this->vars[v].reftype != RefType::UNKNOWN);
assert(var.reftype != RefType::UNKNOWN);
var.ref_consumers.push_back(inst);
// Make sure that this instruction actually references v: this->refs_consumed[inst].push_back(v);
assert(std::find(inst->op_begin(), inst->op_end(), v) != inst->op_end()); //var.ref_consumers.push_back(inst);
} }
llvm::Instruction* findIncrefPt(llvm::BasicBlock* BB) { llvm::Instruction* findIncrefPt(llvm::BasicBlock* BB) {
ASSERT(pred_begin(BB) == pred_end(BB) || pred_end(BB) == ++pred_begin(BB),
"We shouldn't be inserting anything at the beginning of blocks with multiple predecessors");
llvm::Instruction* incref_pt;// = BB->getFirstInsertionPt(); llvm::Instruction* incref_pt;// = BB->getFirstInsertionPt();
if (llvm::isa<llvm::LandingPadInst>(*BB->begin())) { if (llvm::isa<llvm::LandingPadInst>(*BB->begin())) {
// Don't split up the landingpad+extract+cxa_begin_catch // Don't split up the landingpad+extract+cxa_begin_catch
...@@ -240,6 +240,11 @@ void RefcountTracker::addRefcounts(IRGenState* irstate) { ...@@ -240,6 +240,11 @@ void RefcountTracker::addRefcounts(IRGenState* irstate) {
} }
} }
// Don't actually insert any decrefs initially, since they require changing the control flow of the
// function. Instead just make a note of them and we will add them all at the end.
// This is a list of <val to decref, num_decrefs, instruction to insert before> pairs.
std::vector<std::tuple<llvm::Value*, int, llvm::Instruction*>> pending_decrefs;
while (!block_queue.empty()) { while (!block_queue.empty()) {
llvm::BasicBlock& BB = *block_queue.front(); llvm::BasicBlock& BB = *block_queue.front();
block_queue.pop_front(); block_queue.pop_front();
...@@ -258,6 +263,7 @@ void RefcountTracker::addRefcounts(IRGenState* irstate) { ...@@ -258,6 +263,7 @@ void RefcountTracker::addRefcounts(IRGenState* irstate) {
assert(!states.count(&BB)); assert(!states.count(&BB));
RefState& state = states[&BB]; RefState& state = states[&BB];
// Compute the incoming refstate based on the refstate of any successor nodes
llvm::SmallVector<llvm::BasicBlock*, 4> successors; llvm::SmallVector<llvm::BasicBlock*, 4> successors;
successors.insert(successors.end(), llvm::succ_begin(&BB), llvm::succ_end(&BB)); successors.insert(successors.end(), llvm::succ_begin(&BB), llvm::succ_end(&BB));
if (successors.size()) { if (successors.size()) {
...@@ -298,7 +304,7 @@ void RefcountTracker::addRefcounts(IRGenState* irstate) { ...@@ -298,7 +304,7 @@ void RefcountTracker::addRefcounts(IRGenState* irstate) {
//llvm::outs() << "Need to incref " << *v << " at beginning of " << SBB->getName() << "\n"; //llvm::outs() << "Need to incref " << *v << " at beginning of " << SBB->getName() << "\n";
} else if (this_refs < min_refs) { } else if (this_refs < min_refs) {
assert(refstate.reftype == RefType::OWNED); assert(refstate.reftype == RefType::OWNED);
addDecrefs(v, min_refs - this_refs, findIncrefPt(SBB)); pending_decrefs.push_back(std::make_tuple(v, min_refs - this_refs, findIncrefPt(SBB)));
} }
} }
...@@ -322,29 +328,23 @@ void RefcountTracker::addRefcounts(IRGenState* irstate) { ...@@ -322,29 +328,23 @@ void RefcountTracker::addRefcounts(IRGenState* irstate) {
} }
} }
// A place to store any decrefs we might have to do, since those will split the basic block: // Then, iterate backwards through the instructions in this BB, updating the ref states
llvm::SmallVector<std::pair<llvm::Value*, llvm::Instruction*>, 4> last_uses;
for (auto &I : llvm::iterator_range<llvm::BasicBlock::reverse_iterator>(BB.rbegin(), BB.rend())) { for (auto &I : llvm::iterator_range<llvm::BasicBlock::reverse_iterator>(BB.rbegin(), BB.rend())) {
llvm::DenseMap<llvm::Value*, int> num_consumed_by_inst; llvm::DenseMap<llvm::Value*, int> num_consumed_by_inst;
llvm::DenseMap<llvm::Value*, int> num_times_as_op; llvm::DenseMap<llvm::Value*, int> num_times_as_op;
for (auto v : rt->refs_consumed[&I]) {
num_consumed_by_inst[v]++;
assert(rt->vars[v].reftype != RefType::UNKNOWN);
num_times_as_op[v]; // just make sure it appears in there
}
for (llvm::Value* op : I.operands()) { for (llvm::Value* op : I.operands()) {
auto it = rt->vars.find(op); auto it = rt->vars.find(op);
if (it == rt->vars.end()) if (it == rt->vars.end())
continue; continue;
int& nops = num_times_as_op[op]; num_times_as_op[op]++;
nops++;
if (nops > 1)
continue;
auto&& var_state = it->second;
for (auto consuming_inst : var_state.ref_consumers) {
if (consuming_inst == &I)
num_consumed_by_inst[op]++;
}
} }
for (auto&& p : num_times_as_op) { for (auto&& p : num_times_as_op) {
...@@ -360,7 +360,18 @@ void RefcountTracker::addRefcounts(IRGenState* irstate) { ...@@ -360,7 +360,18 @@ void RefcountTracker::addRefcounts(IRGenState* irstate) {
if (state.refs[op] == 0) { if (state.refs[op] == 0) {
// Don't do any updates now since we are iterating over the bb // Don't do any updates now since we are iterating over the bb
llvm::outs() << "Last use of " << *op << " is at " << I << "; adding a decref after\n"; llvm::outs() << "Last use of " << *op << " is at " << I << "; adding a decref after\n";
last_uses.push_back(std::make_pair(op, &I));
if (llvm::InvokeInst* invoke = llvm::dyn_cast<llvm::InvokeInst>(&I)) {
pending_decrefs.push_back(std::make_tuple(op, 1, findIncrefPt(invoke->getNormalDest())));
pending_decrefs.push_back(std::make_tuple(op, 1, findIncrefPt(invoke->getUnwindDest())));
} else {
assert(&I != I.getParent()->getTerminator());
auto next = I.getNextNode();
//while (llvm::isa<llvm::PHINode>(next))
//next = next->getNextNode();
ASSERT(!llvm::isa<llvm::UnreachableInst>(next), "Can't add decrefs after this function...");
pending_decrefs.push_back(std::make_tuple(op, 1, next));
}
state.refs[op] = 1; state.refs[op] = 1;
} }
} }
...@@ -379,9 +390,14 @@ void RefcountTracker::addRefcounts(IRGenState* irstate) { ...@@ -379,9 +390,14 @@ void RefcountTracker::addRefcounts(IRGenState* irstate) {
if (state.refs[inst] != starting_refs) { if (state.refs[inst] != starting_refs) {
llvm::Instruction* insertion_pt = inst->getNextNode(); llvm::Instruction* insertion_pt = inst->getNextNode();
assert(insertion_pt); assert(insertion_pt);
while (llvm::isa<llvm::PHINode>(insertion_pt)) {
insertion_pt = insertion_pt->getNextNode();
assert(insertion_pt);
}
if (state.refs[inst] < starting_refs) { if (state.refs[inst] < starting_refs) {
assert(p.second.reftype == RefType::OWNED); assert(p.second.reftype == RefType::OWNED);
addDecrefs(inst, starting_refs - state.refs[inst], insertion_pt); pending_decrefs.push_back(std::make_tuple(inst, starting_refs - state.refs[inst], insertion_pt));
} else { } else {
addIncrefs(inst, state.refs[inst] - starting_refs, insertion_pt); addIncrefs(inst, state.refs[inst] - starting_refs, insertion_pt);
} }
...@@ -390,18 +406,7 @@ void RefcountTracker::addRefcounts(IRGenState* irstate) { ...@@ -390,18 +406,7 @@ void RefcountTracker::addRefcounts(IRGenState* irstate) {
} }
} }
for (auto& p : last_uses) { // If this is the entry block, finish dealing with the ref state rather than handing off to a predecessor
if (llvm::InvokeInst* invoke = llvm::dyn_cast<llvm::InvokeInst>(p.second)) {
addDecrefs(p.first, 1, findIncrefPt(invoke->getNormalDest()));
addDecrefs(p.first, 1, findIncrefPt(invoke->getUnwindDest()));
} else {
assert(p.second != p.second->getParent()->getTerminator());
auto next = p.second->getNextNode();
ASSERT(!llvm::isa<llvm::UnreachableInst>(next), "Can't add decrefs after this function...");
addDecrefs(p.first, 1, next);
}
}
if (&BB == &BB.getParent()->front()) { if (&BB == &BB.getParent()->front()) {
for (auto&& p : state.refs) { for (auto&& p : state.refs) {
llvm::outs() << *p.first << " " << p.second << '\n'; llvm::outs() << *p.first << " " << p.second << '\n';
...@@ -426,6 +431,7 @@ void RefcountTracker::addRefcounts(IRGenState* irstate) { ...@@ -426,6 +431,7 @@ void RefcountTracker::addRefcounts(IRGenState* irstate) {
state.refs.clear(); state.refs.clear();
} }
// Look for any new blocks that are ready to be processed:
for (auto&& PBB : llvm::iterator_range<llvm::pred_iterator>(llvm::pred_begin(&BB), llvm::pred_end(&BB))) { for (auto&& PBB : llvm::iterator_range<llvm::pred_iterator>(llvm::pred_begin(&BB), llvm::pred_end(&BB))) {
bool all_succ_done = true; bool all_succ_done = true;
for (auto&& SBB : llvm::iterator_range<llvm::succ_iterator>(llvm::succ_begin(PBB), llvm::succ_end(PBB))) { for (auto&& SBB : llvm::iterator_range<llvm::succ_iterator>(llvm::succ_begin(PBB), llvm::succ_end(PBB))) {
...@@ -442,6 +448,18 @@ void RefcountTracker::addRefcounts(IRGenState* irstate) { ...@@ -442,6 +448,18 @@ void RefcountTracker::addRefcounts(IRGenState* irstate) {
} }
} }
// TODO need to do something about loops
ASSERT(states.size() == f->size(), "We didn't process all nodes... backedges??");
// Add any decrefs that we put off earlier:
for (auto& p : pending_decrefs) {
llvm::Value* v;
int num_refs;
llvm::Instruction* insertion_pt;
std::tie(v, num_refs, insertion_pt) = p;
addDecrefs(v, num_refs, insertion_pt);
}
fprintf(stderr, "After refcounts:\n"); fprintf(stderr, "After refcounts:\n");
fprintf(stderr, "\033[35m"); fprintf(stderr, "\033[35m");
f->dump(); f->dump();
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment