Commit 44ec4eef authored by Marius Wachtler's avatar Marius Wachtler

Implement lambda expressions

parent ddabda9a
...@@ -63,6 +63,9 @@ public: ...@@ -63,6 +63,9 @@ public:
_doStore(node->name); _doStore(node->name);
return true; return true;
} }
bool visit_lambda(AST_Lambda* node) { return true; }
bool visit_name(AST_Name* node) { bool visit_name(AST_Name* node) {
if (node->ctx_type == AST_TYPE::Load) if (node->ctx_type == AST_TYPE::Load)
_doLoad(node->id); _doLoad(node->id);
......
...@@ -234,6 +234,19 @@ public: ...@@ -234,6 +234,19 @@ public:
} }
} }
virtual bool visit_lambda(AST_Lambda* node) {
assert(node == orig_node);
for (AST_expr* e : node->args->args)
e->accept(this);
if (node->args->vararg.size())
doWrite(node->args->vararg);
if (node->args->kwarg.size())
doWrite(node->args->kwarg);
node->body->accept(this);
return true;
}
virtual bool visit_import(AST_Import* node) { virtual bool visit_import(AST_Import* node) {
for (int i = 0; i < node->names.size(); i++) { for (int i = 0; i < node->names.size(); i++) {
AST_alias* alias = node->names[i]; AST_alias* alias = node->names[i];
...@@ -328,10 +341,9 @@ void ScopingAnalysis::processNameUsages(ScopingAnalysis::NameUsageMap* usages) { ...@@ -328,10 +341,9 @@ void ScopingAnalysis::processNameUsages(ScopingAnalysis::NameUsageMap* usages) {
ScopeInfo* parent_info = this->scopes[(usage->parent == NULL) ? this->parent_module : usage->parent->node]; ScopeInfo* parent_info = this->scopes[(usage->parent == NULL) ? this->parent_module : usage->parent->node];
switch (node->type) { switch (node->type) {
case AST_TYPE::FunctionDef:
this->scopes[node] = new ScopeInfoBase(parent_info, usage);
break;
case AST_TYPE::ClassDef: case AST_TYPE::ClassDef:
case AST_TYPE::FunctionDef:
case AST_TYPE::Lambda:
this->scopes[node] = new ScopeInfoBase(parent_info, usage); this->scopes[node] = new ScopeInfoBase(parent_info, usage);
break; break;
default: default:
...@@ -363,6 +375,7 @@ ScopeInfo* ScopingAnalysis::getScopeInfoForNode(AST* node) { ...@@ -363,6 +375,7 @@ ScopeInfo* ScopingAnalysis::getScopeInfoForNode(AST* node) {
switch (node->type) { switch (node->type) {
case AST_TYPE::ClassDef: case AST_TYPE::ClassDef:
case AST_TYPE::FunctionDef: case AST_TYPE::FunctionDef:
case AST_TYPE::Lambda:
return analyzeSubtree(node); return analyzeSubtree(node);
// this is handled in the constructor: // this is handled in the constructor:
// case AST_TYPE::Module: // case AST_TYPE::Module:
......
...@@ -340,6 +340,8 @@ private: ...@@ -340,6 +340,8 @@ private:
virtual void* visit_index(AST_Index* node) { return getType(node->value); } virtual void* visit_index(AST_Index* node) { return getType(node->value); }
virtual void* visit_lambda(AST_Lambda* node) { return typeFromClass(function_cls); }
virtual void* visit_langprimitive(AST_LangPrimitive* node) { virtual void* visit_langprimitive(AST_LangPrimitive* node) {
switch (node->opcode) { switch (node->opcode) {
case AST_LangPrimitive::ISINSTANCE: case AST_LangPrimitive::ISINSTANCE:
......
...@@ -47,6 +47,8 @@ const std::string SourceInfo::getName() { ...@@ -47,6 +47,8 @@ const std::string SourceInfo::getName() {
switch (ast->type) { switch (ast->type) {
case AST_TYPE::FunctionDef: case AST_TYPE::FunctionDef:
return ast_cast<AST_FunctionDef>(ast)->name; return ast_cast<AST_FunctionDef>(ast)->name;
case AST_TYPE::Lambda:
return "<lambda>";
case AST_TYPE::Module: case AST_TYPE::Module:
return this->parent_module->name(); return this->parent_module->name();
default: default:
...@@ -59,6 +61,8 @@ AST_arguments* SourceInfo::getArgsAST() { ...@@ -59,6 +61,8 @@ AST_arguments* SourceInfo::getArgsAST() {
switch (ast->type) { switch (ast->type) {
case AST_TYPE::FunctionDef: case AST_TYPE::FunctionDef:
return ast_cast<AST_FunctionDef>(ast)->args; return ast_cast<AST_FunctionDef>(ast)->args;
case AST_TYPE::Lambda:
return ast_cast<AST_Lambda>(ast)->args;
case AST_TYPE::Module: case AST_TYPE::Module:
return NULL; return NULL;
default: default:
...@@ -81,18 +85,6 @@ const std::vector<AST_expr*>* CLFunction::getArgNames() { ...@@ -81,18 +85,6 @@ const std::vector<AST_expr*>* CLFunction::getArgNames() {
return &source->getArgNames(); return &source->getArgNames();
} }
const std::vector<AST_stmt*>& SourceInfo::getBody() {
assert(ast);
switch (ast->type) {
case AST_TYPE::FunctionDef:
return ast_cast<AST_FunctionDef>(ast)->body;
case AST_TYPE::Module:
return ast_cast<AST_Module>(ast)->body;
default:
RELEASE_ASSERT(0, "%d", ast->type);
}
}
EffortLevel::EffortLevel initialEffort() { EffortLevel::EffortLevel initialEffort() {
if (FORCE_OPTIMIZE) if (FORCE_OPTIMIZE)
return EffortLevel::MAXIMAL; return EffortLevel::MAXIMAL;
...@@ -187,7 +179,7 @@ CompiledFunction* compileFunction(CLFunction* f, FunctionSpecialization* spec, E ...@@ -187,7 +179,7 @@ CompiledFunction* compileFunction(CLFunction* f, FunctionSpecialization* spec, E
// Do the analysis now if we had deferred it earlier: // Do the analysis now if we had deferred it earlier:
if (source->cfg == NULL) { if (source->cfg == NULL) {
assert(source->ast); assert(source->ast);
source->cfg = computeCFG(source->ast->type, source->getBody()); source->cfg = computeCFG(source->ast->type, source->body);
source->liveness = computeLivenessInfo(source->cfg); source->liveness = computeLivenessInfo(source->cfg);
source->phis = computeRequiredPhis(args, source->cfg, source->liveness, source->phis = computeRequiredPhis(args, source->cfg, source->liveness,
source->scoping->getScopeInfoForNode(source->ast)); source->scoping->getScopeInfoForNode(source->ast));
...@@ -246,9 +238,8 @@ void compileAndRunModule(AST_Module* m, BoxedModule* bm) { ...@@ -246,9 +238,8 @@ void compileAndRunModule(AST_Module* m, BoxedModule* bm) {
ScopingAnalysis* scoping = runScopingAnalysis(m); ScopingAnalysis* scoping = runScopingAnalysis(m);
SourceInfo* si = new SourceInfo(bm, scoping); SourceInfo* si = new SourceInfo(bm, scoping, m, m->body);
si->cfg = computeCFG(AST_TYPE::Module, m->body); si->cfg = computeCFG(AST_TYPE::Module, m->body);
si->ast = m;
si->liveness = computeLivenessInfo(si->cfg); si->liveness = computeLivenessInfo(si->cfg);
si->phis = computeRequiredPhis(NULL, si->cfg, si->liveness, si->scoping->getScopeInfoForNode(si->ast)); si->phis = computeRequiredPhis(NULL, si->cfg, si->liveness, si->scoping->getScopeInfoForNode(si->ast));
......
...@@ -756,6 +756,24 @@ private: ...@@ -756,6 +756,24 @@ private:
return evalExpr(node->value, exc_info); return evalExpr(node->value, exc_info);
} }
CompilerVariable* evalLambda(AST_Lambda* node, ExcInfo exc_info) {
assert(state != PARTIAL);
AST_Return* expr = new AST_Return();
expr->value = node->body;
SourceInfo* si = new SourceInfo(irstate->getSourceInfo()->parent_module, irstate->getSourceInfo()->scoping,
node, { expr });
CLFunction* cl = new CLFunction(node->args->args.size(), node->args->defaults.size(), node->args->vararg.size(),
node->args->kwarg.size(), si);
CompilerVariable* func = makeFunction(emitter, cl, NULL);
ConcreteCompilerVariable* converted = func->makeConverted(emitter, func->getBoxType());
func->decvref(emitter);
return converted;
}
CompilerVariable* evalList(AST_List* node, ExcInfo exc_info) { CompilerVariable* evalList(AST_List* node, ExcInfo exc_info) {
assert(state != PARTIAL); assert(state != PARTIAL);
...@@ -1027,6 +1045,9 @@ private: ...@@ -1027,6 +1045,9 @@ private:
case AST_TYPE::Index: case AST_TYPE::Index:
rtn = evalIndex(ast_cast<AST_Index>(node), exc_info); rtn = evalIndex(ast_cast<AST_Index>(node), exc_info);
break; break;
case AST_TYPE::Lambda:
rtn = evalLambda(ast_cast<AST_Lambda>(node), exc_info);
break;
case AST_TYPE::List: case AST_TYPE::List:
rtn = evalList(ast_cast<AST_List>(node), exc_info); rtn = evalList(ast_cast<AST_List>(node), exc_info);
break; break;
...@@ -1476,8 +1497,8 @@ private: ...@@ -1476,8 +1497,8 @@ private:
CLFunction*& cl = made[node]; CLFunction*& cl = made[node];
if (cl == NULL) { if (cl == NULL) {
SourceInfo* si = new SourceInfo(irstate->getSourceInfo()->parent_module, irstate->getSourceInfo()->scoping); SourceInfo* si = new SourceInfo(irstate->getSourceInfo()->parent_module, irstate->getSourceInfo()->scoping,
si->ast = node; node, node->body);
cl = new CLFunction(node->args->args.size(), node->args->defaults.size(), node->args->vararg.size(), cl = new CLFunction(node->args->args.size(), node->args->defaults.size(), node->args->vararg.size(),
node->args->kwarg.size(), si); node->args->kwarg.size(), si);
} }
......
...@@ -457,6 +457,16 @@ AST_keyword* read_keyword(BufferedReader* reader) { ...@@ -457,6 +457,16 @@ AST_keyword* read_keyword(BufferedReader* reader) {
return rtn; return rtn;
} }
AST_Lambda* read_lambda(BufferedReader* reader) {
AST_Lambda* rtn = new AST_Lambda();
rtn->args = ast_cast<AST_arguments>(readASTMisc(reader));
rtn->body = readASTExpr(reader);
rtn->col_offset = readColOffset(reader);
rtn->lineno = reader->readULL();
return rtn;
}
AST_List* read_list(BufferedReader* reader) { AST_List* read_list(BufferedReader* reader) {
AST_List* rtn = new AST_List(); AST_List* rtn = new AST_List();
...@@ -696,6 +706,8 @@ AST_expr* readASTExpr(BufferedReader* reader) { ...@@ -696,6 +706,8 @@ AST_expr* readASTExpr(BufferedReader* reader) {
return read_ifexp(reader); return read_ifexp(reader);
case AST_TYPE::Index: case AST_TYPE::Index:
return read_index(reader); return read_index(reader);
case AST_TYPE::Lambda:
return read_lambda(reader);
case AST_TYPE::List: case AST_TYPE::List:
return read_list(reader); return read_list(reader);
case AST_TYPE::ListComp: case AST_TYPE::ListComp:
......
...@@ -567,6 +567,19 @@ void AST_keyword::accept(ASTVisitor* v) { ...@@ -567,6 +567,19 @@ void AST_keyword::accept(ASTVisitor* v) {
value->accept(v); value->accept(v);
} }
void AST_Lambda::accept(ASTVisitor* v) {
bool skip = v->visit_lambda(this);
if (skip)
return;
args->accept(v);
body->accept(v);
}
void* AST_Lambda::accept_expr(ExprVisitor* v) {
return v->visit_lambda(this);
}
void AST_LangPrimitive::accept(ASTVisitor* v) { void AST_LangPrimitive::accept(ASTVisitor* v) {
bool skip = v->visit_langprimitive(this); bool skip = v->visit_langprimitive(this);
if (skip) if (skip)
...@@ -1272,6 +1285,14 @@ bool PrintVisitor::visit_invoke(AST_Invoke* node) { ...@@ -1272,6 +1285,14 @@ bool PrintVisitor::visit_invoke(AST_Invoke* node) {
return true; return true;
} }
bool PrintVisitor::visit_lambda(AST_Lambda* node) {
printf("lambda ");
node->args->accept(this);
printf(": ");
node->body->accept(this);
return true;
}
bool PrintVisitor::visit_langprimitive(AST_LangPrimitive* node) { bool PrintVisitor::visit_langprimitive(AST_LangPrimitive* node) {
printf(":"); printf(":");
switch (node->opcode) { switch (node->opcode) {
...@@ -1726,6 +1747,10 @@ public: ...@@ -1726,6 +1747,10 @@ public:
output->push_back(node); output->push_back(node);
return false; return false;
} }
virtual bool visit_lambda(AST_Lambda* node) {
output->push_back(node);
return !expand_scopes;
}
virtual bool visit_langprimitive(AST_LangPrimitive* node) { virtual bool visit_langprimitive(AST_LangPrimitive* node) {
output->push_back(node); output->push_back(node);
return false; return false;
......
...@@ -532,6 +532,19 @@ public: ...@@ -532,6 +532,19 @@ public:
static const AST_TYPE::AST_TYPE TYPE = AST_TYPE::keyword; static const AST_TYPE::AST_TYPE TYPE = AST_TYPE::keyword;
}; };
class AST_Lambda : public AST_expr {
public:
AST_arguments* args;
AST_expr* body;
virtual void accept(ASTVisitor* v);
virtual void* accept_expr(ExprVisitor* v);
AST_Lambda() : AST_expr(AST_TYPE::Lambda) {}
static const AST_TYPE::AST_TYPE TYPE = AST_TYPE::Lambda;
};
class AST_List : public AST_expr { class AST_List : public AST_expr {
public: public:
std::vector<AST_expr*> elts; std::vector<AST_expr*> elts;
...@@ -916,6 +929,7 @@ public: ...@@ -916,6 +929,7 @@ public:
virtual bool visit_index(AST_Index* node) { RELEASE_ASSERT(0, ""); } virtual bool visit_index(AST_Index* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_invoke(AST_Invoke* node) { RELEASE_ASSERT(0, ""); } virtual bool visit_invoke(AST_Invoke* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_keyword(AST_keyword* node) { RELEASE_ASSERT(0, ""); } virtual bool visit_keyword(AST_keyword* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_lambda(AST_Lambda* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_langprimitive(AST_LangPrimitive* node) { RELEASE_ASSERT(0, ""); } virtual bool visit_langprimitive(AST_LangPrimitive* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_list(AST_List* node) { RELEASE_ASSERT(0, ""); } virtual bool visit_list(AST_List* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_listcomp(AST_ListComp* node) { RELEASE_ASSERT(0, ""); } virtual bool visit_listcomp(AST_ListComp* node) { RELEASE_ASSERT(0, ""); }
...@@ -978,6 +992,7 @@ public: ...@@ -978,6 +992,7 @@ public:
virtual bool visit_index(AST_Index* node) { return false; } virtual bool visit_index(AST_Index* node) { return false; }
virtual bool visit_invoke(AST_Invoke* node) { return false; } virtual bool visit_invoke(AST_Invoke* node) { return false; }
virtual bool visit_keyword(AST_keyword* node) { return false; } virtual bool visit_keyword(AST_keyword* node) { return false; }
virtual bool visit_lambda(AST_Lambda* node) { return false; }
virtual bool visit_langprimitive(AST_LangPrimitive* node) { return false; } virtual bool visit_langprimitive(AST_LangPrimitive* node) { return false; }
virtual bool visit_list(AST_List* node) { return false; } virtual bool visit_list(AST_List* node) { return false; }
virtual bool visit_listcomp(AST_ListComp* node) { return false; } virtual bool visit_listcomp(AST_ListComp* node) { return false; }
...@@ -1020,6 +1035,7 @@ public: ...@@ -1020,6 +1035,7 @@ public:
virtual void* visit_dictcomp(AST_DictComp* node) { RELEASE_ASSERT(0, ""); } virtual void* visit_dictcomp(AST_DictComp* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_ifexp(AST_IfExp* node) { RELEASE_ASSERT(0, ""); } virtual void* visit_ifexp(AST_IfExp* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_index(AST_Index* node) { RELEASE_ASSERT(0, ""); } virtual void* visit_index(AST_Index* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_lambda(AST_Lambda* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_langprimitive(AST_LangPrimitive* node) { RELEASE_ASSERT(0, ""); } virtual void* visit_langprimitive(AST_LangPrimitive* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_list(AST_List* node) { RELEASE_ASSERT(0, ""); } virtual void* visit_list(AST_List* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_listcomp(AST_ListComp* node) { RELEASE_ASSERT(0, ""); } virtual void* visit_listcomp(AST_ListComp* node) { RELEASE_ASSERT(0, ""); }
...@@ -1108,6 +1124,7 @@ public: ...@@ -1108,6 +1124,7 @@ public:
virtual bool visit_index(AST_Index* node); virtual bool visit_index(AST_Index* node);
virtual bool visit_invoke(AST_Invoke* node); virtual bool visit_invoke(AST_Invoke* node);
virtual bool visit_keyword(AST_keyword* node); virtual bool visit_keyword(AST_keyword* node);
virtual bool visit_lambda(AST_Lambda* node);
virtual bool visit_langprimitive(AST_LangPrimitive* node); virtual bool visit_langprimitive(AST_LangPrimitive* node);
virtual bool visit_list(AST_List* node); virtual bool visit_list(AST_List* node);
virtual bool visit_listcomp(AST_ListComp* node); virtual bool visit_listcomp(AST_ListComp* node);
......
...@@ -565,6 +565,21 @@ private: ...@@ -565,6 +565,21 @@ private:
return rtn; return rtn;
} }
AST_expr* remapLambda(AST_Lambda* node) {
AST_Lambda* rtn = new AST_Lambda();
rtn->lineno = node->lineno;
rtn->col_offset = node->col_offset;
rtn->args = node->args;
// remap default arguments
rtn->args->defaults.clear();
for (auto& e : node->args->defaults)
rtn->args->defaults.push_back(remapExpr(e));
rtn->body = node->body;
return rtn;
}
AST_expr* remapLangPrimitive(AST_LangPrimitive* node) { AST_expr* remapLangPrimitive(AST_LangPrimitive* node) {
AST_LangPrimitive* rtn = new AST_LangPrimitive(node->opcode); AST_LangPrimitive* rtn = new AST_LangPrimitive(node->opcode);
for (AST_expr* arg : node->args) { for (AST_expr* arg : node->args) {
...@@ -676,6 +691,9 @@ private: ...@@ -676,6 +691,9 @@ private:
case AST_TYPE::Index: case AST_TYPE::Index:
rtn = remapIndex(ast_cast<AST_Index>(node)); rtn = remapIndex(ast_cast<AST_Index>(node));
break; break;
case AST_TYPE::Lambda:
rtn = remapLambda(ast_cast<AST_Lambda>(node));
break;
case AST_TYPE::LangPrimitive: case AST_TYPE::LangPrimitive:
rtn = remapLangPrimitive(ast_cast<AST_LangPrimitive>(node)); rtn = remapLangPrimitive(ast_cast<AST_LangPrimitive>(node));
break; break;
...@@ -1023,7 +1041,7 @@ public: ...@@ -1023,7 +1041,7 @@ public:
} }
virtual bool visit_return(AST_Return* node) { virtual bool visit_return(AST_Return* node) {
if (root_type != AST_TYPE::FunctionDef) { if (root_type != AST_TYPE::FunctionDef && root_type != AST_TYPE::Lambda) {
fprintf(stderr, "SyntaxError: 'return' outside function\n"); fprintf(stderr, "SyntaxError: 'return' outside function\n");
exit(1); exit(1);
} }
......
...@@ -203,14 +203,14 @@ public: ...@@ -203,14 +203,14 @@ public:
CFG* cfg; CFG* cfg;
LivenessAnalysis* liveness; LivenessAnalysis* liveness;
PhiAnalysis* phis; PhiAnalysis* phis;
const std::vector<AST_stmt*> body;
const std::string getName(); const std::string getName();
AST_arguments* getArgsAST(); AST_arguments* getArgsAST();
const std::vector<AST_expr*>& getArgNames(); const std::vector<AST_expr*>& getArgNames();
const std::vector<AST_stmt*>& getBody();
SourceInfo(BoxedModule* m, ScopingAnalysis* scoping) SourceInfo(BoxedModule* m, ScopingAnalysis* scoping, AST* ast, const std::vector<AST_stmt*>& body)
: parent_module(m), scoping(scoping), ast(NULL), cfg(NULL), liveness(NULL), phis(NULL) {} : parent_module(m), scoping(scoping), ast(ast), cfg(NULL), liveness(NULL), phis(NULL), body(body) {}
}; };
typedef std::vector<CompiledFunction*> FunctionList; typedef std::vector<CompiledFunction*> FunctionList;
......
s = lambda x: x**2
print s(8), s(100)
for i in range(10):
print (lambda x, y: x < y)(i, 5)
t = lambda s: " ".join(s.split())
print t("test \tstr\ni\n ng")
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment