Commit 6781ec24 authored by Kevin Modzelewski's avatar Kevin Modzelewski

Merge pull request #90 from undingen/lambda_expr2

Implement lambda expressions
parents b2b06576 0688008c
......@@ -66,6 +66,13 @@ public:
_doStore(node->name);
return true;
}
bool visit_lambda(AST_Lambda* node) {
for (auto* d : node->args->defaults)
d->accept(this);
return true;
}
bool visit_name(AST_Name* node) {
if (node->ctx_type == AST_TYPE::Load)
_doLoad(node->id);
......
......@@ -246,6 +246,25 @@ public:
}
}
virtual bool visit_lambda(AST_Lambda* node) {
if (node == orig_node) {
for (AST_expr* e : node->args->args)
e->accept(this);
if (node->args->vararg.size())
doWrite(node->args->vararg);
if (node->args->kwarg.size())
doWrite(node->args->kwarg);
node->body->accept(this);
} else {
for (auto* e : node->args->defaults)
e->accept(this);
(*map)[node] = new ScopingAnalysis::ScopeNameUsage(node, cur);
collect(node, map);
}
return true;
}
virtual bool visit_import(AST_Import* node) {
for (int i = 0; i < node->names.size(); i++) {
AST_alias* alias = node->names[i];
......@@ -292,15 +311,13 @@ static std::vector<ScopingAnalysis::ScopeNameUsage*> sortNameUsages(ScopingAnaly
}
void ScopingAnalysis::processNameUsages(ScopingAnalysis::NameUsageMap* usages) {
typedef ScopeNameUsage::StrSet StrSet;
// Resolve name lookups:
for (const auto& p : *usages) {
ScopeNameUsage* usage = p.second;
for (StrSet::iterator it2 = usage->read.begin(), end2 = usage->read.end(); it2 != end2; ++it2) {
if (usage->forced_globals.count(*it2))
for (const auto& name : usage->read) {
if (usage->forced_globals.count(name))
continue;
if (usage->written.count(*it2))
if (usage->written.count(name))
continue;
std::vector<ScopeNameUsage*> intermediate_parents;
......@@ -309,15 +326,15 @@ void ScopingAnalysis::processNameUsages(ScopingAnalysis::NameUsageMap* usages) {
while (parent) {
if (parent->node->type == AST_TYPE::ClassDef) {
parent = parent->parent;
} else if (parent->forced_globals.count(*it2)) {
} else if (parent->forced_globals.count(name)) {
break;
} else if (parent->written.count(*it2)) {
usage->got_from_closure.insert(*it2);
parent->referenced_from_nested.insert(*it2);
} else if (parent->written.count(name)) {
usage->got_from_closure.insert(name);
parent->referenced_from_nested.insert(name);
for (ScopeNameUsage* iparent : intermediate_parents) {
iparent->referenced_from_nested.insert(*it2);
iparent->got_from_closure.insert(*it2);
iparent->referenced_from_nested.insert(name);
iparent->got_from_closure.insert(name);
}
break;
......@@ -340,10 +357,9 @@ void ScopingAnalysis::processNameUsages(ScopingAnalysis::NameUsageMap* usages) {
ScopeInfo* parent_info = this->scopes[(usage->parent == NULL) ? this->parent_module : usage->parent->node];
switch (node->type) {
case AST_TYPE::FunctionDef:
this->scopes[node] = new ScopeInfoBase(parent_info, usage);
break;
case AST_TYPE::ClassDef:
case AST_TYPE::FunctionDef:
case AST_TYPE::Lambda:
this->scopes[node] = new ScopeInfoBase(parent_info, usage);
break;
default:
......@@ -375,6 +391,7 @@ ScopeInfo* ScopingAnalysis::getScopeInfoForNode(AST* node) {
switch (node->type) {
case AST_TYPE::ClassDef:
case AST_TYPE::FunctionDef:
case AST_TYPE::Lambda:
return analyzeSubtree(node);
// this is handled in the constructor:
// case AST_TYPE::Module:
......
......@@ -351,6 +351,8 @@ private:
virtual void* visit_index(AST_Index* node) { return getType(node->value); }
virtual void* visit_lambda(AST_Lambda* node) { return typeFromClass(function_cls); }
virtual void* visit_langprimitive(AST_LangPrimitive* node) {
switch (node->opcode) {
case AST_LangPrimitive::ISINSTANCE:
......
......@@ -51,6 +51,11 @@ SourceInfo::ArgNames::ArgNames(AST* ast) {
args = &f->args->args;
vararg = &f->args->vararg;
kwarg = &f->args->kwarg;
} else if (ast->type == AST_TYPE::Lambda) {
AST_Lambda* l = ast_cast<AST_Lambda>(ast);
args = &l->args->args;
vararg = &l->args->vararg;
kwarg = &l->args->kwarg;
} else {
RELEASE_ASSERT(0, "%d", ast->type);
}
......@@ -63,6 +68,8 @@ const std::string SourceInfo::getName() {
return ast_cast<AST_ClassDef>(ast)->name;
case AST_TYPE::FunctionDef:
return ast_cast<AST_FunctionDef>(ast)->name;
case AST_TYPE::Lambda:
return "<lambda>";
case AST_TYPE::Module:
return this->parent_module->name();
default:
......@@ -70,20 +77,6 @@ const std::string SourceInfo::getName() {
}
}
const std::vector<AST_stmt*>& SourceInfo::getBody() {
assert(ast);
switch (ast->type) {
case AST_TYPE::ClassDef:
return ast_cast<AST_ClassDef>(ast)->body;
case AST_TYPE::FunctionDef:
return ast_cast<AST_FunctionDef>(ast)->body;
case AST_TYPE::Module:
return ast_cast<AST_Module>(ast)->body;
default:
RELEASE_ASSERT(0, "%d", ast->type);
}
}
EffortLevel::EffortLevel initialEffort() {
if (FORCE_OPTIMIZE)
return EffortLevel::MAXIMAL;
......@@ -169,7 +162,7 @@ CompiledFunction* compileFunction(CLFunction* f, FunctionSpecialization* spec, E
// Do the analysis now if we had deferred it earlier:
if (source->cfg == NULL) {
assert(source->ast);
source->cfg = computeCFG(source, source->getBody());
source->cfg = computeCFG(source, source->body);
source->liveness = computeLivenessInfo(source->cfg);
source->phis = computeRequiredPhis(source->arg_names, source->cfg, source->liveness,
source->scoping->getScopeInfoForNode(source->ast));
......@@ -231,7 +224,7 @@ void compileAndRunModule(AST_Module* m, BoxedModule* bm) {
ScopingAnalysis* scoping = runScopingAnalysis(m);
SourceInfo* si = new SourceInfo(bm, scoping, m);
SourceInfo* si = new SourceInfo(bm, scoping, m, m->body);
si->cfg = computeCFG(si, m->body);
si->liveness = computeLivenessInfo(si->cfg);
si->phis = computeRequiredPhis(si->arg_names, si->cfg, si->liveness, si->scoping->getScopeInfoForNode(si->ast));
......
......@@ -757,6 +757,21 @@ private:
return evalExpr(node->value, exc_info);
}
CompilerVariable* evalLambda(AST_Lambda* node, ExcInfo exc_info) {
assert(state != PARTIAL);
AST_Return* expr = new AST_Return();
expr->value = node->body;
std::vector<AST_stmt*> body = { expr };
CompilerVariable* func = _createFunction(node, exc_info, node->args, body);
ConcreteCompilerVariable* converted = func->makeConverted(emitter, func->getBoxType());
func->decvref(emitter);
return converted;
}
CompilerVariable* evalList(AST_List* node, ExcInfo exc_info) {
assert(state != PARTIAL);
......@@ -1052,6 +1067,9 @@ private:
case AST_TYPE::Index:
rtn = evalIndex(ast_cast<AST_Index>(node), exc_info);
break;
case AST_TYPE::Lambda:
rtn = evalLambda(ast_cast<AST_Lambda>(node), exc_info);
break;
case AST_TYPE::List:
rtn = evalList(ast_cast<AST_List>(node), exc_info);
break;
......@@ -1417,10 +1435,9 @@ private:
ConcreteCompilerVariable* converted_base = base->makeConverted(emitter, base->getBoxType());
base->decvref(emitter);
CLFunction* cl = _wrapClassDef(node);
CLFunction* cl = _wrapFunction(node, nullptr, node->body);
// TODO duplication with doFunctionDef:
// TODO duplication with _createFunction:
CompilerVariable* created_closure = NULL;
if (scope_info->takesClosure()) {
created_closure = _getFake(CREATED_CLOSURE_NAME, false);
......@@ -1495,46 +1512,29 @@ private:
converted_slice->decvref(emitter);
}
CLFunction* _wrapFunction(AST_FunctionDef* node) {
CLFunction* _wrapFunction(AST* node, AST_arguments* args, const std::vector<AST_stmt*>& body) {
// Different compilations of the parent scope of a functiondef should lead
// to the same CLFunction* being used:
static std::unordered_map<AST_FunctionDef*, CLFunction*> made;
static std::unordered_map<AST*, CLFunction*> made;
CLFunction*& cl = made[node];
if (cl == NULL) {
SourceInfo* si
= new SourceInfo(irstate->getSourceInfo()->parent_module, irstate->getSourceInfo()->scoping, node);
si->ast = node;
cl = new CLFunction(node->args->args.size(), node->args->defaults.size(), node->args->vararg.size(),
node->args->kwarg.size(), si);
}
return cl;
}
CLFunction* _wrapClassDef(AST_ClassDef* node) {
// TODO duplication with _wrapFunction
static std::unordered_map<AST_ClassDef*, CLFunction*> made;
CLFunction*& cl = made[node];
if (cl == NULL) {
SourceInfo* si
= new SourceInfo(irstate->getSourceInfo()->parent_module, irstate->getSourceInfo()->scoping, node);
si->ast = node;
cl = new CLFunction(0, 0, 0, 0, si);
SourceInfo* source = irstate->getSourceInfo();
SourceInfo* si = new SourceInfo(source->parent_module, source->scoping, node, body);
if (args)
cl = new CLFunction(args->args.size(), args->defaults.size(), args->vararg.size(), args->kwarg.size(), si);
else
cl = new CLFunction(0, 0, 0, 0, si);
}
return cl;
}
void doFunctionDef(AST_FunctionDef* node, ExcInfo exc_info) {
if (state == PARTIAL)
return;
assert(!node->decorator_list.size());
CLFunction* cl = this->_wrapFunction(node);
CompilerVariable* _createFunction(AST* node, ExcInfo exc_info, AST_arguments* args,
const std::vector<AST_stmt*>& body) {
CLFunction* cl = this->_wrapFunction(node, args, body);
std::vector<ConcreteCompilerVariable*> defaults;
for (auto d : node->args->defaults) {
for (auto d : args->defaults) {
CompilerVariable* e = evalExpr(d, exc_info);
ConcreteCompilerVariable* converted = e->makeConverted(emitter, e->getBoxType());
e->decvref(emitter);
......@@ -1558,7 +1558,16 @@ private:
// llvm::Value *boxed = emitter.getBuilder()->CreateCall(g.funcs.boxCLFunction, embedConstantPtr(cl,
// boxCLFuncArgType));
// CompilerVariable *func = new ConcreteCompilerVariable(typeFromClass(function_cls), boxed, true);
return func;
}
void doFunctionDef(AST_FunctionDef* node, ExcInfo exc_info) {
if (state == PARTIAL)
return;
assert(!node->decorator_list.size());
CompilerVariable* func = _createFunction(node, exc_info, node->args, node->body);
_doSet(node->name, func, exc_info);
func->decvref(emitter);
}
......
......@@ -457,6 +457,16 @@ AST_keyword* read_keyword(BufferedReader* reader) {
return rtn;
}
AST_Lambda* read_lambda(BufferedReader* reader) {
AST_Lambda* rtn = new AST_Lambda();
rtn->args = ast_cast<AST_arguments>(readASTMisc(reader));
rtn->body = readASTExpr(reader);
rtn->col_offset = readColOffset(reader);
rtn->lineno = reader->readULL();
return rtn;
}
AST_List* read_list(BufferedReader* reader) {
AST_List* rtn = new AST_List();
......@@ -696,6 +706,8 @@ AST_expr* readASTExpr(BufferedReader* reader) {
return read_ifexp(reader);
case AST_TYPE::Index:
return read_index(reader);
case AST_TYPE::Lambda:
return read_lambda(reader);
case AST_TYPE::List:
return read_list(reader);
case AST_TYPE::ListComp:
......
......@@ -567,6 +567,19 @@ void AST_keyword::accept(ASTVisitor* v) {
value->accept(v);
}
void AST_Lambda::accept(ASTVisitor* v) {
bool skip = v->visit_lambda(this);
if (skip)
return;
args->accept(v);
body->accept(v);
}
void* AST_Lambda::accept_expr(ExprVisitor* v) {
return v->visit_lambda(this);
}
void AST_LangPrimitive::accept(ASTVisitor* v) {
bool skip = v->visit_langprimitive(this);
if (skip)
......@@ -1272,6 +1285,14 @@ bool PrintVisitor::visit_invoke(AST_Invoke* node) {
return true;
}
bool PrintVisitor::visit_lambda(AST_Lambda* node) {
printf("lambda ");
node->args->accept(this);
printf(": ");
node->body->accept(this);
return true;
}
bool PrintVisitor::visit_langprimitive(AST_LangPrimitive* node) {
printf(":");
switch (node->opcode) {
......@@ -1726,6 +1747,10 @@ public:
output->push_back(node);
return false;
}
virtual bool visit_lambda(AST_Lambda* node) {
output->push_back(node);
return !expand_scopes;
}
virtual bool visit_langprimitive(AST_LangPrimitive* node) {
output->push_back(node);
return false;
......
......@@ -532,6 +532,19 @@ public:
static const AST_TYPE::AST_TYPE TYPE = AST_TYPE::keyword;
};
class AST_Lambda : public AST_expr {
public:
AST_arguments* args;
AST_expr* body;
virtual void accept(ASTVisitor* v);
virtual void* accept_expr(ExprVisitor* v);
AST_Lambda() : AST_expr(AST_TYPE::Lambda) {}
static const AST_TYPE::AST_TYPE TYPE = AST_TYPE::Lambda;
};
class AST_List : public AST_expr {
public:
std::vector<AST_expr*> elts;
......@@ -918,6 +931,7 @@ public:
virtual bool visit_index(AST_Index* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_invoke(AST_Invoke* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_keyword(AST_keyword* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_lambda(AST_Lambda* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_langprimitive(AST_LangPrimitive* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_list(AST_List* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_listcomp(AST_ListComp* node) { RELEASE_ASSERT(0, ""); }
......@@ -980,6 +994,7 @@ public:
virtual bool visit_index(AST_Index* node) { return false; }
virtual bool visit_invoke(AST_Invoke* node) { return false; }
virtual bool visit_keyword(AST_keyword* node) { return false; }
virtual bool visit_lambda(AST_Lambda* node) { return false; }
virtual bool visit_langprimitive(AST_LangPrimitive* node) { return false; }
virtual bool visit_list(AST_List* node) { return false; }
virtual bool visit_listcomp(AST_ListComp* node) { return false; }
......@@ -1022,6 +1037,7 @@ public:
virtual void* visit_dictcomp(AST_DictComp* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_ifexp(AST_IfExp* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_index(AST_Index* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_lambda(AST_Lambda* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_langprimitive(AST_LangPrimitive* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_list(AST_List* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_listcomp(AST_ListComp* node) { RELEASE_ASSERT(0, ""); }
......@@ -1110,6 +1126,7 @@ public:
virtual bool visit_index(AST_Index* node);
virtual bool visit_invoke(AST_Invoke* node);
virtual bool visit_keyword(AST_keyword* node);
virtual bool visit_lambda(AST_Lambda* node);
virtual bool visit_langprimitive(AST_LangPrimitive* node);
virtual bool visit_list(AST_List* node);
virtual bool visit_listcomp(AST_ListComp* node);
......
......@@ -621,6 +621,27 @@ private:
return rtn;
}
AST_expr* remapLambda(AST_Lambda* node) {
if (node->args->defaults.empty()) {
return node;
}
AST_Lambda* rtn = new AST_Lambda();
rtn->lineno = node->lineno;
rtn->col_offset = node->col_offset;
rtn->args = new AST_arguments();
rtn->args->args = node->args->args;
rtn->args->vararg = node->args->vararg;
rtn->args->kwarg = node->args->kwarg;
for (auto d : node->args->defaults) {
rtn->args->defaults.push_back(remapExpr(d));
}
rtn->body = node->body;
return rtn;
}
AST_expr* remapLangPrimitive(AST_LangPrimitive* node) {
AST_LangPrimitive* rtn = new AST_LangPrimitive(node->opcode);
for (AST_expr* arg : node->args) {
......@@ -732,6 +753,9 @@ private:
case AST_TYPE::Index:
rtn = remapIndex(ast_cast<AST_Index>(node));
break;
case AST_TYPE::Lambda:
rtn = remapLambda(ast_cast<AST_Lambda>(node));
break;
case AST_TYPE::LangPrimitive:
rtn = remapLangPrimitive(ast_cast<AST_LangPrimitive>(node));
break;
......@@ -1103,7 +1127,7 @@ public:
}
virtual bool visit_return(AST_Return* node) {
if (root_type != AST_TYPE::FunctionDef) {
if (root_type != AST_TYPE::FunctionDef && root_type != AST_TYPE::Lambda) {
fprintf(stderr, "SyntaxError: 'return' outside function\n");
exit(1);
}
......
......@@ -218,13 +218,13 @@ public:
};
ArgNames arg_names;
const std::vector<AST_stmt*> body;
const std::string getName();
// AST_arguments* getArgsAST();
const std::vector<AST_stmt*>& getBody();
SourceInfo(BoxedModule* m, ScopingAnalysis* scoping, AST* ast)
: parent_module(m), scoping(scoping), ast(ast), cfg(NULL), liveness(NULL), phis(NULL), arg_names(ast) {}
SourceInfo(BoxedModule* m, ScopingAnalysis* scoping, AST* ast, const std::vector<AST_stmt*>& body)
: parent_module(m), scoping(scoping), ast(ast), cfg(NULL), liveness(NULL), phis(NULL), arg_names(ast),
body(body) {}
};
typedef std::vector<CompiledFunction*> FunctionList;
......
s = lambda x=5: x**2
print s(8), s(100), s()
for i in range(10):
print (lambda x, y: x < y)(i, 5)
t = lambda s: " ".join(s.split())
print t("test \tstr\ni\n ng")
def T(y):
return (lambda x: x < y)
print T(10)(1), T(10)(20)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment