Commit 91e36eeb authored by Kevin Modzelewski's avatar Kevin Modzelewski

Fix attrwrapper comparisons

And add some extra checking to make sure we don't make these kinds of mistakes more
parent 38ab9381
...@@ -667,7 +667,7 @@ Box* dictFromkeys(Box* cls, Box* iterable, Box* default_value) { ...@@ -667,7 +667,7 @@ Box* dictFromkeys(Box* cls, Box* iterable, Box* default_value) {
return rtn; return rtn;
} }
Box* dictEq(BoxedDict* self, Box* _rhs) { static Box* dictEq(BoxedDict* self, Box* _rhs) {
if (!PyDict_Check(self)) if (!PyDict_Check(self))
raiseExcHelper(TypeError, "descriptor '__eq__' requires a 'dict' object but received a '%s'", raiseExcHelper(TypeError, "descriptor '__eq__' requires a 'dict' object but received a '%s'",
getTypeName(self)); getTypeName(self));
...@@ -698,7 +698,7 @@ Box* dictEq(BoxedDict* self, Box* _rhs) { ...@@ -698,7 +698,7 @@ Box* dictEq(BoxedDict* self, Box* _rhs) {
Py_RETURN_TRUE; Py_RETURN_TRUE;
} }
Box* dictNe(BoxedDict* self, Box* _rhs) { static Box* dictNe(BoxedDict* self, Box* _rhs) {
Box* eq = dictEq(self, _rhs); Box* eq = dictEq(self, _rhs);
if (eq == NotImplemented) if (eq == NotImplemented)
return eq; return eq;
...@@ -993,6 +993,11 @@ static int dict_compare(BoxedDict* a, BoxedDict* b) noexcept { ...@@ -993,6 +993,11 @@ static int dict_compare(BoxedDict* a, BoxedDict* b) noexcept {
int res = 0; int res = 0;
Box* adiff, *bdiff, *aval, *bval; Box* adiff, *bdiff, *aval, *bval;
if (a->cls == attrwrapper_cls)
return dict_compare(autoDecref(attrwrapperToDict(a)), b);
if (b->cls == attrwrapper_cls)
return dict_compare(a, autoDecref(attrwrapperToDict(b)));
/* Compare lengths first */ /* Compare lengths first */
if (a->d.size() < b->d.size()) if (a->d.size() < b->d.size())
return -1; /* a is shorter */ return -1; /* a is shorter */
...@@ -1038,12 +1043,27 @@ Finished: ...@@ -1038,12 +1043,27 @@ Finished:
static PyObject* dict_richcompare(PyObject* v, PyObject* w, int op) noexcept { static PyObject* dict_richcompare(PyObject* v, PyObject* w, int op) noexcept {
Box* res; Box* res;
if (v->cls == attrwrapper_cls)
return dict_richcompare(autoDecref(attrwrapperToDict(v)), w, op);
if (w->cls == attrwrapper_cls)
return dict_richcompare(v, autoDecref(attrwrapperToDict(w)), op);
if (!PyDict_Check(v) || !PyDict_Check(w)) { if (!PyDict_Check(v) || !PyDict_Check(w)) {
res = incref(Py_NotImplemented); res = incref(Py_NotImplemented);
} else if (op == Py_EQ) { } else if (op == Py_EQ) {
try {
res = dictEq((BoxedDict*)v, (BoxedDict*)w); res = dictEq((BoxedDict*)v, (BoxedDict*)w);
} catch (ExcInfo e) {
setCAPIException(e);
return NULL;
}
} else if (op == Py_NE) { } else if (op == Py_NE) {
try {
res = dictNe((BoxedDict*)v, (BoxedDict*)w); res = dictNe((BoxedDict*)v, (BoxedDict*)w);
} catch (ExcInfo e) {
setCAPIException(e);
return NULL;
}
} else { } else {
/* Py3K warning if comparison isn't == or != */ /* Py3K warning if comparison isn't == or != */
if (PyErr_WarnPy3k("dict inequality comparisons not supported " if (PyErr_WarnPy3k("dict inequality comparisons not supported "
...@@ -1156,8 +1176,6 @@ void setupDict() { ...@@ -1156,8 +1176,6 @@ void setupDict() {
new BoxedFunction(BoxedCode::create((void*)dictInit, NONE, 1, true, true, "dict.__init__"))); new BoxedFunction(BoxedCode::create((void*)dictInit, NONE, 1, true, true, "dict.__init__")));
dict_cls->giveAttr("__repr__", new BoxedFunction(BoxedCode::create((void*)dictRepr, STR, 1, "dict.__repr__"))); dict_cls->giveAttr("__repr__", new BoxedFunction(BoxedCode::create((void*)dictRepr, STR, 1, "dict.__repr__")));
dict_cls->giveAttr("__eq__", new BoxedFunction(BoxedCode::create((void*)dictEq, UNKNOWN, 2, "dict.__eq__")));
dict_cls->giveAttr("__ne__", new BoxedFunction(BoxedCode::create((void*)dictNe, UNKNOWN, 2, "dict.__ne__")));
dict_cls->giveAttr("__hash__", incref(Py_None)); dict_cls->giveAttr("__hash__", incref(Py_None));
dict_cls->giveAttr("__iter__", new BoxedFunction(BoxedCode::create( dict_cls->giveAttr("__iter__", new BoxedFunction(BoxedCode::create(
(void*)dictIterKeys, typeFromClass(dictiterkey_cls), 1, "dict.__iter__"))); (void*)dictIterKeys, typeFromClass(dictiterkey_cls), 1, "dict.__iter__")));
......
...@@ -636,6 +636,26 @@ void BoxedClass::freeze() { ...@@ -636,6 +636,26 @@ void BoxedClass::freeze() {
assert(!is_constant); assert(!is_constant);
assert(tp_name); // otherwise debugging will be very hard assert(tp_name); // otherwise debugging will be very hard
// Check Python's "rule of three" for our builtin classes:
#ifndef NDEBUG
auto eq_str = getStaticString("__eq__");
auto ne_str = getStaticString("__ne__");
auto le_str = getStaticString("__le__");
auto lt_str = getStaticString("__lt__");
auto ge_str = getStaticString("__ge__");
auto gt_str = getStaticString("__gt__");
auto hash_str = getStaticString("__hash__");
if (this->hasattr(eq_str)) {
assert(this->hasattr(ne_str));
assert(this->hasattr(le_str));
assert(this->hasattr(lt_str));
assert(this->hasattr(ge_str));
assert(this->hasattr(gt_str));
assert(this->hasattr(hash_str));
}
#endif
auto doc_str = getStaticString("__doc__"); auto doc_str = getStaticString("__doc__");
if (!hasattr(doc_str)) if (!hasattr(doc_str))
giveAttr(incref(doc_str), boxString(tp_name)); giveAttr(incref(doc_str), boxString(tp_name));
......
...@@ -2819,21 +2819,16 @@ public: ...@@ -2819,21 +2819,16 @@ public:
return new AttrWrapperIter(self); return new AttrWrapperIter(self);
} }
static Box* eq(Box* _self, Box* _other) { static PyObject* richcompare(PyObject* v, PyObject* w, int op) noexcept {
RELEASE_ASSERT(_self->cls == attrwrapper_cls, ""); RELEASE_ASSERT(v->cls == attrwrapper_cls, "");
AttrWrapper* self = static_cast<AttrWrapper*>(_self); AttrWrapper* self = static_cast<AttrWrapper*>(v);
// In order to not have to reimplement dict cmp: just create a real dict for now and us it. // In order to not have to reimplement dict cmp: just create a real dict for now and us it.
BoxedDict* dict = (BoxedDict*)AttrWrapper::copy(_self); BoxedDict* dict = (BoxedDict*)AttrWrapper::copy(v);
AUTO_DECREF(dict); AUTO_DECREF(dict);
assert(dict->cls == dict_cls); return dict_cls->tp_richcompare(dict, w, op);
static BoxedString* eq_str = getStaticString("__eq__");
return callattrInternal<CXX, NOT_REWRITABLE>(dict, eq_str, LookupScope::CLASS_ONLY, NULL, ArgPassSpec(1),
_other, NULL, NULL, NULL, NULL);
} }
static Box* ne(Box* _self, Box* _other) { return incref(eq(_self, _other) == Py_True ? Py_False : Py_True); }
friend class AttrWrapperIter; friend class AttrWrapperIter;
}; };
...@@ -4605,10 +4600,6 @@ void setupRuntime() { ...@@ -4605,10 +4600,6 @@ void setupRuntime() {
"attrwrapper.__contains__"))); "attrwrapper.__contains__")));
attrwrapper_cls->giveAttr("has_key", new BoxedFunction(BoxedCode::create((void*)AttrWrapper::hasKey, BOXED_BOOL, 2, attrwrapper_cls->giveAttr("has_key", new BoxedFunction(BoxedCode::create((void*)AttrWrapper::hasKey, BOXED_BOOL, 2,
"attrwrapper.has_key"))); "attrwrapper.has_key")));
attrwrapper_cls->giveAttr(
"__eq__", new BoxedFunction(BoxedCode::create((void*)AttrWrapper::eq, UNKNOWN, 2, "attrwrapper.__eq__")));
attrwrapper_cls->giveAttr(
"__ne__", new BoxedFunction(BoxedCode::create((void*)AttrWrapper::ne, UNKNOWN, 2, "attrwrapper.__ne__")));
attrwrapper_cls->giveAttr( attrwrapper_cls->giveAttr(
"keys", new BoxedFunction(BoxedCode::create((void*)AttrWrapper::keys, LIST, 1, "attrwrapper.keys"))); "keys", new BoxedFunction(BoxedCode::create((void*)AttrWrapper::keys, LIST, 1, "attrwrapper.keys")));
attrwrapper_cls->giveAttr( attrwrapper_cls->giveAttr(
...@@ -4631,7 +4622,14 @@ void setupRuntime() { ...@@ -4631,7 +4622,14 @@ void setupRuntime() {
"__iter__", new BoxedFunction(BoxedCode::create((void*)AttrWrapper::iter, UNKNOWN, 1, "attrwrapper.__iter__"))); "__iter__", new BoxedFunction(BoxedCode::create((void*)AttrWrapper::iter, UNKNOWN, 1, "attrwrapper.__iter__")));
attrwrapper_cls->giveAttr("update", new BoxedFunction(BoxedCode::create((void*)AttrWrapper::update, NONE, 1, true, attrwrapper_cls->giveAttr("update", new BoxedFunction(BoxedCode::create((void*)AttrWrapper::update, NONE, 1, true,
true, "attrwrapper.update"))); true, "attrwrapper.update")));
assert(dict_cls->tp_richcompare);
attrwrapper_cls->tp_richcompare = dict_cls->tp_richcompare;
assert(dict_cls->tp_compare);
attrwrapper_cls->tp_compare = dict_cls->tp_compare;
add_operators(attrwrapper_cls);
attrwrapper_cls->freeze(); attrwrapper_cls->freeze();
assert(attrwrapper_cls->tp_richcompare == dict_cls->tp_richcompare);
attrwrapper_cls->tp_iter = AttrWrapper::iter; attrwrapper_cls->tp_iter = AttrWrapper::iter;
attrwrapper_cls->tp_as_mapping->mp_subscript = (binaryfunc)AttrWrapper::getitem<CAPI>; attrwrapper_cls->tp_as_mapping->mp_subscript = (binaryfunc)AttrWrapper::getitem<CAPI>;
attrwrapper_cls->tp_as_mapping->mp_ass_subscript = (objobjargproc)AttrWrapper::ass_sub; attrwrapper_cls->tp_as_mapping->mp_ass_subscript = (objobjargproc)AttrWrapper::ass_sub;
......
...@@ -8,3 +8,23 @@ api_test.test_attrwrapper_parse(globals()) ...@@ -8,3 +8,23 @@ api_test.test_attrwrapper_parse(globals())
def f(): def f():
pass pass
api_test.test_attrwrapper_parse(f.__dict__) api_test.test_attrwrapper_parse(f.__dict__)
f.a = 1
d = {'a': 1}
d2 = f.__dict__
assert d2 == d2
assert d == d2
assert d2 == d
assert not (d < d2)
assert not (d2 < d)
assert not (d > d2)
assert not (d2 > d)
assert d <= d2
assert d2 <= d
assert d >= d2
assert d2 >= d
assert not (d2 != d2)
assert not (d2 != d)
assert not (d != d2)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment