Commit 8141a942 authored by Mark's avatar Mark

Merge pull request #66 from markflorisson88/fusedmerge

Fused Types
parents 5008e863 751dd58f
from Cython.Compiler.Visitor import VisitorTransform, ScopeTrackingTransform, TreeVisitor from Cython.Compiler.Visitor import VisitorTransform, ScopeTrackingTransform, TreeVisitor
from Nodes import StatListNode, SingleAssignmentNode, CFuncDefNode from Nodes import StatListNode, SingleAssignmentNode, CFuncDefNode, DefNode
from ExprNodes import DictNode, DictItemNode, NameNode, UnicodeNode, NoneNode, \ from ExprNodes import DictNode, DictItemNode, NameNode, UnicodeNode, NoneNode, \
ExprNode, AttributeNode, ModuleRefNode, DocstringRefNode ExprNode, AttributeNode, ModuleRefNode, DocstringRefNode
from PyrexTypes import py_object_type from PyrexTypes import py_object_type
...@@ -63,7 +63,7 @@ class AutoTestDictTransform(ScopeTrackingTransform): ...@@ -63,7 +63,7 @@ class AutoTestDictTransform(ScopeTrackingTransform):
return node return node
def visit_FuncDefNode(self, node): def visit_FuncDefNode(self, node):
if not node.doc: if not node.doc or (isinstance(node, DefNode) and node.fused_py_func):
return node return node
if not self.cdef_docstrings: if not self.cdef_docstrings:
if isinstance(node, CFuncDefNode) and not node.py_func: if isinstance(node, CFuncDefNode) and not node.py_func:
......
...@@ -263,7 +263,6 @@ class UtilityCode(UtilityCodeBase): ...@@ -263,7 +263,6 @@ class UtilityCode(UtilityCodeBase):
def get_tree(self): def get_tree(self):
pass pass
def specialize(self, pyrex_type=None, tempita=False, **data): def specialize(self, pyrex_type=None, tempita=False, **data):
# Dicts aren't hashable... # Dicts aren't hashable...
if pyrex_type is not None: if pyrex_type is not None:
......
...@@ -18,6 +18,13 @@ class CythonScope(ModuleScope): ...@@ -18,6 +18,13 @@ class CythonScope(ModuleScope):
# The Main.Context object # The Main.Context object
self.context = context self.context = context
for fused_type in (cy_integral_type, cy_floating_type, cy_numeric_type):
entry = self.declare_typedef(fused_type.name,
fused_type,
None,
cname='<error>')
entry.in_cinclude = True
def lookup_type(self, name): def lookup_type(self, name):
# This function should go away when types are all first-level objects. # This function should go away when types are all first-level objects.
type = parse_basic_type(name) type = parse_basic_type(name)
...@@ -114,6 +121,7 @@ class CythonScope(ModuleScope): ...@@ -114,6 +121,7 @@ class CythonScope(ModuleScope):
view_utility_scope = MemoryView.view_utility_code.declare_in_scope( view_utility_scope = MemoryView.view_utility_code.declare_in_scope(
viewscope, cython_scope=self) viewscope, cython_scope=self)
# MemoryView.memview_fromslice_utility_code.from_scope = view_utility_scope # MemoryView.memview_fromslice_utility_code.from_scope = view_utility_scope
# MemoryView.memview_fromslice_utility_code.declare_in_scope(viewscope) # MemoryView.memview_fromslice_utility_code.declare_in_scope(viewscope)
...@@ -124,7 +132,6 @@ def create_cython_scope(context): ...@@ -124,7 +132,6 @@ def create_cython_scope(context):
# it across different contexts) # it across different contexts)
return CythonScope(context) return CythonScope(context)
# Load test utilities for the cython scope # Load test utilities for the cython scope
def load_testscope_utility(cy_util_name, **kwargs): def load_testscope_utility(cy_util_name, **kwargs):
......
This diff is collapsed.
...@@ -130,9 +130,6 @@ def get_buf_flags(specs): ...@@ -130,9 +130,6 @@ def get_buf_flags(specs):
return memview_strided_access return memview_strided_access
def use_cython_array(env):
env.use_utility_code(cython_array_utility_code)
def src_conforms_to_dst(src, dst): def src_conforms_to_dst(src, dst):
''' '''
returns True if src conforms to dst, False otherwise. returns True if src conforms to dst, False otherwise.
...@@ -171,15 +168,18 @@ def valid_memslice_dtype(dtype): ...@@ -171,15 +168,18 @@ def valid_memslice_dtype(dtype):
return ( return (
dtype.is_error or dtype.is_error or
# Pointers are not valid (yet)
# (dtype.is_ptr and valid_memslice_dtype(dtype.base_type)) or
dtype.is_numeric or dtype.is_numeric or
dtype.is_struct or dtype.is_struct or
dtype.is_pyobject or dtype.is_pyobject or
dtype.is_fused or # accept this as it will be replaced by specializations later
(dtype.is_typedef and valid_memslice_dtype(dtype.typedef_base_type)) (dtype.is_typedef and valid_memslice_dtype(dtype.typedef_base_type))
) )
def validate_memslice_dtype(pos, dtype): def validate_memslice_dtype(pos, dtype):
if not valid_memslice_dtype(dtype): if not valid_memslice_dtype(dtype):
error(pos, "Invalid base type for memoryview slice") error(pos, "Invalid base type for memoryview slice: %s" % dtype)
class MemoryViewSliceBufferEntry(Buffer.BufferEntry): class MemoryViewSliceBufferEntry(Buffer.BufferEntry):
...@@ -936,6 +936,10 @@ def load_memview_c_utility(util_code_name, context=None, **kwargs): ...@@ -936,6 +936,10 @@ def load_memview_c_utility(util_code_name, context=None, **kwargs):
return UtilityCode.load(util_code_name, "MemoryView_C.c", return UtilityCode.load(util_code_name, "MemoryView_C.c",
context=context, **kwargs) context=context, **kwargs)
def use_cython_array_utility_code(env):
env.global_scope().context.cython_scope.lookup('array_cwrapper').used = True
env.use_utility_code(cython_array_utility_code)
context = { context = {
'memview_struct_name': memview_objstruct_cname, 'memview_struct_name': memview_objstruct_cname,
'max_dims': Options.buffer_max_dims, 'max_dims': Options.buffer_max_dims,
......
...@@ -591,6 +591,10 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -591,6 +591,10 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
PyCode_New(a, k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos) PyCode_New(a, k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos)
#endif #endif
#if PY_MAJOR_VERSION < 3 && PY_MINOR_VERSION < 6
#define PyUnicode_FromString(s) PyUnicode_Decode(s, strlen(s), "UTF-8", "strict")
#endif
#if PY_MAJOR_VERSION >= 3 #if PY_MAJOR_VERSION >= 3
#define Py_TPFLAGS_CHECKTYPES 0 #define Py_TPFLAGS_CHECKTYPES 0
#define Py_TPFLAGS_HAVE_INDEX 0 #define Py_TPFLAGS_HAVE_INDEX 0
...@@ -985,6 +989,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -985,6 +989,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
# Generate struct declaration for an extension type's vtable. # Generate struct declaration for an extension type's vtable.
type = entry.type type = entry.type
scope = type.scope scope = type.scope
self.specialize_fused_types(scope)
if type.vtabstruct_cname: if type.vtabstruct_cname:
code.putln("") code.putln("")
code.putln( code.putln(
...@@ -1128,6 +1135,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1128,6 +1135,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
def generate_cfunction_declarations(self, env, code, definition): def generate_cfunction_declarations(self, env, code, definition):
for entry in env.cfunc_entries: for entry in env.cfunc_entries:
if entry.used:
generate_cfunction_declaration(entry, env, code, definition) generate_cfunction_declaration(entry, env, code, definition)
def generate_variable_definitions(self, env, code): def generate_variable_definitions(self, env, code):
...@@ -1800,6 +1808,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1800,6 +1808,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
"static PyMethodDef %s[] = {" % "static PyMethodDef %s[] = {" %
env.method_table_cname) env.method_table_cname)
for entry in env.pyfunc_entries: for entry in env.pyfunc_entries:
if not entry.fused_cfunction:
code.put_pymethoddef(entry, ",") code.put_pymethoddef(entry, ",")
code.putln( code.putln(
"{0, 0, 0, 0}") "{0, 0, 0, 0}")
...@@ -1928,6 +1937,10 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1928,6 +1937,10 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("if (__Pyx_CyFunction_init() < 0) %s" % code.error_goto(self.pos)) code.putln("if (__Pyx_CyFunction_init() < 0) %s" % code.error_goto(self.pos))
code.putln("#endif") code.putln("#endif")
code.putln("#ifdef __Pyx_FusedFunction_USED")
code.putln("if (__pyx_FusedFunction_init() < 0) %s" % code.error_goto(self.pos))
code.putln("#endif")
code.putln("/*--- Library function declarations ---*/") code.putln("/*--- Library function declarations ---*/")
env.generate_library_function_declarations(code) env.generate_library_function_declarations(code)
...@@ -1983,6 +1996,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1983,6 +1996,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("/*--- Function import code ---*/") code.putln("/*--- Function import code ---*/")
for module in imported_modules: for module in imported_modules:
self.specialize_fused_types(module)
self.generate_c_function_import_code_for_module(module, env, code) self.generate_c_function_import_code_for_module(module, env, code)
code.putln("/*--- Execution code ---*/") code.putln("/*--- Execution code ---*/")
...@@ -2200,6 +2214,18 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -2200,6 +2214,18 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
if entry.defined_in_pxd: if entry.defined_in_pxd:
self.generate_type_import_code(env, entry.type, entry.pos, code) self.generate_type_import_code(env, entry.type, entry.pos, code)
def specialize_fused_types(self, pxd_env):
"""
If fused c(p)def functions are defined in an imported pxd, but not
used in this implementation file, we still have fused entries and
not specialized ones. This method replaces any fused entries with their
specialized ones.
"""
for entry in pxd_env.cfunc_entries[:]:
if entry.type.is_fused:
# This call modifies the cfunc_entries in-place
entry.type.get_all_specific_function_types()
def generate_c_variable_import_code_for_module(self, module, env, code): def generate_c_variable_import_code_for_module(self, module, env, code):
# Generate import code for all exported C functions in a cimported module. # Generate import code for all exported C functions in a cimported module.
entries = [] entries = []
...@@ -2232,7 +2258,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -2232,7 +2258,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
# Generate import code for all exported C functions in a cimported module. # Generate import code for all exported C functions in a cimported module.
entries = [] entries = []
for entry in module.cfunc_entries: for entry in module.cfunc_entries:
if entry.defined_in_pxd: if entry.defined_in_pxd and entry.used:
entries.append(entry) entries.append(entry)
if entries: if entries:
env.use_utility_code(import_module_utility_code) env.use_utility_code(import_module_utility_code)
...@@ -2441,7 +2467,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -2441,7 +2467,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
def generate_cfunction_declaration(entry, env, code, definition): def generate_cfunction_declaration(entry, env, code, definition):
from_cy_utility = entry.used and entry.utility_code_definition from_cy_utility = entry.used and entry.utility_code_definition
if entry.inline_func_in_pxd or (not entry.in_cinclude and (definition if entry.used and entry.inline_func_in_pxd or (not entry.in_cinclude and (definition
or entry.defined_in_pxd or entry.visibility == 'extern' or from_cy_utility)): or entry.defined_in_pxd or entry.visibility == 'extern' or from_cy_utility)):
if entry.visibility == 'extern': if entry.visibility == 'extern':
storage_class = "%s " % Naming.extern_c_macro storage_class = "%s " % Naming.extern_c_macro
......
...@@ -92,6 +92,7 @@ enc_scope_cname = pyrex_prefix + "enc_scope" ...@@ -92,6 +92,7 @@ enc_scope_cname = pyrex_prefix + "enc_scope"
frame_cname = pyrex_prefix + "frame" frame_cname = pyrex_prefix + "frame"
frame_code_cname = pyrex_prefix + "frame_code" frame_code_cname = pyrex_prefix + "frame_code"
binding_cfunc = pyrex_prefix + "binding_PyCFunctionType" binding_cfunc = pyrex_prefix + "binding_PyCFunctionType"
fused_func_prefix = pyrex_prefix + 'fuse_'
quick_temp_cname = pyrex_prefix + "temp" # temp variable for quick'n'dirty temping quick_temp_cname = pyrex_prefix + "temp" # temp variable for quick'n'dirty temping
genexpr_id_ref = 'genexpr' genexpr_id_ref = 'genexpr'
......
This diff is collapsed.
...@@ -2992,8 +2992,18 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): ...@@ -2992,8 +2992,18 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
literal nodes at each step. Non-literal nodes are never merged literal nodes at each step. Non-literal nodes are never merged
into a single node. into a single node.
""" """
def __init__(self, reevaluate=False):
"""
The reevaluate argument specifies whether constant values that were
previously computed should be recomputed.
"""
super(ConstantFolding, self).__init__()
self.reevaluate = reevaluate
def _calculate_const(self, node): def _calculate_const(self, node):
if node.constant_result is not ExprNodes.constant_value_not_set: if (not self.reevaluate and
node.constant_result is not ExprNodes.constant_value_not_set):
return return
# make sure we always set the value # make sure we always set the value
......
...@@ -629,8 +629,9 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -629,8 +629,9 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
'operator.comma' : ExprNodes.c_binop_constructor(','), 'operator.comma' : ExprNodes.c_binop_constructor(','),
} }
special_methods = set(['declare', 'union', 'struct', 'typedef', 'sizeof', special_methods = set(['declare', 'union', 'struct', 'typedef',
'cast', 'pointer', 'compiled', 'NULL', 'parallel']) 'sizeof', 'cast', 'pointer', 'compiled',
'NULL', 'fused_type', 'parallel'])
special_methods.update(unop_method_nodes.keys()) special_methods.update(unop_method_nodes.keys())
valid_parallel_directives = set([ valid_parallel_directives = set([
...@@ -1381,10 +1382,13 @@ if VALUE is not None: ...@@ -1381,10 +1382,13 @@ if VALUE is not None:
count += 1 count += 1
""") """)
fused_function = None
def __call__(self, root): def __call__(self, root):
self.env_stack = [root.scope] self.env_stack = [root.scope]
# needed to determine if a cdef var is declared after it's used. # needed to determine if a cdef var is declared after it's used.
self.seen_vars_stack = [] self.seen_vars_stack = []
self.fused_error_funcs = set()
return super(AnalyseDeclarationsTransform, self).__call__(root) return super(AnalyseDeclarationsTransform, self).__call__(root)
def visit_NameNode(self, node): def visit_NameNode(self, node):
...@@ -1424,9 +1428,20 @@ if VALUE is not None: ...@@ -1424,9 +1428,20 @@ if VALUE is not None:
return node return node
def visit_FuncDefNode(self, node): def visit_FuncDefNode(self, node):
"""
Analyse a function and its body, as that hasn't happend yet. Also
analyse the directive_locals set by @cython.locals(). Then, if we are
a function with fused arguments, replace the function (after it has
declared itself in the symbol table!) with a FusedCFuncDefNode, and
analyse its children (which are in turn normal functions). If we're a
normal function, just analyse the body of the function.
"""
env = self.env_stack[-1]
self.seen_vars_stack.append(set()) self.seen_vars_stack.append(set())
lenv = node.local_scope lenv = node.local_scope
node.declare_arguments(lenv) node.declare_arguments(lenv)
for var, type_node in node.directive_locals.items(): for var, type_node in node.directive_locals.items():
if not lenv.lookup_here(var): # don't redeclare args if not lenv.lookup_here(var): # don't redeclare args
type = type_node.analyse_as_type(lenv) type = type_node.analyse_as_type(lenv)
...@@ -1434,6 +1449,37 @@ if VALUE is not None: ...@@ -1434,6 +1449,37 @@ if VALUE is not None:
lenv.declare_var(var, type, type_node.pos) lenv.declare_var(var, type, type_node.pos)
else: else:
error(type_node.pos, "Not a type") error(type_node.pos, "Not a type")
if node.is_generator and node.has_fused_arguments:
node.has_fused_arguments = False
error(node.pos, "Fused generators not supported")
node.gbody = Nodes.StatListNode(node.pos,
stats=[],
body=Nodes.PassStatNode(node.pos))
if node.has_fused_arguments:
if self.fused_function:
if self.fused_function not in self.fused_error_funcs:
error(node.pos, "Cannot nest fused functions")
self.fused_error_funcs.add(self.fused_function)
# env.declare_var(node.name, PyrexTypes.py_object_type, node.pos)
node = Nodes.SingleAssignmentNode(
node.pos,
lhs=ExprNodes.NameNode(node.pos, name=node.name),
rhs=ExprNodes.NoneNode(node.pos))
node.analyse_declarations(env)
return node
node = Nodes.FusedCFuncDefNode(node, env)
self.fused_function = node
self.visitchildren(node)
self.fused_function = None
if node.py_func:
node.stats.insert(0, node.py_func)
else:
node.body.analyse_declarations(lenv) node.body.analyse_declarations(lenv)
if lenv.nogil and lenv.has_with_gil_block: if lenv.nogil and lenv.has_with_gil_block:
...@@ -1450,6 +1496,7 @@ if VALUE is not None: ...@@ -1450,6 +1496,7 @@ if VALUE is not None:
self.env_stack.append(lenv) self.env_stack.append(lenv)
self.visitchildren(node) self.visitchildren(node)
self.env_stack.pop() self.env_stack.pop()
self.seen_vars_stack.pop() self.seen_vars_stack.pop()
return node return node
...@@ -1628,15 +1675,18 @@ if VALUE is not None: ...@@ -1628,15 +1675,18 @@ if VALUE is not None:
class AnalyseExpressionsTransform(CythonTransform): class AnalyseExpressionsTransform(CythonTransform):
def visit_ModuleNode(self, node): def visit_ModuleNode(self, node):
self.env_stack = [node.scope]
node.scope.infer_types() node.scope.infer_types()
node.body.analyse_expressions(node.scope) node.body.analyse_expressions(node.scope)
self.visitchildren(node) self.visitchildren(node)
return node return node
def visit_FuncDefNode(self, node): def visit_FuncDefNode(self, node):
self.env_stack.append(node.local_scope)
node.local_scope.infer_types() node.local_scope.infer_types()
node.body.analyse_expressions(node.local_scope) node.body.analyse_expressions(node.local_scope)
self.visitchildren(node) self.visitchildren(node)
self.env_stack.pop()
return node return node
def visit_ScopedExprNode(self, node): def visit_ScopedExprNode(self, node):
...@@ -1646,6 +1696,24 @@ class AnalyseExpressionsTransform(CythonTransform): ...@@ -1646,6 +1696,24 @@ class AnalyseExpressionsTransform(CythonTransform):
self.visitchildren(node) self.visitchildren(node)
return node return node
def visit_IndexNode(self, node):
"""
Replace index nodes used to specialize cdef functions with fused
argument types with the Attribute- or NameNode referring to the
function. We then need to copy over the specialization properties to
the attribute or name node.
Because the indexing might be a Python indexing operation on a fused
function, or (usually) a Cython indexing operation, we need to
re-analyse the types.
"""
self.visit_Node(node)
if node.is_fused_index and node.type is not PyrexTypes.error_type:
node = node.base
return node
class ExpandInplaceOperators(EnvTransform): class ExpandInplaceOperators(EnvTransform):
...@@ -2084,6 +2152,10 @@ class CreateClosureClasses(CythonTransform): ...@@ -2084,6 +2152,10 @@ class CreateClosureClasses(CythonTransform):
target_module_scope.check_c_class(func_scope.scope_class) target_module_scope.check_c_class(func_scope.scope_class)
def visit_LambdaNode(self, node): def visit_LambdaNode(self, node):
if not isinstance(node.def_node, Nodes.DefNode):
# fused function, an error has been previously issued
return node
was_in_lambda = self.in_lambda was_in_lambda = self.in_lambda
self.in_lambda = True self.in_lambda = True
self.create_class_from_scope(node.def_node, self.module_scope, node) self.create_class_from_scope(node.def_node, self.module_scope, node)
...@@ -2396,6 +2468,95 @@ class TransformBuiltinMethods(EnvTransform): ...@@ -2396,6 +2468,95 @@ class TransformBuiltinMethods(EnvTransform):
return node return node
class ReplaceFusedTypeChecks(VisitorTransform):
"""
This is not a transform in the pipeline. It is invoked on the specific
versions of a cdef function with fused argument types. It filters out any
type branches that don't match. e.g.
if fused_t is mytype:
...
elif fused_t in other_fused_type:
...
"""
# Defer the import until now to avoid circularity...
from Cython.Compiler import Optimize
transform = Optimize.ConstantFolding(reevaluate=True)
def __init__(self, local_scope):
super(ReplaceFusedTypeChecks, self).__init__()
self.local_scope = local_scope
def visit_IfStatNode(self, node):
"""
Filters out any if clauses with false compile time type check
expression.
"""
self.visitchildren(node)
return self.transform(node)
def visit_PrimaryCmpNode(self, node):
type1 = node.operand1.analyse_as_type(self.local_scope)
type2 = node.operand2.analyse_as_type(self.local_scope)
if type1 and type2:
false_node = ExprNodes.BoolNode(node.pos, value=False)
true_node = ExprNodes.BoolNode(node.pos, value=True)
type1 = self.specialize_type(type1, node.operand1.pos)
op = node.operator
if op in ('is', 'is_not', '==', '!='):
type2 = self.specialize_type(type2, node.operand2.pos)
is_same = type1.same_as(type2)
eq = op in ('is', '==')
if (is_same and eq) or (not is_same and not eq):
return true_node
elif op in ('in', 'not_in'):
# We have to do an instance check directly, as operand2
# needs to be a fused type and not a type with a subtype
# that is fused. First unpack the typedef
if isinstance(type2, PyrexTypes.CTypedefType):
type2 = type2.typedef_base_type
if type1.is_fused:
error(node.operand1.pos, "Type is fused")
elif not type2.is_fused:
error(node.operand2.pos,
"Can only use 'in' or 'not in' on a fused type")
else:
types = PyrexTypes.get_specialized_types(type2)
for specific_type in types:
if type1.same_as(specific_type):
if op == 'in':
return true_node
else:
return false_node
if op == 'not_in':
return true_node
return false_node
return node
def specialize_type(self, type, pos):
try:
return type.specialize(self.local_scope.fused_to_specific)
except KeyError:
error(pos, "Type is not specific")
return type
def visit_Node(self, node):
self.visitchildren(node)
return node
class DebugTransform(CythonTransform): class DebugTransform(CythonTransform):
""" """
Write debug information for this Cython module. Write debug information for this Cython module.
......
...@@ -134,9 +134,10 @@ cdef p_buffer_or_template(PyrexScanner s, base_type_node, templates) ...@@ -134,9 +134,10 @@ cdef p_buffer_or_template(PyrexScanner s, base_type_node, templates)
cdef is_memoryviewslice_access(PyrexScanner s) cdef is_memoryviewslice_access(PyrexScanner s)
cdef p_memoryviewslice_access(PyrexScanner s, base_type_node) cdef p_memoryviewslice_access(PyrexScanner s, base_type_node)
cdef bint looking_at_name(PyrexScanner s) except -2 cdef bint looking_at_name(PyrexScanner s) except -2
cdef bint looking_at_expr(PyrexScanner s) except -2 cdef object looking_at_expr(PyrexScanner s)# except -2
cdef bint looking_at_base_type(PyrexScanner s) except -2 cdef bint looking_at_base_type(PyrexScanner s) except -2
cdef bint looking_at_dotted_name(PyrexScanner s) except -2 cdef bint looking_at_dotted_name(PyrexScanner s) except -2
cdef bint looking_at_call(PyrexScanner s) except -2
cdef p_sign_and_longness(PyrexScanner s) cdef p_sign_and_longness(PyrexScanner s)
cdef p_opt_cname(PyrexScanner s) cdef p_opt_cname(PyrexScanner s)
cpdef p_c_declarator(PyrexScanner s, ctx = *, bint empty = *, bint is_type = *, bint cmethod_flag = *, cpdef p_c_declarator(PyrexScanner s, ctx = *, bint empty = *, bint is_type = *, bint cmethod_flag = *,
...@@ -161,6 +162,7 @@ cdef p_c_enum_definition(PyrexScanner s, pos, ctx) ...@@ -161,6 +162,7 @@ cdef p_c_enum_definition(PyrexScanner s, pos, ctx)
cdef p_c_enum_line(PyrexScanner s, ctx, list items) cdef p_c_enum_line(PyrexScanner s, ctx, list items)
cdef p_c_enum_item(PyrexScanner s, ctx, list items) cdef p_c_enum_item(PyrexScanner s, ctx, list items)
cdef p_c_struct_or_union_definition(PyrexScanner s, pos, ctx) cdef p_c_struct_or_union_definition(PyrexScanner s, pos, ctx)
cdef p_fused_definition(PyrexScanner s, pos, ctx)
cdef p_visibility(PyrexScanner s, prev_visibility) cdef p_visibility(PyrexScanner s, prev_visibility)
cdef p_c_modifiers(PyrexScanner s) cdef p_c_modifiers(PyrexScanner s)
cdef p_c_func_or_var_declaration(PyrexScanner s, pos, ctx) cdef p_c_func_or_var_declaration(PyrexScanner s, pos, ctx)
......
...@@ -2121,10 +2121,12 @@ def looking_at_expr(s): ...@@ -2121,10 +2121,12 @@ def looking_at_expr(s):
name = s.systring name = s.systring
dotted_path = [] dotted_path = []
s.next() s.next()
while s.sy == '.': while s.sy == '.':
s.next() s.next()
dotted_path.append(s.systring) dotted_path.append(s.systring)
s.expect('IDENT') s.expect('IDENT')
saved = s.sy, s.systring saved = s.sy, s.systring
if s.sy == 'IDENT': if s.sy == 'IDENT':
is_type = True is_type = True
...@@ -2140,12 +2142,14 @@ def looking_at_expr(s): ...@@ -2140,12 +2142,14 @@ def looking_at_expr(s):
s.next() s.next()
is_type = s.sy == ']' is_type = s.sy == ']'
s.put_back(*saved) s.put_back(*saved)
dotted_path.reverse() dotted_path.reverse()
for p in dotted_path: for p in dotted_path:
s.put_back('IDENT', p) s.put_back('IDENT', p)
s.put_back('.', '.') s.put_back('.', '.')
s.put_back('IDENT', name) s.put_back('IDENT', name)
return not is_type return not is_type and saved[0]
else: else:
return True return True
...@@ -2163,6 +2167,17 @@ def looking_at_dotted_name(s): ...@@ -2163,6 +2167,17 @@ def looking_at_dotted_name(s):
else: else:
return 0 return 0
def looking_at_call(s):
"See if we're looking at a.b.c("
# Don't mess up the original position, so save and restore it.
# Unfortunately there's no good way to handle this, as a subsequent call
# to next() will not advance the position until it reads a new token.
position = s.start_line, s.start_col
result = looking_at_expr(s) == u'('
if not result:
s.start_line, s.start_col = position
return result
basic_c_type_names = ("void", "char", "int", "float", "double", "bint") basic_c_type_names = ("void", "char", "int", "float", "double", "bint")
special_basic_c_types = { special_basic_c_types = {
...@@ -2179,6 +2194,8 @@ sign_and_longness_words = ("short", "long", "signed", "unsigned") ...@@ -2179,6 +2194,8 @@ sign_and_longness_words = ("short", "long", "signed", "unsigned")
base_type_start_words = \ base_type_start_words = \
basic_c_type_names + sign_and_longness_words + tuple(special_basic_c_types) basic_c_type_names + sign_and_longness_words + tuple(special_basic_c_types)
struct_enum_union = ("struct", "union", "enum", "packed")
def p_sign_and_longness(s): def p_sign_and_longness(s):
signed = 1 signed = 1
longness = 0 longness = 0
...@@ -2485,15 +2502,14 @@ def p_cdef_statement(s, ctx): ...@@ -2485,15 +2502,14 @@ def p_cdef_statement(s, ctx):
if ctx.visibility != 'extern': if ctx.visibility != 'extern':
error(pos, "C++ classes need to be declared extern") error(pos, "C++ classes need to be declared extern")
return p_cpp_class_definition(s, pos, ctx) return p_cpp_class_definition(s, pos, ctx)
elif s.sy == 'IDENT' and s.systring in ("struct", "union", "enum", "packed"): elif s.sy == 'IDENT' and s.systring in struct_enum_union:
if ctx.level not in ('module', 'module_pxd'): if ctx.level not in ('module', 'module_pxd'):
error(pos, "C struct/union/enum definition not allowed here") error(pos, "C struct/union/enum definition not allowed here")
if ctx.overridable: if ctx.overridable:
error(pos, "C struct/union/enum cannot be declared cpdef") error(pos, "C struct/union/enum cannot be declared cpdef")
if s.systring == "enum": return p_struct_enum(s, pos, ctx)
return p_c_enum_definition(s, pos, ctx) elif s.sy == 'IDENT' and s.systring == 'fused':
else: return p_fused_definition(s, pos, ctx)
return p_c_struct_or_union_definition(s, pos, ctx)
else: else:
return p_c_func_or_var_declaration(s, pos, ctx) return p_c_func_or_var_declaration(s, pos, ctx)
...@@ -2610,6 +2626,46 @@ def p_c_struct_or_union_definition(s, pos, ctx): ...@@ -2610,6 +2626,46 @@ def p_c_struct_or_union_definition(s, pos, ctx):
typedef_flag = ctx.typedef_flag, visibility = ctx.visibility, typedef_flag = ctx.typedef_flag, visibility = ctx.visibility,
api = ctx.api, in_pxd = ctx.level == 'module_pxd', packed = packed) api = ctx.api, in_pxd = ctx.level == 'module_pxd', packed = packed)
def p_fused_definition(s, pos, ctx):
"""
c(type)def fused my_fused_type:
...
"""
# s.systring == 'fused'
if ctx.level not in ('module', 'module_pxd'):
error(pos, "Fused type definition not allowed here")
s.next()
name = p_ident(s)
s.expect(":")
s.expect_newline()
s.expect_indent()
types = []
while s.sy != 'DEDENT':
if s.sy != 'pass':
#types.append(p_c_declarator(s))
types.append(p_c_base_type(s)) #, nonempty=1))
else:
s.next()
s.expect_newline()
s.expect_dedent()
if not types:
error(pos, "Need at least one type")
return Nodes.FusedTypeNode(pos, name=name, types=types)
def p_struct_enum(s, pos, ctx):
if s.systring == 'enum':
return p_c_enum_definition(s, pos, ctx)
else:
return p_c_struct_or_union_definition(s, pos, ctx)
def p_visibility(s, prev_visibility): def p_visibility(s, prev_visibility):
pos = s.position() pos = s.position()
visibility = prev_visibility visibility = prev_visibility
...@@ -2680,11 +2736,10 @@ def p_ctypedef_statement(s, ctx): ...@@ -2680,11 +2736,10 @@ def p_ctypedef_statement(s, ctx):
ctx.api = 1 ctx.api = 1
if s.sy == 'class': if s.sy == 'class':
return p_c_class_definition(s, pos, ctx) return p_c_class_definition(s, pos, ctx)
elif s.sy == 'IDENT' and s.systring in ('packed', 'struct', 'union', 'enum'): elif s.sy == 'IDENT' and s.systring in struct_enum_union:
if s.systring == 'enum': return p_struct_enum(s, pos, ctx)
return p_c_enum_definition(s, pos, ctx) elif s.sy == 'IDENT' and s.systring == 'fused':
else: return p_fused_definition(s, pos, ctx)
return p_c_struct_or_union_definition(s, pos, ctx)
else: else:
base_type = p_c_base_type(s, nonempty = 1) base_type = p_c_base_type(s, nonempty = 1)
declarator = p_c_declarator(s, ctx, is_type = 1, nonempty = 1) declarator = p_c_declarator(s, ctx, is_type = 1, nonempty = 1)
......
...@@ -62,17 +62,25 @@ def inject_pxd_code_stage_factory(context): ...@@ -62,17 +62,25 @@ def inject_pxd_code_stage_factory(context):
return module_node return module_node
return inject_pxd_code_stage return inject_pxd_code_stage
def use_utility_code_definitions(scope, target): def use_utility_code_definitions(scope, target, seen=None):
if seen is None:
seen = set()
for entry in scope.entries.itervalues(): for entry in scope.entries.itervalues():
if entry in seen:
continue
seen.add(entry)
if entry.used and entry.utility_code_definition: if entry.used and entry.utility_code_definition:
target.use_utility_code(entry.utility_code_definition) target.use_utility_code(entry.utility_code_definition)
for required_utility in entry.utility_code_definition.requires: for required_utility in entry.utility_code_definition.requires:
target.use_utility_code(required_utility) target.use_utility_code(required_utility)
elif entry.as_module: elif entry.as_module:
use_utility_code_definitions(entry.as_module, target) use_utility_code_definitions(entry.as_module, target, seen)
def inject_utility_code_stage_factory(context): def inject_utility_code_stage_factory(context):
def inject_utility_code_stage(module_node): def inject_utility_code_stage(module_node):
use_utility_code_definitions(context.cython_scope, module_node.scope)
added = [] added = []
# Note: the list might be extended inside the loop (if some utility code # Note: the list might be extended inside the loop (if some utility code
# pulls in other utility code, explicitly or implicitly) # pulls in other utility code, explicitly or implicitly)
......
This diff is collapsed.
...@@ -115,6 +115,9 @@ class EncodedString(_unicode): ...@@ -115,6 +115,9 @@ class EncodedString(_unicode):
# otherwise # otherwise
encoding = None encoding = None
def __deepcopy__(self, memo):
return self
def byteencode(self): def byteencode(self):
assert self.encoding is not None assert self.encoding is not None
return self.encode(self.encoding) return self.encode(self.encoding)
...@@ -131,6 +134,9 @@ class BytesLiteral(_bytes): ...@@ -131,6 +134,9 @@ class BytesLiteral(_bytes):
# bytes subclass that is compatible with EncodedString # bytes subclass that is compatible with EncodedString
encoding = None encoding = None
def __deepcopy__(self, memo):
return self
def byteencode(self): def byteencode(self):
if IS_PYTHON3: if IS_PYTHON3:
return _bytes(self) return _bytes(self)
......
...@@ -180,6 +180,7 @@ class Entry(object): ...@@ -180,6 +180,7 @@ class Entry(object):
buffer_aux = None buffer_aux = None
prev_entry = None prev_entry = None
might_overflow = 0 might_overflow = 0
fused_cfunction = None
utility_code_definition = None utility_code_definition = None
in_with_gil_block = 0 in_with_gil_block = 0
from_cython_utility_code = None from_cython_utility_code = None
...@@ -250,6 +251,7 @@ class Scope(object): ...@@ -250,6 +251,7 @@ class Scope(object):
scope_prefix = "" scope_prefix = ""
in_cinclude = 0 in_cinclude = 0
nogil = 0 nogil = 0
fused_to_specific = None
def __init__(self, name, outer_scope, parent_scope): def __init__(self, name, outer_scope, parent_scope):
# The outer_scope is the next scope in the lookup chain. # The outer_scope is the next scope in the lookup chain.
...@@ -286,6 +288,9 @@ class Scope(object): ...@@ -286,6 +288,9 @@ class Scope(object):
self.return_type = None self.return_type = None
self.id_counters = {} self.id_counters = {}
def __deepcopy__(self, memo):
return self
def merge_in(self, other, merge_unused=True): def merge_in(self, other, merge_unused=True):
# Use with care... # Use with care...
entries = [(name, entry) entries = [(name, entry)
...@@ -415,6 +420,9 @@ class Scope(object): ...@@ -415,6 +420,9 @@ class Scope(object):
entry.api = api entry.api = api
if defining: if defining:
self.type_entries.append(entry) self.type_entries.append(entry)
type.entry = entry
# here we would set as_variable to an object representing this type # here we would set as_variable to an object representing this type
return entry return entry
...@@ -670,6 +678,7 @@ class Scope(object): ...@@ -670,6 +678,7 @@ class Scope(object):
if modifiers: if modifiers:
entry.func_modifiers = modifiers entry.func_modifiers = modifiers
entry.utility_code = utility_code entry.utility_code = utility_code
type.entry = entry
return entry return entry
def add_cfunction(self, name, type, pos, cname, visibility, modifiers): def add_cfunction(self, name, type, pos, cname, visibility, modifiers):
...@@ -726,6 +735,8 @@ class Scope(object): ...@@ -726,6 +735,8 @@ class Scope(object):
def lookup_type(self, name): def lookup_type(self, name):
entry = self.lookup(name) entry = self.lookup(name)
if entry and entry.is_type: if entry and entry.is_type:
if entry.type.is_fused and self.fused_to_specific:
return entry.type.specialize(self.fused_to_specific)
return entry.type return entry.type
def lookup_operator(self, operator, operands): def lookup_operator(self, operator, operands):
...@@ -770,6 +781,7 @@ class Scope(object): ...@@ -770,6 +781,7 @@ class Scope(object):
def add_include_file(self, filename): def add_include_file(self, filename):
self.outer_scope.add_include_file(filename) self.outer_scope.add_include_file(filename)
class PreImportScope(Scope): class PreImportScope(Scope):
namespace_cname = Naming.preimport_cname namespace_cname = Naming.preimport_cname
...@@ -1562,6 +1574,7 @@ class ClosureScope(LocalScope): ...@@ -1562,6 +1574,7 @@ class ClosureScope(LocalScope):
def declare_pyfunction(self, name, pos, allow_redefine=False): def declare_pyfunction(self, name, pos, allow_redefine=False):
return LocalScope.declare_pyfunction(self, name, pos, allow_redefine, visibility='private') return LocalScope.declare_pyfunction(self, name, pos, allow_redefine, visibility='private')
class StructOrUnionScope(Scope): class StructOrUnionScope(Scope):
# Namespace of a C struct or union. # Namespace of a C struct or union.
...@@ -1850,12 +1863,16 @@ class CClassScope(ClassScope): ...@@ -1850,12 +1863,16 @@ class CClassScope(ClassScope):
if defining: if defining:
entry.func_cname = self.mangle(Naming.func_prefix, name) entry.func_cname = self.mangle(Naming.func_prefix, name)
entry.utility_code = utility_code entry.utility_code = utility_code
type.entry = entry
if u'inline' in modifiers: if u'inline' in modifiers:
entry.is_inline_cmethod = True entry.is_inline_cmethod = True
if (self.parent_type.is_final_type or entry.is_inline_cmethod or if (self.parent_type.is_final_type or entry.is_inline_cmethod or
self.directives.get('final')): self.directives.get('final')):
entry.is_final_cmethod = True entry.is_final_cmethod = True
entry.final_func_cname = entry.func_cname entry.final_func_cname = entry.func_cname
return entry return entry
def add_cfunction(self, name, type, pos, cname, visibility, modifiers): def add_cfunction(self, name, type, pos, cname, visibility, modifiers):
...@@ -1898,12 +1915,21 @@ class CClassScope(ClassScope): ...@@ -1898,12 +1915,21 @@ class CClassScope(ClassScope):
# to work with this type. # to work with this type.
def adapt(cname): def adapt(cname):
return "%s.%s" % (Naming.obj_base_cname, base_entry.cname) return "%s.%s" % (Naming.obj_base_cname, base_entry.cname)
for base_entry in \
base_scope.inherited_var_entries + base_scope.var_entries: entries = base_scope.inherited_var_entries + base_scope.var_entries
for base_entry in entries:
entry = self.declare(base_entry.name, adapt(base_entry.cname), entry = self.declare(base_entry.name, adapt(base_entry.cname),
base_entry.type, None, 'private') base_entry.type, None, 'private')
entry.is_variable = 1 entry.is_variable = 1
self.inherited_var_entries.append(entry) self.inherited_var_entries.append(entry)
# If the class defined in a pxd, specific entries have not been added.
# Ensure now that the parent (base) scope has specific entries
# Iterate over a copy as get_all_specific_function_types() will mutate
for base_entry in base_scope.cfunc_entries[:]:
if base_entry.type.is_fused:
base_entry.type.get_all_specific_function_types()
for base_entry in base_scope.cfunc_entries: for base_entry in base_scope.cfunc_entries:
cname = base_entry.cname cname = base_entry.cname
var_entry = base_entry.as_variable var_entry = base_entry.as_variable
...@@ -1993,6 +2019,7 @@ class CppClassScope(Scope): ...@@ -1993,6 +2019,7 @@ class CppClassScope(Scope):
if prev_entry: if prev_entry:
entry.overloaded_alternatives = prev_entry.all_alternatives() entry.overloaded_alternatives = prev_entry.all_alternatives()
entry.utility_code = utility_code entry.utility_code = utility_code
type.entry = entry
return entry return entry
def declare_inherited_cpp_attributes(self, base_scope): def declare_inherited_cpp_attributes(self, base_scope):
......
...@@ -372,7 +372,7 @@ class SimpleAssignmentTypeInferer(object): ...@@ -372,7 +372,7 @@ class SimpleAssignmentTypeInferer(object):
while ready_to_infer: while ready_to_infer:
entry = ready_to_infer.pop() entry = ready_to_infer.pop()
types = [expr.infer_type(scope) for expr in entry.assignments] types = [expr.infer_type(scope) for expr in entry.assignments]
if types: if types and Utils.all(types):
entry.type = spanning_type(types, entry.might_overflow) entry.type = spanning_type(types, entry.might_overflow)
else: else:
# FIXME: raise a warning? # FIXME: raise a warning?
......
...@@ -55,16 +55,6 @@ class Optimization(object): ...@@ -55,16 +55,6 @@ class Optimization(object):
optimization = Optimization() optimization = Optimization()
try:
any
except NameError:
def any(it):
for x in it:
if x:
return True
return False
class build_ext(_build_ext.build_ext): class build_ext(_build_ext.build_ext):
...@@ -128,8 +118,8 @@ class build_ext(_build_ext.build_ext): ...@@ -128,8 +118,8 @@ class build_ext(_build_ext.build_ext):
# If --pyrex-gdb is in effect as a command line option or as option # If --pyrex-gdb is in effect as a command line option or as option
# of any Extension module, disable optimization for the C or C++ # of any Extension module, disable optimization for the C or C++
# compiler. # compiler.
if (self.pyrex_gdb or any([getattr(ext, 'pyrex_gdb', False) if self.pyrex_gdb or [1 for ext in self.extensions
for ext in self.extensions])): if getattr(ext, 'pyrex_gdb', False)]:
optimization.disable_optimization() optimization.disable_optimization()
_build_ext.build_ext.run(self) _build_ext.build_ext.run(self)
......
...@@ -66,7 +66,8 @@ def sizeof(arg): ...@@ -66,7 +66,8 @@ def sizeof(arg):
return 1 return 1
def typeof(arg): def typeof(arg):
return type(arg) return arg.__class__.__name__
# return type(arg)
def address(arg): def address(arg):
return pointer(type(arg))([arg]) return pointer(type(arg))([arg])
...@@ -138,6 +139,9 @@ class PointerType(CythonType): ...@@ -138,6 +139,9 @@ class PointerType(CythonType):
else: else:
return not self._items and not value._items return not self._items and not value._items
def __repr__(self):
return "%s *" % (self._basetype,)
class ArrayType(PointerType): class ArrayType(PointerType):
def __init__(self): def __init__(self):
...@@ -221,22 +225,54 @@ def union(**members): ...@@ -221,22 +225,54 @@ def union(**members):
class typedef(CythonType): class typedef(CythonType):
def __init__(self, type): def __init__(self, type, name=None):
self._basetype = type self._basetype = type
self.name = name
def __call__(self, *arg): def __call__(self, *arg):
value = cast(self._basetype, *arg) value = cast(self._basetype, *arg)
return value return value
def __repr__(self):
return self.name or str(self._basetype)
class _FusedType(CythonType):
pass
def fused_type(*args):
if not args:
raise TypeError("Expected at least one type as argument")
# Find the numeric type with biggest rank if all types are numeric
rank = -1
for type in args:
if type not in (py_int, py_long, py_float, py_complex):
break
if type_ordering.index(type) > rank:
result_type = type
else:
return result_type
# Not a simple numeric type, return a fused type instance. The result
# isn't really meant to be used, as we can't keep track of the context in
# pure-mode. Casting won't do anything in this case.
return _FusedType()
def _specialized_from_args(signatures, args, kwargs):
"Perhaps this should be implemented in a TreeFragment in Cython code"
raise Exception("yet to be implemented")
py_int = int py_int = typedef(int, "int")
try: try:
py_long = long py_long = typedef(long, "long")
except NameError: # Py3 except NameError: # Py3
py_long = int py_long = typedef(int, "long")
py_float = float py_float = typedef(float, "float")
py_complex = complex py_complex = typedef(complex, "double complex")
# Predefined types # Predefined types
...@@ -246,30 +282,43 @@ float_types = ['longdouble', 'double', 'float'] ...@@ -246,30 +282,43 @@ float_types = ['longdouble', 'double', 'float']
complex_types = ['longdoublecomplex', 'doublecomplex', 'floatcomplex', 'complex'] complex_types = ['longdoublecomplex', 'doublecomplex', 'floatcomplex', 'complex']
other_types = ['bint', 'void'] other_types = ['bint', 'void']
to_repr = {
'longlong': 'long long',
'longdouble': 'long double',
'longdoublecomplex': 'long double complex',
'doublecomplex': 'double complex',
'floatcomplex': 'float complex',
}.get
gs = globals() gs = globals()
for name in int_types: for name in int_types:
gs[name] = typedef(py_int) reprname = to_repr(name, name)
gs[name] = typedef(py_int, reprname)
if name != 'Py_UNICODE' and not name.endswith('size_t'): if name != 'Py_UNICODE' and not name.endswith('size_t'):
gs['u'+name] = typedef(py_int) gs['u'+name] = typedef(py_int, "unsigned " + reprname)
gs['s'+name] = typedef(py_int) gs['s'+name] = typedef(py_int, "signed " + reprname)
for name in float_types: for name in float_types:
gs[name] = typedef(py_float) gs[name] = typedef(py_float, to_repr(name, name))
for name in complex_types: for name in complex_types:
gs[name] = typedef(py_complex) gs[name] = typedef(py_complex, to_repr(name, name))
bint = typedef(bool) bint = typedef(bool, "bint")
void = typedef(int) void = typedef(int, "void")
for t in int_types + float_types + complex_types + other_types: for t in int_types + float_types + complex_types + other_types:
for i in range(1, 4): for i in range(1, 4):
gs["%s_%s" % ('p'*i, t)] = globals()[t]._pointer(i) gs["%s_%s" % ('p'*i, t)] = globals()[t]._pointer(i)
void = typedef(None) void = typedef(None, "void")
NULL = p_void(0) NULL = p_void(0)
integral = floating = numeric = _FusedType()
type_ordering = [py_int, py_long, py_float, py_complex]
class CythonDotParallel(object): class CythonDotParallel(object):
""" """
The cython.parallel module. The cython.parallel module.
......
This diff is collapsed.
...@@ -216,3 +216,22 @@ def long_literal(value): ...@@ -216,3 +216,22 @@ def long_literal(value):
if isinstance(value, basestring): if isinstance(value, basestring):
value = str_to_number(value) value = str_to_number(value)
return not -2**31 <= value < 2**31 return not -2**31 <= value < 2**31
# all() and any() are new in 2.5
try:
# Make sure to bind them on the module, as they will be accessed as
# attributes
all = all
any = any
except NameError:
def all(items):
for item in items:
if not item:
return False
return True
def any(items):
for item in items:
if item:
return True
return False
.. highlight:: cython
.. _fusedtypes:
**************************
Fused Types (Templates)
**************************
Fused types can be used to fuse multiple types into a single type, to allow a single
algorithm to operate on values of multiple types. They are somewhat akin to templates
or generics.
.. Note:: Support is experimental and new in this release, there may be bugs!
Declaring Fused Types
=====================
Fused types may be declared as follows::
cimport cython
ctypedef fused my_fused_type:
cython.p_int
cython.p_float
This declares a new type called ``my_fused_type`` which is composed of a ``int *`` and a ``double *``.
Alternatively, the declaration may be written as::
my_fused_type = cython.fused_type(cython.p_int, cython.p_float)
Only names may be used for the constituent types, but they may be any (non-fused) type, including a typedef.
i.e. one may write::
ctypedef double *doublep
my_fused_type = cython.fused_type(cython.p_int, doublep)
Using Fused Types
=================
Fused types can be used to declare parameters of functions or methods::
cdef cfunc(my_fused_type arg1, my_fused_type arg2):
return cython.typeof(arg1) == cython.typeof(arg2)
This declares a function with two parameters. The type of both parameters is either a pointer to an int,
or a pointer to a float (according to the previous examples). So this function always True for every possible
invocation. You are allowed to mix fused types however::
def func(A x, B y):
...
where ``A`` and ``B`` are different fused types. This will result in all combination of types.
Note that specializations of only numeric types may not be very useful, as one can usually rely on
promotion of types. This is not true for arrays, pointers and typed views of memory however.
Indeed, one may write::
def myfunc(A[:, :] x):
...
# and
cdef otherfunc(A *x):
...
Selecting Specializations
=========================
You can select a specialization (an instance of the function with specific or specialized (i.e.,
non-fused) argument types) in two ways: either by indexing or by calling.
Indexing
--------
You can index functions with types to get certain specializations, i.e.::
cfunc[cython.p_double](p1, p2)
# From Cython space
func[float, double](myfloat, mydouble)
# From Python space
func[cython.float, cython.double](myfloat, mydouble)
If a fused type is used as a base type, this will mean that the base type is the fused type, so the
base type is what needs to be specialized::
cdef myfunc(A *x):
...
# Specialize using int, not int *
myfunc[int](myint)
Calling
-------
A fused function can also be called with arguments, where the dispatch is figured out automatically::
cfunc(p1, p2)
func(myfloat, mydouble)
For a ``cdef`` or ``cpdef`` function called from Cython this means that the specialization is figured
out at compile time. For ``def`` functions the arguments are typechecked at runtime, and a best-effort
approach is performed to figure out which specialization is needed. This means that this may result in
a runtime ``TypeError`` if no specialization was found. A ``cpdef`` function is treated the same way as
a ``def`` function if the type of the function is unknown (e.g. if it is external and there is no cimport
for it).
The automatic dispatching rules are typically as follows, in order of preference:
* try to find an exact match
* choose the biggest corresponding numerical type (biggest float, biggest complex, biggest int)
Built-in Fused Types
====================
There are some built-in fused types available for convenience, these are::
cython.integral # int, long
cython.floating # float, double
cython.numeric # long, double, double complex
Casting Fused Functions
=======================
Fused ``cdef`` and ``cpdef`` functions may be cast or assigned to C function pointers as follows::
cdef myfunc(cython.floating, cython.integral):
...
# assign directly
cdef object (*funcp)(float, int)
funcp = myfunc
funcp(f, i)
# alternatively, cast it
(<object (*)(float, int)> myfunc)(f, i)
# This is also valid
funcp = myfunc[float, int]
funcp(f, i)
Type Checking Specializations
=============================
Decisions can be made based on the specializations of the fused parameters. False conditions are pruned
to avoid invalid code. One may check with ``is``, ``is not`` and ``==`` and ``!=`` to see if a fused type
is equal to a certain other non-fused type (to check the specialization), or use ``in`` and ``not in`` to
figure out whether a specialization is part of another set of types (specified as a fused type). In
example::
ctypedef fused bunch_of_types:
...
ctypedef fused string_t:
cython.p_char
bytes
unicode
cdef cython.integral myfunc(cython.integral i, bunch_of_types s):
cdef int *int_pointer
cdef long *long_pointer
# Only one of these branches will be compiled for each specialization!
if cython.integral is int:
int_pointer = &i
else:
long_pointer = &i
if bunch_of_types in string_t:
print "s is a string!"
__signatures__
==============
Finally, function objects from ``def`` or ``cpdef`` functions have an attribute __signatures__, which maps
the signature strings to the actual specialized functions. This may be useful for inspection.
Listed signature strings may also be used as indices to the fused function::
specialized_function = fused_function["MyExtensionClass, int, float"]
It would usually be preferred to index like this, however::
specialized_function = fused_function[MyExtensionClass, int, float]
Although the latter will select the biggest types for ``int`` and ``float`` from Python space, as they are
not type identifiers but builtin types there. Passing ``cython.int`` and ``cython.float`` would resolve that,
however.
...@@ -15,6 +15,7 @@ Contents: ...@@ -15,6 +15,7 @@ Contents:
external_C_code external_C_code
source_files_and_compilation source_files_and_compilation
wrapping_CPlusPlus wrapping_CPlusPlus
fusedtypes
limitations limitations
pyrex_differences pyrex_differences
early_binding_for_speed early_binding_for_speed
......
...@@ -716,6 +716,7 @@ def run_forked_test(result, run_func, test_name, fork=True): ...@@ -716,6 +716,7 @@ def run_forked_test(result, run_func, test_name, fork=True):
gc.collect() gc.collect()
return return
module_name = test_name.split()[-1]
# fork to make sure we do not keep the tested module loaded # fork to make sure we do not keep the tested module loaded
result_handle, result_file = tempfile.mkstemp() result_handle, result_file = tempfile.mkstemp()
os.close(result_handle) os.close(result_handle)
......
...@@ -9,5 +9,5 @@ cdef extern from *: ...@@ -9,5 +9,5 @@ cdef extern from *:
new Foo(1, 2) new Foo(1, 2)
_ERRORS = u""" _ERRORS = u"""
9:7: no suitable method found 9:7: Call with wrong number of arguments (expected 1, got 2)
""" """
# mode: error
cimport cython
def closure(cython.integral i):
def inner(cython.floating f):
pass
def closure2(cython.integral i):
return lambda cython.integral i: i
def closure3(cython.integral i):
def inner():
return lambda cython.floating f: f
def generator(cython.integral i):
yield i
_ERRORS = u"""
e_fused_closure.pyx:6:4: Cannot nest fused functions
e_fused_closure.pyx:10:11: Cannot nest fused functions
e_fused_closure.pyx:14:15: Cannot nest fused functions
e_fused_closure.pyx:16:0: Fused generators not supported
"""
# mode: error
ctypedef char *string_t
ctypedef public char *public_string_t
ctypedef api char *api_string_t
# This should all fail
cdef public pub_func1(string_t x):
pass
cdef api api_func1(string_t x):
pass
cdef public string_t pub_func2():
pass
cdef api string_t api_func2():
pass
cdef public opt_pub_func(x = None):
pass
cdef api opt_api_func(x = None):
pass
# This should all work
cdef public pub_func3(public_string_t x, api_string_t y):
pass
cdef api api_func3(public_string_t x, api_string_t y):
pass
cdef opt_func(x = None):
pass
_ERRORS = u"""
e_public_cdef_private_types.pyx:8:22: Function declared public or api may not have private types
e_public_cdef_private_types.pyx:11:19: Function declared public or api may not have private types
e_public_cdef_private_types.pyx:14:5: Function declared public or api may not have private types
e_public_cdef_private_types.pyx:17:5: Function declared public or api may not have private types
e_public_cdef_private_types.pyx:20:25: Function with optional arguments may not be declared public or api
e_public_cdef_private_types.pyx:23:22: Function with optional arguments may not be declared public or api
"""
# mode: error
cdef fused my_fused_type: int a; char b
_ERRORS = u"""
fused_syntax.pyx:3:26: Expected a newline
"""
# mode: error
cimport cython
ctypedef cython.fused_type(int, float) fused_t
_ERRORS = u"""
fused_syntax_ctypedef.pyx:5:39: Syntax error in ctypedef statement
"""
# mode: error
cimport cython
from cython import fused_type
# This is all invalid
# ctypedef foo(int) dtype1
# ctypedef foo.bar(float) dtype2
# ctypedef fused_type(foo) dtype3
dtype4 = cython.fused_type(int, long, kw=None)
# ctypedef public cython.fused_type(int, long) dtype7
# ctypedef api cython.fused_type(int, long) dtype8
int_t = cython.fused_type(short, short, int)
int2_t = cython.fused_type(int, long)
dtype9 = cython.fused_type(int2_t, int)
floating = cython.fused_type(float, double)
cdef func(floating x, int2_t y):
print x, y
cdef float x = 10.0
cdef int y = 10
func[float](x, y)
func[float][int](x, y)
func[float, int](x)
func[float, int](x, y, y)
func(x, y=y)
ctypedef fused memslice_dtype_t:
cython.p_int # invalid dtype
cython.long
def f(memslice_dtype_t[:, :] a):
pass
# This is all valid
dtype5 = fused_type(int, long, float)
dtype6 = cython.fused_type(int, long)
func[float, int](x, y)
cdef fused fused1:
int
long long
ctypedef fused fused2:
int
long long
func(x, y)
_ERRORS = u"""
fused_types.pyx:10:15: fused_type does not take keyword arguments
fused_types.pyx:15:38: Type specified multiple times
fused_types.pyx:17:33: Cannot fuse a fused type
fused_types.pyx:26:4: Not enough types specified to specialize the function, int2_t is still fused
fused_types.pyx:27:4: Not enough types specified to specialize the function, int2_t is still fused
fused_types.pyx:28:16: Call with wrong number of arguments (expected 2, got 1)
fused_types.pyx:29:16: Call with wrong number of arguments (expected 2, got 3)
fused_types.pyx:30:4: Keyword and starred arguments not allowed in cdef functions.
fused_types.pyx:36:6: Invalid base type for memoryview slice: int *
"""
...@@ -59,10 +59,10 @@ _ERRORS = u''' ...@@ -59,10 +59,10 @@ _ERRORS = u'''
20:22: Invalid axis specification. 20:22: Invalid axis specification.
21:25: Invalid axis specification. 21:25: Invalid axis specification.
22:22: no expressions allowed in axis spec, only names and literals. 22:22: no expressions allowed in axis spec, only names and literals.
25:51: Memoryview 'object[::contiguous, :]' not conformable to memoryview 'object[:, ::contiguous]'. 25:51: Memoryview 'object[::1, :]' not conformable to memoryview 'object[:, ::1]'.
28:36: Different base types for memoryviews (int, Python object) 28:36: Different base types for memoryviews (int, Python object)
31:9: Dimension may not be contiguous 31:9: Dimension may not be contiguous
37:9: Only one direct contiguous axis may be specified. 37:9: Only one direct contiguous axis may be specified.
38:9:Only dimensions 3 and 2 may be contiguous and direct 38:9:Only dimensions 3 and 2 may be contiguous and direct
44:10: Invalid base type for memoryview slice 44:10: Invalid base type for memoryview slice: intp
''' '''
cimport cython
cimport check_fused_types_pxd
import math
ctypedef char *string_t
fused_t = cython.fused_type(int, long, float, string_t)
other_t = cython.fused_type(int, long)
base_t = cython.fused_type(short, int)
# complex_t = cython.fused_type(cython.floatcomplex, cython.doublecomplex)
cdef fused complex_t:
float complex
double complex
ctypedef base_t **base_t_p_p
# ctypedef cython.fused_type(char, base_t_p_p, fused_t, complex_t) composed_t
cdef fused composed_t:
char
int
float
string_t
cython.pp_int
float complex
double complex
int complex
long complex
cdef func(fused_t a, other_t b):
cdef int int_a
cdef string_t string_a
cdef other_t other_a
if fused_t is other_t:
print 'fused_t is other_t'
other_a = a
if fused_t is int:
print 'fused_t is int'
int_a = a
if fused_t is string_t:
print 'fused_t is string_t'
string_a = a
if fused_t in check_fused_types_pxd.unresolved_t:
print 'fused_t in unresolved_t'
if int in check_fused_types_pxd.unresolved_t:
print 'int in unresolved_t'
if string_t in check_fused_types_pxd.unresolved_t:
print 'string_t in unresolved_t'
def test_int_int():
"""
>>> test_int_int()
fused_t is other_t
fused_t is int
fused_t in unresolved_t
int in unresolved_t
"""
cdef int x = 1
cdef int y = 2
func(x, y)
def test_int_long():
"""
>>> test_int_long()
fused_t is int
fused_t in unresolved_t
int in unresolved_t
"""
cdef int x = 1
cdef long y = 2
func(x, y)
def test_float_int():
"""
>>> test_float_int()
fused_t in unresolved_t
int in unresolved_t
"""
cdef float x = 1
cdef int y = 2
func(x, y)
def test_string_int():
"""
>>> test_string_int()
fused_t is string_t
int in unresolved_t
"""
cdef string_t x = b"spam"
cdef int y = 2
func(x, y)
cdef if_then_else(fused_t a, other_t b):
cdef other_t other_a
cdef string_t string_a
cdef fused_t specific_a
if fused_t is other_t:
print 'fused_t is other_t'
other_a = a
elif fused_t is string_t:
print 'fused_t is string_t'
string_a = a
else:
print 'none of the above'
specific_a = a
def test_if_then_else_long_long():
"""
>>> test_if_then_else_long_long()
fused_t is other_t
"""
cdef long x = 0, y = 0
if_then_else(x, y)
def test_if_then_else_string_int():
"""
>>> test_if_then_else_string_int()
fused_t is string_t
"""
cdef string_t x = b"spam"
cdef int y = 0
if_then_else(x, y)
def test_if_then_else_float_int():
"""
>>> test_if_then_else_float_int()
none of the above
"""
cdef float x = 0.0
cdef int y = 1
if_then_else(x, y)
cdef composed_t composed(composed_t x, composed_t y):
if composed_t in base_t_p_p or composed_t is string_t:
if string_t == composed_t:
print x.decode('ascii'), y.decode('ascii')
else:
print x[0][0], y[0][0]
return x
elif composed_t == string_t:
print 'this is never executed'
elif list():
print 'neither is this one'
else:
if composed_t not in complex_t:
print 'not a complex number'
print <int> x, <int> y
else:
print 'it is a complex number'
print x.real, x.imag
return x + y
def test_composed_types():
"""
>>> test_composed_types()
it is a complex number
0.5 0.6
9 4
<BLANKLINE>
not a complex number
7 8
15
<BLANKLINE>
7 8
<BLANKLINE>
spam eggs
spam
"""
cdef double complex a = 0.5 + 0.6j, b = 0.4 -0.2j, result
cdef int c = 7, d = 8
cdef int *cp = &c, *dp = &d
cdef string_t e = "spam", f = "eggs"
result = composed(a, b)
print int(math.ceil(result.real * 10)), int(math.ceil(result.imag * 10))
print
print composed(c, d)
print
composed(&cp, &dp)
print
print composed(e, f).decode('ascii')
cimport cython
unresolved_t = cython.fused_type(int, float)
cimport cython
cy = __import__("cython")
cpdef func1(self, cython.integral x):
print "%s," % (self,),
if cython.integral is int:
print 'x is int', x, cython.typeof(x)
else:
print 'x is long', x, cython.typeof(x)
class A(object):
meth = func1
def __str__(self):
return "A"
pyfunc = func1
def test_fused_cpdef():
"""
>>> test_fused_cpdef()
None, x is int 2 int
None, x is long 2 long
None, x is long 2 long
<BLANKLINE>
None, x is int 2 int
None, x is long 2 long
<BLANKLINE>
A, x is int 2 int
A, x is long 2 long
A, x is long 2 long
A, x is long 2 long
"""
func1[int](None, 2)
func1[long](None, 2)
func1(None, 2)
print
pyfunc[cy.int](None, 2)
pyfunc(None, 2)
print
A.meth[cy.int](A(), 2)
A.meth(A(), 2)
A().meth[cy.long](2)
A().meth(2)
def assert_raise(func, *args):
try:
func(*args)
except TypeError:
pass
else:
assert False, "Function call did not raise TypeError"
def test_badcall():
"""
>>> test_badcall()
"""
assert_raise(pyfunc)
assert_raise(pyfunc, 1, 2, 3)
assert_raise(pyfunc[cy.int], 10, 11, 12)
assert_raise(pyfunc, None, object())
assert_raise(A().meth)
assert_raise(A.meth)
assert_raise(A().meth[cy.int])
assert_raise(A.meth[cy.int])
ctypedef long double long_double
cpdef multiarg(cython.integral x, cython.floating y):
if cython.integral is int:
print "x is an int,",
else:
print "x is a long,",
if cython.floating is long_double:
print "y is a long double:",
elif float is cython.floating:
print "y is a float:",
else:
print "y is a double:",
print x, y
def test_multiarg():
"""
>>> test_multiarg()
x is an int, y is a float: 1 2.0
x is an int, y is a float: 1 2.0
x is a long, y is a double: 4 5.0
"""
multiarg[int, float](1, 2.0)
multiarg[cy.int, cy.float](1, 2.0)
multiarg(4, 5.0)
# mode: run
"""
Test Python def functions without extern types
"""
cy = __import__("cython")
cimport cython
cdef class Base(object):
def __repr__(self):
return type(self).__name__
cdef class ExtClassA(Base):
pass
cdef class ExtClassB(Base):
pass
cdef enum MyEnum:
entry0
entry1
entry2
entry3
entry4
ctypedef fused fused_t:
str
int
long
complex
ExtClassA
ExtClassB
MyEnum
f = 5.6
i = 9
def opt_func(fused_t obj, cython.floating myf = 1.2, cython.integral myi = 7):
"""
Test runtime dispatch, indexing of various kinds and optional arguments
>>> opt_func("spam", f, i)
str object double long
spam 5.60 9 5.60 9
>>> opt_func[str, float, int]("spam", f, i)
str object float int
spam 5.60 9 5.60 9
>>> opt_func["str, double, long"]("spam", f, i)
str object double long
spam 5.60 9 5.60 9
>>> opt_func[str, float, cy.int]("spam", f, i)
str object float int
spam 5.60 9 5.60 9
>>> opt_func(ExtClassA(), f, i)
ExtClassA double long
ExtClassA 5.60 9 5.60 9
>>> opt_func[ExtClassA, float, int](ExtClassA(), f, i)
ExtClassA float int
ExtClassA 5.60 9 5.60 9
>>> opt_func["ExtClassA, double, long"](ExtClassA(), f, i)
ExtClassA double long
ExtClassA 5.60 9 5.60 9
>>> opt_func(ExtClassB(), f, i)
ExtClassB double long
ExtClassB 5.60 9 5.60 9
>>> opt_func[ExtClassB, cy.double, cy.long](ExtClassB(), f, i)
ExtClassB double long
ExtClassB 5.60 9 5.60 9
>>> opt_func(10, f)
long double long
10 5.60 7 5.60 9
>>> opt_func[int, float, int](10, f)
int float int
10 5.60 7 5.60 9
>>> opt_func(10 + 2j, myf = 2.6)
double complex double long
(10+2j) 2.60 7 5.60 9
>>> opt_func[cy.py_complex, float, int](10 + 2j, myf = 2.6)
double complex float int
(10+2j) 2.60 7 5.60 9
>>> opt_func[cy.doublecomplex, cy.float, cy.int](10 + 2j, myf = 2.6)
double complex float int
(10+2j) 2.60 7 5.60 9
>>> opt_func(object(), f)
Traceback (most recent call last):
...
TypeError: Function call with ambiguous argument types
>>> opt_func[ExtClassA, cy.float, cy.long](object(), f)
Traceback (most recent call last):
...
TypeError: Argument 'obj' has incorrect type (expected fused_def.ExtClassA, got object)
"""
print cython.typeof(obj), cython.typeof(myf), cython.typeof(myi)
print obj, "%.2f" % myf, myi, "%.2f" % f, i
def test_opt_func():
"""
>>> test_opt_func()
str object double long
ham 5.60 4 5.60 9
"""
opt_func("ham", f, entry4)
def args_kwargs(fused_t obj, cython.floating myf = 1.2, *args, **kwargs):
"""
>>> args_kwargs("foo")
str object double
foo 1.20 5.60 () {}
>>> args_kwargs("eggs", f, 1, 2, [], d={})
str object double
eggs 5.60 5.60 (1, 2, []) {'d': {}}
>>> args_kwargs[str, float]("eggs", f, 1, 2, [], d={})
str object float
eggs 5.60 5.60 (1, 2, []) {'d': {}}
"""
print cython.typeof(obj), cython.typeof(myf)
print obj, "%.2f" % myf, "%.2f" % f, args, kwargs
# mode: run
cimport cython
from cython cimport integral
from cpython cimport Py_INCREF
from Cython import Shadow as pure_cython
ctypedef char * string_t
# floating = cython.fused_type(float, double) floating
# integral = cython.fused_type(int, long) integral
ctypedef cython.floating floating
fused_type1 = cython.fused_type(int, long, float, double, string_t)
fused_type2 = cython.fused_type(string_t)
ctypedef fused_type1 *composed_t
other_t = cython.fused_type(int, double)
ctypedef double *p_double
ctypedef int *p_int
def test_pure():
"""
>>> test_pure()
10
"""
mytype = pure_cython.typedef(pure_cython.fused_type(int, long, complex))
print mytype(10)
cdef cdef_func_with_fused_args(fused_type1 x, fused_type1 y, fused_type2 z):
if fused_type1 is string_t:
print x.decode('ascii'), y.decode('ascii'), z.decode('ascii')
else:
print x, y, z.decode('ascii')
return x + y
def test_cdef_func_with_fused_args():
"""
>>> test_cdef_func_with_fused_args()
spam ham eggs
spamham
10 20 butter
30
4.2 8.6 bunny
12.8
"""
print cdef_func_with_fused_args('spam', 'ham', 'eggs').decode('ascii')
print cdef_func_with_fused_args(10, 20, 'butter')
print cdef_func_with_fused_args(4.2, 8.6, 'bunny')
cdef fused_type1 fused_with_pointer(fused_type1 *array):
for i in range(5):
if fused_type1 is string_t:
print array[i].decode('ascii')
else:
print array[i]
obj = array[0] + array[1] + array[2] + array[3] + array[4]
# if cython.typeof(fused_type1) is string_t:
Py_INCREF(obj)
return obj
def test_fused_with_pointer():
"""
>>> test_fused_with_pointer()
0
1
2
3
4
10
<BLANKLINE>
0
1
2
3
4
10
<BLANKLINE>
0.0
1.0
2.0
3.0
4.0
10.0
<BLANKLINE>
humpty
dumpty
fall
splatch
breakfast
humptydumptyfallsplatchbreakfast
"""
cdef int int_array[5]
cdef long long_array[5]
cdef float float_array[5]
cdef string_t string_array[5]
cdef char *s
strings = [b"humpty", b"dumpty", b"fall", b"splatch", b"breakfast"]
for i in range(5):
int_array[i] = i
long_array[i] = i
float_array[i] = i
s = strings[i]
string_array[i] = s
print fused_with_pointer(int_array)
print
print fused_with_pointer(long_array)
print
print fused_with_pointer(float_array)
print
print fused_with_pointer(string_array).decode('ascii')
include "cythonarrayutil.pxi"
cpdef cython.integral test_fused_memoryviews(cython.integral[:, ::1] a):
"""
>>> import cython
>>> a = create_array((3, 5), mode="c")
>>> test_fused_memoryviews[cython.int](a)
7
"""
return a[1, 2]
ctypedef int[:, ::1] memview_int
ctypedef long[:, ::1] memview_long
memview_t = cython.fused_type(memview_int, memview_long)
def test_fused_memoryview_def(memview_t a):
"""
>>> a = create_array((3, 5), mode="c")
>>> test_fused_memoryview_def["memview_int"](a)
7
"""
return a[1, 2]
cdef test_specialize(fused_type1 x, fused_type1 *y, composed_t z, other_t *a):
cdef fused_type1 result
if composed_t is p_double:
print "double pointer"
if fused_type1 in floating:
result = x + y[0] + z[0] + a[0]
return result
def test_specializations():
"""
>>> test_specializations()
double pointer
double pointer
double pointer
double pointer
double pointer
"""
cdef object (*f)(double, double *, double *, int *)
cdef double somedouble = 2.2
cdef double otherdouble = 3.3
cdef int someint = 4
cdef p_double somedouble_p = &somedouble
cdef p_double otherdouble_p = &otherdouble
cdef p_int someint_p = &someint
f = test_specialize
assert f(1.1, somedouble_p, otherdouble_p, someint_p) == 10.6
f = <object (*)(double, double *, double *, int *)> test_specialize
assert f(1.1, somedouble_p, otherdouble_p, someint_p) == 10.6
assert (<object (*)(double, double *, double *, int *)>
test_specialize)(1.1, somedouble_p, otherdouble_p, someint_p) == 10.6
f = test_specialize[double, int]
assert f(1.1, somedouble_p, otherdouble_p, someint_p) == 10.6
assert test_specialize[double, int](1.1, somedouble_p, otherdouble_p, someint_p) == 10.6
# The following cases are not supported
# f = test_specialize[double][p_int]
# print f(1.1, somedouble_p, otherdouble_p)
# print
# print test_specialize[double][p_int](1.1, somedouble_p, otherdouble_p)
# print
# print test_specialize[double](1.1, somedouble_p, otherdouble_p)
# print
cdef opt_args(integral x, floating y = 4.0):
print x, y
def test_opt_args():
"""
>>> test_opt_args()
3 4.0
3 4.0
3 4.0
3 4.0
"""
opt_args[int, float](3)
opt_args[int, double](3)
opt_args[int, float](3, 4.0)
opt_args[int, double](3, 4.0)
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment