Commit 91e36eeb authored by Kevin Modzelewski's avatar Kevin Modzelewski

Fix attrwrapper comparisons

And add some extra checking to make sure we don't make these kinds of mistakes more
parent 38ab9381
......@@ -667,7 +667,7 @@ Box* dictFromkeys(Box* cls, Box* iterable, Box* default_value) {
return rtn;
}
Box* dictEq(BoxedDict* self, Box* _rhs) {
static Box* dictEq(BoxedDict* self, Box* _rhs) {
if (!PyDict_Check(self))
raiseExcHelper(TypeError, "descriptor '__eq__' requires a 'dict' object but received a '%s'",
getTypeName(self));
......@@ -698,7 +698,7 @@ Box* dictEq(BoxedDict* self, Box* _rhs) {
Py_RETURN_TRUE;
}
Box* dictNe(BoxedDict* self, Box* _rhs) {
static Box* dictNe(BoxedDict* self, Box* _rhs) {
Box* eq = dictEq(self, _rhs);
if (eq == NotImplemented)
return eq;
......@@ -993,6 +993,11 @@ static int dict_compare(BoxedDict* a, BoxedDict* b) noexcept {
int res = 0;
Box* adiff, *bdiff, *aval, *bval;
if (a->cls == attrwrapper_cls)
return dict_compare(autoDecref(attrwrapperToDict(a)), b);
if (b->cls == attrwrapper_cls)
return dict_compare(a, autoDecref(attrwrapperToDict(b)));
/* Compare lengths first */
if (a->d.size() < b->d.size())
return -1; /* a is shorter */
......@@ -1038,12 +1043,27 @@ Finished:
static PyObject* dict_richcompare(PyObject* v, PyObject* w, int op) noexcept {
Box* res;
if (v->cls == attrwrapper_cls)
return dict_richcompare(autoDecref(attrwrapperToDict(v)), w, op);
if (w->cls == attrwrapper_cls)
return dict_richcompare(v, autoDecref(attrwrapperToDict(w)), op);
if (!PyDict_Check(v) || !PyDict_Check(w)) {
res = incref(Py_NotImplemented);
} else if (op == Py_EQ) {
res = dictEq((BoxedDict*)v, (BoxedDict*)w);
try {
res = dictEq((BoxedDict*)v, (BoxedDict*)w);
} catch (ExcInfo e) {
setCAPIException(e);
return NULL;
}
} else if (op == Py_NE) {
res = dictNe((BoxedDict*)v, (BoxedDict*)w);
try {
res = dictNe((BoxedDict*)v, (BoxedDict*)w);
} catch (ExcInfo e) {
setCAPIException(e);
return NULL;
}
} else {
/* Py3K warning if comparison isn't == or != */
if (PyErr_WarnPy3k("dict inequality comparisons not supported "
......@@ -1156,8 +1176,6 @@ void setupDict() {
new BoxedFunction(BoxedCode::create((void*)dictInit, NONE, 1, true, true, "dict.__init__")));
dict_cls->giveAttr("__repr__", new BoxedFunction(BoxedCode::create((void*)dictRepr, STR, 1, "dict.__repr__")));
dict_cls->giveAttr("__eq__", new BoxedFunction(BoxedCode::create((void*)dictEq, UNKNOWN, 2, "dict.__eq__")));
dict_cls->giveAttr("__ne__", new BoxedFunction(BoxedCode::create((void*)dictNe, UNKNOWN, 2, "dict.__ne__")));
dict_cls->giveAttr("__hash__", incref(Py_None));
dict_cls->giveAttr("__iter__", new BoxedFunction(BoxedCode::create(
(void*)dictIterKeys, typeFromClass(dictiterkey_cls), 1, "dict.__iter__")));
......
......@@ -636,6 +636,26 @@ void BoxedClass::freeze() {
assert(!is_constant);
assert(tp_name); // otherwise debugging will be very hard
// Check Python's "rule of three" for our builtin classes:
#ifndef NDEBUG
auto eq_str = getStaticString("__eq__");
auto ne_str = getStaticString("__ne__");
auto le_str = getStaticString("__le__");
auto lt_str = getStaticString("__lt__");
auto ge_str = getStaticString("__ge__");
auto gt_str = getStaticString("__gt__");
auto hash_str = getStaticString("__hash__");
if (this->hasattr(eq_str)) {
assert(this->hasattr(ne_str));
assert(this->hasattr(le_str));
assert(this->hasattr(lt_str));
assert(this->hasattr(ge_str));
assert(this->hasattr(gt_str));
assert(this->hasattr(hash_str));
}
#endif
auto doc_str = getStaticString("__doc__");
if (!hasattr(doc_str))
giveAttr(incref(doc_str), boxString(tp_name));
......
......@@ -2819,21 +2819,16 @@ public:
return new AttrWrapperIter(self);
}
static Box* eq(Box* _self, Box* _other) {
RELEASE_ASSERT(_self->cls == attrwrapper_cls, "");
AttrWrapper* self = static_cast<AttrWrapper*>(_self);
static PyObject* richcompare(PyObject* v, PyObject* w, int op) noexcept {
RELEASE_ASSERT(v->cls == attrwrapper_cls, "");
AttrWrapper* self = static_cast<AttrWrapper*>(v);
// In order to not have to reimplement dict cmp: just create a real dict for now and us it.
BoxedDict* dict = (BoxedDict*)AttrWrapper::copy(_self);
BoxedDict* dict = (BoxedDict*)AttrWrapper::copy(v);
AUTO_DECREF(dict);
assert(dict->cls == dict_cls);
static BoxedString* eq_str = getStaticString("__eq__");
return callattrInternal<CXX, NOT_REWRITABLE>(dict, eq_str, LookupScope::CLASS_ONLY, NULL, ArgPassSpec(1),
_other, NULL, NULL, NULL, NULL);
return dict_cls->tp_richcompare(dict, w, op);
}
static Box* ne(Box* _self, Box* _other) { return incref(eq(_self, _other) == Py_True ? Py_False : Py_True); }
friend class AttrWrapperIter;
};
......@@ -4605,10 +4600,6 @@ void setupRuntime() {
"attrwrapper.__contains__")));
attrwrapper_cls->giveAttr("has_key", new BoxedFunction(BoxedCode::create((void*)AttrWrapper::hasKey, BOXED_BOOL, 2,
"attrwrapper.has_key")));
attrwrapper_cls->giveAttr(
"__eq__", new BoxedFunction(BoxedCode::create((void*)AttrWrapper::eq, UNKNOWN, 2, "attrwrapper.__eq__")));
attrwrapper_cls->giveAttr(
"__ne__", new BoxedFunction(BoxedCode::create((void*)AttrWrapper::ne, UNKNOWN, 2, "attrwrapper.__ne__")));
attrwrapper_cls->giveAttr(
"keys", new BoxedFunction(BoxedCode::create((void*)AttrWrapper::keys, LIST, 1, "attrwrapper.keys")));
attrwrapper_cls->giveAttr(
......@@ -4631,7 +4622,14 @@ void setupRuntime() {
"__iter__", new BoxedFunction(BoxedCode::create((void*)AttrWrapper::iter, UNKNOWN, 1, "attrwrapper.__iter__")));
attrwrapper_cls->giveAttr("update", new BoxedFunction(BoxedCode::create((void*)AttrWrapper::update, NONE, 1, true,
true, "attrwrapper.update")));
assert(dict_cls->tp_richcompare);
attrwrapper_cls->tp_richcompare = dict_cls->tp_richcompare;
assert(dict_cls->tp_compare);
attrwrapper_cls->tp_compare = dict_cls->tp_compare;
add_operators(attrwrapper_cls);
attrwrapper_cls->freeze();
assert(attrwrapper_cls->tp_richcompare == dict_cls->tp_richcompare);
attrwrapper_cls->tp_iter = AttrWrapper::iter;
attrwrapper_cls->tp_as_mapping->mp_subscript = (binaryfunc)AttrWrapper::getitem<CAPI>;
attrwrapper_cls->tp_as_mapping->mp_ass_subscript = (objobjargproc)AttrWrapper::ass_sub;
......
......@@ -8,3 +8,23 @@ api_test.test_attrwrapper_parse(globals())
def f():
pass
api_test.test_attrwrapper_parse(f.__dict__)
f.a = 1
d = {'a': 1}
d2 = f.__dict__
assert d2 == d2
assert d == d2
assert d2 == d
assert not (d < d2)
assert not (d2 < d)
assert not (d > d2)
assert not (d2 > d)
assert d <= d2
assert d2 <= d
assert d >= d2
assert d2 >= d
assert not (d2 != d2)
assert not (d2 != d)
assert not (d != d2)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment