Commit c120e21f authored by Kevin Modzelewski's avatar Kevin Modzelewski

Allow closures into/through genexps

The issue was that if we transformed the AST nodes corresponding to scopes,
we wouldn't be able to match the initial analysis with the subsequent queries
to the transformed AST nodes.

We had run into that before, but worked around it by just modifying the AST
nodes in place.  For generator expressions that wasn't a possibility,
so now we explicitly registers when we replace scope-related AST nodes.
parent e38400db
......@@ -212,7 +212,6 @@ public:
virtual bool visit_for(AST_For* node) { return false; }
// virtual bool visit_functiondef(AST_FunctionDef *node) { return false; }
// virtual bool visit_global(AST_Global *node) { return false; }
virtual bool visit_generatorexp(AST_GeneratorExp* node) { return false; }
virtual bool visit_if(AST_If* node) { return false; }
virtual bool visit_ifexp(AST_IfExp* node) { return false; }
virtual bool visit_index(AST_Index* node) { return false; }
......@@ -301,6 +300,26 @@ public:
}
}
virtual bool visit_generatorexp(AST_GeneratorExp* node) {
if (node == orig_node) {
bool first = true;
for (AST_comprehension* c : node->generators) {
if (!first)
c->iter->accept(this);
c->target->accept(this);
first = false;
}
node->elt->accept(this);
} else {
node->generators[0]->iter->accept(this);
(*map)[node] = new ScopingAnalysis::ScopeNameUsage(node, cur);
collect(node, map);
}
return true;
}
virtual bool visit_lambda(AST_Lambda* node) {
if (node == orig_node) {
for (AST_expr* e : node->args->args)
......@@ -426,9 +445,15 @@ void ScopingAnalysis::processNameUsages(ScopingAnalysis::NameUsageMap* usages) {
case AST_TYPE::ClassDef:
case AST_TYPE::FunctionDef:
case AST_TYPE::Lambda: {
ScopeInfoBase* scopInfo = new ScopeInfoBase(parent_info, usage);
scopInfo->setTakesGenerator(containsYield(node));
this->scopes[node] = scopInfo;
ScopeInfoBase* scopeInfo = new ScopeInfoBase(parent_info, usage);
scopeInfo->setTakesGenerator(containsYield(node));
this->scopes[node] = scopeInfo;
break;
}
case AST_TYPE::GeneratorExp: {
ScopeInfoBase* scopeInfo = new ScopeInfoBase(parent_info, usage);
scopeInfo->setTakesGenerator(true);
this->scopes[node] = scopeInfo;
break;
}
default:
......@@ -447,30 +472,36 @@ ScopeInfo* ScopingAnalysis::analyzeSubtree(AST* node) {
ScopeInfo* rtn = scopes[node];
assert(rtn);
return rtn;
}
rtn->setTakesGenerator(containsYield(node));
void ScopingAnalysis::registerScopeReplacement(AST* original_node, AST* new_node) {
assert(scope_replacements.count(original_node) == 0);
assert(scope_replacements.count(new_node) == 0);
assert(scopes.count(new_node) == 0);
return rtn;
#ifndef NDEBUG
// NULL this out just to make sure it doesn't get accessed:
scopes[new_node] = NULL;
#endif
scope_replacements[new_node] = original_node;
}
ScopeInfo* ScopingAnalysis::getScopeInfoForNode(AST* node) {
assert(node);
ScopeInfo* rtn = scopes[node];
if (rtn)
return rtn;
auto it = scope_replacements.find(node);
if (it != scope_replacements.end())
node = it->second;
switch (node->type) {
case AST_TYPE::ClassDef:
case AST_TYPE::FunctionDef:
case AST_TYPE::Lambda:
return analyzeSubtree(node);
// this is handled in the constructor:
// case AST_TYPE::Module:
// return new ModuleScopeInfo();
default:
RELEASE_ASSERT(0, "%d", node->type);
auto rtn = scopes.find(node);
if (rtn != scopes.end()) {
assert(rtn->second);
return rtn->second;
}
return analyzeSubtree(node);
}
ScopingAnalysis::ScopingAnalysis(AST_Module* m) : parent_module(m) {
......
......@@ -57,10 +57,21 @@ private:
std::unordered_map<AST*, ScopeInfo*> scopes;
AST_Module* parent_module;
std::unordered_map<AST*, AST*> scope_replacements;
ScopeInfo* analyzeSubtree(AST* node);
void processNameUsages(NameUsageMap* usages);
public:
// The scope-analysis is done before any CFG-ization is done,
// but many of the queries will be done post-CFG-ization.
// The CFG process can replace scope AST nodes with others (ex:
// generator expressions with generator functions), so we need to
// have a way of mapping the original analysis with the new queries.
// This is a hook for the CFG process to register when it has replaced
// a scope-node with a different node.
void registerScopeReplacement(AST* original_node, AST* new_node);
ScopingAnalysis(AST_Module* m);
ScopeInfo* getScopeInfoForNode(AST* node);
};
......
......@@ -979,10 +979,10 @@ CompiledFunction* doCompile(SourceInfo* source, const OSREntryDescriptor* entry_
std::vector<llvm::Type*> llvm_arg_types;
if (entry_descriptor == NULL) {
if (source->scoping->getScopeInfoForNode(source->ast)->takesClosure())
if (source->getScopeInfo()->takesClosure())
llvm_arg_types.push_back(g.llvm_closure_type_ptr);
if (source->scoping->getScopeInfoForNode(source->ast)->takesGenerator())
if (source->getScopeInfo()->takesGenerator())
llvm_arg_types.push_back(g.llvm_generator_type_ptr);
for (int i = 0; i < nargs; i++) {
......@@ -1022,7 +1022,7 @@ CompiledFunction* doCompile(SourceInfo* source, const OSREntryDescriptor* entry_
if (ENABLE_SPECULATION && effort >= EffortLevel::MODERATE)
speculation_level = TypeAnalysis::SOME;
TypeAnalysis* types = doTypeAnalysis(source->cfg, source->arg_names, spec->arg_types, effort, speculation_level,
source->scoping->getScopeInfoForNode(source->ast));
source->getScopeInfo());
_t2.split();
......@@ -1060,9 +1060,8 @@ CompiledFunction* doCompile(SourceInfo* source, const OSREntryDescriptor* entry_
assert(deopt_full_blocks.size() || deopt_partial_blocks.size());
irgen_us += _t2.split();
TypeAnalysis* deopt_types
= doTypeAnalysis(source->cfg, source->arg_names, spec->arg_types, effort, TypeAnalysis::NONE,
source->scoping->getScopeInfoForNode(source->ast));
TypeAnalysis* deopt_types = doTypeAnalysis(source->cfg, source->arg_names, spec->arg_types, effort,
TypeAnalysis::NONE, source->getScopeInfo());
_t2.split();
emitBBs(&irstate, "deopt", deopt_guards, guards, deopt_types, NULL, deopt_full_blocks, deopt_partial_blocks);
......
......@@ -77,6 +77,10 @@ const std::string SourceInfo::getName() {
}
}
ScopeInfo* SourceInfo::getScopeInfo() {
return scoping->getScopeInfoForNode(ast);
}
EffortLevel::EffortLevel initialEffort() {
if (FORCE_OPTIMIZE)
return EffortLevel::MAXIMAL;
......@@ -177,8 +181,7 @@ CompiledFunction* compileFunction(CLFunction* f, FunctionSpecialization* spec, E
assert(source->ast);
source->cfg = computeCFG(source, source->body);
source->liveness = computeLivenessInfo(source->cfg);
source->phis = computeRequiredPhis(source->arg_names, source->cfg, source->liveness,
source->scoping->getScopeInfoForNode(source->ast));
source->phis = computeRequiredPhis(source->arg_names, source->cfg, source->liveness, source->getScopeInfo());
}
CompiledFunction* cf = doCompile(source, entry, effort, spec, name);
......@@ -244,7 +247,7 @@ void compileAndRunModule(AST_Module* m, BoxedModule* bm) {
SourceInfo* si = new SourceInfo(bm, scoping, m, m->body);
si->cfg = computeCFG(si, m->body);
si->liveness = computeLivenessInfo(si->cfg);
si->phis = computeRequiredPhis(si->arg_names, si->cfg, si->liveness, si->scoping->getScopeInfoForNode(si->ast));
si->phis = computeRequiredPhis(si->arg_names, si->cfg, si->liveness, si->getScopeInfo());
CLFunction* cl_f = new CLFunction(0, 0, false, false, si);
......
......@@ -66,8 +66,12 @@ llvm::Value* IRGenState::getScratchSpace(int min_bytes) {
}
ScopeInfo* IRGenState::getScopeInfo() {
SourceInfo* source = getSourceInfo();
return source->scoping->getScopeInfoForNode(source->ast);
return getSourceInfo()->getScopeInfo();
}
ScopeInfo* IRGenState::getScopeInfoForNode(AST* node) {
auto source = getSourceInfo();
return source->scoping->getScopeInfoForNode(node);
}
GuardList::ExprTypeGuard::ExprTypeGuard(CFGBlock* cfg_block, llvm::BranchInst* branch, AST_expr* ast_node,
......@@ -1416,7 +1420,7 @@ private:
return;
assert(node->type == AST_TYPE::ClassDef);
ScopeInfo* scope_info = irstate->getSourceInfo()->scoping->getScopeInfoForNode(node);
ScopeInfo* scope_info = irstate->getScopeInfoForNode(node);
assert(scope_info);
std::vector<CompilerVariable*> bases;
......@@ -1612,11 +1616,11 @@ private:
if (irstate->getSourceInfo()->ast->type == AST_TYPE::Module)
takes_closure = false;
else {
takes_closure = irstate->getSourceInfo()->scoping->getScopeInfoForNode(node)->takesClosure();
takes_closure = irstate->getScopeInfoForNode(node)->takesClosure();
}
// TODO: this lines disables the optimization mentioned above...
bool is_generator = irstate->getSourceInfo()->scoping->getScopeInfoForNode(node)->takesGenerator();
bool is_generator = irstate->getScopeInfoForNode(node)->takesGenerator();
if (takes_closure) {
if (irstate->getScopeInfo()->createsClosure()) {
......@@ -1910,6 +1914,8 @@ private:
} else if (var->getType() == FLOAT) {
// val = emitter.getBuilder()->CreateBitCast(val, g.llvm_value_type_ptr);
ptr = emitter.getBuilder()->CreateBitCast(ptr, g.double_->getPointerTo());
} else if (var->getType() == GENERATOR) {
ptr = emitter.getBuilder()->CreateBitCast(ptr, g.llvm_generator_type_ptr->getPointerTo());
} else if (var->getType() == UNDEF) {
// TODO if there are any undef variables, we're in 'unreachable' territory.
// Do we even need to generate any of this code?
......
......@@ -83,6 +83,7 @@ public:
SourceInfo* getSourceInfo() { return source_info; }
ScopeInfo* getScopeInfo();
ScopeInfo* getScopeInfoForNode(AST* node);
llvm::MDNode* getFuncDbgInfo() { return func_dbg_info; }
};
......
......@@ -263,6 +263,13 @@ const LineInfo* getLineInfoForInterpretedFrame(void* frame_ptr) {
}
}
void dumpLLVM(llvm::Value* v) {
v->dump();
}
void dumpLLVM(llvm::Instruction* v) {
v->dump();
}
Box* interpretFunction(llvm::Function* f, int nargs, Box* closure, Box* generator, Box* arg1, Box* arg2, Box* arg3,
Box** args) {
assert(f);
......
......@@ -67,6 +67,7 @@ private:
AST_TYPE::AST_TYPE root_type;
CFG* cfg;
CFGBlock* curblock;
ScopingAnalysis* scoping_analysis;
struct LoopInfo {
CFGBlock* continue_dest, *break_dest;
......@@ -648,6 +649,8 @@ private:
std::string func_name(nodeName(func));
func->name = func_name;
scoping_analysis->registerScopeReplacement(node, func);
func->args = new AST_arguments();
func->args->vararg = "";
func->args->kwarg = "";
......@@ -947,7 +950,8 @@ private:
}
public:
CFGVisitor(AST_TYPE::AST_TYPE root_type, CFG* cfg) : root_type(root_type), cfg(cfg) {
CFGVisitor(AST_TYPE::AST_TYPE root_type, ScopingAnalysis* scoping_analysis, CFG* cfg)
: root_type(root_type), cfg(cfg), scoping_analysis(scoping_analysis) {
curblock = cfg->addBlock();
curblock->info = "entry";
}
......@@ -1955,7 +1959,10 @@ void CFG::print() {
CFG* computeCFG(SourceInfo* source, std::vector<AST_stmt*> body) {
CFG* rtn = new CFG();
CFGVisitor visitor(source->ast->type, rtn);
ScopingAnalysis* scoping_analysis = source->scoping;
CFGVisitor visitor(source->ast->type, scoping_analysis, rtn);
if (source->ast->type == AST_TYPE::ClassDef) {
// A classdef always starts with "__module__ = __name__"
......@@ -1987,8 +1994,6 @@ CFG* computeCFG(SourceInfo* source, std::vector<AST_stmt*> body) {
// The functions we create for classdefs are supposed to return a dictionary of their locals.
// This is the place that we add all of that:
if (source->ast->type == AST_TYPE::ClassDef) {
ScopeInfo* scope_info = source->scoping->getScopeInfoForNode(source->ast);
AST_LangPrimitive* locals = new AST_LangPrimitive(AST_LangPrimitive::LOCALS);
AST_Return* rtn = new AST_Return();
......
......@@ -180,6 +180,7 @@ public:
};
class BoxedModule;
class ScopeInfo;
class SourceInfo {
public:
BoxedModule* parent_module;
......@@ -189,6 +190,8 @@ public:
LivenessAnalysis* liveness;
PhiAnalysis* phis;
ScopeInfo* getScopeInfo();
struct ArgNames {
const std::vector<AST_expr*>* args;
const std::string* vararg, *kwarg;
......
......@@ -150,6 +150,12 @@ extern "C" Box* pow_i64_i64(i64 lhs, i64 rhs) {
if (rhs < 0)
return boxFloat(pow_float_float(lhs, rhs));
if (rhs == 0) {
if (lhs < 0)
return boxInt(-1);
return boxInt(1);
}
assert(rhs > 0);
while (rhs) {
if (rhs & 1) {
......
# Test to make sure that generators create and receive closures as appropriate.
def f(E, N, M):
print list((i**E for i in xrange(N) for j in xrange(M)))
f(4, 3, 2)
def f2(x):
yield list((list(i for i in xrange(y)) for y in xrange(x)))
print list(f2(4))
def f3(z):
print list((lambda x: x**y)(z) for y in xrange(10))
f3(4)
# Generator-closure handling also needs to handle when the closures
# are at the module scope:
n = 5
g1 = (i for i in xrange(n))
......@@ -9,3 +9,7 @@ for i in xrange(1, 12):
print i | j
print i & j
print i ^ j
print 1 ** 0
print 0 ** 0
print -1 ** 0
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment