Commit edf04816 authored by Mark Florisson's avatar Mark Florisson

Support fused def functions + lambda + better runtime dispatch

parent d80c2c4c
...@@ -63,7 +63,7 @@ class AutoTestDictTransform(ScopeTrackingTransform): ...@@ -63,7 +63,7 @@ class AutoTestDictTransform(ScopeTrackingTransform):
return node return node
def visit_FuncDefNode(self, node): def visit_FuncDefNode(self, node):
if not node.doc: if not node.doc or node.fused_py_func:
return node return node
if not self.cdef_docstrings: if not self.cdef_docstrings:
if isinstance(node, CFuncDefNode) and not node.py_func: if isinstance(node, CFuncDefNode) and not node.py_func:
......
...@@ -8969,12 +8969,17 @@ static PyObject *%(binding_cfunc)s_call(PyObject *func, PyObject *args, PyObject ...@@ -8969,12 +8969,17 @@ static PyObject *%(binding_cfunc)s_call(PyObject *func, PyObject *args, PyObject
static PyObject *%(binding_cfunc)s_get__name__(%(binding_cfunc)s_object *func, void *closure); static PyObject *%(binding_cfunc)s_get__name__(%(binding_cfunc)s_object *func, void *closure);
static int %(binding_cfunc)s_set__name__(%(binding_cfunc)s_object *func, PyObject *value, void *closure); static int %(binding_cfunc)s_set__name__(%(binding_cfunc)s_object *func, PyObject *value, void *closure);
static PyObject *%(binding_cfunc)s_get__doc__(%(binding_cfunc)s_object *func, void *closure);
static PyGetSetDef %(binding_cfunc)s_getsets[] = { static PyGetSetDef %(binding_cfunc)s_getsets[] = {
{(char *)"__name__", {(char *)"__name__",
(getter) %(binding_cfunc)s_get__name__, (getter) %(binding_cfunc)s_get__name__,
(setter) %(binding_cfunc)s_set__name__, (setter) %(binding_cfunc)s_set__name__,
NULL}, NULL},
{(char *)"__doc__",
(getter) %(binding_cfunc)s_get__doc__,
NULL,
NULL},
{NULL}, {NULL},
}; };
...@@ -9139,6 +9144,12 @@ static int ...@@ -9139,6 +9144,12 @@ static int
return PyDict_SetItemString(func->__dict__, "__name__", value); return PyDict_SetItemString(func->__dict__, "__name__", value);
} }
static PyObject *
%(binding_cfunc)s_get__doc__(%(binding_cfunc)s_object *func, void *closure)
{
return PyUnicode_FromString(func->func.m_ml->ml_doc);
}
static PyObject * static PyObject *
%(binding_cfunc)s_descr_get(PyObject *op, PyObject *obj, PyObject *type) %(binding_cfunc)s_descr_get(PyObject *op, PyObject *obj, PyObject *type)
{ {
...@@ -9290,11 +9301,13 @@ static PyObject * ...@@ -9290,11 +9301,13 @@ static PyObject *
binaryfunc meth = (binaryfunc) binding_func->func.m_ml->ml_meth; binaryfunc meth = (binaryfunc) binding_func->func.m_ml->ml_meth;
func = new_func = meth(binding_func->__signatures__, args); func = new_func = meth(binding_func->__signatures__, args);
*/ */
PyObject *tup = PyTuple_Pack(2, binding_func->__signatures__, args); PyObject *tup = PyTuple_Pack(3, binding_func->__signatures__, args,
kw == NULL ? Py_None : kw);
if (!tup) if (!tup)
goto __pyx_err; goto __pyx_err;
func = new_func = PyCFunction_Call(func, tup, NULL); func = new_func = PyCFunction_Call(func, tup, NULL);
Py_DECREF(tup);
if (!new_func) if (!new_func)
goto __pyx_err; goto __pyx_err;
......
...@@ -936,6 +936,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -936,6 +936,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
# Generate struct declaration for an extension type's vtable. # Generate struct declaration for an extension type's vtable.
type = entry.type type = entry.type
scope = type.scope scope = type.scope
self.specialize_fused_types(scope)
if type.vtabstruct_cname: if type.vtabstruct_cname:
code.putln("") code.putln("")
code.putln( code.putln(
...@@ -1942,7 +1945,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1942,7 +1945,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("/*--- Function import code ---*/") code.putln("/*--- Function import code ---*/")
for module in imported_modules: for module in imported_modules:
self.specialize_fused_types(module, env) self.specialize_fused_types(module)
self.generate_c_function_import_code_for_module(module, env, code) self.generate_c_function_import_code_for_module(module, env, code)
code.putln("/*--- Execution code ---*/") code.putln("/*--- Execution code ---*/")
...@@ -2160,7 +2163,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -2160,7 +2163,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
if entry.defined_in_pxd: if entry.defined_in_pxd:
self.generate_type_import_code(env, entry.type, entry.pos, code) self.generate_type_import_code(env, entry.type, entry.pos, code)
def specialize_fused_types(self, pxd_env, impl_env): def specialize_fused_types(self, pxd_env):
""" """
If fused c(p)def functions are defined in an imported pxd, but not If fused c(p)def functions are defined in an imported pxd, but not
used in this implementation file, we still have fused entries and used in this implementation file, we still have fused entries and
......
...@@ -139,7 +139,6 @@ class Node(object): ...@@ -139,7 +139,6 @@ class Node(object):
cf_state = None cf_state = None
def __init__(self, pos, **kw): def __init__(self, pos, **kw):
self.pos = pos self.pos = pos
self.__dict__.update(kw) self.__dict__.update(kw)
...@@ -2039,7 +2038,8 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2039,7 +2038,8 @@ class FusedCFuncDefNode(StatListNode):
node FuncDefNode the original function node FuncDefNode the original function
nodes [FuncDefNode] list of copies of node with different specific types nodes [FuncDefNode] list of copies of node with different specific types
py_func DefNode the original python function (in case of a cpdef) py_func DefNode the fused python function subscriptable from
Python space
""" """
def __init__(self, node, env): def __init__(self, node, env):
...@@ -2048,42 +2048,78 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2048,42 +2048,78 @@ class FusedCFuncDefNode(StatListNode):
self.nodes = [] self.nodes = []
self.node = node self.node = node
self.copy_cdefs(env) is_def = isinstance(self.node, DefNode)
if is_def:
self.copy_def(env)
else:
self.copy_cdef(env)
# Perform some sanity checks. If anything fails, it's a bug # Perform some sanity checks. If anything fails, it's a bug
for n in self.nodes: for n in self.nodes:
assert not n.type.is_fused assert not n.entry.type.is_fused
assert not n.local_scope.return_type.is_fused assert not n.local_scope.return_type.is_fused
if node.return_type.is_fused: if node.return_type.is_fused:
assert not n.return_type.is_fused assert not n.return_type.is_fused
if n.cfunc_declarator.optional_arg_count: if not is_def and n.cfunc_declarator.optional_arg_count:
assert n.type.op_arg_struct assert n.type.op_arg_struct
assert n.type.entry
assert node.type.is_fused
node.entry.fused_cfunction = self node.entry.fused_cfunction = self
if self.py_func: if self.py_func:
self.py_func.entry.fused_cfunction = self self.py_func.entry.fused_cfunction = self
for node in self.nodes: for node in self.nodes:
node.py_func.fused_py_func = self.py_func if is_def:
node.entry.as_variable = self.py_func.entry node.fused_py_func = self.py_func
else:
node.py_func.fused_py_func = self.py_func
node.entry.as_variable = self.py_func.entry
# Copy the nodes as AnalyseDeclarationsTransform will append # Copy the nodes as AnalyseDeclarationsTransform will prepend
# self.py_func to self.stats, as we only want specialized # self.py_func to self.stats, as we only want specialized
# CFuncDefNodes in self.nodes # CFuncDefNodes in self.nodes
self.stats = self.nodes[:] self.stats = self.nodes[:]
def copy_cdefs(self, env): def copy_def(self, env):
""" """
Gives a list of fused types and the parent environment, make copies Create a copy of the original def or lambda function for specialized
of the original cdef function. versions.
""" """
from Cython.Compiler import ParseTreeTransforms fused_types = [arg.type for arg in self.node.args if arg.type.is_fused]
permutations = PyrexTypes.get_all_specific_permutations(fused_types)
if self.node.entry in env.pyfunc_entries:
env.pyfunc_entries.remove(self.node.entry)
for cname, fused_to_specific in permutations:
copied_node = copy.deepcopy(self.node)
for arg in copied_node.args:
if arg.type.is_fused:
arg.type = arg.type.specialize(fused_to_specific)
copied_node.return_type = self.node.return_type.specialize(
fused_to_specific)
copied_node.analyse_declarations(env)
self.create_new_local_scope(copied_node, env, fused_to_specific)
self.specialize_copied_def(copied_node, cname, self.node.entry,
fused_to_specific, fused_types)
PyrexTypes.specialize_entry(copied_node.entry, cname)
copied_node.entry.used = True
env.entries[copied_node.entry.name] = copied_node.entry
if not self.replace_fused_typechecks(copied_node):
break
self.py_func = self.make_fused_cpdef(self.node, env, is_def=True)
def copy_cdef(self, env):
"""
Create a copy of the original c(p)def function for all specialized
versions.
"""
permutations = self.node.type.get_all_specific_permutations() permutations = self.node.type.get_all_specific_permutations()
# print 'Node %s has %d specializations:' % (self.node.entry.name, # print 'Node %s has %d specializations:' % (self.node.entry.name,
# len(permutations)) # len(permutations))
...@@ -2120,13 +2156,7 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2120,13 +2156,7 @@ class FusedCFuncDefNode(StatListNode):
type, env, fused_cname=cname) type, env, fused_cname=cname)
copied_node.return_type = type.return_type copied_node.return_type = type.return_type
copied_node.create_local_scope(env) self.create_new_local_scope(copied_node, env, fused_to_specific)
copied_node.local_scope.fused_to_specific = fused_to_specific
# This is copied from the original function, set it to false to
# stop recursion
copied_node.has_fused_arguments = False
self.nodes.append(copied_node)
# Make the argument types in the CFuncDeclarator specific # Make the argument types in the CFuncDeclarator specific
for arg in copied_node.cfunc_declarator.args: for arg in copied_node.cfunc_declarator.args:
...@@ -2135,45 +2165,83 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2135,45 +2165,83 @@ class FusedCFuncDefNode(StatListNode):
type.specialize_entry(entry, cname) type.specialize_entry(entry, cname)
env.cfunc_entries.append(entry) env.cfunc_entries.append(entry)
# If a cpdef, declare all specialized cpdefs # If a cpdef, declare all specialized cpdefs (this
# also calls analyse_declarations)
copied_node.declare_cpdef_wrapper(env) copied_node.declare_cpdef_wrapper(env)
if copied_node.py_func: if copied_node.py_func:
env.pyfunc_entries.remove(copied_node.py_func.entry) env.pyfunc_entries.remove(copied_node.py_func.entry)
# copied_node.py_func.self_in_stararg = True
type_strings = [
fused_to_specific[fused_type].typeof_name()
for fused_type in fused_types]
if len(type_strings) == 1:
sigstring = type_strings[0]
else:
sigstring = ', '.join(type_strings)
copied_node.py_func.specialized_signature_string = sigstring self.specialize_copied_def(
copied_node.py_func, cname, self.node.entry.as_variable,
fused_to_specific, fused_types)
e = copied_node.py_func.entry if not self.replace_fused_typechecks(copied_node):
e.pymethdef_cname = PyrexTypes.get_fused_cname(
cname, e.pymethdef_cname)
num_errors = Errors.num_errors
transform = ParseTreeTransforms.ReplaceFusedTypeChecks(
copied_node.local_scope)
transform(copied_node)
if Errors.num_errors > num_errors:
break break
if orig_py_func: if orig_py_func:
self.py_func = self.make_fused_cpdef(orig_py_func, env) self.py_func = self.make_fused_cpdef(orig_py_func, env,
is_def=False)
else: else:
self.py_func = orig_py_func self.py_func = orig_py_func
def create_new_local_scope(self, node, env, f2s):
"""
Create a new local scope for the copied node and append it to
self.nodes. A new local scope is needed because the arguments with the
fused types are aready in the local scope, and we need the specialized
entries created after analyse_declarations on each specialized version
of the (CFunc)DefNode.
f2s is a dict mapping each fused type to its specialized version
"""
node.create_local_scope(env)
node.local_scope.fused_to_specific = f2s
def make_fused_cpdef(self, orig_py_func, env): # This is copied from the original function, set it to false to
# stop recursion
node.has_fused_arguments = False
self.nodes.append(node)
def specialize_copied_def(self, node, cname, py_entry, f2s, fused_types):
"""Specialize the copy of a DefNode given the copied node,
the specialization cname and the original DefNode entry"""
type_strings = [f2s[fused_type].typeof_name()
for fused_type in fused_types]
node.specialized_signature_string = ', '.join(type_strings)
node.entry.pymethdef_cname = PyrexTypes.get_fused_cname(
cname, node.entry.pymethdef_cname)
node.entry.doc = py_entry.doc
node.entry.doc_cname = py_entry.doc_cname
def replace_fused_typechecks(self, copied_node):
"""
Branch-prune fused type checks like
if fused_t is int:
...
Returns whether an error was issued and whether we should stop in
in order to prevent a flood of errors.
"""
from Cython.Compiler import ParseTreeTransforms
num_errors = Errors.num_errors
transform = ParseTreeTransforms.ReplaceFusedTypeChecks(
copied_node.local_scope)
transform(copied_node)
if Errors.num_errors > num_errors:
return False
return True
def make_fused_cpdef(self, orig_py_func, env, is_def):
""" """
This creates the function that is indexable from Python and does This creates the function that is indexable from Python and does
runtime dispatch based on the argument types. runtime dispatch based on the argument types. The function gets the
arg tuple and kwargs dict (or None) as arugments from the Binding
Fused Function's tp_call.
""" """
from Cython.Compiler import TreeFragment from Cython.Compiler import TreeFragment
from Cython.Compiler import ParseTreeTransforms from Cython.Compiler import ParseTreeTransforms
...@@ -2184,17 +2252,28 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2184,17 +2252,28 @@ class FusedCFuncDefNode(StatListNode):
# list of statements that do the instance checks # list of statements that do the instance checks
body_stmts = [] body_stmts = []
for i, arg_type in enumerate(self.node.type.args): args = self.node.args
arg_type = arg_type.type for i, arg in enumerate(args):
arg_type = arg.type
if arg_type.is_fused and arg_type not in seen_fused_types: if arg_type.is_fused and arg_type not in seen_fused_types:
seen_fused_types.add(arg_type) seen_fused_types.add(arg_type)
specialized_types = PyrexTypes.get_specific_types(arg_type) specialized_types = PyrexTypes.get_specialized_types(arg_type)
# Prefer long over int, etc # Prefer long over int, etc
specialized_types.sort() # specialized_types.sort()
seen_py_type_names = cython.set() seen_py_type_names = cython.set()
first_check = True first_check = True
body_stmts.append(u"""
if nargs >= %(nextidx)d or '%(argname)s' in kwargs:
if nargs >= %(nextidx)d:
arg = args[%(idx)d]
else:
arg = kwargs['%(argname)s']
""" % {'idx': i, 'nextidx': i + 1, 'argname': arg.name})
all_numeric = True
for specialized_type in specialized_types: for specialized_type in specialized_types:
py_type_name = specialized_type.py_type_name() py_type_name = specialized_type.py_type_name()
...@@ -2203,6 +2282,8 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2203,6 +2282,8 @@ class FusedCFuncDefNode(StatListNode):
seen_py_type_names.add(py_type_name) seen_py_type_names.add(py_type_name)
all_numeric = all_numeric and specialized_type.is_numeric
if first_check: if first_check:
if_ = 'if' if_ = 'if'
first_check = False first_check = False
...@@ -2216,24 +2297,43 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2216,24 +2297,43 @@ class FusedCFuncDefNode(StatListNode):
if py_type_name in ('long', 'unicode', 'bytes'): if py_type_name in ('long', 'unicode', 'bytes'):
instance_check_py_type_name += '_' instance_check_py_type_name += '_'
tup = (if_, i, instance_check_py_type_name, tup = (if_, instance_check_py_type_name,
len(seen_fused_types) - 1, len(seen_fused_types) - 1,
specialized_type.typeof_name()) specialized_type.typeof_name())
body_stmts.append( body_stmts.append(
" %s isinstance(args[%d], %s): " " %s isinstance(arg, %s): "
"dest_sig[%d] = '%s'" % tup) "dest_sig[%d] = '%s'" % tup)
if arg.default and all_numeric:
arg.default.analyse_types(env)
ts = specialized_types
if arg.default.type.is_complex:
typelist = [t for t in ts if t.is_complex]
elif arg.default.type.is_float:
typelist = [t for t in ts if t.is_float]
else:
typelist = [t for t in ts if t.is_int]
if typelist:
body_stmts.append(u"""\
else:
dest_sig[%d] = '%s'
""" % (i, typelist[0].typeof_name()))
fmt_dict = { fmt_dict = {
'body': '\n'.join(body_stmts), 'body': '\n'.join(body_stmts),
'nargs': len(self.node.type.args), 'nargs': len(args),
'name': orig_py_func.entry.name, 'name': orig_py_func.entry.name,
} }
fragment = TreeFragment.TreeFragment(u""" fragment_code = u"""
def __pyx_fused_cpdef(signatures, args): def __pyx_fused_cpdef(signatures, args, kwargs):
if len(args) < %(nargs)d: #if len(args) < %(nargs)d:
raise TypeError("Invalid number of arguments, expected %(nargs)d, " # raise TypeError("Invalid number of arguments, expected %(nargs)d, "
"got %%d" %% len(args)) # "got %%d" %% len(args))
cdef int nargs
nargs = len(args)
import sys import sys
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
...@@ -2245,7 +2345,10 @@ def __pyx_fused_cpdef(signatures, args): ...@@ -2245,7 +2345,10 @@ def __pyx_fused_cpdef(signatures, args):
unicode_ = unicode unicode_ = unicode
bytes_ = str bytes_ = str
dest_sig = [None] * len(args) dest_sig = [None] * %(nargs)d
if kwargs is None:
kwargs = {}
# instance check body # instance check body
%(body)s %(body)s
...@@ -2266,8 +2369,11 @@ def __pyx_fused_cpdef(signatures, args): ...@@ -2266,8 +2369,11 @@ def __pyx_fused_cpdef(signatures, args):
raise TypeError("Function call with ambiguous argument types") raise TypeError("Function call with ambiguous argument types")
else: else:
return signatures[candidates[0]] return signatures[candidates[0]]
""" % fmt_dict, level='module') """ % fmt_dict
# print fragment_code
fragment = TreeFragment.TreeFragment(fragment_code, level='module')
# analyse the declarations of our fragment ... # analyse the declarations of our fragment ...
py_func, = fragment.substitute(pos=self.node.pos).stats py_func, = fragment.substitute(pos=self.node.pos).stats
...@@ -2283,17 +2389,31 @@ def __pyx_fused_cpdef(signatures, args): ...@@ -2283,17 +2389,31 @@ def __pyx_fused_cpdef(signatures, args):
py_func.name = e.name = orig_e.name py_func.name = e.name = orig_e.name
e.cname, e.func_cname = orig_e.cname, orig_e.func_cname e.cname, e.func_cname = orig_e.cname, orig_e.func_cname
e.pymethdef_cname = orig_e.pymethdef_cname e.pymethdef_cname = orig_e.pymethdef_cname
e.doc, e.doc_cname = orig_e.doc, orig_e.doc_cname
# e.signature = TypeSlots.binaryfunc # e.signature = TypeSlots.binaryfunc
py_func.doc = orig_py_func.doc
# ... and the symbol table # ... and the symbol table
del env.entries['__pyx_fused_cpdef'] del env.entries['__pyx_fused_cpdef']
env.entries[e.name].as_variable = e if is_def:
env.entries[e.name] = e
else:
env.entries[e.name].as_variable = e
env.pyfunc_entries.append(e) env.pyfunc_entries.append(e)
py_func.specialized_cpdefs = [n.py_func for n in self.nodes] if is_def:
py_func.specialized_cpdefs = self.nodes[:]
else:
py_func.specialized_cpdefs = [n.py_func for n in self.nodes]
return py_func return py_func
def generate_function_definitions(self, env, code): def generate_function_definitions(self, env, code):
# Ensure the indexable fused function is generated first, so we can
# use its docstring
# self.stats.insert(0, self.stats.pop())
for stat in self.stats: for stat in self.stats:
# print stat.entry, stat.entry.used # print stat.entry, stat.entry.used
if stat.entry.used: if stat.entry.used:
...@@ -2482,6 +2602,7 @@ class DefNode(FuncDefNode): ...@@ -2482,6 +2602,7 @@ class DefNode(FuncDefNode):
self.declare_lambda_function(env) self.declare_lambda_function(env)
else: else:
self.declare_pyfunction(env) self.declare_pyfunction(env)
self.analyse_signature(env) self.analyse_signature(env)
self.return_type = self.entry.signature.return_type() self.return_type = self.entry.signature.return_type()
self.create_local_scope(env) self.create_local_scope(env)
...@@ -2498,6 +2619,10 @@ class DefNode(FuncDefNode): ...@@ -2498,6 +2619,10 @@ class DefNode(FuncDefNode):
arg.declarator.analyse(base_type, env) arg.declarator.analyse(base_type, env)
arg.name = name_declarator.name arg.name = name_declarator.name
arg.type = type arg.type = type
if type.is_fused:
self.has_fused_arguments = True
self.align_argument_type(env, arg) self.align_argument_type(env, arg)
if name_declarator and name_declarator.cname: if name_declarator and name_declarator.cname:
error(self.pos, error(self.pos,
...@@ -2707,6 +2832,10 @@ class DefNode(FuncDefNode): ...@@ -2707,6 +2832,10 @@ class DefNode(FuncDefNode):
def synthesize_assignment_node(self, env): def synthesize_assignment_node(self, env):
import ExprNodes import ExprNodes
if self.fused_py_func:
return
genv = env genv = env
while genv.is_py_class_scope or genv.is_c_class_scope: while genv.is_py_class_scope or genv.is_c_class_scope:
genv = genv.outer_scope genv = genv.outer_scope
...@@ -2766,22 +2895,29 @@ class DefNode(FuncDefNode): ...@@ -2766,22 +2895,29 @@ class DefNode(FuncDefNode):
# If we are the specialized version of the cpdef, we still # If we are the specialized version of the cpdef, we still
# want the prototype for the "fused cpdef", in case we're # want the prototype for the "fused cpdef", in case we're
# checking to see if our method was overridden in Python # checking to see if our method was overridden in Python
self.fused_py_func.generate_function_header(code, with_pymethdef, proto_only=True) self.fused_py_func.generate_function_header(
code, with_pymethdef, proto_only=True)
return return
if (Options.docstrings and self.entry.doc and if (Options.docstrings and self.entry.doc and
not self.fused_py_func and
not self.entry.scope.is_property_scope and not self.entry.scope.is_property_scope and
(not self.entry.is_special or self.entry.wrapperbase_cname)): (not self.entry.is_special or self.entry.wrapperbase_cname)):
# h_code = code.globalstate['h_code']
docstr = self.entry.doc docstr = self.entry.doc
if docstr.is_unicode: if docstr.is_unicode:
docstr = docstr.utf8encode() docstr = docstr.utf8encode()
code.putln( code.putln(
'static char %s[] = "%s";' % ( 'static char %s[] = "%s";' % (
self.entry.doc_cname, self.entry.doc_cname,
split_string_literal(escape_byte_string(docstr)))) split_string_literal(escape_byte_string(docstr))))
if self.entry.is_special: if self.entry.is_special:
code.putln( code.putln(
"struct wrapperbase %s;" % self.entry.wrapperbase_cname) "struct wrapperbase %s;" % self.entry.wrapperbase_cname)
if with_pymethdef or self.fused_py_func: if with_pymethdef or self.fused_py_func:
code.put( code.put(
"static PyMethodDef %s = " % "static PyMethodDef %s = " %
...@@ -2909,10 +3045,12 @@ class DefNode(FuncDefNode): ...@@ -2909,10 +3045,12 @@ class DefNode(FuncDefNode):
else: else:
func = arg.type.from_py_function func = arg.type.from_py_function
if func: if func:
code.putln("%s = %s(%s); %s" % ( rhs = "%s(%s)" % (func, item)
if arg.type.is_enum:
rhs = arg.type.cast_code(rhs)
code.putln("%s = %s; %s" % (
arg.entry.cname, arg.entry.cname,
func, rhs,
item,
code.error_goto_if(arg.type.error_condition(arg.entry.cname), arg.pos))) code.error_goto_if(arg.type.error_condition(arg.entry.cname), arg.pos)))
else: else:
error(arg.pos, "Cannot convert Python object argument to type '%s'" % arg.type) error(arg.pos, "Cannot convert Python object argument to type '%s'" % arg.type)
......
...@@ -1348,10 +1348,13 @@ if VALUE is not None: ...@@ -1348,10 +1348,13 @@ if VALUE is not None:
count += 1 count += 1
""") """)
fused_function = None
def __call__(self, root): def __call__(self, root):
self.env_stack = [root.scope] self.env_stack = [root.scope]
# needed to determine if a cdef var is declared after it's used. # needed to determine if a cdef var is declared after it's used.
self.seen_vars_stack = [] self.seen_vars_stack = []
self.fused_error_funcs = cython.set()
return super(AnalyseDeclarationsTransform, self).__call__(root) return super(AnalyseDeclarationsTransform, self).__call__(root)
def visit_NameNode(self, node): def visit_NameNode(self, node):
...@@ -1399,9 +1402,12 @@ if VALUE is not None: ...@@ -1399,9 +1402,12 @@ if VALUE is not None:
analyse its children (which are in turn normal functions). If we're a analyse its children (which are in turn normal functions). If we're a
normal function, just analyse the body of the function. normal function, just analyse the body of the function.
""" """
env = self.env_stack[-1]
self.seen_vars_stack.append(cython.set()) self.seen_vars_stack.append(cython.set())
lenv = node.local_scope lenv = node.local_scope
node.declare_arguments(lenv) node.declare_arguments(lenv)
for var, type_node in node.directive_locals.items(): for var, type_node in node.directive_locals.items():
if not lenv.lookup_here(var): # don't redeclare args if not lenv.lookup_here(var): # don't redeclare args
type = type_node.analyse_as_type(lenv) type = type_node.analyse_as_type(lenv)
...@@ -1411,10 +1417,27 @@ if VALUE is not None: ...@@ -1411,10 +1417,27 @@ if VALUE is not None:
error(type_node.pos, "Not a type") error(type_node.pos, "Not a type")
if node.has_fused_arguments: if node.has_fused_arguments:
node = Nodes.FusedCFuncDefNode(node, self.env_stack[-1]) if self.fused_function:
if self.fused_function not in self.fused_error_funcs:
error(node.pos, "Cannot nest fused functions")
self.fused_error_funcs.add(self.fused_function)
# env.declare_var(node.name, PyrexTypes.py_object_type, node.pos)
node = Nodes.SingleAssignmentNode(
node.pos,
lhs=ExprNodes.NameNode(node.pos, name=node.name),
rhs=ExprNodes.NoneNode(node.pos))
node.analyse_declarations(env)
return node
node = Nodes.FusedCFuncDefNode(node, env)
self.fused_function = node
self.visitchildren(node) self.visitchildren(node)
self.fused_function = None
if node.py_func: if node.py_func:
node.stats.append(node.py_func) node.stats.insert(0, node.py_func)
else: else:
node.body.analyse_declarations(lenv) node.body.analyse_declarations(lenv)
...@@ -2082,6 +2105,10 @@ class CreateClosureClasses(CythonTransform): ...@@ -2082,6 +2105,10 @@ class CreateClosureClasses(CythonTransform):
target_module_scope.check_c_class(func_scope.scope_class) target_module_scope.check_c_class(func_scope.scope_class)
def visit_LambdaNode(self, node): def visit_LambdaNode(self, node):
if not isinstance(node.def_node, Nodes.DefNode):
# fused function, an error has been previously issued
return node
was_in_lambda = self.in_lambda was_in_lambda = self.in_lambda
self.in_lambda = True self.in_lambda = True
self.create_class_from_scope(node.def_node, self.module_scope, node) self.create_class_from_scope(node.def_node, self.module_scope, node)
...@@ -2408,7 +2435,7 @@ class ReplaceFusedTypeChecks(VisitorTransform): ...@@ -2408,7 +2435,7 @@ class ReplaceFusedTypeChecks(VisitorTransform):
error(node.operand2.pos, error(node.operand2.pos,
"Can only use 'in' or 'not in' on a fused type") "Can only use 'in' or 'not in' on a fused type")
else: else:
types = PyrexTypes.get_specific_types(type2) types = PyrexTypes.get_specialized_types(type2)
for specific_type in types: for specific_type in types:
if type1.same_as(specific_type): if type1.same_as(specific_type):
......
...@@ -235,9 +235,16 @@ def public_decl(base_code, dll_linkage): ...@@ -235,9 +235,16 @@ def public_decl(base_code, dll_linkage):
return base_code return base_code
def create_typedef_type(name, base_type, cname, is_external=0): def create_typedef_type(name, base_type, cname, is_external=0):
if base_type.is_complex: is_fused = base_type.is_fused
if base_type.is_complex or is_fused:
if is_external: if is_external:
raise ValueError("Complex external typedefs not supported") if is_fused:
msg = "Fused"
else:
msg = "Complex"
raise ValueError("%s external typedefs not supported" % msg)
return base_type return base_type
else: else:
return CTypedefType(name, base_type, cname, is_external) return CTypedefType(name, base_type, cname, is_external)
...@@ -2123,6 +2130,7 @@ class CFuncType(CType): ...@@ -2123,6 +2130,7 @@ class CFuncType(CType):
result = [] result = []
permutations = self.get_all_specific_permutations() permutations = self.get_all_specific_permutations()
for cname, fused_to_specific in permutations: for cname, fused_to_specific in permutations:
new_func_type = self.entry.type.specialize(fused_to_specific) new_func_type = self.entry.type.specialize(fused_to_specific)
...@@ -2150,20 +2158,25 @@ class CFuncType(CType): ...@@ -2150,20 +2158,25 @@ class CFuncType(CType):
def specialize_entry(self, entry, cname): def specialize_entry(self, entry, cname):
assert not self.is_fused assert not self.is_fused
specialize_entry(entry, cname)
entry.name = get_fused_cname(cname, entry.name)
if entry.is_cmethod: def specialize_entry(entry, cname):
entry.cname = entry.name """
if entry.is_inherited: Specialize an entry of a copied fused function or method
entry.cname = StringEncoding.EncodedString( """
"%s.%s" % (Naming.obj_base_cname, entry.cname)) entry.name = get_fused_cname(cname, entry.name)
else:
entry.cname = get_fused_cname(cname, entry.cname)
if entry.func_cname: if entry.is_cmethod:
entry.func_cname = get_fused_cname(cname, entry.func_cname) entry.cname = entry.name
if entry.is_inherited:
entry.cname = StringEncoding.EncodedString(
"%s.%s" % (Naming.obj_base_cname, entry.cname))
else:
entry.cname = get_fused_cname(cname, entry.cname)
if entry.func_cname:
entry.func_cname = get_fused_cname(cname, entry.func_cname)
def get_fused_cname(fused_cname, orig_cname): def get_fused_cname(fused_cname, orig_cname):
""" """
...@@ -2177,7 +2190,7 @@ def get_all_specific_permutations(fused_types, id="", f2s=()): ...@@ -2177,7 +2190,7 @@ def get_all_specific_permutations(fused_types, id="", f2s=()):
fused_type = fused_types[0] fused_type = fused_types[0]
result = [] result = []
for newid, specific_type in enumerate(fused_type.types): for newid, specific_type in enumerate(sorted(fused_type.types)):
# f2s = dict(f2s, **{ fused_type: specific_type }) # f2s = dict(f2s, **{ fused_type: specific_type })
f2s = dict(f2s) f2s = dict(f2s)
f2s.update({ fused_type: specific_type }) f2s.update({ fused_type: specific_type })
...@@ -2195,17 +2208,21 @@ def get_all_specific_permutations(fused_types, id="", f2s=()): ...@@ -2195,17 +2208,21 @@ def get_all_specific_permutations(fused_types, id="", f2s=()):
return result return result
def get_specific_types(type): def get_specialized_types(type):
"""
Return a list of specialized types sorted in reverse order in accordance
with their preference in runtime fused-type dispatch
"""
assert type.is_fused assert type.is_fused
if isinstance(type, FusedType): if isinstance(type, FusedType):
return type.types result = type.types
else:
result = [] result = []
for cname, f2s in get_all_specific_permutations(type.get_fused_types()): for cname, f2s in get_all_specific_permutations(type.get_fused_types()):
result.append(type.specialize(f2s)) result.append(type.specialize(f2s))
return result return sorted(result)
class CFuncTypeArg(BaseType): class CFuncTypeArg(BaseType):
......
...@@ -1509,6 +1509,7 @@ class ClosureScope(LocalScope): ...@@ -1509,6 +1509,7 @@ class ClosureScope(LocalScope):
def declare_pyfunction(self, name, pos, allow_redefine=False): def declare_pyfunction(self, name, pos, allow_redefine=False):
return LocalScope.declare_pyfunction(self, name, pos, allow_redefine, visibility='private') return LocalScope.declare_pyfunction(self, name, pos, allow_redefine, visibility='private')
class StructOrUnionScope(Scope): class StructOrUnionScope(Scope):
# Namespace of a C struct or union. # Namespace of a C struct or union.
......
...@@ -272,7 +272,7 @@ try: ...@@ -272,7 +272,7 @@ try:
except NameError: # Py3 except NameError: # Py3
py_long = typedef(int, "long") py_long = typedef(int, "long")
py_float = typedef(float, "float") py_float = typedef(float, "float")
py_complex = typedef(complex, "complex") py_complex = typedef(complex, "double complex")
try: try:
......
# mode: error
cimport cython
def closure(cython.integral i):
def inner(cython.floating f):
pass
def closure2(cython.integral i):
return lambda cython.integral i: i
def closure3(cython.integral i):
def inner():
return lambda cython.floating f: f
_ERRORS = u"""
e_fused_closure.pyx:6:4: Cannot nest fused functions
e_fused_closure.pyx:10:11: Cannot nest fused functions
e_fused_closure.pyx:14:15: Cannot nest fused functions
"""
# mode: run
"""
Test Python def functions without extern types
"""
cy = __import__("cython")
cimport cython
cdef class Base(object):
def __repr__(self):
return type(self).__name__
cdef class ExtClassA(Base):
pass
cdef class ExtClassB(Base):
pass
cdef enum MyEnum:
entry0
entry1
entry2
entry3
entry4
ctypedef fused fused_t:
str
int
long
complex
ExtClassA
ExtClassB
MyEnum
f = 5.6
i = 9
def opt_func(fused_t obj, cython.floating myf = 1.2, cython.integral myi = 7):
"""
Test runtime dispatch, indexing of various kinds and optional arguments
>>> opt_func("spam", f, i)
str object double long
spam 5.60 9 5.60 9
>>> opt_func[str, float, int]("spam", f, i)
str object float int
spam 5.60 9 5.60 9
>>> opt_func["str, double, long"]("spam", f, i)
str object double long
spam 5.60 9 5.60 9
>>> opt_func[str, float, cy.int]("spam", f, i)
str object float int
spam 5.60 9 5.60 9
>>> opt_func(ExtClassA(), f, i)
ExtClassA double long
ExtClassA 5.60 9 5.60 9
>>> opt_func[ExtClassA, float, int](ExtClassA(), f, i)
ExtClassA float int
ExtClassA 5.60 9 5.60 9
>>> opt_func["ExtClassA, double, long"](ExtClassA(), f, i)
ExtClassA double long
ExtClassA 5.60 9 5.60 9
>>> opt_func(ExtClassB(), f, i)
ExtClassB double long
ExtClassB 5.60 9 5.60 9
>>> opt_func[ExtClassB, cy.double, cy.long](ExtClassB(), f, i)
ExtClassB double long
ExtClassB 5.60 9 5.60 9
>>> opt_func(10, f)
long double long
10 5.60 7 5.60 9
>>> opt_func[int, float, int](10, f)
int float int
10 5.60 7 5.60 9
>>> opt_func(10 + 2j, myf = 2.6)
double complex double long
(10+2j) 2.60 7 5.60 9
>>> opt_func[cy.py_complex, float, int](10 + 2j, myf = 2.6)
double complex float int
(10+2j) 2.60 7 5.60 9
>>> opt_func[cy.doublecomplex, cy.float, cy.int](10 + 2j, myf = 2.6)
double complex float int
(10+2j) 2.60 7 5.60 9
>>> opt_func(object(), f)
Traceback (most recent call last):
...
TypeError: Function call with ambiguous argument types
>>> opt_func[ExtClassA, cy.float, long](object(), f)
Traceback (most recent call last):
...
TypeError: Argument 'obj' has incorrect type (expected fused_def.ExtClassA, got object)
"""
print cython.typeof(obj), cython.typeof(myf), cython.typeof(myi)
print obj, "%.2f" % myf, myi, "%.2f" % f, i
def test_opt_func():
"""
>>> test_opt_func()
str object double long
ham 5.60 4 5.60 9
"""
cdef char *s = "ham"
opt_func(s, f, entry4)
def args_kwargs(fused_t obj, cython.floating myf = 1.2, *args, **kwargs):
"""
>>> args_kwargs("foo")
str object double
foo 1.20 5.60 () {}
>>> args_kwargs("eggs", f, 1, 2, [], d={})
str object double
eggs 5.60 5.60 (1, 2, []) {'d': {}}
>>> args_kwargs[str, float]("eggs", f, 1, 2, [], d={})
str object float
eggs 5.60 5.60 (1, 2, []) {'d': {}}
"""
print cython.typeof(obj), cython.typeof(myf)
print obj, "%.2f" % myf, "%.2f" % f, args, kwargs
...@@ -35,6 +35,21 @@ less_simple_t = cython.fused_type(int, float, string_t) ...@@ -35,6 +35,21 @@ less_simple_t = cython.fused_type(int, float, string_t)
struct_t = cython.fused_type(mystruct_t, myunion_t, MyExt) struct_t = cython.fused_type(mystruct_t, myunion_t, MyExt)
builtin_t = cython.fused_type(str, unicode, bytes) builtin_t = cython.fused_type(str, unicode, bytes)
ctypedef fused fusedbunch:
int
long
complex
string_t
ctypedef fused fused1:
short
string_t
cdef fused fused2:
float
double
string_t
cdef struct_t add_simple(struct_t obj, simple_t simple) cdef struct_t add_simple(struct_t obj, simple_t simple)
cdef less_simple_t add_to_simple(struct_t obj, less_simple_t simple) cdef less_simple_t add_to_simple(struct_t obj, less_simple_t simple)
cdef public_optional_args(struct_t obj, simple_t simple = *) cdef public_optional_args(struct_t obj, simple_t simple = *)
...@@ -79,7 +94,17 @@ cdef class TestFusedExtMethods(object): ...@@ -79,7 +94,17 @@ cdef class TestFusedExtMethods(object):
cpdef cpdef_method(self, cython.integral x, cython.floating y): cpdef cpdef_method(self, cython.integral x, cython.floating y):
return cython.typeof(x), cython.typeof(y) return cython.typeof(x), cython.typeof(y)
def def_method(self, fused1 x, fused2 y):
if (fused1 is string_t and fused2 is not string_t or
not fused1 is string_t and fused2 is string_t):
return x, y
else:
return <fused1> x + y
cpdef public_cpdef(cython.integral x, cython.floating y, object_t z): cpdef public_cpdef(cython.integral x, cython.floating y, object_t z):
if cython.integral is int:
pass
return cython.typeof(x), cython.typeof(y), cython.typeof(z) return cython.typeof(x), cython.typeof(y), cython.typeof(z)
...@@ -131,7 +156,9 @@ cdef double b = 7.0 ...@@ -131,7 +156,9 @@ cdef double b = 7.0
cdef double (*func)(TestFusedExtMethods, long, double) cdef double (*func)(TestFusedExtMethods, long, double)
func = obj.method func = obj.method
assert func(obj, a, b) == 15.0
result = func(obj, a, b)
assert result == 15.0, result
func = <double (*)(TestFusedExtMethods, long, double)> obj.method func = <double (*)(TestFusedExtMethods, long, double)> obj.method
assert func(obj, x, y) == 11.0 assert func(obj, x, y) == 11.0
...@@ -200,5 +227,11 @@ ae(myobj.cpdef_method[cy.int, cy.float](10, 10.0), (10, 10.0)) ...@@ -200,5 +227,11 @@ ae(myobj.cpdef_method[cy.int, cy.float](10, 10.0), (10, 10.0))
""" """
d = {'obj': obj, 'myobj': myobj, 'ae': ae} d = {'obj': obj, 'myobj': myobj, 'ae': ae}
# FIXME: uncomment after subclassing CyFunction
#exec s in d #exec s in d
# Test def methods
# ae(obj.def_method(12, 14.9), 26)
# ae(obj.def_method(13, "spam"), (13, "spam"))
# ae(obj.def_method[cy.short, cy.float](13, 16.3), 29)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment