Commit 6a1d65dd authored by Kevin Modzelewski's avatar Kevin Modzelewski

Implement simple 'a in b' support

parent ad568495
...@@ -284,10 +284,12 @@ class BasicBlockTypePropagator : public ExprVisitor, public StmtVisitor { ...@@ -284,10 +284,12 @@ class BasicBlockTypePropagator : public ExprVisitor, public StmtVisitor {
CompilerType *left = getType(node->left); CompilerType *left = getType(node->left);
CompilerType *right = getType(node->comparators[0]); CompilerType *right = getType(node->comparators[0]);
if (node->ops[0] == AST_TYPE::Is || node->ops[0] == AST_TYPE::IsNot) { AST_TYPE::AST_TYPE op_type = node->ops[0];
if (op_type == AST_TYPE::Is || op_type == AST_TYPE::IsNot || op_type == AST_TYPE::In || op_type == AST_TYPE::NotIn) {
assert(node->ops.size() == 1 && "I don't think this should happen"); assert(node->ops.size() == 1 && "I don't think this should happen");
return BOOL; return BOOL;
} }
std::string name = getOpName(node->ops[0]); std::string name = getOpName(node->ops[0]);
CompilerType *attr_type = left->getattrType(&name, true); CompilerType *attr_type = left->getattrType(&name, true);
......
...@@ -67,6 +67,8 @@ std::string getOpSymbol(int op_type) { ...@@ -67,6 +67,8 @@ std::string getOpSymbol(int op_type) {
return "not"; return "not";
case AST_TYPE::NotEq: case AST_TYPE::NotEq:
return "!="; return "!=";
case AST_TYPE::NotIn:
return "not in";
case AST_TYPE::Pow: case AST_TYPE::Pow:
return "**"; return "**";
case AST_TYPE::RShift: case AST_TYPE::RShift:
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <sstream> #include <sstream>
#include <algorithm> #include <algorithm>
#include "core/ast.h"
#include "core/common.h" #include "core/common.h"
#include "core/stats.h" #include "core/stats.h"
#include "core/types.h" #include "core/types.h"
...@@ -276,6 +277,20 @@ Box* listSort1(BoxedList* self) { ...@@ -276,6 +277,20 @@ Box* listSort1(BoxedList* self) {
return None; return None;
} }
Box* listContains(BoxedList* self, Box *elt) {
int size = self->size;
for (int i = 0; i < size; i++) {
Box* e = self->elts->elts[i];
Box* cmp = compareInternal(e, elt, AST_TYPE::Eq, NULL);
bool b = nonzero(cmp);
if (b)
return True;
}
return False;
}
BoxedClass *list_iterator_cls = NULL; BoxedClass *list_iterator_cls = NULL;
extern "C" void listIteratorGCHandler(GCVisitor *v, void* p) { extern "C" void listIteratorGCHandler(GCVisitor *v, void* p) {
boxGCHandler(v, p); boxGCHandler(v, p);
...@@ -347,6 +362,7 @@ void setupList() { ...@@ -347,6 +362,7 @@ void setupList() {
list_cls->giveAttr("__add__", new BoxedFunction(boxRTFunction((void*)listAdd, NULL, 2, false))); list_cls->giveAttr("__add__", new BoxedFunction(boxRTFunction((void*)listAdd, NULL, 2, false)));
list_cls->giveAttr("sort", new BoxedFunction(boxRTFunction((void*)listSort1, NULL, 1, false))); list_cls->giveAttr("sort", new BoxedFunction(boxRTFunction((void*)listSort1, NULL, 1, false)));
list_cls->giveAttr("__contains__", new BoxedFunction(boxRTFunction((void*)listContains, BOXED_BOOL, 2, false)));
CLFunction *new_ = boxRTFunction((void*)listNew1, NULL, 1, false); CLFunction *new_ = boxRTFunction((void*)listNew1, NULL, 1, false);
addRTFunction(new_, (void*)listNew2, NULL, 2, false); addRTFunction(new_, (void*)listNew2, NULL, 2, false);
......
...@@ -1758,6 +1758,30 @@ Box* compareInternal(Box* lhs, Box* rhs, int op_type, CompareRewriteArgs *rewrit ...@@ -1758,6 +1758,30 @@ Box* compareInternal(Box* lhs, Box* rhs, int op_type, CompareRewriteArgs *rewrit
return boxBool((lhs == rhs) ^ neg); return boxBool((lhs == rhs) ^ neg);
} }
if (op_type == AST_TYPE::In || op_type == AST_TYPE::NotIn) {
// TODO do rewrite
static const std::string str_contains("__contains__");
Box* contained = callattrInternal1(rhs, &str_contains, CLASS_ONLY, NULL, 1, lhs);
if (contained == NULL) {
static const std::string str_iter("__iter__");
Box* iter = callattrInternal0(rhs, &str_iter, CLASS_ONLY, NULL, 0);
if (iter)
ASSERT(isUserDefined(rhs->cls), "%s should probably have a __contains__", getTypeName(rhs)->c_str());
RELEASE_ASSERT(iter == NULL, "need to try iterating");
Box* getitem = getattr_internal(rhs, "__getitem__", false, false, NULL, NULL);
if (getitem)
ASSERT(isUserDefined(rhs->cls), "%s should probably have a __contains__", getTypeName(rhs)->c_str());
RELEASE_ASSERT(getitem == NULL, "need to try old iteration protocol");
}
bool b = nonzero(contained);
if (op_type == AST_TYPE::NotIn)
return boxBool(!b);
return boxBool(b);
}
// Can do the guard checks after the Is/IsNot handling, since that is // Can do the guard checks after the Is/IsNot handling, since that is
// irrespective of the object classes // irrespective of the object classes
if (rewrite_args) { if (rewrite_args) {
......
...@@ -6,7 +6,5 @@ def f(n): ...@@ -6,7 +6,5 @@ def f(n):
print "f(%d)" % n print "f(%d)" % n
return n return n
f(1) <= f(2) < f(3) # f(3) shouldn't get called:
f(1) <= f(2) < f(1) < f(3)
for i in xrange(1, 4):
print i in range(6), i not in range(5)
class C(object):
def __init__(self, x):
self.x = x
def __eq__(self, y):
print "__eq__", y
return self.x
def __contains__(self, y):
print "__contains__", y
return self.x
print 1 in C("hello") # "a in b" expressions get coerced to boolean
print 2 in C("")
print 1 in [C("hello")] # True
print 2 in [C("")] # False
for i in xrange(1, 4):
print i in range(6), i not in range(5)
# expected: fail
# - exceptions
class D(object):
def __getitem__(self, idx):
print "getitem", idx
if idx >= 20:
raise IndexError()
return idx
print 10 in D()
print 1000 in D()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment