Commit edf04816 authored by Mark Florisson's avatar Mark Florisson

Support fused def functions + lambda + better runtime dispatch

parent d80c2c4c
...@@ -63,7 +63,7 @@ class AutoTestDictTransform(ScopeTrackingTransform): ...@@ -63,7 +63,7 @@ class AutoTestDictTransform(ScopeTrackingTransform):
return node return node
def visit_FuncDefNode(self, node): def visit_FuncDefNode(self, node):
if not node.doc: if not node.doc or node.fused_py_func:
return node return node
if not self.cdef_docstrings: if not self.cdef_docstrings:
if isinstance(node, CFuncDefNode) and not node.py_func: if isinstance(node, CFuncDefNode) and not node.py_func:
......
...@@ -8969,12 +8969,17 @@ static PyObject *%(binding_cfunc)s_call(PyObject *func, PyObject *args, PyObject ...@@ -8969,12 +8969,17 @@ static PyObject *%(binding_cfunc)s_call(PyObject *func, PyObject *args, PyObject
static PyObject *%(binding_cfunc)s_get__name__(%(binding_cfunc)s_object *func, void *closure); static PyObject *%(binding_cfunc)s_get__name__(%(binding_cfunc)s_object *func, void *closure);
static int %(binding_cfunc)s_set__name__(%(binding_cfunc)s_object *func, PyObject *value, void *closure); static int %(binding_cfunc)s_set__name__(%(binding_cfunc)s_object *func, PyObject *value, void *closure);
static PyObject *%(binding_cfunc)s_get__doc__(%(binding_cfunc)s_object *func, void *closure);
static PyGetSetDef %(binding_cfunc)s_getsets[] = { static PyGetSetDef %(binding_cfunc)s_getsets[] = {
{(char *)"__name__", {(char *)"__name__",
(getter) %(binding_cfunc)s_get__name__, (getter) %(binding_cfunc)s_get__name__,
(setter) %(binding_cfunc)s_set__name__, (setter) %(binding_cfunc)s_set__name__,
NULL}, NULL},
{(char *)"__doc__",
(getter) %(binding_cfunc)s_get__doc__,
NULL,
NULL},
{NULL}, {NULL},
}; };
...@@ -9139,6 +9144,12 @@ static int ...@@ -9139,6 +9144,12 @@ static int
return PyDict_SetItemString(func->__dict__, "__name__", value); return PyDict_SetItemString(func->__dict__, "__name__", value);
} }
static PyObject *
%(binding_cfunc)s_get__doc__(%(binding_cfunc)s_object *func, void *closure)
{
return PyUnicode_FromString(func->func.m_ml->ml_doc);
}
static PyObject * static PyObject *
%(binding_cfunc)s_descr_get(PyObject *op, PyObject *obj, PyObject *type) %(binding_cfunc)s_descr_get(PyObject *op, PyObject *obj, PyObject *type)
{ {
...@@ -9290,11 +9301,13 @@ static PyObject * ...@@ -9290,11 +9301,13 @@ static PyObject *
binaryfunc meth = (binaryfunc) binding_func->func.m_ml->ml_meth; binaryfunc meth = (binaryfunc) binding_func->func.m_ml->ml_meth;
func = new_func = meth(binding_func->__signatures__, args); func = new_func = meth(binding_func->__signatures__, args);
*/ */
PyObject *tup = PyTuple_Pack(2, binding_func->__signatures__, args); PyObject *tup = PyTuple_Pack(3, binding_func->__signatures__, args,
kw == NULL ? Py_None : kw);
if (!tup) if (!tup)
goto __pyx_err; goto __pyx_err;
func = new_func = PyCFunction_Call(func, tup, NULL); func = new_func = PyCFunction_Call(func, tup, NULL);
Py_DECREF(tup);
if (!new_func) if (!new_func)
goto __pyx_err; goto __pyx_err;
......
...@@ -936,6 +936,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -936,6 +936,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
# Generate struct declaration for an extension type's vtable. # Generate struct declaration for an extension type's vtable.
type = entry.type type = entry.type
scope = type.scope scope = type.scope
self.specialize_fused_types(scope)
if type.vtabstruct_cname: if type.vtabstruct_cname:
code.putln("") code.putln("")
code.putln( code.putln(
...@@ -1942,7 +1945,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1942,7 +1945,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("/*--- Function import code ---*/") code.putln("/*--- Function import code ---*/")
for module in imported_modules: for module in imported_modules:
self.specialize_fused_types(module, env) self.specialize_fused_types(module)
self.generate_c_function_import_code_for_module(module, env, code) self.generate_c_function_import_code_for_module(module, env, code)
code.putln("/*--- Execution code ---*/") code.putln("/*--- Execution code ---*/")
...@@ -2160,7 +2163,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -2160,7 +2163,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
if entry.defined_in_pxd: if entry.defined_in_pxd:
self.generate_type_import_code(env, entry.type, entry.pos, code) self.generate_type_import_code(env, entry.type, entry.pos, code)
def specialize_fused_types(self, pxd_env, impl_env): def specialize_fused_types(self, pxd_env):
""" """
If fused c(p)def functions are defined in an imported pxd, but not If fused c(p)def functions are defined in an imported pxd, but not
used in this implementation file, we still have fused entries and used in this implementation file, we still have fused entries and
......
This diff is collapsed.
...@@ -1348,10 +1348,13 @@ if VALUE is not None: ...@@ -1348,10 +1348,13 @@ if VALUE is not None:
count += 1 count += 1
""") """)
fused_function = None
def __call__(self, root): def __call__(self, root):
self.env_stack = [root.scope] self.env_stack = [root.scope]
# needed to determine if a cdef var is declared after it's used. # needed to determine if a cdef var is declared after it's used.
self.seen_vars_stack = [] self.seen_vars_stack = []
self.fused_error_funcs = cython.set()
return super(AnalyseDeclarationsTransform, self).__call__(root) return super(AnalyseDeclarationsTransform, self).__call__(root)
def visit_NameNode(self, node): def visit_NameNode(self, node):
...@@ -1399,9 +1402,12 @@ if VALUE is not None: ...@@ -1399,9 +1402,12 @@ if VALUE is not None:
analyse its children (which are in turn normal functions). If we're a analyse its children (which are in turn normal functions). If we're a
normal function, just analyse the body of the function. normal function, just analyse the body of the function.
""" """
env = self.env_stack[-1]
self.seen_vars_stack.append(cython.set()) self.seen_vars_stack.append(cython.set())
lenv = node.local_scope lenv = node.local_scope
node.declare_arguments(lenv) node.declare_arguments(lenv)
for var, type_node in node.directive_locals.items(): for var, type_node in node.directive_locals.items():
if not lenv.lookup_here(var): # don't redeclare args if not lenv.lookup_here(var): # don't redeclare args
type = type_node.analyse_as_type(lenv) type = type_node.analyse_as_type(lenv)
...@@ -1411,10 +1417,27 @@ if VALUE is not None: ...@@ -1411,10 +1417,27 @@ if VALUE is not None:
error(type_node.pos, "Not a type") error(type_node.pos, "Not a type")
if node.has_fused_arguments: if node.has_fused_arguments:
node = Nodes.FusedCFuncDefNode(node, self.env_stack[-1]) if self.fused_function:
if self.fused_function not in self.fused_error_funcs:
error(node.pos, "Cannot nest fused functions")
self.fused_error_funcs.add(self.fused_function)
# env.declare_var(node.name, PyrexTypes.py_object_type, node.pos)
node = Nodes.SingleAssignmentNode(
node.pos,
lhs=ExprNodes.NameNode(node.pos, name=node.name),
rhs=ExprNodes.NoneNode(node.pos))
node.analyse_declarations(env)
return node
node = Nodes.FusedCFuncDefNode(node, env)
self.fused_function = node
self.visitchildren(node) self.visitchildren(node)
self.fused_function = None
if node.py_func: if node.py_func:
node.stats.append(node.py_func) node.stats.insert(0, node.py_func)
else: else:
node.body.analyse_declarations(lenv) node.body.analyse_declarations(lenv)
...@@ -2082,6 +2105,10 @@ class CreateClosureClasses(CythonTransform): ...@@ -2082,6 +2105,10 @@ class CreateClosureClasses(CythonTransform):
target_module_scope.check_c_class(func_scope.scope_class) target_module_scope.check_c_class(func_scope.scope_class)
def visit_LambdaNode(self, node): def visit_LambdaNode(self, node):
if not isinstance(node.def_node, Nodes.DefNode):
# fused function, an error has been previously issued
return node
was_in_lambda = self.in_lambda was_in_lambda = self.in_lambda
self.in_lambda = True self.in_lambda = True
self.create_class_from_scope(node.def_node, self.module_scope, node) self.create_class_from_scope(node.def_node, self.module_scope, node)
...@@ -2408,7 +2435,7 @@ class ReplaceFusedTypeChecks(VisitorTransform): ...@@ -2408,7 +2435,7 @@ class ReplaceFusedTypeChecks(VisitorTransform):
error(node.operand2.pos, error(node.operand2.pos,
"Can only use 'in' or 'not in' on a fused type") "Can only use 'in' or 'not in' on a fused type")
else: else:
types = PyrexTypes.get_specific_types(type2) types = PyrexTypes.get_specialized_types(type2)
for specific_type in types: for specific_type in types:
if type1.same_as(specific_type): if type1.same_as(specific_type):
......
...@@ -235,9 +235,16 @@ def public_decl(base_code, dll_linkage): ...@@ -235,9 +235,16 @@ def public_decl(base_code, dll_linkage):
return base_code return base_code
def create_typedef_type(name, base_type, cname, is_external=0): def create_typedef_type(name, base_type, cname, is_external=0):
if base_type.is_complex: is_fused = base_type.is_fused
if base_type.is_complex or is_fused:
if is_external: if is_external:
raise ValueError("Complex external typedefs not supported") if is_fused:
msg = "Fused"
else:
msg = "Complex"
raise ValueError("%s external typedefs not supported" % msg)
return base_type return base_type
else: else:
return CTypedefType(name, base_type, cname, is_external) return CTypedefType(name, base_type, cname, is_external)
...@@ -2123,6 +2130,7 @@ class CFuncType(CType): ...@@ -2123,6 +2130,7 @@ class CFuncType(CType):
result = [] result = []
permutations = self.get_all_specific_permutations() permutations = self.get_all_specific_permutations()
for cname, fused_to_specific in permutations: for cname, fused_to_specific in permutations:
new_func_type = self.entry.type.specialize(fused_to_specific) new_func_type = self.entry.type.specialize(fused_to_specific)
...@@ -2150,20 +2158,25 @@ class CFuncType(CType): ...@@ -2150,20 +2158,25 @@ class CFuncType(CType):
def specialize_entry(self, entry, cname): def specialize_entry(self, entry, cname):
assert not self.is_fused assert not self.is_fused
specialize_entry(entry, cname)
entry.name = get_fused_cname(cname, entry.name)
if entry.is_cmethod: def specialize_entry(entry, cname):
entry.cname = entry.name """
if entry.is_inherited: Specialize an entry of a copied fused function or method
entry.cname = StringEncoding.EncodedString( """
"%s.%s" % (Naming.obj_base_cname, entry.cname)) entry.name = get_fused_cname(cname, entry.name)
else:
entry.cname = get_fused_cname(cname, entry.cname)
if entry.func_cname: if entry.is_cmethod:
entry.func_cname = get_fused_cname(cname, entry.func_cname) entry.cname = entry.name
if entry.is_inherited:
entry.cname = StringEncoding.EncodedString(
"%s.%s" % (Naming.obj_base_cname, entry.cname))
else:
entry.cname = get_fused_cname(cname, entry.cname)
if entry.func_cname:
entry.func_cname = get_fused_cname(cname, entry.func_cname)
def get_fused_cname(fused_cname, orig_cname): def get_fused_cname(fused_cname, orig_cname):
""" """
...@@ -2177,7 +2190,7 @@ def get_all_specific_permutations(fused_types, id="", f2s=()): ...@@ -2177,7 +2190,7 @@ def get_all_specific_permutations(fused_types, id="", f2s=()):
fused_type = fused_types[0] fused_type = fused_types[0]
result = [] result = []
for newid, specific_type in enumerate(fused_type.types): for newid, specific_type in enumerate(sorted(fused_type.types)):
# f2s = dict(f2s, **{ fused_type: specific_type }) # f2s = dict(f2s, **{ fused_type: specific_type })
f2s = dict(f2s) f2s = dict(f2s)
f2s.update({ fused_type: specific_type }) f2s.update({ fused_type: specific_type })
...@@ -2195,17 +2208,21 @@ def get_all_specific_permutations(fused_types, id="", f2s=()): ...@@ -2195,17 +2208,21 @@ def get_all_specific_permutations(fused_types, id="", f2s=()):
return result return result
def get_specific_types(type): def get_specialized_types(type):
"""
Return a list of specialized types sorted in reverse order in accordance
with their preference in runtime fused-type dispatch
"""
assert type.is_fused assert type.is_fused
if isinstance(type, FusedType): if isinstance(type, FusedType):
return type.types result = type.types
else:
result = [] result = []
for cname, f2s in get_all_specific_permutations(type.get_fused_types()): for cname, f2s in get_all_specific_permutations(type.get_fused_types()):
result.append(type.specialize(f2s)) result.append(type.specialize(f2s))
return result return sorted(result)
class CFuncTypeArg(BaseType): class CFuncTypeArg(BaseType):
......
...@@ -1509,6 +1509,7 @@ class ClosureScope(LocalScope): ...@@ -1509,6 +1509,7 @@ class ClosureScope(LocalScope):
def declare_pyfunction(self, name, pos, allow_redefine=False): def declare_pyfunction(self, name, pos, allow_redefine=False):
return LocalScope.declare_pyfunction(self, name, pos, allow_redefine, visibility='private') return LocalScope.declare_pyfunction(self, name, pos, allow_redefine, visibility='private')
class StructOrUnionScope(Scope): class StructOrUnionScope(Scope):
# Namespace of a C struct or union. # Namespace of a C struct or union.
......
...@@ -272,7 +272,7 @@ try: ...@@ -272,7 +272,7 @@ try:
except NameError: # Py3 except NameError: # Py3
py_long = typedef(int, "long") py_long = typedef(int, "long")
py_float = typedef(float, "float") py_float = typedef(float, "float")
py_complex = typedef(complex, "complex") py_complex = typedef(complex, "double complex")
try: try:
......
# mode: error
cimport cython
def closure(cython.integral i):
def inner(cython.floating f):
pass
def closure2(cython.integral i):
return lambda cython.integral i: i
def closure3(cython.integral i):
def inner():
return lambda cython.floating f: f
_ERRORS = u"""
e_fused_closure.pyx:6:4: Cannot nest fused functions
e_fused_closure.pyx:10:11: Cannot nest fused functions
e_fused_closure.pyx:14:15: Cannot nest fused functions
"""
# mode: run
"""
Test Python def functions without extern types
"""
cy = __import__("cython")
cimport cython
cdef class Base(object):
def __repr__(self):
return type(self).__name__
cdef class ExtClassA(Base):
pass
cdef class ExtClassB(Base):
pass
cdef enum MyEnum:
entry0
entry1
entry2
entry3
entry4
ctypedef fused fused_t:
str
int
long
complex
ExtClassA
ExtClassB
MyEnum
f = 5.6
i = 9
def opt_func(fused_t obj, cython.floating myf = 1.2, cython.integral myi = 7):
"""
Test runtime dispatch, indexing of various kinds and optional arguments
>>> opt_func("spam", f, i)
str object double long
spam 5.60 9 5.60 9
>>> opt_func[str, float, int]("spam", f, i)
str object float int
spam 5.60 9 5.60 9
>>> opt_func["str, double, long"]("spam", f, i)
str object double long
spam 5.60 9 5.60 9
>>> opt_func[str, float, cy.int]("spam", f, i)
str object float int
spam 5.60 9 5.60 9
>>> opt_func(ExtClassA(), f, i)
ExtClassA double long
ExtClassA 5.60 9 5.60 9
>>> opt_func[ExtClassA, float, int](ExtClassA(), f, i)
ExtClassA float int
ExtClassA 5.60 9 5.60 9
>>> opt_func["ExtClassA, double, long"](ExtClassA(), f, i)
ExtClassA double long
ExtClassA 5.60 9 5.60 9
>>> opt_func(ExtClassB(), f, i)
ExtClassB double long
ExtClassB 5.60 9 5.60 9
>>> opt_func[ExtClassB, cy.double, cy.long](ExtClassB(), f, i)
ExtClassB double long
ExtClassB 5.60 9 5.60 9
>>> opt_func(10, f)
long double long
10 5.60 7 5.60 9
>>> opt_func[int, float, int](10, f)
int float int
10 5.60 7 5.60 9
>>> opt_func(10 + 2j, myf = 2.6)
double complex double long
(10+2j) 2.60 7 5.60 9
>>> opt_func[cy.py_complex, float, int](10 + 2j, myf = 2.6)
double complex float int
(10+2j) 2.60 7 5.60 9
>>> opt_func[cy.doublecomplex, cy.float, cy.int](10 + 2j, myf = 2.6)
double complex float int
(10+2j) 2.60 7 5.60 9
>>> opt_func(object(), f)
Traceback (most recent call last):
...
TypeError: Function call with ambiguous argument types
>>> opt_func[ExtClassA, cy.float, long](object(), f)
Traceback (most recent call last):
...
TypeError: Argument 'obj' has incorrect type (expected fused_def.ExtClassA, got object)
"""
print cython.typeof(obj), cython.typeof(myf), cython.typeof(myi)
print obj, "%.2f" % myf, myi, "%.2f" % f, i
def test_opt_func():
"""
>>> test_opt_func()
str object double long
ham 5.60 4 5.60 9
"""
cdef char *s = "ham"
opt_func(s, f, entry4)
def args_kwargs(fused_t obj, cython.floating myf = 1.2, *args, **kwargs):
"""
>>> args_kwargs("foo")
str object double
foo 1.20 5.60 () {}
>>> args_kwargs("eggs", f, 1, 2, [], d={})
str object double
eggs 5.60 5.60 (1, 2, []) {'d': {}}
>>> args_kwargs[str, float]("eggs", f, 1, 2, [], d={})
str object float
eggs 5.60 5.60 (1, 2, []) {'d': {}}
"""
print cython.typeof(obj), cython.typeof(myf)
print obj, "%.2f" % myf, "%.2f" % f, args, kwargs
...@@ -35,6 +35,21 @@ less_simple_t = cython.fused_type(int, float, string_t) ...@@ -35,6 +35,21 @@ less_simple_t = cython.fused_type(int, float, string_t)
struct_t = cython.fused_type(mystruct_t, myunion_t, MyExt) struct_t = cython.fused_type(mystruct_t, myunion_t, MyExt)
builtin_t = cython.fused_type(str, unicode, bytes) builtin_t = cython.fused_type(str, unicode, bytes)
ctypedef fused fusedbunch:
int
long
complex
string_t
ctypedef fused fused1:
short
string_t
cdef fused fused2:
float
double
string_t
cdef struct_t add_simple(struct_t obj, simple_t simple) cdef struct_t add_simple(struct_t obj, simple_t simple)
cdef less_simple_t add_to_simple(struct_t obj, less_simple_t simple) cdef less_simple_t add_to_simple(struct_t obj, less_simple_t simple)
cdef public_optional_args(struct_t obj, simple_t simple = *) cdef public_optional_args(struct_t obj, simple_t simple = *)
...@@ -79,7 +94,17 @@ cdef class TestFusedExtMethods(object): ...@@ -79,7 +94,17 @@ cdef class TestFusedExtMethods(object):
cpdef cpdef_method(self, cython.integral x, cython.floating y): cpdef cpdef_method(self, cython.integral x, cython.floating y):
return cython.typeof(x), cython.typeof(y) return cython.typeof(x), cython.typeof(y)
def def_method(self, fused1 x, fused2 y):
if (fused1 is string_t and fused2 is not string_t or
not fused1 is string_t and fused2 is string_t):
return x, y
else:
return <fused1> x + y
cpdef public_cpdef(cython.integral x, cython.floating y, object_t z): cpdef public_cpdef(cython.integral x, cython.floating y, object_t z):
if cython.integral is int:
pass
return cython.typeof(x), cython.typeof(y), cython.typeof(z) return cython.typeof(x), cython.typeof(y), cython.typeof(z)
...@@ -131,7 +156,9 @@ cdef double b = 7.0 ...@@ -131,7 +156,9 @@ cdef double b = 7.0
cdef double (*func)(TestFusedExtMethods, long, double) cdef double (*func)(TestFusedExtMethods, long, double)
func = obj.method func = obj.method
assert func(obj, a, b) == 15.0
result = func(obj, a, b)
assert result == 15.0, result
func = <double (*)(TestFusedExtMethods, long, double)> obj.method func = <double (*)(TestFusedExtMethods, long, double)> obj.method
assert func(obj, x, y) == 11.0 assert func(obj, x, y) == 11.0
...@@ -200,5 +227,11 @@ ae(myobj.cpdef_method[cy.int, cy.float](10, 10.0), (10, 10.0)) ...@@ -200,5 +227,11 @@ ae(myobj.cpdef_method[cy.int, cy.float](10, 10.0), (10, 10.0))
""" """
d = {'obj': obj, 'myobj': myobj, 'ae': ae} d = {'obj': obj, 'myobj': myobj, 'ae': ae}
# FIXME: uncomment after subclassing CyFunction
#exec s in d #exec s in d
# Test def methods
# ae(obj.def_method(12, 14.9), 26)
# ae(obj.def_method(13, "spam"), (13, "spam"))
# ae(obj.def_method[cy.short, cy.float](13, 16.3), 29)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment