Commit e95f4bef authored by Marius Wachtler's avatar Marius Wachtler

Implement __cmp__ alias tp_compare

parent 3ebd2551
...@@ -848,9 +848,7 @@ PyTypeObject PyBuffer_Type = { ...@@ -848,9 +848,7 @@ PyTypeObject PyBuffer_Type = {
0, /* tp_print */ 0, /* tp_print */
0, /* tp_getattr */ 0, /* tp_getattr */
0, /* tp_setattr */ 0, /* tp_setattr */
// Pyston change: (cmpfunc)buffer_compare, /* tp_compare */
//(cmpfunc)buffer_compare, /* tp_compare */
NULL, /* tp_compare */
(reprfunc)buffer_repr, /* tp_repr */ (reprfunc)buffer_repr, /* tp_repr */
0, /* tp_as_number */ 0, /* tp_as_number */
&buffer_as_sequence, /* tp_as_sequence */ &buffer_as_sequence, /* tp_as_sequence */
......
...@@ -699,8 +699,7 @@ _PyWeakref_ProxyType = { ...@@ -699,8 +699,7 @@ _PyWeakref_ProxyType = {
0, /* tp_print */ 0, /* tp_print */
0, /* tp_getattr */ 0, /* tp_getattr */
0, /* tp_setattr */ 0, /* tp_setattr */
// Pyston change: proxy_compare, /* tp_compare */
0, //proxy_compare, /* tp_compare */
(reprfunc)proxy_repr, /* tp_repr */ (reprfunc)proxy_repr, /* tp_repr */
&proxy_as_number, /* tp_as_number */ &proxy_as_number, /* tp_as_number */
&proxy_as_sequence, /* tp_as_sequence */ &proxy_as_sequence, /* tp_as_sequence */
...@@ -737,7 +736,7 @@ _PyWeakref_CallableProxyType = { ...@@ -737,7 +736,7 @@ _PyWeakref_CallableProxyType = {
0, /* tp_print */ 0, /* tp_print */
0, /* tp_getattr */ 0, /* tp_getattr */
0, /* tp_setattr */ 0, /* tp_setattr */
0, //proxy_compare, /* tp_compare */ proxy_compare, /* tp_compare */
(unaryfunc)proxy_repr, /* tp_repr */ (unaryfunc)proxy_repr, /* tp_repr */
&proxy_as_number, /* tp_as_number */ &proxy_as_number, /* tp_as_number */
&proxy_as_sequence, /* tp_as_sequence */ &proxy_as_sequence, /* tp_as_sequence */
......
...@@ -2934,8 +2934,6 @@ extern "C" int PyType_Ready(PyTypeObject* cls) noexcept { ...@@ -2934,8 +2934,6 @@ extern "C" int PyType_Ready(PyTypeObject* cls) noexcept {
gc::registerNonheapRootObject(cls); gc::registerNonheapRootObject(cls);
// unhandled fields: // unhandled fields:
RELEASE_ASSERT(cls->tp_compare == NULL, "");
int ALLOWABLE_FLAGS = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC | Py_TPFLAGS_CHECKTYPES int ALLOWABLE_FLAGS = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC | Py_TPFLAGS_CHECKTYPES
| Py_TPFLAGS_HAVE_NEWBUFFER; | Py_TPFLAGS_HAVE_NEWBUFFER;
ALLOWABLE_FLAGS |= Py_TPFLAGS_INT_SUBCLASS | Py_TPFLAGS_LONG_SUBCLASS | Py_TPFLAGS_LIST_SUBCLASS ALLOWABLE_FLAGS |= Py_TPFLAGS_INT_SUBCLASS | Py_TPFLAGS_LONG_SUBCLASS | Py_TPFLAGS_LIST_SUBCLASS
......
...@@ -177,20 +177,29 @@ std::string getInplaceOpName(int op_type) { ...@@ -177,20 +177,29 @@ std::string getInplaceOpName(int op_type) {
// Maybe better name is "swapped" -- it's what the runtime will try if the normal op // Maybe better name is "swapped" -- it's what the runtime will try if the normal op
// name fails, it will switch the order of the lhs/rhs and call the reverse op. // name fails, it will switch the order of the lhs/rhs and call the reverse op.
// Calling it "reverse" because that's what I'm assuming the 'r' stands for in ex __radd__ // Calling it "reverse" because that's what I'm assuming the 'r' stands for in ex __radd__
std::string getReverseOpName(int op_type) { int getReverseCmpOp(int op_type, bool& success) {
success = true;
if (op_type == AST_TYPE::Lt) if (op_type == AST_TYPE::Lt)
return getOpName(AST_TYPE::Gt); return AST_TYPE::Gt;
if (op_type == AST_TYPE::LtE) if (op_type == AST_TYPE::LtE)
return getOpName(AST_TYPE::GtE); return AST_TYPE::GtE;
if (op_type == AST_TYPE::Gt) if (op_type == AST_TYPE::Gt)
return getOpName(AST_TYPE::Lt); return AST_TYPE::Lt;
if (op_type == AST_TYPE::GtE) if (op_type == AST_TYPE::GtE)
return getOpName(AST_TYPE::LtE); return AST_TYPE::LtE;
if (op_type == AST_TYPE::NotEq) if (op_type == AST_TYPE::NotEq)
return getOpName(AST_TYPE::NotEq); return AST_TYPE::NotEq;
if (op_type == AST_TYPE::Eq) if (op_type == AST_TYPE::Eq)
return getOpName(AST_TYPE::Eq); return AST_TYPE::Eq;
success = false;
return op_type;
}
std::string getReverseOpName(int op_type) {
bool reversed = false;
op_type = getReverseCmpOp(op_type, reversed);
if (reversed)
return getOpName(op_type);
const std::string& normal_name = getOpName(op_type); const std::string& normal_name = getOpName(op_type);
return "__r" + normal_name.substr(2); return "__r" + normal_name.substr(2);
} }
......
...@@ -1402,6 +1402,7 @@ template <class T, class R> void findNodes(const R& roots, std::vector<T*>& outp ...@@ -1402,6 +1402,7 @@ template <class T, class R> void findNodes(const R& roots, std::vector<T*>& outp
llvm::StringRef getOpSymbol(int op_type); llvm::StringRef getOpSymbol(int op_type);
const std::string& getOpName(int op_type); const std::string& getOpName(int op_type);
int getReverseCmpOp(int op_type, bool& success);
std::string getReverseOpName(int op_type); std::string getReverseOpName(int op_type);
std::string getInplaceOpName(int op_type); std::string getInplaceOpName(int op_type);
std::string getInplaceOpSymbol(int op_type); std::string getInplaceOpSymbol(int op_type);
......
...@@ -467,6 +467,140 @@ Box* instanceDelitem(Box* _inst, Box* key) { ...@@ -467,6 +467,140 @@ Box* instanceDelitem(Box* _inst, Box* key) {
return runtimeCall(delitem_func, ArgPassSpec(1), key, NULL, NULL, NULL, NULL); return runtimeCall(delitem_func, ArgPassSpec(1), key, NULL, NULL, NULL, NULL);
} }
/* Try a 3-way comparison, returning an int; v is an instance. Return:
-2 for an exception;
-1 if v < w;
0 if v == w;
1 if v > w;
2 if this particular 3-way comparison is not implemented or undefined.
*/
static int half_cmp(PyObject* v, PyObject* w) noexcept {
// static PyObject* cmp_obj;
PyObject* args;
PyObject* cmp_func;
PyObject* result;
long l;
assert(PyInstance_Check(v));
// Pyston change:
#if 0
if (cmp_obj == NULL) {
cmp_obj = PyString_InternFromString("__cmp__");
if (cmp_obj == NULL)
return -2;
}
cmp_func = PyObject_GetAttr(v, cmp_obj);
if (cmp_func == NULL) {
if (!PyErr_ExceptionMatches(PyExc_AttributeError))
return -2;
PyErr_Clear();
return 2;
}
#else
try {
cmp_func = _instanceGetattribute(v, boxStrConstant("__cmp__"), false);
if (!cmp_func)
return 2;
} catch (ExcInfo e) {
setCAPIException(e);
return -2;
}
#endif
args = PyTuple_Pack(1, w);
if (args == NULL) {
Py_DECREF(cmp_func);
return -2;
}
result = PyEval_CallObject(cmp_func, args);
Py_DECREF(args);
Py_DECREF(cmp_func);
if (result == NULL)
return -2;
if (result == Py_NotImplemented) {
Py_DECREF(result);
return 2;
}
l = PyInt_AsLong(result);
Py_DECREF(result);
if (l == -1 && PyErr_Occurred()) {
PyErr_SetString(PyExc_TypeError, "comparison did not return an int");
return -2;
}
return l < 0 ? -1 : l > 0 ? 1 : 0;
}
/* Try a 3-way comparison, returning an int; either v or w is an instance.
We first try a coercion. Return:
-2 for an exception;
-1 if v < w;
0 if v == w;
1 if v > w;
2 if this particular 3-way comparison is not implemented or undefined.
*/
static int instance_compare(PyObject* v, PyObject* w) noexcept {
int c;
c = PyNumber_CoerceEx(&v, &w);
if (c < 0)
return -2;
if (c == 0) {
/* If neither is now an instance, use regular comparison */
if (!PyInstance_Check(v) && !PyInstance_Check(w)) {
c = PyObject_Compare(v, w);
Py_DECREF(v);
Py_DECREF(w);
if (PyErr_Occurred())
return -2;
return c < 0 ? -1 : c > 0 ? 1 : 0;
}
} else {
/* The coercion didn't do anything.
Treat this the same as returning v and w unchanged. */
Py_INCREF(v);
Py_INCREF(w);
}
if (PyInstance_Check(v)) {
c = half_cmp(v, w);
if (c <= 1) {
Py_DECREF(v);
Py_DECREF(w);
return c;
}
}
if (PyInstance_Check(w)) {
c = half_cmp(w, v);
if (c <= 1) {
Py_DECREF(v);
Py_DECREF(w);
if (c >= -1)
c = -c;
return c;
}
}
Py_DECREF(v);
Py_DECREF(w);
return 2;
}
Box* instanceCompare(Box* _inst, Box* other) {
int rtn = instance_compare(_inst, other);
if (rtn == 2)
return NotImplemented;
if (rtn == -2)
throwCAPIException();
return boxInt(rtn);
}
Box* instanceContains(Box* _inst, Box* key) { Box* instanceContains(Box* _inst, Box* key) {
RELEASE_ASSERT(_inst->cls == instance_cls, ""); RELEASE_ASSERT(_inst->cls == instance_cls, "");
BoxedInstance* inst = static_cast<BoxedInstance*>(_inst); BoxedInstance* inst = static_cast<BoxedInstance*>(_inst);
...@@ -695,6 +829,7 @@ void setupClassobj() { ...@@ -695,6 +829,7 @@ void setupClassobj() {
instance_cls->giveAttr("__getitem__", new BoxedFunction(boxRTFunction((void*)instanceGetitem, UNKNOWN, 2))); instance_cls->giveAttr("__getitem__", new BoxedFunction(boxRTFunction((void*)instanceGetitem, UNKNOWN, 2)));
instance_cls->giveAttr("__setitem__", new BoxedFunction(boxRTFunction((void*)instanceSetitem, UNKNOWN, 3))); instance_cls->giveAttr("__setitem__", new BoxedFunction(boxRTFunction((void*)instanceSetitem, UNKNOWN, 3)));
instance_cls->giveAttr("__delitem__", new BoxedFunction(boxRTFunction((void*)instanceDelitem, UNKNOWN, 2))); instance_cls->giveAttr("__delitem__", new BoxedFunction(boxRTFunction((void*)instanceDelitem, UNKNOWN, 2)));
instance_cls->giveAttr("__cmp__", new BoxedFunction(boxRTFunction((void*)instanceCompare, UNKNOWN, 2)));
instance_cls->giveAttr("__contains__", new BoxedFunction(boxRTFunction((void*)instanceContains, UNKNOWN, 2))); instance_cls->giveAttr("__contains__", new BoxedFunction(boxRTFunction((void*)instanceContains, UNKNOWN, 2)));
instance_cls->giveAttr("__hash__", new BoxedFunction(boxRTFunction((void*)instanceHash, UNKNOWN, 1))); instance_cls->giveAttr("__hash__", new BoxedFunction(boxRTFunction((void*)instanceHash, UNKNOWN, 1)));
instance_cls->giveAttr("__iter__", new BoxedFunction(boxRTFunction((void*)instanceIter, UNKNOWN, 1))); instance_cls->giveAttr("__iter__", new BoxedFunction(boxRTFunction((void*)instanceIter, UNKNOWN, 1)));
......
...@@ -3620,6 +3620,28 @@ extern "C" Box* augbinop(Box* lhs, Box* rhs, int op_type) { ...@@ -3620,6 +3620,28 @@ extern "C" Box* augbinop(Box* lhs, Box* rhs, int op_type) {
return rtn; return rtn;
} }
static bool convert3wayCompareResultToBool(Box* v, int op_type) {
long result = PyInt_AsLong(v);
if (result == -1 && PyErr_Occurred())
throwCAPIException();
switch (op_type) {
case AST_TYPE::Eq:
return result == 0;
case AST_TYPE::NotEq:
return result != 0;
case AST_TYPE::Lt:
return result < 0;
case AST_TYPE::Gt:
return result > 0;
case AST_TYPE::LtE:
return result < 0 || result == 0;
case AST_TYPE::GtE:
return result > 0 || result == 0;
default:
RELEASE_ASSERT(0, "op type %d not implemented", op_type);
};
}
Box* compareInternal(Box* lhs, Box* rhs, int op_type, CompareRewriteArgs* rewrite_args) { Box* compareInternal(Box* lhs, Box* rhs, int op_type, CompareRewriteArgs* rewrite_args) {
if (op_type == AST_TYPE::Is || op_type == AST_TYPE::IsNot) { if (op_type == AST_TYPE::Is || op_type == AST_TYPE::IsNot) {
bool neg = (op_type == AST_TYPE::IsNot); bool neg = (op_type == AST_TYPE::IsNot);
...@@ -3704,6 +3726,18 @@ Box* compareInternal(Box* lhs, Box* rhs, int op_type, CompareRewriteArgs* rewrit ...@@ -3704,6 +3726,18 @@ Box* compareInternal(Box* lhs, Box* rhs, int op_type, CompareRewriteArgs* rewrit
if (rrtn != NULL && rrtn != NotImplemented) if (rrtn != NULL && rrtn != NotImplemented)
return rrtn; return rrtn;
std::string cmp_name = "__cmp__";
lrtn = callattrInternal1(lhs, &cmp_name, CLASS_ONLY, NULL, ArgPassSpec(1), rhs);
if (lrtn && lrtn != NotImplemented) {
return boxBool(convert3wayCompareResultToBool(lrtn, op_type));
}
rrtn = callattrInternal1(rhs, &cmp_name, CLASS_ONLY, NULL, ArgPassSpec(1), lhs);
if (rrtn && rrtn != NotImplemented) {
bool success = false;
int reversed_op = getReverseCmpOp(op_type, success);
assert(success);
return boxBool(convert3wayCompareResultToBool(rrtn, reversed_op));
}
if (op_type == AST_TYPE::Eq) if (op_type == AST_TYPE::Eq)
return boxBool(lhs == rhs); return boxBool(lhs == rhs);
......
...@@ -391,6 +391,17 @@ CREATE_UN(s_hex, PyString_FromString("hex")); ...@@ -391,6 +391,17 @@ CREATE_UN(s_hex, PyString_FromString("hex"));
#undef CREATE_BIN #undef CREATE_BIN
static int
slots_tester_compare(PyObject* x, PyObject* y)
{
printf("inside slots_tester_compare\n");
if (x < y)
return -1;
else if (x == y)
return 0;
return 1;
}
static PyNumberMethods slots_tester_as_number = { static PyNumberMethods slots_tester_as_number = {
(binaryfunc)s_add, /* nb_add */ (binaryfunc)s_add, /* nb_add */
(binaryfunc)s_subtract, /* nb_subtract */ (binaryfunc)s_subtract, /* nb_subtract */
...@@ -442,7 +453,7 @@ static PyTypeObject slots_tester_num = { ...@@ -442,7 +453,7 @@ static PyTypeObject slots_tester_num = {
0, /* tp_print */ 0, /* tp_print */
0, /* tp_getattr */ 0, /* tp_getattr */
0, /* tp_setattr */ 0, /* tp_setattr */
0, /* tp_compare */ slots_tester_compare, /* tp_compare */
0, /* tp_repr */ 0, /* tp_repr */
&slots_tester_as_number, /* tp_as_number */ &slots_tester_as_number, /* tp_as_number */
0, /* tp_as_sequence */ 0, /* tp_as_sequence */
......
# expected: fail
# we don't support tp_compare yet
s = 'Hello world'
t = buffer(s, 6, 5)
s2 = "Goodbye world"
t2 = buffer(s2, 8, 5)
print t == t2
s = 'Hello world' s = 'Hello world'
t = buffer(s, 6, 5) t = buffer(s, 6, 5)
print t print t
s2 = "Goodbye world"
t2 = buffer(s2, 8, 5)
print t2
print t == t2
...@@ -38,6 +38,8 @@ for i in xrange(3): ...@@ -38,6 +38,8 @@ for i in xrange(3):
print float(t) print float(t)
print hex(t) print hex(t)
print oct(t) print oct(t)
print slots_test.SlotsTesterNum(0) == slots_test.SlotsTesterNum(0)
print slots_test.SlotsTesterNum(0) == slots_test.SlotsTesterNum(1)
for i in slots_test.SlotsTesterSeq(6): for i in slots_test.SlotsTesterSeq(6):
print i print i
......
...@@ -59,6 +59,9 @@ class C(object): ...@@ -59,6 +59,9 @@ class C(object):
def __ge__(self, rhs): def __ge__(self, rhs):
print "ge" print "ge"
return False return False
def __cmp__(self, rhs):
print "cmp"
assert False
for i in xrange(2): for i in xrange(2):
print C("") > 2 print C("") > 2
...@@ -113,3 +116,34 @@ d = {} ...@@ -113,3 +116,34 @@ d = {}
for i in xrange(20): for i in xrange(20):
d[NonboolEq(i % 10)] = i d[NonboolEq(i % 10)] = i
print len(d), sorted(d.values()) print len(d), sorted(d.values())
class C(object):
def __init__(self, n):
self.n = n
def __eq__(self, rhs):
print "eq"
if isinstance(rhs, C):
return self.n == rhs.n
return self.n == int(rhs)
def __cmp__(self, rhs):
print "cmp"
v = 0
if isinstance(rhs, C):
v = rhs.n
else:
v = int(rhs)
if self.n < v:
return -2L
elif self.n > v:
return 2L
return 0L
for lhs in (C(0), C(1), 0, 1):
for rhs in (C(0), C(1), 0, 1):
print lhs < rhs, lhs == rhs, lhs != rhs, lhs > rhs, lhs <= rhs, lhs >= rhs
del C.__eq__
for lhs in (C(0), C(1), 0, 1):
for rhs in (C(0), C(1), 0, 1):
print lhs < rhs, lhs == rhs, lhs != rhs, lhs > rhs, lhs <= rhs, lhs >= rhs
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment