Commit be4a3b7c authored by Kevin Modzelewski's avatar Kevin Modzelewski

Merge pull request #369 from rntz/set-comprehensions

Implement set comprehensions, fix dict comprehension scope
parents 8052286c 6b26b471
...@@ -302,7 +302,6 @@ public: ...@@ -302,7 +302,6 @@ public:
// bool visit_classdef(AST_ClassDef *node) override { return false; } // bool visit_classdef(AST_ClassDef *node) override { return false; }
bool visit_continue(AST_Continue* node) override { return false; } bool visit_continue(AST_Continue* node) override { return false; }
bool visit_dict(AST_Dict* node) override { return false; } bool visit_dict(AST_Dict* node) override { return false; }
bool visit_dictcomp(AST_DictComp* node) override { return false; }
bool visit_excepthandler(AST_ExceptHandler* node) override { return false; } bool visit_excepthandler(AST_ExceptHandler* node) override { return false; }
bool visit_expr(AST_Expr* node) override { return false; } bool visit_expr(AST_Expr* node) override { return false; }
bool visit_for(AST_For* node) override { return false; } bool visit_for(AST_For* node) override { return false; }
...@@ -334,11 +333,8 @@ public: ...@@ -334,11 +333,8 @@ public:
bool visit_while(AST_While* node) override { return false; } bool visit_while(AST_While* node) override { return false; }
bool visit_with(AST_With* node) override { return false; } bool visit_with(AST_With* node) override { return false; }
bool visit_yield(AST_Yield* node) override { return false; } bool visit_yield(AST_Yield* node) override { return false; }
bool visit_branch(AST_Branch* node) override { return false; } bool visit_branch(AST_Branch* node) override { return false; }
bool visit_jump(AST_Jump* node) override { return false; } bool visit_jump(AST_Jump* node) override { return false; }
bool visit_delete(AST_Delete* node) override { return false; } bool visit_delete(AST_Delete* node) override { return false; }
bool visit_global(AST_Global* node) override { bool visit_global(AST_Global* node) override {
...@@ -401,7 +397,16 @@ public: ...@@ -401,7 +397,16 @@ public:
} }
} }
bool visit_generatorexp(AST_GeneratorExp* node) override { // helper methods for visit_{generatorexp,dictcomp,setcomp}
void visit_comp_values(AST_GeneratorExp* node) { node->elt->accept(this); }
void visit_comp_values(AST_SetComp* node) { node->elt->accept(this); }
void visit_comp_values(AST_DictComp* node) {
node->key->accept(this);
node->value->accept(this);
}
template <typename CompType> bool visit_comp(CompType* node) {
// NB. comprehensions evaluate their first for-subject's expression outside of the function scope they create.
if (node == orig_node) { if (node == orig_node) {
bool first = true; bool first = true;
for (AST_comprehension* c : node->generators) { for (AST_comprehension* c : node->generators) {
...@@ -413,16 +418,19 @@ public: ...@@ -413,16 +418,19 @@ public:
first = false; first = false;
} }
node->elt->accept(this); visit_comp_values(node);
} else { } else {
node->generators[0]->iter->accept(this); node->generators[0]->iter->accept(this);
(*map)[node] = new ScopingAnalysis::ScopeNameUsage(node, cur, scoping); (*map)[node] = new ScopingAnalysis::ScopeNameUsage(node, cur, scoping);
collect(node, map, scoping); collect(node, map, scoping);
} }
return true; return true;
} }
bool visit_generatorexp(AST_GeneratorExp* node) override { return visit_comp(node); }
bool visit_dictcomp(AST_DictComp* node) override { return visit_comp(node); }
bool visit_setcomp(AST_SetComp* node) override { return visit_comp(node); }
bool visit_lambda(AST_Lambda* node) override { bool visit_lambda(AST_Lambda* node) override {
if (node == orig_node) { if (node == orig_node) {
for (AST_expr* e : node->args->args) for (AST_expr* e : node->args->args)
...@@ -562,7 +570,9 @@ void ScopingAnalysis::processNameUsages(ScopingAnalysis::NameUsageMap* usages) { ...@@ -562,7 +570,9 @@ void ScopingAnalysis::processNameUsages(ScopingAnalysis::NameUsageMap* usages) {
} }
case AST_TYPE::FunctionDef: case AST_TYPE::FunctionDef:
case AST_TYPE::Lambda: case AST_TYPE::Lambda:
case AST_TYPE::GeneratorExp: { case AST_TYPE::GeneratorExp:
case AST_TYPE::DictComp:
case AST_TYPE::SetComp: {
ScopeInfoBase* scopeInfo ScopeInfoBase* scopeInfo
= new ScopeInfoBase(parent_info, usage, usage->node, false /* usesNameLookup */); = new ScopeInfoBase(parent_info, usage, usage->node, false /* usesNameLookup */);
this->scopes[node] = scopeInfo; this->scopes[node] = scopeInfo;
......
...@@ -221,12 +221,6 @@ private: ...@@ -221,12 +221,6 @@ private:
AST_Name* remapName(AST_Name* name) { return name; } AST_Name* remapName(AST_Name* name) { return name; }
AST_expr* applyComprehensionCall(AST_DictComp* node, AST_Name* name) {
AST_expr* key = remapExpr(node->key);
AST_expr* value = remapExpr(node->value);
return makeCall(makeLoadAttribute(name, internString("__setitem__"), true), key, value);
}
AST_expr* applyComprehensionCall(AST_ListComp* node, AST_Name* name) { AST_expr* applyComprehensionCall(AST_ListComp* node, AST_Name* name) {
AST_expr* elt = remapExpr(node->elt); AST_expr* elt = remapExpr(node->elt);
return makeCall(makeLoadAttribute(name, internString("append"), true), elt); return makeCall(makeLoadAttribute(name, internString("append"), true), elt);
...@@ -786,40 +780,37 @@ private: ...@@ -786,40 +780,37 @@ private:
} }
return rtn; return rtn;
}; }
AST_expr* remapGeneratorExp(AST_GeneratorExp* node) {
assert(node->generators.size());
AST_expr* first = remapExpr(node->generators[0]->iter);
// This is a helper function used for generators expressions and comprehensions.
//
// Generates a FunctionDef which produces scope for `node'. The function produced is empty, so you'd better fill it.
// `node' had better be a kind of node that scoping_analysis thinks can carry scope (see the switch (node->type)
// block in ScopingAnalysis::processNameUsages in analysis/scoping_analysis.cpp); e.g. a Lambda or GeneratorExp.
AST_FunctionDef* makeFunctionForScope(AST* node) {
AST_FunctionDef* func = new AST_FunctionDef(); AST_FunctionDef* func = new AST_FunctionDef();
func->lineno = node->lineno; func->lineno = node->lineno;
func->col_offset = node->col_offset; func->col_offset = node->col_offset;
InternedString func_name(nodeName(func)); InternedString func_name = nodeName(func);
func->name = func_name; func->name = func_name;
scoping_analysis->registerScopeReplacement(node, func);
func->args = new AST_arguments(); func->args = new AST_arguments();
func->args->vararg = internString(""); func->args->vararg = internString("");
func->args->kwarg = internString(""); func->args->kwarg = internString("");
scoping_analysis->registerScopeReplacement(node, func); // critical bit
return func;
}
InternedString first_generator_name = nodeName(node->generators[0]); // This is a helper function used for generator expressions and comprehensions.
func->args->args.push_back(makeName(first_generator_name, AST_TYPE::Param, node->lineno)); // TODO(rntz): use this to handle unscoped (i.e. list) comprehensions as well?
void emitComprehensionLoops(std::vector<AST_stmt*>* insert_point,
std::vector<AST_stmt*>* insert_point = &func->body; const std::vector<AST_comprehension*>& comprehensions, AST_expr* first_generator,
for (int i = 0; i < node->generators.size(); i++) { std::function<void(std::vector<AST_stmt*>*)> do_yield) {
AST_comprehension* c = node->generators[i]; for (int i = 0; i < comprehensions.size(); i++) {
AST_comprehension* c = comprehensions[i];
AST_For* loop = new AST_For(); AST_For* loop = new AST_For();
loop->target = c->target; loop->target = c->target;
loop->iter = (i == 0) ? first_generator : c->iter;
if (i == 0) {
loop->iter = makeName(first_generator_name, AST_TYPE::Load, node->lineno);
} else {
loop->iter = c->iter;
}
insert_point->push_back(loop); insert_point->push_back(loop);
insert_point = &loop->body; insert_point = &loop->body;
...@@ -836,21 +827,72 @@ private: ...@@ -836,21 +827,72 @@ private:
} }
} }
AST_Yield* y = new AST_Yield(); do_yield(insert_point);
y->value = node->elt; }
insert_point->push_back(makeExpr(y));
AST_expr* remapGeneratorExp(AST_GeneratorExp* node) {
assert(node->generators.size());
// We need to evaluate the first for-expression immediately, as the PEP dictates; so we pass it in as an
// argument to the function we create. See
// https://www.python.org/dev/peps/pep-0289/#early-binding-versus-late-binding
AST_expr* first = remapExpr(node->generators[0]->iter);
InternedString first_generator_name = nodeName(node->generators[0]);
AST_FunctionDef* func = makeFunctionForScope(node);
func->args->args.push_back(makeName(first_generator_name, AST_TYPE::Param, node->lineno));
emitComprehensionLoops(&func->body, node->generators,
makeName(first_generator_name, AST_TYPE::Load, node->lineno),
[this, node](std::vector<AST_stmt*>* insert_point) {
auto y = new AST_Yield();
y->value = node->elt;
insert_point->push_back(makeExpr(y));
});
push_back(func);
return makeCall(makeName(func->name, AST_TYPE::Load, node->lineno), first);
}
void emitComprehensionYield(AST_DictComp* node, InternedString dict_name, std::vector<AST_stmt*>* insert_point) {
// add entry to the dictionary
AST_expr* setitem
= makeLoadAttribute(makeName(dict_name, AST_TYPE::Load, node->lineno), internString("__setitem__"), true);
insert_point->push_back(makeExpr(makeCall(setitem, node->key, node->value)));
}
void emitComprehensionYield(AST_SetComp* node, InternedString set_name, std::vector<AST_stmt*>* insert_point) {
// add entry to the dictionary
AST_expr* add = makeLoadAttribute(makeName(set_name, AST_TYPE::Load, node->lineno), internString("add"), true);
insert_point->push_back(makeExpr(makeCall(add, node->elt)));
}
template <typename ResultType, typename CompType> AST_expr* remapScopedComprehension(CompType* node) {
// See comment in remapGeneratorExp re early vs. late binding.
AST_expr* first = remapExpr(node->generators[0]->iter);
InternedString first_generator_name = nodeName(node->generators[0]);
AST_FunctionDef* func = makeFunctionForScope(node);
func->args->args.push_back(makeName(first_generator_name, AST_TYPE::Param, node->lineno));
InternedString rtn_name = nodeName(node);
auto asgn = new AST_Assign();
asgn->targets.push_back(makeName(rtn_name, AST_TYPE::Store, node->lineno));
asgn->value = new ResultType();
func->body.push_back(asgn);
auto lambda =
[&](std::vector<AST_stmt*>* insert_point) { emitComprehensionYield(node, rtn_name, insert_point); };
AST_Name* first_name = makeName(first_generator_name, AST_TYPE::Load, node->lineno);
emitComprehensionLoops(&func->body, node->generators, first_name, lambda);
auto rtn = new AST_Return();
rtn->value = makeName(rtn_name, AST_TYPE::Load, node->lineno);
func->body.push_back(rtn);
push_back(func); push_back(func);
AST_Call* call = new AST_Call();
call->lineno = node->lineno;
call->col_offset = node->col_offset;
call->starargs = NULL; return makeCall(makeName(func->name, AST_TYPE::Load, node->lineno), first);
call->kwargs = NULL; }
call->func = makeName(func_name, AST_TYPE::Load, node->lineno);
call->args.push_back(first);
return call;
};
AST_expr* remapIfExp(AST_IfExp* node) { AST_expr* remapIfExp(AST_IfExp* node) {
InternedString rtn_name = nodeName(node); InternedString rtn_name = nodeName(node);
...@@ -1043,7 +1085,7 @@ private: ...@@ -1043,7 +1085,7 @@ private:
rtn = remapDict(ast_cast<AST_Dict>(node)); rtn = remapDict(ast_cast<AST_Dict>(node));
break; break;
case AST_TYPE::DictComp: case AST_TYPE::DictComp:
rtn = remapComprehension<AST_Dict>(ast_cast<AST_DictComp>(node)); rtn = remapScopedComprehension<AST_Dict>(ast_cast<AST_DictComp>(node));
break; break;
case AST_TYPE::GeneratorExp: case AST_TYPE::GeneratorExp:
rtn = remapGeneratorExp(ast_cast<AST_GeneratorExp>(node)); rtn = remapGeneratorExp(ast_cast<AST_GeneratorExp>(node));
...@@ -1079,6 +1121,9 @@ private: ...@@ -1079,6 +1121,9 @@ private:
case AST_TYPE::Set: case AST_TYPE::Set:
rtn = remapSet(ast_cast<AST_Set>(node)); rtn = remapSet(ast_cast<AST_Set>(node));
break; break;
case AST_TYPE::SetComp:
rtn = remapScopedComprehension<AST_Set>(ast_cast<AST_SetComp>(node));
break;
case AST_TYPE::Slice: case AST_TYPE::Slice:
rtn = remapSlice(ast_cast<AST_Slice>(node)); rtn = remapSlice(ast_cast<AST_Slice>(node));
break; break;
......
...@@ -13,24 +13,28 @@ def f(): ...@@ -13,24 +13,28 @@ def f():
# print i, j # print i, j
f() f()
def f2(x):
print dict2str({x: i for i in [x]})
print dict2str({i: i for i in [x]})
f2(7)
# Combine a list comprehension with a bunch of other control-flow expressions: # Combine a dict comprehension with a bunch of other control-flow expressions:
def f(x, y): def f3(x, y):
# TODO make sure to use an 'if' in a comprehension where the if contains control flow # TODO make sure to use an 'if' in a comprehension where the if contains control flow
print dict2str({y if i % 3 else y ** 2 + i: (i if i%2 else i/2) for i in (xrange(4 if x else 5) if y else xrange(3))}) print dict2str({y if i % 3 else y ** 2 + i: (i if i%2 else i/2) for i in (xrange(4 if x else 5) if y else xrange(3))})
f(0, 0) f3(0, 0)
f(0, 1) f3(0, 1)
f(1, 0) f3(1, 0)
f(1, 1) f3(1, 1)
# TODO: test on ifs # TODO: test on ifs
def f(): def f4():
print dict2str({i : j for (i, j) in sorted({1:2, 3:4, 5:6, 7:8}.items())}) print dict2str({i : j for (i, j) in sorted({1:2, 3:4, 5:6, 7:8}.items())})
f() f4()
# The expr should not get evaluated if the if-condition fails: # The expr should not get evaluated if the if-condition fails:
def f(): def f5():
def p(i): def p(i):
print i print i
return i ** 2 return i ** 2
...@@ -39,21 +43,22 @@ def f(): ...@@ -39,21 +43,22 @@ def f():
return i * 4 + i return i * 4 + i
print dict2str({k(i):p(i) for i in xrange(50) if i % 5 == 0 if i % 3 == 0}) print dict2str({k(i):p(i) for i in xrange(50) if i % 5 == 0 if i % 3 == 0})
f() f5()
def f(): def f6():
print dict2str({i: j for i in xrange(4) for j in xrange(i)}) print dict2str({i: j for i in xrange(4) for j in xrange(i)})
f() f6()
def f(): def f7():
j = 1 j = 1
# The 'if' part of this list comprehension references j; # The 'if' part of this list comprehension references j;
# the first time through it will use the j above, but later times # the first time through it will use the j above, but later times
# it may-or-may-not use the j from the inner part of the listcomp. # it may-or-may-not use the j from the inner part of the listcomp.
print dict2str({i: j for i in xrange(7) if i % 2 != j % 2 for j in xrange(i)}) print dict2str({i: j for i in xrange(7) if i % 2 != j % 2 for j in xrange(i)})
f() # XXX: why is this here? if we un-indent this line, python raises an exception
f7()
def f(): def f8():
# Checking the order of evaluation of the if conditions: # Checking the order of evaluation of the if conditions:
def c1(x): def c1(x):
...@@ -65,8 +70,30 @@ def f(): ...@@ -65,8 +70,30 @@ def f():
return x % 3 == 0 return x % 3 == 0
print dict2str({i : i for i in xrange(20) if c1(i) if c2(i)}) print dict2str({i : i for i in xrange(20) if c1(i) if c2(i)})
f() f8()
def f9():
# checking that dictcomps don't contaminate our scope like listcomps do
print dict2str({i:j for i,j in [(1,2)]})
try: print i
except NameError as e: print e
try: print j
except NameError as e: print e
print dict2str({1:2 for x in xrange(4) for y in xrange(5)})
try: print x
except NameError as e: print e
try: print y
except NameError as e: print e
f9()
def f10():
x = 'for'
y = 'eva'
print dict2str({i:j for i in [x for x in xrange(6)] for j in [y for y in [1]]})
print x, y
f10()
def control_flow_in_listcomp(): def control_flow_in_dictcomp():
print dict2str({(i ** 2 if i > 5 else i ** 2 * -1):(i if i else -1) for i in (xrange(10) if True else []) if (i % 2 == 0 or i % 3 != 0)}) print dict2str({(i ** 2 if i > 5 else i ** 2 * -1):(i if i else -1) for i in (xrange(10) if True else []) if (i % 2 == 0 or i % 3 != 0)})
control_flow_in_listcomp() control_flow_in_dictcomp()
# this file is adapted from dictcomp.py
def set2str(s):
# set isn't guaranteed to keep things in sorted order (although in CPython it does AFAICT)
# https://docs.python.org/2/library/stdtypes.html#types-set
return '{%s}' % ', '.join(repr(x) for x in sorted(s))
print set2str({i for i in xrange(4)})
print set2str({(i,j) for i in xrange(4) for j in xrange(4)})
def f():
print set2str({i: j for i in range(4) for j in range(4)})
f()
def f2(x):
print set2str({(x, i) for i in [x]})
print set2str({i for i in [x]})
f2(7)
# Combine a set comprehension with a bunch of other control-flow expressions:
def f3(x, y):
print set2str({(y if i % 3 else y ** 2 + i, i if i%2 else i/2) for i in (xrange(4 if x else 5) if y else xrange(3))})
f3(0, 0)
f3(0, 1)
f3(1, 0)
f3(1, 1)
def f4():
print set2str({(i, j) for i, j in sorted({1:2, 3:4, 5:6, 7:8}.items())})
f4()
# The expr should not get evaluated if the if-condition fails:
def f5():
def p(i):
print i
return i ** 2
def k(i):
print i
return i * 4 + i
print set2str({(k(i), p(i)) for i in xrange(50) if i % 5 == 0 if i % 3 == 0})
f5()
def f6():
print set2str({(i, j) for i in xrange(4) for j in xrange(i)})
f6()
def f8():
# Checking the order of evaluation of the if conditions:
def c1(x):
print "c1", x
return x % 2 == 0
def c2(x):
print "c2", x
return x % 3 == 0
print set2str({i for i in xrange(20) if c1(i) if c2(i)})
f8()
def f9():
# checking that setcomps don't contaminate our scope like listcomps do
print set2str({i for i in [1]})
try: print i
except NameError as e: print e
print set2str({1 for x in xrange(4) for y in xrange(5)})
try: print x
except NameError as e: print e
try: print y
except NameError as e: print e
f9()
def f10():
x = 'for'
y = 'eva'
print set2str({(i,j) for i in [x for x in xrange(6)] for j in [y for y in xrange(3)]})
print x, y
f10()
def control_flow_in_setcomp():
print set2str({(i ** 2 if i > 5 else i ** 2 * -1, i if i else -1) for i in (xrange(10) if True else []) if (i % 2 == 0 or i % 3 != 0)})
control_flow_in_setcomp()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment