Commit e7426088 authored by Mark Florisson's avatar Mark Florisson

Support sharing of cdef functions with fused types

parent ae59c5a0
......@@ -2999,9 +2999,9 @@ class SimpleCallNode(CallNode):
overloaded_entry = None
if overloaded_entry:
if overloaded_entry.fused_cfunction:
specific_cdef_funcs = overloaded_entry.fused_cfunction.nodes
alternatives = [n.entry for n in specific_cdef_funcs]
if self.function.type.is_fused:
alternatives = []
self.function.type.map_with_specific_entries(alternatives.append)
else:
alternatives = overloaded_entry.all_alternatives()
......
......@@ -156,19 +156,14 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
f.close()
def generate_public_declaration(self, entry, h_code, i_code):
if entry.fused_cfunction:
for cfunction in entry.fused_cfunction.nodes:
self._generate_public_declaration(cfunction.entry,
cfunction.entry.cname, h_code, i_code)
else:
self._generate_public_declaration(entry, entry.cname,
entry.type.map_with_specific_entries(self._generate_public_declaration,
h_code, i_code)
def _generate_public_declaration(self, entry, cname, h_code, i_code):
def _generate_public_declaration(self, entry, h_code, i_code):
h_code.putln("%s %s;" % (
Naming.extern_c_macro,
entry.type.declaration_code(
cname, dll_linkage = "DL_IMPORT")))
entry.cname, dll_linkage = "DL_IMPORT")))
if i_code:
i_code.putln("cdef extern %s" %
entry.type.declaration_code(cname, pyrex = 1))
......@@ -996,15 +991,10 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
def generate_cfunction_predeclarations(self, env, code, definition):
for entry in env.cfunc_entries:
if entry.fused_cfunction:
for node in entry.fused_cfunction.nodes:
self._generate_cfunction_predeclaration(
code, definition, node.entry)
else:
self._generate_cfunction_predeclaration(code, definition, entry)
entry.type.map_with_specific_entries(
self._generate_cfunction_predeclaration, code, definition)
def _generate_cfunction_predeclaration(self, code, definition, entry):
def _generate_cfunction_predeclaration(self, entry, code, definition):
if entry.inline_func_in_pxd or (not entry.in_cinclude and (definition
or entry.defined_in_pxd or entry.visibility == 'extern')):
if entry.visibility == 'public':
......@@ -2056,15 +2046,27 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
def generate_c_function_export_code(self, env, code):
# Generate code to create PyCFunction wrappers for exported C functions.
func = self._generate_c_function_export_code
for entry in env.cfunc_entries:
from_fused = entry.type.is_fused
if entry.api or entry.defined_in_pxd:
entry.type.map_with_specific_entries(func, env, code, from_fused)
def _generate_c_function_export_code(self, entry, env, code, from_fused):
env.use_utility_code(function_export_utility_code)
signature = entry.type.signature_string()
code.putln('if (__Pyx_ExportFunction("%s", (void (*)(void))%s, "%s") < 0) %s' % (
entry.name,
entry.cname,
signature,
code.error_goto(self.pos)))
s = 'if (__Pyx_ExportFunction("%s", (void (*)(void))%s, "%s") < 0) %s'
if from_fused:
# Specific version of a fused function. Fused functions can never
# be declared public or api, but they may need to be exported when
# declared in a .pxd. We need to give them a unique name in that
# case
name = entry.cname
else:
name = entry.name
code.putln(s % (name, entry.cname, signature, code.error_goto(self.pos)))
def generate_type_import_code_for_module(self, module, env, code):
# Generate type import code for all exported extension types in
......@@ -2090,15 +2092,28 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
module.qualified_name,
temp,
code.error_goto(self.pos)))
for entry in entries:
entry.type.map_with_specific_entries(self._import_cdef_func,
code,
temp,
entry.type.is_fused)
code.putln("Py_DECREF(%s); %s = 0;" % (temp, temp))
def _import_cdef_func(self, entry, code, temp, from_fused):
if from_fused:
name = entry.cname
else:
name = entry.name
code.putln(
'if (__Pyx_ImportFunction(%s, "%s", (void (**)(void))&%s, "%s") < 0) %s' % (
temp,
entry.name,
name,
entry.cname,
entry.type.signature_string(),
code.error_goto(self.pos)))
code.putln("Py_DECREF(%s); %s = 0;" % (temp, temp))
def generate_type_init_code(self, env, code):
# Generate type import code for extern extension types
......
......@@ -22,6 +22,7 @@ import TypeSlots
from PyrexTypes import py_object_type, error_type, CFuncType
from Symtab import ModuleScope, LocalScope, ClosureScope, \
StructOrUnionScope, PyClassScope, CClassScope, CppClassScope
from Cython.Compiler import Symtab
from Cython.Utils import open_new_file, replace_suffix
from Code import UtilityCode, ClosureTempAllocator
from StringEncoding import EncodedString, escape_byte_string, split_string_literal
......@@ -940,7 +941,14 @@ class FusedTypeNode(CBaseTypeNode):
if len(self.types) == 1:
return self.types[0]
return PyrexTypes.FusedType(self.types)
types = []
for type in self.types:
if type.is_fused:
types.extend(type.types)
else:
types.append(type)
return PyrexTypes.FusedType(types)
class CVarDefNode(StatNode):
......@@ -2001,7 +2009,7 @@ class FusedCFuncDefNode(StatListNode):
self.nodes = self.stats = []
self.node = node
self.copy_cdefs(node.type.get_fused_types(), env)
self.copy_cdefs(env)
# Perform some sanity checks. If anything fails, it's a bug
for n in self.nodes:
......@@ -2014,12 +2022,12 @@ class FusedCFuncDefNode(StatListNode):
node.entry.fused_cfunction = self
def copy_cdefs(self, fused_types, env):
def copy_cdefs(self, env):
"""
Gives a list of fused types and the parent environment, make copies
of the original cdef function.
"""
permutations = self.get_all_specific_permutations(fused_types)
permutations = self.node.type.get_all_specific_permutations()
for cname, fused_to_specific in permutations:
copied_node = copy.deepcopy(self.node)
......@@ -2027,6 +2035,7 @@ class FusedCFuncDefNode(StatListNode):
newtype = copied_node.type.specialize(fused_to_specific)
copied_node.type = newtype
copied_node.entry.type = newtype
newtype.entry = copied_node.entry
copied_node.return_type = newtype.return_type
copied_node.create_local_scope(env)
......@@ -2041,33 +2050,9 @@ class FusedCFuncDefNode(StatListNode):
for arg in copied_node.cfunc_declarator.args:
arg.type = arg.type.specialize(fused_to_specific)
cname = '%s%s%s' % (Naming.fused_func_prefix,
cname,
self.node.entry.func_cname)
cname = self.node.type.get_specific_cname(cname)
copied_node.entry.func_cname = copied_node.entry.cname = cname
def get_all_specific_permutations(self, fused_types):
"""
Permute all the types. For every specific instance of a fused type, we
want all other specific instances of all other fused types.
It returns an iterable of two-tuples of the cname that should prefix
the cname of the function, and a dict mapping any fused types to their
respective specific types.
"""
fused_type = fused_types[0]
for specific_type in fused_type.types:
cname = str(specific_type)
result_fused_to_specific = { fused_type: specific_type }
if len(fused_types) > 1:
it = self.get_all_specific_permutations(fused_types[1:])
for next_cname, fused_to_specific in it:
d = dict(fused_to_specific, **result_fused_to_specific)
yield '%s_%s' % (cname, next_cname), d
else:
yield cname, result_fused_to_specific
class PyArgDeclNode(Node):
# Argument which must be a Python object (used
......
......@@ -664,7 +664,7 @@ class CType(PyrexType):
return 0
class FusedType(CType):
class FusedType(PyrexType):
"""
Represents a Fused Type. All it needs to do is keep track of the types
it aggregates, as it will be replaced with its specific version wherever
......@@ -2005,6 +2005,87 @@ class CFuncType(CType):
def opt_arg_cname(self, arg_name):
return self.op_arg_struct.base_type.scope.lookup(arg_name).cname
# Methods that deal with Fused Types
# All but map_with_specific_entries should be called only on functions
# with fused types (and not on their corresponding specific versions).
def get_all_specific_permutations(self, fused_types=None):
"""
Permute all the types. For every specific instance of a fused type, we
want all other specific instances of all other fused types.
It returns an iterable of two-tuples of the cname that should prefix
the cname of the function, and a dict mapping any fused types to their
respective specific types.
"""
assert self.is_fused
if fused_types is None:
fused_types = self.get_fused_types()
fused_type = fused_types[0]
for specific_type in fused_type.types:
cname = str(specific_type)
result_fused_to_specific = { fused_type: specific_type }
if len(fused_types) > 1:
it = self.get_all_specific_permutations(fused_types[1:])
for next_cname, fused_to_specific in it:
d = dict(fused_to_specific, **result_fused_to_specific)
yield '%s_%s' % (cname, next_cname), d
else:
yield cname, result_fused_to_specific
def get_all_specific_function_types(self):
"""
Get all the specific function types of this one.
"""
assert self.is_fused
permutations = self.get_all_specific_permutations()
for cname, fused_to_specific in permutations:
new_func_type = self.entry.type.specialize(fused_to_specific)
new_entry = copy.deepcopy(self.entry)
new_entry.cname = self.get_specific_cname(cname)
new_entry.type = new_func_type
new_func_type.entry = new_entry
yield new_func_type
def get_specific_cname(self, fused_cname):
"""
Given the cname for a permutation of fused types, return the cname
for the corresponding function with specific types.
The fused_cname is usually '_'.join(str(t) for t in specific_types)
"""
assert self.is_fused
return '%s%s%s' % (Naming.fused_func_prefix,
fused_cname,
self.entry.func_cname)
def map_with_specific_entries(self, func, *args, **kwargs):
"""
Call func for every specific function instance. If this is not a
signature with fused types, call it with the entry for this cdef
function.
"""
entry = self.entry
if entry.fused_cfunction:
# cdef with fused types defined in this file
for cfunction in entry.fused_cfunction.nodes:
func(cfunction.entry, *args, **kwargs)
elif entry.type.is_fused:
# cdef with fused types defined in another file, create their
# signatures
for func_type in self.get_all_specific_function_types():
func(func_type.entry, *args, **kwargs)
else:
# a normal cdef
return func(entry, *args, **kwargs)
class CFuncTypeArg(BaseType):
# name string
......@@ -2641,7 +2722,7 @@ def best_match(args, functions, pos=None, env=None):
bad_types = []
needed_coercions = {}
for func, func_type in candidates:
score = [0,0,0]
score = [0,0,0,0]
for i in range(min(len(args), len(func_type.args))):
src_type = args[i].type
dst_type = func_type.args[i].type
......@@ -2669,6 +2750,9 @@ def best_match(args, functions, pos=None, env=None):
if src_type == dst_type or dst_type.same_as(src_type):
pass # score 0
elif is_promotion(src_type, dst_type):
score[3] += 1
elif ((src_type.is_int and dst_type.is_int) or
(src_type.is_float and dst_type.is_float)):
score[2] += 1
elif not src_type.is_pyobject:
score[1] += 1
......
......@@ -629,6 +629,7 @@ class Scope(object):
if modifiers:
entry.func_modifiers = modifiers
entry.utility_code = utility_code
type.entry = entry
return entry
def add_cfunction(self, name, type, pos, cname, visibility, modifiers):
......
PYTHON setup.py build_ext --inplace
PYTHON -c "import b"
######## setup.py ########
from Cython.Build import cythonize
from distutils.core import setup
setup(
ext_modules = cythonize("*.pyx"),
)
######## a.pxd ########
cimport cython
cdef extern from "header.h":
ctypedef int extern_int
ctypedef long extern_long
cdef struct mystruct_t:
extern_int a
ctypedef union myunion_t:
extern_long a
cdef public class MyExt [ type MyExtType, object MyExtObject ]:
cdef unsigned char a
ctypedef char *string_t
ctypedef cython.fused_type(int, float) simple_t
ctypedef cython.fused_type(simple_t, string_t) less_simple_t
ctypedef cython.fused_type(mystruct_t, myunion_t, MyExt) object_t
ctypedef cython.fused_type(str, unicode, bytes) builtin_t
cdef object_t add_simple(object_t obj, simple_t simple)
cdef less_simple_t add_to_simple(object_t obj, less_simple_t simple)
######## header.h ########
typedef int extern_int;
typedef long extern_long;
######## a.pyx ########
cimport cython
cdef object_t add_simple(object_t obj, simple_t simple):
obj.a = <int> (obj.a + simple)
return obj
cdef less_simple_t add_to_simple(object_t obj, less_simple_t simple):
return obj.a + simple
######## b.pyx ########
from a cimport *
cdef mystruct_t mystruct
cdef myunion_t myunion
cdef MyExt myext = MyExt()
mystruct.a = 5
myunion.a = 5
myext.a = 5
assert add_simple(mystruct, 5).a == 10
assert add_simple(myunion, 5.0).a == 10.0
assert add_to_simple(mystruct, 5.0) == 10.0
assert add_to_simple(myunion, b"spamhameggs") == b"ameggs"
assert add_to_simple(myext, 5) == 10
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment