Commit 39c966d2 authored by Mark Florisson's avatar Mark Florisson

Better fused buffer runtime dispatch + dispatch restructuring + PyxCodeWriter

parent a3230e4a
...@@ -14,6 +14,7 @@ import re ...@@ -14,6 +14,7 @@ import re
import sys import sys
from string import Template from string import Template
import operator import operator
import textwrap
import Naming import Naming
import Options import Options
...@@ -376,7 +377,7 @@ class UtilityCode(UtilityCodeBase): ...@@ -376,7 +377,7 @@ class UtilityCode(UtilityCodeBase):
self.cleanup(writer, output.module_pos) self.cleanup(writer, output.module_pos)
def sub_tempita(s, context, file, name): def sub_tempita(s, context, file=None, name=None):
"Run tempita on string s with given context." "Run tempita on string s with given context."
if not s: if not s:
return None return None
...@@ -1940,6 +1941,63 @@ class PyrexCodeWriter(object): ...@@ -1940,6 +1941,63 @@ class PyrexCodeWriter(object):
def dedent(self): def dedent(self):
self.level -= 1 self.level -= 1
class PyxCodeWriter(object):
"""
Can be used for writing out some Cython code.
"""
def __init__(self, buffer=None, indent_level=0, context=None):
self.buffer = buffer or StringIOTree()
self.level = indent_level
self.context = context
self.encoding = 'ascii'
def indent(self, levels=1):
self.level += levels
def dedent(self, levels=1):
self.level -= levels
def indenter(self, line):
"""
with pyx_code.indenter("for i in range(10):"):
pyx_code.putln("print i")
"""
self.putln(line)
return self
def getvalue(self):
return unicode(self.buffer.getvalue(), self.encoding)
def putln(self, line, context=None):
context = context or self.context
if context:
line = sub_tempita(line, context)
self._putln(line)
def _putln(self, line):
self.buffer.write("%s%s\n" % (self.level * " ", line))
def put_chunk(self, chunk, context=None):
context = context or self.context
if context:
chunk = sub_tempita(chunk, context)
chunk = textwrap.dedent(chunk)
for line in chunk.splitlines():
self._putln(line)
def insertion_point(self):
return PyxCodeWriter(self.buffer.insertion_point(), self.level,
self.context)
def named_insertion_point(self, name):
setattr(self, name, self.insertion_point())
__enter__ = indent
def __exit__(self, exc_value, exc_type, exc_tb):
self.dedent()
class ClosureTempAllocator(object): class ClosureTempAllocator(object):
def __init__(self, klass): def __init__(self, klass):
......
This diff is collapsed.
...@@ -1498,6 +1498,7 @@ if VALUE is not None: ...@@ -1498,6 +1498,7 @@ if VALUE is not None:
# Create PyCFunction nodes for each specialization # Create PyCFunction nodes for each specialization
node.stats.insert(0, node.py_func) node.stats.insert(0, node.py_func)
node.py_func = self.visit(node.py_func) node.py_func = self.visit(node.py_func)
node.update_fused_defnode_entry(env)
pycfunc = ExprNodes.PyCFunctionNode.from_defnode(node.py_func, pycfunc = ExprNodes.PyCFunctionNode.from_defnode(node.py_func,
True) True)
pycfunc = ExprNodes.ProxyNode(pycfunc.coerce_to_temp(env)) pycfunc = ExprNodes.ProxyNode(pycfunc.coerce_to_temp(env))
......
...@@ -362,11 +362,11 @@ class Scope(object): ...@@ -362,11 +362,11 @@ class Scope(object):
# Return the module-level scope containing this scope. # Return the module-level scope containing this scope.
return self.outer_scope.builtin_scope() return self.outer_scope.builtin_scope()
def declare(self, name, cname, type, pos, visibility, shadow = 0): def declare(self, name, cname, type, pos, visibility, shadow = 0, is_type = 0):
# Create new entry, and add to dictionary if # Create new entry, and add to dictionary if
# name is not None. Reports a warning if already # name is not None. Reports a warning if already
# declared. # declared.
if type.is_buffer and not isinstance(self, LocalScope): if type.is_buffer and not isinstance(self, LocalScope) and not is_type:
error(pos, 'Buffer types only allowed as function local variables') error(pos, 'Buffer types only allowed as function local variables')
if not self.in_cinclude and cname and re.match("^_[_A-Z]+$", cname): if not self.in_cinclude and cname and re.match("^_[_A-Z]+$", cname):
# See http://www.gnu.org/software/libc/manual/html_node/Reserved-Names.html#Reserved-Names # See http://www.gnu.org/software/libc/manual/html_node/Reserved-Names.html#Reserved-Names
...@@ -417,7 +417,7 @@ class Scope(object): ...@@ -417,7 +417,7 @@ class Scope(object):
# Add an entry for a type definition. # Add an entry for a type definition.
if not cname: if not cname:
cname = name cname = name
entry = self.declare(name, cname, type, pos, visibility, shadow) entry = self.declare(name, cname, type, pos, visibility, shadow, True)
entry.is_type = 1 entry.is_type = 1
entry.api = api entry.api = api
if defining: if defining:
......
...@@ -231,6 +231,12 @@ class TreeFragment(object): ...@@ -231,6 +231,12 @@ class TreeFragment(object):
substitutions = nodes, substitutions = nodes,
temps = self.temps + temps, pos = pos) temps = self.temps + temps, pos = pos)
class SetPosTransform(VisitorTransform):
def __init__(self, pos):
super(SetPosTransform, self).__init__()
self.pos = pos
def visit_Node(self, node):
node.pos = self.pos
self.visitchildren(node)
return node
\ No newline at end of file
...@@ -167,3 +167,11 @@ class CythonUtilityCode(Code.UtilityCodeBase): ...@@ -167,3 +167,11 @@ class CythonUtilityCode(Code.UtilityCodeBase):
dep.declare_in_scope(dest_scope) dep.declare_in_scope(dest_scope)
return original_scope return original_scope
def declare_declarations_in_scope(declaration_string, env, private_type=True,
*args, **kwargs):
"""
Declare some declarations given as Cython code in declaration_string
in scope env.
"""
CythonUtilityCode(declaration_string, *args, **kwargs).declare_in_scope(env)
...@@ -740,8 +740,6 @@ __pyx_FusedFunction_callfunction(PyObject *func, PyObject *args, PyObject *kw) ...@@ -740,8 +740,6 @@ __pyx_FusedFunction_callfunction(PyObject *func, PyObject *args, PyObject *kw)
int static_specialized = (cyfunc->flags & __Pyx_CYFUNCTION_STATICMETHOD && int static_specialized = (cyfunc->flags & __Pyx_CYFUNCTION_STATICMETHOD &&
!((__pyx_FusedFunctionObject *) func)->__signatures__); !((__pyx_FusedFunctionObject *) func)->__signatures__);
//PyObject_Print(args, stdout, Py_PRINT_RAW);
if (cyfunc->flags & __Pyx_CYFUNCTION_CCLASS && !static_specialized) { if (cyfunc->flags & __Pyx_CYFUNCTION_CCLASS && !static_specialized) {
Py_ssize_t argc; Py_ssize_t argc;
PyObject *new_args; PyObject *new_args;
...@@ -827,8 +825,9 @@ __pyx_FusedFunction_call(PyObject *func, PyObject *args, PyObject *kw) ...@@ -827,8 +825,9 @@ __pyx_FusedFunction_call(PyObject *func, PyObject *args, PyObject *kw)
} }
if (binding_func->__signatures__) { if (binding_func->__signatures__) {
PyObject *tup = PyTuple_Pack(3, binding_func->__signatures__, args, PyObject *tup = PyTuple_Pack(4, binding_func->__signatures__, args,
kw == NULL ? Py_None : kw); kw == NULL ? Py_None : kw,
binding_func->func.defaults_tuple);
if (!tup) if (!tup)
goto __pyx_err; goto __pyx_err;
......
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
cimport numpy as np cimport numpy as np
cimport cython cimport cython
from libc.stdlib cimport malloc
def little_endian(): def little_endian():
cdef int endian_detector = 1 cdef int endian_detector = 1
return (<char*>&endian_detector)[0] != 0 return (<char*>&endian_detector)[0] != 0
...@@ -503,19 +505,28 @@ def test_point_record(): ...@@ -503,19 +505,28 @@ def test_point_record():
test[i].y = -i test[i].y = -i
print repr(test).replace('<', '!').replace('>', '!') print repr(test).replace('<', '!').replace('>', '!')
def test_fused_ndarray_dtype(np.ndarray[cython.floating, ndim=1] a): # Test fused np.ndarray dtypes and runtime dispatch
def test_fused_ndarray_floating_dtype(np.ndarray[cython.floating, ndim=1] a):
""" """
>>> import cython >>> import cython
>>> sorted(test_fused_ndarray_dtype.__signatures__) >>> sorted(test_fused_ndarray_floating_dtype.__signatures__)
['double', 'float'] ['double', 'float']
>>> test_fused_ndarray_dtype[cython.double](np.arange(10, dtype=np.float64))
>>> test_fused_ndarray_floating_dtype[cython.double](np.arange(10, dtype=np.float64))
ndarray[double,ndim=1] ndarray[double,ndim=1] 5.0 6.0 ndarray[double,ndim=1] ndarray[double,ndim=1] 5.0 6.0
>>> test_fused_ndarray_dtype[cython.float](np.arange(10, dtype=np.float32)) >>> test_fused_ndarray_floating_dtype(np.arange(10, dtype=np.float64))
ndarray[double,ndim=1] ndarray[double,ndim=1] 5.0 6.0
>>> test_fused_ndarray_floating_dtype[cython.float](np.arange(10, dtype=np.float32))
ndarray[float,ndim=1] ndarray[float,ndim=1] 5.0 6.0
>>> test_fused_ndarray_floating_dtype(np.arange(10, dtype=np.float32))
ndarray[float,ndim=1] ndarray[float,ndim=1] 5.0 6.0 ndarray[float,ndim=1] ndarray[float,ndim=1] 5.0 6.0
""" """
cdef np.ndarray[cython.floating, ndim=1] b = a cdef np.ndarray[cython.floating, ndim=1] b = a
print cython.typeof(a), cython.typeof(b), a[5], b[6] print cython.typeof(a), cython.typeof(b), a[5], b[6]
double_array = np.linspace(0, 1, 100) double_array = np.linspace(0, 1, 100)
int32_array = np.arange(100, dtype=np.int32) int32_array = np.arange(100, dtype=np.int32)
...@@ -568,4 +579,231 @@ def test_fused_cpdef_buffers(): ...@@ -568,4 +579,231 @@ def test_fused_cpdef_buffers():
cdef np.ndarray[np.int32_t] typed_array = int32_array cdef np.ndarray[np.int32_t] typed_array = int32_array
_fused_cpdef_buffers(typed_array) _fused_cpdef_buffers(typed_array)
def test_fused_ndarray_integral_dtype(np.ndarray[cython.integral, ndim=1] a):
"""
>>> import cython
>>> sorted(test_fused_ndarray_integral_dtype.__signatures__)
['int', 'long', 'short']
>>> test_fused_ndarray_integral_dtype[cython.int](np.arange(10, dtype=np.dtype('i')))
ndarray[int,ndim=1] ndarray[int,ndim=1] 5 6
>>> test_fused_ndarray_integral_dtype(np.arange(10, dtype=np.dtype('i')))
ndarray[int,ndim=1] ndarray[int,ndim=1] 5 6
>>> test_fused_ndarray_integral_dtype[cython.long](np.arange(10, dtype=np.long))
ndarray[long,ndim=1] ndarray[long,ndim=1] 5 6
>>> test_fused_ndarray_integral_dtype(np.arange(10, dtype=np.long))
ndarray[long,ndim=1] ndarray[long,ndim=1] 5 6
"""
cdef np.ndarray[cython.integral, ndim=1] b = a
print cython.typeof(a), cython.typeof(b), a[5], b[6]
cdef fused fused_dtype:
float complex
double complex
object
def test_fused_ndarray_other_dtypes(np.ndarray[fused_dtype, ndim=1] a):
"""
>>> import cython
>>> sorted(test_fused_ndarray_other_dtypes.__signatures__)
['double complex', 'float complex', 'object']
>>> test_fused_ndarray_other_dtypes(np.arange(10, dtype=np.complex64))
ndarray[float complex,ndim=1] ndarray[float complex,ndim=1] (5+0j) (6+0j)
>>> test_fused_ndarray_other_dtypes(np.arange(10, dtype=np.complex128))
ndarray[double complex,ndim=1] ndarray[double complex,ndim=1] (5+0j) (6+0j)
>>> test_fused_ndarray_other_dtypes(np.arange(10, dtype=np.object))
ndarray[Python object,ndim=1] ndarray[Python object,ndim=1] 5 6
"""
cdef np.ndarray[fused_dtype, ndim=1] b = a
print cython.typeof(a), cython.typeof(b), a[5], b[6]
# Test fusing the array types together and runtime dispatch
cdef struct Foo:
int a
float b
cdef fused fused_FooArray:
np.ndarray[Foo, ndim=1]
cdef fused fused_ndarray:
np.ndarray[float, ndim=1]
np.ndarray[double, ndim=1]
np.ndarray[Foo, ndim=1]
def get_Foo_array():
cdef Foo[:] result = <Foo[:10]> malloc(sizeof(Foo) * 10)
result[5].b = 9.0
return np.asarray(result)
def test_fused_ndarray(fused_ndarray a):
"""
>>> import cython
>>> sorted(test_fused_ndarray.__signatures__)
['ndarray[Foo,ndim=1]', 'ndarray[double,ndim=1]', 'ndarray[float,ndim=1]']
>>> test_fused_ndarray(get_Foo_array())
ndarray[Foo,ndim=1] ndarray[Foo,ndim=1]
9.0
>>> test_fused_ndarray(np.arange(10, dtype=np.float64))
ndarray[double,ndim=1] ndarray[double,ndim=1]
5.0
>>> test_fused_ndarray(np.arange(10, dtype=np.float32))
ndarray[float,ndim=1] ndarray[float,ndim=1]
5.0
"""
cdef fused_ndarray b = a
print cython.typeof(a), cython.typeof(b)
if fused_ndarray in fused_FooArray:
print b[5].b
else:
print b[5]
cpdef test_fused_cpdef_ndarray(fused_ndarray a):
"""
>>> import cython
>>> sorted(test_fused_cpdef_ndarray.__signatures__)
['ndarray[Foo,ndim=1]', 'ndarray[double,ndim=1]', 'ndarray[float,ndim=1]']
>>> test_fused_cpdef_ndarray(get_Foo_array())
ndarray[Foo,ndim=1] ndarray[Foo,ndim=1]
9.0
>>> test_fused_cpdef_ndarray(np.arange(10, dtype=np.float64))
ndarray[double,ndim=1] ndarray[double,ndim=1]
5.0
>>> test_fused_cpdef_ndarray(np.arange(10, dtype=np.float32))
ndarray[float,ndim=1] ndarray[float,ndim=1]
5.0
"""
cdef fused_ndarray b = a
print cython.typeof(a), cython.typeof(b)
if fused_ndarray in fused_FooArray:
print b[5].b
else:
print b[5]
def test_fused_cpdef_ndarray_cdef_call():
"""
>>> test_fused_cpdef_ndarray_cdef_call()
ndarray[Foo,ndim=1] ndarray[Foo,ndim=1]
9.0
"""
cdef np.ndarray[Foo, ndim=1] foo_array = get_Foo_array()
test_fused_cpdef_ndarray(foo_array)
cdef fused int_type:
np.int32_t
np.int64_t
float64_array = np.arange(10, dtype=np.float64)
float32_array = np.arange(10, dtype=np.float32)
int32_array = np.arange(10, dtype=np.int32)
int64_array = np.arange(10, dtype=np.int64)
def test_dispatch_non_clashing_declarations_repeating_types(np.ndarray[cython.floating] a1,
np.ndarray[int_type] a2,
np.ndarray[cython.floating] a3,
np.ndarray[int_type] a4):
"""
>>> test_dispatch_non_clashing_declarations_repeating_types(float64_array, int32_array, float64_array, int32_array)
1.0 2 3.0 4
>>> test_dispatch_non_clashing_declarations_repeating_types(float64_array, int64_array, float64_array, int64_array)
1.0 2 3.0 4
>>> test_dispatch_non_clashing_declarations_repeating_types(float64_array, int32_array, float64_array, int64_array)
Traceback (most recent call last):
...
TypeError: No matching signature found
"""
print a1[1], a2[2], a3[3], a4[4]
ctypedef np.int32_t typedeffed_type
cdef fused typedeffed_fused_type:
typedeffed_type
int
long
def test_dispatch_typedef(np.ndarray[typedeffed_fused_type] a):
"""
>>> test_dispatch_typedef(int32_array)
5
"""
print a[5]
cdef extern from "types.h":
ctypedef unsigned char actually_long_t
cdef fused confusing_fused_typedef:
actually_long_t
unsigned char
signed char
def test_dispatch_external_typedef(np.ndarray[confusing_fused_typedef] a):
"""
>>> test_dispatch_external_typedef(np.arange(10, dtype=np.long))
5
"""
print a[5]
# test fused memoryview slices
cdef fused memslice_fused_dtype:
float
double
int
long
float complex
double complex
object
def test_fused_memslice_other_dtypes(memslice_fused_dtype[:] a):
"""
>>> import cython
>>> sorted(test_fused_memslice_other_dtypes.__signatures__)
['double', 'double complex', 'float', 'float complex', 'int', 'long', 'object']
>>> test_fused_memslice_other_dtypes(np.arange(10, dtype=np.complex64))
float complex[:] float complex[:] (5+0j) (6+0j)
>>> test_fused_memslice_other_dtypes(np.arange(10, dtype=np.complex128))
double complex[:] double complex[:] (5+0j) (6+0j)
>>> test_fused_memslice_other_dtypes(np.arange(10, dtype=np.float32))
float[:] float[:] 5.0 6.0
>>> test_fused_memslice_other_dtypes(np.arange(10, dtype=np.dtype('i')))
int[:] int[:] 5 6
>>> test_fused_memslice_other_dtypes(np.arange(10, dtype=np.object))
object[:] object[:] 5 6
"""
cdef memslice_fused_dtype[:] b = a
print cython.typeof(a), cython.typeof(b), a[5], b[6]
cdef fused memslice_fused:
float[:]
double[:]
int[:]
long[:]
float complex[:]
double complex[:]
object[:]
def test_fused_memslice_fused(memslice_fused a):
"""
>>> import cython
>>> sorted(test_fused_memslice_fused.__signatures__)
['double complex[:]', 'double[:]', 'float complex[:]', 'float[:]', 'int[:]', 'long[:]', 'object[:]']
>>> test_fused_memslice_fused(np.arange(10, dtype=np.complex64))
float complex[:] float complex[:] (5+0j) (6+0j)
>>> test_fused_memslice_fused(np.arange(10, dtype=np.complex128))
double complex[:] double complex[:] (5+0j) (6+0j)
>>> test_fused_memslice_fused(np.arange(10, dtype=np.float32))
float[:] float[:] 5.0 6.0
>>> test_fused_memslice_fused(np.arange(10, dtype=np.dtype('i')))
int[:] int[:] 5 6
>>> test_fused_memslice_fused(np.arange(10, dtype=np.object))
object[:] object[:] 5 6
"""
cdef memslice_fused b = a
print cython.typeof(a), cython.typeof(b), a[5], b[6]
include "numpy_common.pxi" include "numpy_common.pxi"
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment