Commit 29b22283 authored by Mark Florisson's avatar Mark Florisson

Fix fused signature delimiter and ndim dispatch

parent 977c15c8
...@@ -211,7 +211,7 @@ class FusedCFuncDefNode(StatListNode): ...@@ -211,7 +211,7 @@ class FusedCFuncDefNode(StatListNode):
for fused_type in fused_types for fused_type in fused_types
] ]
node.specialized_signature_string = ', '.join(type_strings) node.specialized_signature_string = '|'.join(type_strings)
node.entry.pymethdef_cname = PyrexTypes.get_fused_cname( node.entry.pymethdef_cname = PyrexTypes.get_fused_cname(
cname, node.entry.pymethdef_cname) cname, node.entry.pymethdef_cname)
...@@ -322,7 +322,8 @@ class FusedCFuncDefNode(StatListNode): ...@@ -322,7 +322,8 @@ class FusedCFuncDefNode(StatListNode):
for dtype_category, codewriter in dtypes: for dtype_category, codewriter in dtypes:
if dtype_category: if dtype_category:
cond = '{{itemsize_match}}' cond = '{{itemsize_match}} and arg.ndim == %d' % (
specialized_type.ndim,)
if dtype.is_int: if dtype.is_int:
cond += ' and {{signed_match}}' cond += ' and {{signed_match}}'
...@@ -587,7 +588,7 @@ class FusedCFuncDefNode(StatListNode): ...@@ -587,7 +588,7 @@ class FusedCFuncDefNode(StatListNode):
candidates = [] candidates = []
for sig in signatures: for sig in signatures:
match_found = False match_found = False
for src_type, dst_type in zip(sig.strip('()').split(', '), dest_sig): for src_type, dst_type in zip(sig.strip('()').split('|'), dest_sig):
if dst_type is not None: if dst_type is not None:
if src_type == dst_type: if src_type == dst_type:
match_found = True match_found = True
......
...@@ -849,4 +849,21 @@ def test_dispatch_memoryview_object(): ...@@ -849,4 +849,21 @@ def test_dispatch_memoryview_object():
cdef int[:] m3 = <object> m cdef int[:] m3 = <object> m
test_fused_memslice(m3) test_fused_memslice(m3)
cdef fused ndim_t:
double[:]
double[:, :]
double[:, :, :]
@testcase
def test_dispatch_ndim(ndim_t array):
"""
>>> test_dispatch_ndim(np.empty(5, dtype=np.double))
double[:] 1
>>> test_dispatch_ndim(np.empty((5, 5), dtype=np.double))
double[:, :] 2
>>> test_dispatch_ndim(np.empty((5, 5, 5), dtype=np.double))
double[:, :, :] 3
"""
print cython.typeof(array), np.asarray(array).ndim
include "numpy_common.pxi" include "numpy_common.pxi"
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment