From 628ffd74faccb7dc783666c4b5135f3e008638e8 Mon Sep 17 00:00:00 2001
From: Stefan Behnel <stefan_ml@behnel.de>
Date: Fri, 13 Oct 2017 23:36:13 +0200
Subject: [PATCH] Fix the first argument special method signatures of "__eq__",
 "__lt__" and the other richcmp methods to be of "self" type, in accordance
 with to the CPython specs. See
 https://docs.python.org/3/reference/datamodel.html#emulating-container-types
 Closes #1935.

---
 Cython/Compiler/TypeSlots.py           |   2 +-
 docs/src/userguide/special_methods.rst |  30 ++---
 tests/run/ext_auto_richcmp.py          | 171 +++++++++++++++++++++----
 3 files changed, 161 insertions(+), 42 deletions(-)

diff --git a/Cython/Compiler/TypeSlots.py b/Cython/Compiler/TypeSlots.py
index afafb53d8..ae0257ff0 100644
--- a/Cython/Compiler/TypeSlots.py
+++ b/Cython/Compiler/TypeSlots.py
@@ -576,7 +576,7 @@ def get_special_method_signature(name):
     if slot:
         return slot.signature
     elif name in richcmp_special_methods:
-        return binaryfunc
+        return ibinaryfunc
     else:
         return None
 
diff --git a/docs/src/userguide/special_methods.rst b/docs/src/userguide/special_methods.rst
index 84c5950c3..3882cb84e 100644
--- a/docs/src/userguide/special_methods.rst
+++ b/docs/src/userguide/special_methods.rst
@@ -209,21 +209,21 @@ Rich comparison operators
 
 https://docs.python.org/3/reference/datamodel.html#basic-customization
 
-+-----------------------+---------------------------------------+-------------+-----------------------------------------------------+
-| __richcmp__           |x, y, int op                           | object      | Rich comparison (no direct Python equivalent)       |
-+-----------------------+---------------------------------------+-------------+-----------------------------------------------------+
-| __eq__                |x, y                                   | object      | x == y                                              |
-+-----------------------+---------------------------------------+-------------+-----------------------------------------------------+
-| __ne__                |x, y                                   | object      | x != y  (falls back to ``__eq__`` if not available) |
-+-----------------------+---------------------------------------+-------------+-----------------------------------------------------+
-| __lt__                |x, y                                   | object      | x < y                                               |
-+-----------------------+---------------------------------------+-------------+-----------------------------------------------------+
-| __gt__                |x, y                                   | object      | x > y                                               |
-+-----------------------+---------------------------------------+-------------+-----------------------------------------------------+
-| __le__                |x, y                                   | object      | x <= y                                              |
-+-----------------------+---------------------------------------+-------------+-----------------------------------------------------+
-| __ge__                |x, y                                   | object      | x >= y                                              |
-+-----------------------+---------------------------------------+-------------+-----------------------------------------------------+
++-----------------------+---------------------------------------+-------------+--------------------------------------------------------+
+| __richcmp__           |x, y, int op                           | object      | Rich comparison (no direct Python equivalent)          |
++-----------------------+---------------------------------------+-------------+--------------------------------------------------------+
+| __eq__                |self, y                                | object      | self == y                                              |
++-----------------------+---------------------------------------+-------------+--------------------------------------------------------+
+| __ne__                |self, y                                | object      | self != y  (falls back to ``__eq__`` if not available) |
++-----------------------+---------------------------------------+-------------+--------------------------------------------------------+
+| __lt__                |self, y                                | object      | self < y                                               |
++-----------------------+---------------------------------------+-------------+--------------------------------------------------------+
+| __gt__                |self, y                                | object      | self > y                                               |
++-----------------------+---------------------------------------+-------------+--------------------------------------------------------+
+| __le__                |self, y                                | object      | self <= y                                              |
++-----------------------+---------------------------------------+-------------+--------------------------------------------------------+
+| __ge__                |self, y                                | object      | self >= y                                              |
++-----------------------+---------------------------------------+-------------+--------------------------------------------------------+
 
 Arithmetic operators
 ^^^^^^^^^^^^^^^^^^^^
diff --git a/tests/run/ext_auto_richcmp.py b/tests/run/ext_auto_richcmp.py
index ef42856f9..3d8f87659 100644
--- a/tests/run/ext_auto_richcmp.py
+++ b/tests/run/ext_auto_richcmp.py
@@ -9,7 +9,7 @@ IS_PY2 = sys.version_info[0] == 2
 
 @cython.cclass
 class X(object):
-    x = cython.declare(cython.int, visibility="public")
+    x = cython.declare(cython.int)
 
     def __init__(self, x):
         self.x = x
@@ -18,6 +18,12 @@ class X(object):
         return "<%d>" % self.x
 
 
+@cython.cfunc
+@cython.locals(x=X)
+def x_of(x):
+    return x.x
+
+
 @cython.cclass
 class ClassEq(X):
     """
@@ -74,9 +80,12 @@ class ClassEq(X):
     TypeError...
     """
     def __eq__(self, other):
-        if isinstance(self, X):
-            if isinstance(other, X):
-                return self.x == other.x
+        assert 1 <= self.x <= 2
+        assert isinstance(self, ClassEq), type(self)
+        if isinstance(other, X):
+            return self.x == x_of(other)
+        elif isinstance(other, int):
+            return self.x < other
         return NotImplemented
 
 
@@ -134,9 +143,12 @@ class ClassEqNe(ClassEq):
     TypeError...
     """
     def __ne__(self, other):
-        if isinstance(self, X):
-            if isinstance(other, X):
-                return self.x != other.x
+        assert 1 <= self.x <= 2
+        assert isinstance(self, ClassEqNe), type(self)
+        if isinstance(other, X):
+            return self.x != x_of(other)
+        elif isinstance(other, int):
+            return self.x < other
         return NotImplemented
 
 
@@ -208,11 +220,34 @@ class ClassEqNeGe(ClassEqNe):
     ... else: a > b
     Traceback (most recent call last):
     TypeError...
-    """
+
+    >>> 2 <= a
+    False
+    >>> a >= 2
+    False
+    >>> 1 <= a
+    True
+    >>> a >= 1
+    True
+    >>> a >= 2
+    False
+
+    >>> if IS_PY2: raise TypeError  # doctest: +ELLIPSIS
+    ... else: 'x' <= a
+    Traceback (most recent call last):
+    TypeError...
+    >>> if IS_PY2: raise TypeError  # doctest: +ELLIPSIS
+    ... else: a >= 'x'
+    Traceback (most recent call last):
+    TypeError...
+   """
     def __ge__(self, other):
-        if isinstance(self, X):
-            if isinstance(other, X):
-                return self.x >= other.x
+        assert 1 <= self.x <= 2
+        assert isinstance(self, ClassEqNeGe), type(self)
+        if isinstance(other, X):
+            return self.x >= x_of(other)
+        elif isinstance(other, int):
+            return self.x >= other
         return NotImplemented
 
 
@@ -274,11 +309,34 @@ class ClassLe(X):
     True
     >>> b >= c
     True
+
+    >>> 2 >= a
+    True
+    >>> a <= 2
+    True
+    >>> 1 >= a
+    True
+    >>> a <= 1
+    True
+    >>> a <= 0
+    False
+
+    >>> if IS_PY2: raise TypeError  # doctest: +ELLIPSIS
+    ... else: 'x' >= a
+    Traceback (most recent call last):
+    TypeError...
+    >>> if IS_PY2: raise TypeError  # doctest: +ELLIPSIS
+    ... else: a <= 'x'
+    Traceback (most recent call last):
+    TypeError...
     """
     def __le__(self, other):
-        if isinstance(self, X):
-            if isinstance(other, X):
-                return self.x <= other.x
+        assert 1 <= self.x <= 2
+        assert isinstance(self, ClassLe), type(self)
+        if isinstance(other, X):
+            return self.x <= x_of(other)
+        elif isinstance(other, int):
+            return self.x <= other
         return NotImplemented
 
 
@@ -320,11 +378,37 @@ class ClassLt(X):
     [<1>, <1>, <2>]
     >>> sorted([b, a, c])
     [<1>, <1>, <2>]
+
+    >>> 2 > a
+    True
+    >>> a < 2
+    True
+    >>> 1 > a
+    False
+    >>> a < 1
+    False
+
+    >>> if IS_PY2: raise TypeError  # doctest: +ELLIPSIS
+    ... else: 1 < a
+    Traceback (most recent call last):
+    TypeError...
+
+    >>> if IS_PY2: raise TypeError  # doctest: +ELLIPSIS
+    ... else: 'x' > a
+    Traceback (most recent call last):
+    TypeError...
+    >>> if IS_PY2: raise TypeError  # doctest: +ELLIPSIS
+    ... else: a < 'x'
+    Traceback (most recent call last):
+    TypeError...
     """
     def __lt__(self, other):
-        if isinstance(self, X):
-            if isinstance(other, X):
-                return self.x < other.x
+        assert 1 <= self.x <= 2
+        assert isinstance(self, ClassLt), type(self)
+        if isinstance(other, X):
+            return self.x < x_of(other)
+        elif isinstance(other, int):
+            return self.x < other
         return NotImplemented
 
 
@@ -368,9 +452,12 @@ class ClassLtGtInherited(X):
     [<1>, <1>, <2>]
     """
     def __gt__(self, other):
-        if isinstance(self, X):
-            if isinstance(other, X):
-                return self.x > other.x
+        assert 1 <= self.x <= 2
+        assert isinstance(self, ClassLtGtInherited), type(self)
+        if isinstance(other, X):
+            return self.x > x_of(other)
+        elif isinstance(other, int):
+            return self.x > other
         return NotImplemented
 
 
@@ -412,17 +499,49 @@ class ClassLtGt(X):
     [<1>, <1>, <2>]
     >>> sorted([b, a, c])
     [<1>, <1>, <2>]
+
+    >>> 2 > a
+    True
+    >>> 2 < a
+    False
+    >>> a < 2
+    True
+    >>> a > 2
+    False
+
+    >>> if IS_PY2: raise TypeError  # doctest: +ELLIPSIS
+    ... else: 'x' > a
+    Traceback (most recent call last):
+    TypeError...
+    >>> if IS_PY2: raise TypeError  # doctest: +ELLIPSIS
+    ... else: 'x' < a
+    Traceback (most recent call last):
+    TypeError...
+    >>> if IS_PY2: raise TypeError  # doctest: +ELLIPSIS
+    ... else: a < 'x'
+    Traceback (most recent call last):
+    TypeError...
+    >>> if IS_PY2: raise TypeError  # doctest: +ELLIPSIS
+    ... else: a > 'x'
+    Traceback (most recent call last):
+    TypeError...
     """
     def __lt__(self, other):
-        if isinstance(self, X):
-            if isinstance(other, X):
-                return self.x < other.x
+        assert 1 <= self.x <= 2
+        assert isinstance(self, ClassLtGt), type(self)
+        if isinstance(other, X):
+            return self.x < x_of(other)
+        elif isinstance(other, int):
+            return self.x < other
         return NotImplemented
 
     def __gt__(self, other):
-        if isinstance(self, X):
-            if isinstance(other, X):
-                return self.x > other.x
+        assert 1 <= self.x <= 2
+        assert isinstance(self, ClassLtGt), type(self)
+        if isinstance(other, X):
+            return self.x > x_of(other)
+        elif isinstance(other, int):
+            return self.x > other
         return NotImplemented
 
 
-- 
2.30.9