Commit ae1e15a0 authored by Vinzenz Feenstra's avatar Vinzenz Feenstra

Implementation of dict comprehension

Signed-off-by: default avatarVinzenz Feenstra <evilissimo@gmail.com>
parent 96cc5efa
...@@ -158,6 +158,7 @@ public: ...@@ -158,6 +158,7 @@ public:
// virtual bool visit_classdef(AST_ClassDef *node) { return false; } // virtual bool visit_classdef(AST_ClassDef *node) { return false; }
virtual bool visit_continue(AST_Continue* node) { return false; } virtual bool visit_continue(AST_Continue* node) { return false; }
virtual bool visit_dict(AST_Dict* node) { return false; } virtual bool visit_dict(AST_Dict* node) { return false; }
virtual bool visit_dictcomp(AST_DictComp* node) { return false; }
virtual bool visit_excepthandler(AST_ExceptHandler* node) { return false; } virtual bool visit_excepthandler(AST_ExceptHandler* node) { return false; }
virtual bool visit_expr(AST_Expr* node) { return false; } virtual bool visit_expr(AST_Expr* node) { return false; }
virtual bool visit_for(AST_For* node) { return false; } virtual bool visit_for(AST_For* node) { return false; }
......
...@@ -328,6 +328,17 @@ AST_Dict* read_dict(BufferedReader* reader) { ...@@ -328,6 +328,17 @@ AST_Dict* read_dict(BufferedReader* reader) {
return rtn; return rtn;
} }
AST_DictComp* read_dictcomp(BufferedReader* reader) {
AST_DictComp* rtn = new AST_DictComp();
rtn->col_offset = readColOffset(reader);
readMiscVector(rtn->generators, reader);
rtn->key = readASTExpr(reader);
rtn->lineno = reader->readULL();
rtn->value = readASTExpr(reader);
return rtn;
}
AST_ExceptHandler* read_excepthandler(BufferedReader* reader) { AST_ExceptHandler* read_excepthandler(BufferedReader* reader) {
AST_ExceptHandler* rtn = new AST_ExceptHandler(); AST_ExceptHandler* rtn = new AST_ExceptHandler();
...@@ -679,6 +690,8 @@ AST_expr* readASTExpr(BufferedReader* reader) { ...@@ -679,6 +690,8 @@ AST_expr* readASTExpr(BufferedReader* reader) {
return read_compare(reader); return read_compare(reader);
case AST_TYPE::Dict: case AST_TYPE::Dict:
return read_dict(reader); return read_dict(reader);
case AST_TYPE::DictComp:
return read_dictcomp(reader);
case AST_TYPE::IfExp: case AST_TYPE::IfExp:
return read_ifexp(reader); return read_ifexp(reader);
case AST_TYPE::Index: case AST_TYPE::Index:
......
...@@ -399,6 +399,23 @@ void* AST_Dict::accept_expr(ExprVisitor* v) { ...@@ -399,6 +399,23 @@ void* AST_Dict::accept_expr(ExprVisitor* v) {
return v->visit_dict(this); return v->visit_dict(this);
} }
void AST_DictComp::accept(ASTVisitor* v) {
bool skip = v->visit_dictcomp(this);
if (skip)
return;
for (auto c : generators) {
c->accept(v);
}
value->accept(v);
key->accept(v);
}
void* AST_DictComp::accept_expr(ExprVisitor* v) {
return v->visit_dictcomp(this);
}
void AST_ExceptHandler::accept(ASTVisitor* v) { void AST_ExceptHandler::accept(ASTVisitor* v) {
bool skip = v->visit_excepthandler(this); bool skip = v->visit_excepthandler(this);
if (skip) if (skip)
...@@ -1094,6 +1111,19 @@ bool PrintVisitor::visit_dict(AST_Dict* node) { ...@@ -1094,6 +1111,19 @@ bool PrintVisitor::visit_dict(AST_Dict* node) {
return true; return true;
} }
bool PrintVisitor::visit_dictcomp(AST_DictComp* node) {
printf("{");
node->key->accept(this);
printf(":");
node->value->accept(this);
for (auto c : node->generators) {
printf(" ");
c->accept(this);
}
printf("}");
return true;
}
bool PrintVisitor::visit_excepthandler(AST_ExceptHandler* node) { bool PrintVisitor::visit_excepthandler(AST_ExceptHandler* node) {
printf("except"); printf("except");
if (node->type) { if (node->type) {
...@@ -1627,6 +1657,10 @@ public: ...@@ -1627,6 +1657,10 @@ public:
output->push_back(node); output->push_back(node);
return false; return false;
} }
virtual bool visit_dictcomp(AST_DictComp* node) {
output->push_back(node);
return false;
}
virtual bool visit_excepthandler(AST_ExceptHandler* node) { virtual bool visit_excepthandler(AST_ExceptHandler* node) {
output->push_back(node); output->push_back(node);
return false; return false;
......
...@@ -364,6 +364,19 @@ public: ...@@ -364,6 +364,19 @@ public:
static const AST_TYPE::AST_TYPE TYPE = AST_TYPE::Dict; static const AST_TYPE::AST_TYPE TYPE = AST_TYPE::Dict;
}; };
class AST_DictComp : public AST_expr {
public:
std::vector<AST_comprehension*> generators;
AST_expr* key, *value;
virtual void accept(ASTVisitor* v);
virtual void* accept_expr(ExprVisitor* v);
AST_DictComp() : AST_expr(AST_TYPE::DictComp) {}
const static AST_TYPE::AST_TYPE TYPE = AST_TYPE::DictComp;
};
class AST_Delete : public AST_stmt { class AST_Delete : public AST_stmt {
public: public:
std::vector<AST_expr*> targets; std::vector<AST_expr*> targets;
...@@ -875,6 +888,7 @@ public: ...@@ -875,6 +888,7 @@ public:
virtual bool visit_continue(AST_Continue* node) { RELEASE_ASSERT(0, ""); } virtual bool visit_continue(AST_Continue* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_delete(AST_Delete* node) { RELEASE_ASSERT(0, ""); } virtual bool visit_delete(AST_Delete* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_dict(AST_Dict* node) { RELEASE_ASSERT(0, ""); } virtual bool visit_dict(AST_Dict* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_dictcomp(AST_DictComp* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_excepthandler(AST_ExceptHandler* node) { RELEASE_ASSERT(0, ""); } virtual bool visit_excepthandler(AST_ExceptHandler* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_expr(AST_Expr* node) { RELEASE_ASSERT(0, ""); } virtual bool visit_expr(AST_Expr* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_for(AST_For* node) { RELEASE_ASSERT(0, ""); } virtual bool visit_for(AST_For* node) { RELEASE_ASSERT(0, ""); }
...@@ -935,6 +949,7 @@ public: ...@@ -935,6 +949,7 @@ public:
virtual bool visit_continue(AST_Continue* node) { return false; } virtual bool visit_continue(AST_Continue* node) { return false; }
virtual bool visit_delete(AST_Delete* node) { return false; } virtual bool visit_delete(AST_Delete* node) { return false; }
virtual bool visit_dict(AST_Dict* node) { return false; } virtual bool visit_dict(AST_Dict* node) { return false; }
virtual bool visit_dictcomp(AST_DictComp* node) { return false; }
virtual bool visit_excepthandler(AST_ExceptHandler* node) { return false; } virtual bool visit_excepthandler(AST_ExceptHandler* node) { return false; }
virtual bool visit_expr(AST_Expr* node) { return false; } virtual bool visit_expr(AST_Expr* node) { return false; }
virtual bool visit_for(AST_For* node) { return false; } virtual bool visit_for(AST_For* node) { return false; }
...@@ -985,6 +1000,7 @@ public: ...@@ -985,6 +1000,7 @@ public:
virtual void* visit_clsattribute(AST_ClsAttribute* node) { RELEASE_ASSERT(0, ""); } virtual void* visit_clsattribute(AST_ClsAttribute* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_compare(AST_Compare* node) { RELEASE_ASSERT(0, ""); } virtual void* visit_compare(AST_Compare* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_dict(AST_Dict* node) { RELEASE_ASSERT(0, ""); } virtual void* visit_dict(AST_Dict* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_dictcomp(AST_DictComp* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_ifexp(AST_IfExp* node) { RELEASE_ASSERT(0, ""); } virtual void* visit_ifexp(AST_IfExp* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_index(AST_Index* node) { RELEASE_ASSERT(0, ""); } virtual void* visit_index(AST_Index* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_langprimitive(AST_LangPrimitive* node) { RELEASE_ASSERT(0, ""); } virtual void* visit_langprimitive(AST_LangPrimitive* node) { RELEASE_ASSERT(0, ""); }
...@@ -1061,6 +1077,7 @@ public: ...@@ -1061,6 +1077,7 @@ public:
virtual bool visit_continue(AST_Continue* node); virtual bool visit_continue(AST_Continue* node);
virtual bool visit_delete(AST_Delete* node); virtual bool visit_delete(AST_Delete* node);
virtual bool visit_dict(AST_Dict* node); virtual bool visit_dict(AST_Dict* node);
virtual bool visit_dictcomp(AST_DictComp* node);
virtual bool visit_excepthandler(AST_ExceptHandler* node); virtual bool visit_excepthandler(AST_ExceptHandler* node);
virtual bool visit_expr(AST_Expr* node); virtual bool visit_expr(AST_Expr* node);
virtual bool visit_for(AST_For* node); virtual bool visit_for(AST_For* node);
......
...@@ -181,6 +181,18 @@ private: ...@@ -181,6 +181,18 @@ private:
return call; return call;
} }
AST_Call* makeCall(AST_expr* func, AST_expr* arg0, AST_expr* arg1) {
AST_Call* call = new AST_Call();
call->args.push_back(arg0);
call->args.push_back(arg1);
call->starargs = NULL;
call->kwargs = NULL;
call->func = func;
call->col_offset = func->col_offset;
call->lineno = func->lineno;
return call;
}
AST_Name* makeName(const std::string& id, AST_TYPE::AST_TYPE ctx_type, int lineno = -1, int col_offset = -1) { AST_Name* makeName(const std::string& id, AST_TYPE::AST_TYPE ctx_type, int lineno = -1, int col_offset = -1) {
AST_Name* name = new AST_Name(); AST_Name* name = new AST_Name();
name->id = id; name->id = id;
...@@ -369,6 +381,137 @@ private: ...@@ -369,6 +381,137 @@ private:
return rtn; return rtn;
}; };
AST_expr* remapDictComp(AST_DictComp* node) {
std::string rtn_name = nodeName(node);
push_back(makeAssign(rtn_name, new AST_Dict()));
std::vector<CFGBlock*> exit_blocks;
// Where the current level should jump to after finishing its iteration.
// For the outermost comprehension, this is NULL, and it doesn't jump anywhere;
// for the inner comprehensions, they should jump to the next-outer comprehension
// when they are done iterating.
CFGBlock* finished_block = NULL;
for (int i = 0, n = node->generators.size(); i < n; i++) {
AST_comprehension* c = node->generators[i];
bool is_innermost = (i == n - 1);
AST_expr* remapped_iter = remapExpr(c->iter);
AST_expr* iter_attr = makeLoadAttribute(remapped_iter, "__iter__", true);
AST_expr* iter_call = makeCall(iter_attr);
std::string iter_name = nodeName(node, "iter", i);
AST_stmt* iter_assign = makeAssign(iter_name, iter_call);
push_back(iter_assign);
// TODO bad to save these like this?
AST_expr* hasnext_attr = makeLoadAttribute(makeName(iter_name, AST_TYPE::Load), "__hasnext__", true);
AST_expr* next_attr = makeLoadAttribute(makeName(iter_name, AST_TYPE::Load), "next", true);
AST_Jump* j;
CFGBlock* test_block = cfg->addBlock();
test_block->info = "dictcomp_test";
// printf("Test block for comp %d is %d\n", i, test_block->idx);
j = new AST_Jump();
j->target = test_block;
curblock->connectTo(test_block);
push_back(j);
curblock = test_block;
AST_expr* test_call = remapExpr(makeCall(hasnext_attr));
CFGBlock* body_block = cfg->addBlock();
body_block->info = "dictcomp_body";
CFGBlock* exit_block = cfg->addDeferredBlock();
exit_block->info = "dictcomp_exit";
exit_blocks.push_back(exit_block);
// printf("Body block for comp %d is %d\n", i, body_block->idx);
AST_Branch* br = new AST_Branch();
br->col_offset = node->col_offset;
br->lineno = node->lineno;
br->test = test_call;
br->iftrue = body_block;
br->iffalse = exit_block;
curblock->connectTo(body_block);
curblock->connectTo(exit_block);
push_back(br);
curblock = body_block;
push_back(makeAssign(c->target, makeCall(next_attr)));
for (AST_expr* if_condition : c->ifs) {
AST_expr* remapped = remapExpr(if_condition);
AST_Branch* br = new AST_Branch();
br->test = remapped;
push_back(br);
// Put this below the entire body?
CFGBlock* body_tramp = cfg->addBlock();
body_tramp->info = "dictcomp_if_trampoline";
// printf("body_tramp for %d is %d\n", i, body_tramp->idx);
CFGBlock* body_continue = cfg->addBlock();
body_continue->info = "dictcomp_if_continue";
// printf("body_continue for %d is %d\n", i, body_continue->idx);
br->iffalse = body_tramp;
curblock->connectTo(body_tramp);
br->iftrue = body_continue;
curblock->connectTo(body_continue);
curblock = body_tramp;
j = new AST_Jump();
j->target = test_block;
push_back(j);
curblock->connectTo(test_block, true);
curblock = body_continue;
}
CFGBlock* body_end = curblock;
assert((finished_block != NULL) == (i != 0));
if (finished_block) {
curblock = exit_block;
j = new AST_Jump();
j->target = finished_block;
curblock->connectTo(finished_block, true);
push_back(j);
}
finished_block = test_block;
curblock = body_end;
if (is_innermost) {
AST_expr* key = remapExpr(node->key);
AST_expr* value = remapExpr(node->value);
push_back(
makeExpr(makeCall(makeLoadAttribute(makeName(rtn_name, AST_TYPE::Load), "__setitem__", true), key, value)));
j = new AST_Jump();
j->target = test_block;
curblock->connectTo(test_block, true);
push_back(j);
assert(exit_blocks.size());
curblock = exit_blocks[0];
} else {
// continue onto the next comprehension and add to this body
}
}
// Wait until the end to place the end blocks, so that
// we get a nice nesting structure, that looks similar to what
// you'd get with a nested for loop:
for (int i = exit_blocks.size() - 1; i >= 0; i--) {
cfg->placeBlock(exit_blocks[i]);
// printf("Exit block for comp %d is %d\n", i, exit_blocks[i]->idx);
}
return makeName(rtn_name, AST_TYPE::Load);
};
AST_expr* remapIfExp(AST_IfExp* node) { AST_expr* remapIfExp(AST_IfExp* node) {
std::string rtn_name = nodeName(node); std::string rtn_name = nodeName(node);
...@@ -651,6 +794,9 @@ private: ...@@ -651,6 +794,9 @@ private:
case AST_TYPE::Dict: case AST_TYPE::Dict:
rtn = remapDict(ast_cast<AST_Dict>(node)); rtn = remapDict(ast_cast<AST_Dict>(node));
break; break;
case AST_TYPE::DictComp:
rtn = remapDictComp(ast_cast<AST_DictComp>(node));
break;
case AST_TYPE::IfExp: case AST_TYPE::IfExp:
rtn = remapIfExp(ast_cast<AST_IfExp>(node)); rtn = remapIfExp(ast_cast<AST_IfExp>(node));
break; break;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment