Commit c04ca703 authored by Stefan Behnel's avatar Stefan Behnel

Repair accidentally broken support for comparing non-trivial objects (e.g....

Repair accidentally broken support for comparing non-trivial objects (e.g. NumPy arrays) to integer constants.
Closes #2444.
parent 0d01f077
...@@ -3260,9 +3260,6 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, ...@@ -3260,9 +3260,6 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
return node return node
if node.type.is_pyobject: if node.type.is_pyobject:
if operator in ('Eq', 'Ne'):
ret_type = PyrexTypes.c_bint_type
else:
ret_type = PyrexTypes.py_object_type ret_type = PyrexTypes.py_object_type
elif node.type is PyrexTypes.c_bint_type and operator in ('Eq', 'Ne'): elif node.type is PyrexTypes.c_bint_type and operator in ('Eq', 'Ne'):
ret_type = PyrexTypes.c_bint_type ret_type = PyrexTypes.c_bint_type
......
...@@ -713,7 +713,7 @@ static CYTHON_INLINE {{c_ret_type}} __Pyx_PyInt_{{'' if ret_type.is_pyobject els ...@@ -713,7 +713,7 @@ static CYTHON_INLINE {{c_ret_type}} __Pyx_PyInt_{{'' if ret_type.is_pyobject els
{{py: c_op = {'Eq': '==', 'Ne': '!='}[op] }} {{py: c_op = {'Eq': '==', 'Ne': '!='}[op] }}
{{py: {{py:
return_compare = ( return_compare = (
(lambda a,b,c_op: "if ({a} {c_op} {b}) {return_true}; else {return_false};".format( (lambda a,b,c_op, return_true=return_true, return_false=return_false: "if ({a} {c_op} {b}) {return_true}; else {return_false};".format(
a=a, b=b, c_op=c_op, return_true=return_true, return_false=return_false)) a=a, b=b, c_op=c_op, return_true=return_true, return_false=return_false))
if ret_type.is_pyobject else if ret_type.is_pyobject else
(lambda a,b,c_op: "return ({a} {c_op} {b});".format(a=a, b=b, c_op=c_op)) (lambda a,b,c_op: "return ({a} {c_op} {b});".format(a=a, b=b, c_op=c_op))
......
...@@ -898,4 +898,33 @@ def test_copy_buffer(np.ndarray[double, ndim=1] a): ...@@ -898,4 +898,33 @@ def test_copy_buffer(np.ndarray[double, ndim=1] a):
return a return a
@testcase
def test_broadcast_comparison(np.ndarray[double, ndim=1] a):
"""
>>> a = np.ones(10, dtype=np.double)
>>> a0, obj0, a1, obj1 = test_broadcast_comparison(a)
>>> np.all(a0 == (a == 0)) or a0
True
>>> np.all(a1 == (a == 1)) or a1
True
>>> np.all(obj0 == (a == 0)) or obj0
True
>>> np.all(obj1 == (a == 1)) or obj1
True
>>> a = np.zeros(10, dtype=np.double)
>>> a0, obj0, a1, obj1 = test_broadcast_comparison(a)
>>> np.all(a0 == (a == 0)) or a0
True
>>> np.all(a1 == (a == 1)) or a1
True
>>> np.all(obj0 == (a == 0)) or obj0
True
>>> np.all(obj1 == (a == 1)) or obj1
True
"""
cdef object obj = a
return a == 0, obj == 0, a == 1, obj == 1
include "numpy_common.pxi" include "numpy_common.pxi"
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment