Commit e7426088 authored by Mark Florisson's avatar Mark Florisson

Support sharing of cdef functions with fused types

parent ae59c5a0
...@@ -2999,9 +2999,9 @@ class SimpleCallNode(CallNode): ...@@ -2999,9 +2999,9 @@ class SimpleCallNode(CallNode):
overloaded_entry = None overloaded_entry = None
if overloaded_entry: if overloaded_entry:
if overloaded_entry.fused_cfunction: if self.function.type.is_fused:
specific_cdef_funcs = overloaded_entry.fused_cfunction.nodes alternatives = []
alternatives = [n.entry for n in specific_cdef_funcs] self.function.type.map_with_specific_entries(alternatives.append)
else: else:
alternatives = overloaded_entry.all_alternatives() alternatives = overloaded_entry.all_alternatives()
......
...@@ -156,19 +156,14 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -156,19 +156,14 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
f.close() f.close()
def generate_public_declaration(self, entry, h_code, i_code): def generate_public_declaration(self, entry, h_code, i_code):
if entry.fused_cfunction: entry.type.map_with_specific_entries(self._generate_public_declaration,
for cfunction in entry.fused_cfunction.nodes: h_code, i_code)
self._generate_public_declaration(cfunction.entry,
cfunction.entry.cname, h_code, i_code)
else:
self._generate_public_declaration(entry, entry.cname,
h_code, i_code)
def _generate_public_declaration(self, entry, cname, h_code, i_code): def _generate_public_declaration(self, entry, h_code, i_code):
h_code.putln("%s %s;" % ( h_code.putln("%s %s;" % (
Naming.extern_c_macro, Naming.extern_c_macro,
entry.type.declaration_code( entry.type.declaration_code(
cname, dll_linkage = "DL_IMPORT"))) entry.cname, dll_linkage = "DL_IMPORT")))
if i_code: if i_code:
i_code.putln("cdef extern %s" % i_code.putln("cdef extern %s" %
entry.type.declaration_code(cname, pyrex = 1)) entry.type.declaration_code(cname, pyrex = 1))
...@@ -996,15 +991,10 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -996,15 +991,10 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
def generate_cfunction_predeclarations(self, env, code, definition): def generate_cfunction_predeclarations(self, env, code, definition):
for entry in env.cfunc_entries: for entry in env.cfunc_entries:
if entry.fused_cfunction: entry.type.map_with_specific_entries(
for node in entry.fused_cfunction.nodes: self._generate_cfunction_predeclaration, code, definition)
self._generate_cfunction_predeclaration(
code, definition, node.entry)
else:
self._generate_cfunction_predeclaration(code, definition, entry)
def _generate_cfunction_predeclaration(self, code, definition, entry): def _generate_cfunction_predeclaration(self, entry, code, definition):
if entry.inline_func_in_pxd or (not entry.in_cinclude and (definition if entry.inline_func_in_pxd or (not entry.in_cinclude and (definition
or entry.defined_in_pxd or entry.visibility == 'extern')): or entry.defined_in_pxd or entry.visibility == 'extern')):
if entry.visibility == 'public': if entry.visibility == 'public':
...@@ -2056,15 +2046,27 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -2056,15 +2046,27 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
def generate_c_function_export_code(self, env, code): def generate_c_function_export_code(self, env, code):
# Generate code to create PyCFunction wrappers for exported C functions. # Generate code to create PyCFunction wrappers for exported C functions.
func = self._generate_c_function_export_code
for entry in env.cfunc_entries: for entry in env.cfunc_entries:
from_fused = entry.type.is_fused
if entry.api or entry.defined_in_pxd: if entry.api or entry.defined_in_pxd:
env.use_utility_code(function_export_utility_code) entry.type.map_with_specific_entries(func, env, code, from_fused)
signature = entry.type.signature_string()
code.putln('if (__Pyx_ExportFunction("%s", (void (*)(void))%s, "%s") < 0) %s' % ( def _generate_c_function_export_code(self, entry, env, code, from_fused):
entry.name, env.use_utility_code(function_export_utility_code)
entry.cname, signature = entry.type.signature_string()
signature, s = 'if (__Pyx_ExportFunction("%s", (void (*)(void))%s, "%s") < 0) %s'
code.error_goto(self.pos)))
if from_fused:
# Specific version of a fused function. Fused functions can never
# be declared public or api, but they may need to be exported when
# declared in a .pxd. We need to give them a unique name in that
# case
name = entry.cname
else:
name = entry.name
code.putln(s % (name, entry.cname, signature, code.error_goto(self.pos)))
def generate_type_import_code_for_module(self, module, env, code): def generate_type_import_code_for_module(self, module, env, code):
# Generate type import code for all exported extension types in # Generate type import code for all exported extension types in
...@@ -2090,16 +2092,29 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -2090,16 +2092,29 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
module.qualified_name, module.qualified_name,
temp, temp,
code.error_goto(self.pos))) code.error_goto(self.pos)))
for entry in entries: for entry in entries:
code.putln( entry.type.map_with_specific_entries(self._import_cdef_func,
'if (__Pyx_ImportFunction(%s, "%s", (void (**)(void))&%s, "%s") < 0) %s' % ( code,
temp, temp,
entry.name, entry.type.is_fused)
entry.cname,
entry.type.signature_string(),
code.error_goto(self.pos)))
code.putln("Py_DECREF(%s); %s = 0;" % (temp, temp)) code.putln("Py_DECREF(%s); %s = 0;" % (temp, temp))
def _import_cdef_func(self, entry, code, temp, from_fused):
if from_fused:
name = entry.cname
else:
name = entry.name
code.putln(
'if (__Pyx_ImportFunction(%s, "%s", (void (**)(void))&%s, "%s") < 0) %s' % (
temp,
name,
entry.cname,
entry.type.signature_string(),
code.error_goto(self.pos)))
def generate_type_init_code(self, env, code): def generate_type_init_code(self, env, code):
# Generate type import code for extern extension types # Generate type import code for extern extension types
# and type ready code for non-extern ones. # and type ready code for non-extern ones.
......
...@@ -22,6 +22,7 @@ import TypeSlots ...@@ -22,6 +22,7 @@ import TypeSlots
from PyrexTypes import py_object_type, error_type, CFuncType from PyrexTypes import py_object_type, error_type, CFuncType
from Symtab import ModuleScope, LocalScope, ClosureScope, \ from Symtab import ModuleScope, LocalScope, ClosureScope, \
StructOrUnionScope, PyClassScope, CClassScope, CppClassScope StructOrUnionScope, PyClassScope, CClassScope, CppClassScope
from Cython.Compiler import Symtab
from Cython.Utils import open_new_file, replace_suffix from Cython.Utils import open_new_file, replace_suffix
from Code import UtilityCode, ClosureTempAllocator from Code import UtilityCode, ClosureTempAllocator
from StringEncoding import EncodedString, escape_byte_string, split_string_literal from StringEncoding import EncodedString, escape_byte_string, split_string_literal
...@@ -940,7 +941,14 @@ class FusedTypeNode(CBaseTypeNode): ...@@ -940,7 +941,14 @@ class FusedTypeNode(CBaseTypeNode):
if len(self.types) == 1: if len(self.types) == 1:
return self.types[0] return self.types[0]
return PyrexTypes.FusedType(self.types) types = []
for type in self.types:
if type.is_fused:
types.extend(type.types)
else:
types.append(type)
return PyrexTypes.FusedType(types)
class CVarDefNode(StatNode): class CVarDefNode(StatNode):
...@@ -2001,7 +2009,7 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2001,7 +2009,7 @@ class FusedCFuncDefNode(StatListNode):
self.nodes = self.stats = [] self.nodes = self.stats = []
self.node = node self.node = node
self.copy_cdefs(node.type.get_fused_types(), env) self.copy_cdefs(env)
# Perform some sanity checks. If anything fails, it's a bug # Perform some sanity checks. If anything fails, it's a bug
for n in self.nodes: for n in self.nodes:
...@@ -2014,12 +2022,12 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2014,12 +2022,12 @@ class FusedCFuncDefNode(StatListNode):
node.entry.fused_cfunction = self node.entry.fused_cfunction = self
def copy_cdefs(self, fused_types, env): def copy_cdefs(self, env):
""" """
Gives a list of fused types and the parent environment, make copies Gives a list of fused types and the parent environment, make copies
of the original cdef function. of the original cdef function.
""" """
permutations = self.get_all_specific_permutations(fused_types) permutations = self.node.type.get_all_specific_permutations()
for cname, fused_to_specific in permutations: for cname, fused_to_specific in permutations:
copied_node = copy.deepcopy(self.node) copied_node = copy.deepcopy(self.node)
...@@ -2027,6 +2035,7 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2027,6 +2035,7 @@ class FusedCFuncDefNode(StatListNode):
newtype = copied_node.type.specialize(fused_to_specific) newtype = copied_node.type.specialize(fused_to_specific)
copied_node.type = newtype copied_node.type = newtype
copied_node.entry.type = newtype copied_node.entry.type = newtype
newtype.entry = copied_node.entry
copied_node.return_type = newtype.return_type copied_node.return_type = newtype.return_type
copied_node.create_local_scope(env) copied_node.create_local_scope(env)
...@@ -2041,33 +2050,9 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2041,33 +2050,9 @@ class FusedCFuncDefNode(StatListNode):
for arg in copied_node.cfunc_declarator.args: for arg in copied_node.cfunc_declarator.args:
arg.type = arg.type.specialize(fused_to_specific) arg.type = arg.type.specialize(fused_to_specific)
cname = '%s%s%s' % (Naming.fused_func_prefix, cname = self.node.type.get_specific_cname(cname)
cname,
self.node.entry.func_cname)
copied_node.entry.func_cname = copied_node.entry.cname = cname copied_node.entry.func_cname = copied_node.entry.cname = cname
def get_all_specific_permutations(self, fused_types):
"""
Permute all the types. For every specific instance of a fused type, we
want all other specific instances of all other fused types.
It returns an iterable of two-tuples of the cname that should prefix
the cname of the function, and a dict mapping any fused types to their
respective specific types.
"""
fused_type = fused_types[0]
for specific_type in fused_type.types:
cname = str(specific_type)
result_fused_to_specific = { fused_type: specific_type }
if len(fused_types) > 1:
it = self.get_all_specific_permutations(fused_types[1:])
for next_cname, fused_to_specific in it:
d = dict(fused_to_specific, **result_fused_to_specific)
yield '%s_%s' % (cname, next_cname), d
else:
yield cname, result_fused_to_specific
class PyArgDeclNode(Node): class PyArgDeclNode(Node):
# Argument which must be a Python object (used # Argument which must be a Python object (used
......
...@@ -664,7 +664,7 @@ class CType(PyrexType): ...@@ -664,7 +664,7 @@ class CType(PyrexType):
return 0 return 0
class FusedType(CType): class FusedType(PyrexType):
""" """
Represents a Fused Type. All it needs to do is keep track of the types Represents a Fused Type. All it needs to do is keep track of the types
it aggregates, as it will be replaced with its specific version wherever it aggregates, as it will be replaced with its specific version wherever
...@@ -2005,6 +2005,87 @@ class CFuncType(CType): ...@@ -2005,6 +2005,87 @@ class CFuncType(CType):
def opt_arg_cname(self, arg_name): def opt_arg_cname(self, arg_name):
return self.op_arg_struct.base_type.scope.lookup(arg_name).cname return self.op_arg_struct.base_type.scope.lookup(arg_name).cname
# Methods that deal with Fused Types
# All but map_with_specific_entries should be called only on functions
# with fused types (and not on their corresponding specific versions).
def get_all_specific_permutations(self, fused_types=None):
"""
Permute all the types. For every specific instance of a fused type, we
want all other specific instances of all other fused types.
It returns an iterable of two-tuples of the cname that should prefix
the cname of the function, and a dict mapping any fused types to their
respective specific types.
"""
assert self.is_fused
if fused_types is None:
fused_types = self.get_fused_types()
fused_type = fused_types[0]
for specific_type in fused_type.types:
cname = str(specific_type)
result_fused_to_specific = { fused_type: specific_type }
if len(fused_types) > 1:
it = self.get_all_specific_permutations(fused_types[1:])
for next_cname, fused_to_specific in it:
d = dict(fused_to_specific, **result_fused_to_specific)
yield '%s_%s' % (cname, next_cname), d
else:
yield cname, result_fused_to_specific
def get_all_specific_function_types(self):
"""
Get all the specific function types of this one.
"""
assert self.is_fused
permutations = self.get_all_specific_permutations()
for cname, fused_to_specific in permutations:
new_func_type = self.entry.type.specialize(fused_to_specific)
new_entry = copy.deepcopy(self.entry)
new_entry.cname = self.get_specific_cname(cname)
new_entry.type = new_func_type
new_func_type.entry = new_entry
yield new_func_type
def get_specific_cname(self, fused_cname):
"""
Given the cname for a permutation of fused types, return the cname
for the corresponding function with specific types.
The fused_cname is usually '_'.join(str(t) for t in specific_types)
"""
assert self.is_fused
return '%s%s%s' % (Naming.fused_func_prefix,
fused_cname,
self.entry.func_cname)
def map_with_specific_entries(self, func, *args, **kwargs):
"""
Call func for every specific function instance. If this is not a
signature with fused types, call it with the entry for this cdef
function.
"""
entry = self.entry
if entry.fused_cfunction:
# cdef with fused types defined in this file
for cfunction in entry.fused_cfunction.nodes:
func(cfunction.entry, *args, **kwargs)
elif entry.type.is_fused:
# cdef with fused types defined in another file, create their
# signatures
for func_type in self.get_all_specific_function_types():
func(func_type.entry, *args, **kwargs)
else:
# a normal cdef
return func(entry, *args, **kwargs)
class CFuncTypeArg(BaseType): class CFuncTypeArg(BaseType):
# name string # name string
...@@ -2641,7 +2722,7 @@ def best_match(args, functions, pos=None, env=None): ...@@ -2641,7 +2722,7 @@ def best_match(args, functions, pos=None, env=None):
bad_types = [] bad_types = []
needed_coercions = {} needed_coercions = {}
for func, func_type in candidates: for func, func_type in candidates:
score = [0,0,0] score = [0,0,0,0]
for i in range(min(len(args), len(func_type.args))): for i in range(min(len(args), len(func_type.args))):
src_type = args[i].type src_type = args[i].type
dst_type = func_type.args[i].type dst_type = func_type.args[i].type
...@@ -2669,6 +2750,9 @@ def best_match(args, functions, pos=None, env=None): ...@@ -2669,6 +2750,9 @@ def best_match(args, functions, pos=None, env=None):
if src_type == dst_type or dst_type.same_as(src_type): if src_type == dst_type or dst_type.same_as(src_type):
pass # score 0 pass # score 0
elif is_promotion(src_type, dst_type): elif is_promotion(src_type, dst_type):
score[3] += 1
elif ((src_type.is_int and dst_type.is_int) or
(src_type.is_float and dst_type.is_float)):
score[2] += 1 score[2] += 1
elif not src_type.is_pyobject: elif not src_type.is_pyobject:
score[1] += 1 score[1] += 1
......
...@@ -629,6 +629,7 @@ class Scope(object): ...@@ -629,6 +629,7 @@ class Scope(object):
if modifiers: if modifiers:
entry.func_modifiers = modifiers entry.func_modifiers = modifiers
entry.utility_code = utility_code entry.utility_code = utility_code
type.entry = entry
return entry return entry
def add_cfunction(self, name, type, pos, cname, visibility, modifiers): def add_cfunction(self, name, type, pos, cname, visibility, modifiers):
......
PYTHON setup.py build_ext --inplace
PYTHON -c "import b"
######## setup.py ########
from Cython.Build import cythonize
from distutils.core import setup
setup(
ext_modules = cythonize("*.pyx"),
)
######## a.pxd ########
cimport cython
cdef extern from "header.h":
ctypedef int extern_int
ctypedef long extern_long
cdef struct mystruct_t:
extern_int a
ctypedef union myunion_t:
extern_long a
cdef public class MyExt [ type MyExtType, object MyExtObject ]:
cdef unsigned char a
ctypedef char *string_t
ctypedef cython.fused_type(int, float) simple_t
ctypedef cython.fused_type(simple_t, string_t) less_simple_t
ctypedef cython.fused_type(mystruct_t, myunion_t, MyExt) object_t
ctypedef cython.fused_type(str, unicode, bytes) builtin_t
cdef object_t add_simple(object_t obj, simple_t simple)
cdef less_simple_t add_to_simple(object_t obj, less_simple_t simple)
######## header.h ########
typedef int extern_int;
typedef long extern_long;
######## a.pyx ########
cimport cython
cdef object_t add_simple(object_t obj, simple_t simple):
obj.a = <int> (obj.a + simple)
return obj
cdef less_simple_t add_to_simple(object_t obj, less_simple_t simple):
return obj.a + simple
######## b.pyx ########
from a cimport *
cdef mystruct_t mystruct
cdef myunion_t myunion
cdef MyExt myext = MyExt()
mystruct.a = 5
myunion.a = 5
myext.a = 5
assert add_simple(mystruct, 5).a == 10
assert add_simple(myunion, 5.0).a == 10.0
assert add_to_simple(mystruct, 5.0) == 10.0
assert add_to_simple(myunion, b"spamhameggs") == b"ameggs"
assert add_to_simple(myext, 5) == 10
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment