Commit b91071cd authored by Marius Wachtler's avatar Marius Wachtler

Fix set comparisons

parent 54a9559e
......@@ -85,7 +85,7 @@ Box* dictCopy(BoxedDict* self) {
raiseExcHelper(TypeError, "descriptor 'copy' requires a 'dict' object but received a '%s'", getTypeName(self));
BoxedDict* r = new BoxedDict();
r->d.insert(self->d.begin(), self->d.end());
r->d = self->d;
return r;
}
......@@ -576,11 +576,11 @@ Box* dictEq(BoxedDict* self, Box* _rhs) {
if (self->d.size() != rhs->d.size())
return False;
for (const auto& p : *self) {
for (const auto& p : self->d) {
auto it = rhs->d.find(p.first);
if (it == rhs->d.end())
return False;
if (!nonzero(compare(p.second, it->second, AST_TYPE::Eq)))
if (!PyEq()(p.second, it->second))
return False;
}
......
......@@ -407,6 +407,9 @@ static Box* setIssubset(BoxedSet* self, Box* container) {
assert(PyAnySet_Check(container));
BoxedSet* rhs = static_cast<BoxedSet*>(container);
if (self->s.size() > rhs->s.size())
return False;
for (auto e : self->s) {
if (rhs->s.find(e) == rhs->s.end())
return False;
......@@ -421,13 +424,7 @@ static Box* setIssuperset(BoxedSet* self, Box* container) {
container = makeNewSet(set_cls, container);
}
assert(PyAnySet_Check(container));
BoxedSet* rhs = static_cast<BoxedSet*>(container);
for (auto e : rhs->s) {
if (self->s.find(e) == self->s.end())
return False;
}
return True;
return setIssubset((BoxedSet*)container, self);
}
static Box* setIsdisjoint(BoxedSet* self, Box* container) {
......@@ -473,7 +470,7 @@ Box* setCopy(BoxedSet* self) {
RELEASE_ASSERT(PyAnySet_Check(self), "");
BoxedSet* rtn = new BoxedSet();
rtn->s.insert(self->s.begin(), self->s.end());
rtn->s = self->s;
return rtn;
}
......@@ -497,24 +494,56 @@ Box* setContains(BoxedSet* self, Box* v) {
Box* setEq(BoxedSet* self, BoxedSet* rhs) {
RELEASE_ASSERT(PyAnySet_Check(self), "");
if (!PyAnySet_Check(rhs))
return NotImplemented;
return False;
if (self->s.size() != rhs->s.size())
return False;
for (auto e : self->s) {
if (!rhs->s.count(e))
return False;
}
return True;
return setIssubset(self, rhs);
}
Box* setNe(BoxedSet* self, BoxedSet* rhs) {
Box* r = setEq(self, rhs);
if (r->cls == bool_cls)
assert(r->cls == bool_cls);
return boxBool(r == False);
assert(r == NotImplemented);
return r;
}
Box* setLe(BoxedSet* self, BoxedSet* rhs) {
RELEASE_ASSERT(PyAnySet_Check(self), "");
if (!PyAnySet_Check(rhs))
raiseExcHelper(TypeError, "can only compare to a set");
return setIssubset(self, rhs);
}
Box* setLt(BoxedSet* self, BoxedSet* rhs) {
RELEASE_ASSERT(PyAnySet_Check(self), "");
if (!PyAnySet_Check(rhs))
raiseExcHelper(TypeError, "can only compare to a set");
if (self->s.size() >= rhs->s.size())
return False;
return setIssubset(self, rhs);
}
Box* setGe(BoxedSet* self, BoxedSet* rhs) {
RELEASE_ASSERT(PyAnySet_Check(self), "");
if (!PyAnySet_Check(rhs))
raiseExcHelper(TypeError, "can only compare to a set");
return setIssuperset(self, rhs);
}
Box* setGt(BoxedSet* self, BoxedSet* rhs) {
RELEASE_ASSERT(PyAnySet_Check(self), "");
if (!PyAnySet_Check(rhs))
raiseExcHelper(TypeError, "can only compare to a set");
if (self->s.size() <= rhs->s.size())
return False;
return setIssuperset(self, rhs);
}
Box* setNonzero(BoxedSet* self) {
......@@ -627,10 +656,18 @@ void setupSet() {
set_cls->giveAttr("__contains__", new BoxedFunction(boxRTFunction((void*)setContains, BOXED_BOOL, 2)));
frozenset_cls->giveAttr("__contains__", set_cls->getattr(internStringMortal("__contains__")));
set_cls->giveAttr("__eq__", new BoxedFunction(boxRTFunction((void*)setEq, UNKNOWN, 2)));
set_cls->giveAttr("__eq__", new BoxedFunction(boxRTFunction((void*)setEq, BOXED_BOOL, 2)));
frozenset_cls->giveAttr("__eq__", set_cls->getattr(internStringMortal("__eq__")));
set_cls->giveAttr("__ne__", new BoxedFunction(boxRTFunction((void*)setNe, UNKNOWN, 2)));
set_cls->giveAttr("__ne__", new BoxedFunction(boxRTFunction((void*)setNe, BOXED_BOOL, 2)));
frozenset_cls->giveAttr("__ne__", set_cls->getattr(internStringMortal("__ne__")));
set_cls->giveAttr("__le__", new BoxedFunction(boxRTFunction((void*)setLe, BOXED_BOOL, 2)));
frozenset_cls->giveAttr("__le__", set_cls->getattr(internStringMortal("__le__")));
set_cls->giveAttr("__lt__", new BoxedFunction(boxRTFunction((void*)setLt, BOXED_BOOL, 2)));
frozenset_cls->giveAttr("__lt__", set_cls->getattr(internStringMortal("__lt__")));
set_cls->giveAttr("__ge__", new BoxedFunction(boxRTFunction((void*)setGe, BOXED_BOOL, 2)));
frozenset_cls->giveAttr("__ge__", set_cls->getattr(internStringMortal("__ge__")));
set_cls->giveAttr("__gt__", new BoxedFunction(boxRTFunction((void*)setGt, BOXED_BOOL, 2)));
frozenset_cls->giveAttr("__gt__", set_cls->getattr(internStringMortal("__gt__")));
set_cls->giveAttr("__nonzero__", new BoxedFunction(boxRTFunction((void*)setNonzero, BOXED_BOOL, 1)));
frozenset_cls->giveAttr("__nonzero__", set_cls->getattr(internStringMortal("__nonzero__")));
......
......@@ -128,8 +128,12 @@ for i in xrange(10):
for s1 in set(range(5)), frozenset(range(5)):
for s2 in compare_to:
print type(s2), sorted(s2), s1.issubset(s2), s1.issuperset(s2), s1 == s2, s1 != s2, sorted(s1.difference(s2)), s1.isdisjoint(s2), sorted(s1.union(s2)), sorted(s1.intersection(s2))
print type(s2), sorted(s2), s1.issubset(s2), s1.issuperset(s2), sorted(s1.difference(s2)), s1.isdisjoint(s2), sorted(s1.union(s2)), sorted(s1.intersection(s2))
print s1 == s2, s1 != s2
try:
print s1 < s2, s1 <= s2, s1 > s2, s1 >= s2
except Exception as e:
print e
f = float('nan')
s = set([f])
print f in s, f == list(s)[0]
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment