Commit edf04816 authored by Mark Florisson's avatar Mark Florisson

Support fused def functions + lambda + better runtime dispatch

parent d80c2c4c
......@@ -63,7 +63,7 @@ class AutoTestDictTransform(ScopeTrackingTransform):
return node
def visit_FuncDefNode(self, node):
if not node.doc:
if not node.doc or node.fused_py_func:
return node
if not self.cdef_docstrings:
if isinstance(node, CFuncDefNode) and not node.py_func:
......
......@@ -8969,12 +8969,17 @@ static PyObject *%(binding_cfunc)s_call(PyObject *func, PyObject *args, PyObject
static PyObject *%(binding_cfunc)s_get__name__(%(binding_cfunc)s_object *func, void *closure);
static int %(binding_cfunc)s_set__name__(%(binding_cfunc)s_object *func, PyObject *value, void *closure);
static PyObject *%(binding_cfunc)s_get__doc__(%(binding_cfunc)s_object *func, void *closure);
static PyGetSetDef %(binding_cfunc)s_getsets[] = {
{(char *)"__name__",
(getter) %(binding_cfunc)s_get__name__,
(setter) %(binding_cfunc)s_set__name__,
NULL},
{(char *)"__doc__",
(getter) %(binding_cfunc)s_get__doc__,
NULL,
NULL},
{NULL},
};
......@@ -9139,6 +9144,12 @@ static int
return PyDict_SetItemString(func->__dict__, "__name__", value);
}
static PyObject *
%(binding_cfunc)s_get__doc__(%(binding_cfunc)s_object *func, void *closure)
{
return PyUnicode_FromString(func->func.m_ml->ml_doc);
}
static PyObject *
%(binding_cfunc)s_descr_get(PyObject *op, PyObject *obj, PyObject *type)
{
......@@ -9290,11 +9301,13 @@ static PyObject *
binaryfunc meth = (binaryfunc) binding_func->func.m_ml->ml_meth;
func = new_func = meth(binding_func->__signatures__, args);
*/
PyObject *tup = PyTuple_Pack(2, binding_func->__signatures__, args);
PyObject *tup = PyTuple_Pack(3, binding_func->__signatures__, args,
kw == NULL ? Py_None : kw);
if (!tup)
goto __pyx_err;
func = new_func = PyCFunction_Call(func, tup, NULL);
Py_DECREF(tup);
if (!new_func)
goto __pyx_err;
......
......@@ -936,6 +936,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
# Generate struct declaration for an extension type's vtable.
type = entry.type
scope = type.scope
self.specialize_fused_types(scope)
if type.vtabstruct_cname:
code.putln("")
code.putln(
......@@ -1942,7 +1945,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("/*--- Function import code ---*/")
for module in imported_modules:
self.specialize_fused_types(module, env)
self.specialize_fused_types(module)
self.generate_c_function_import_code_for_module(module, env, code)
code.putln("/*--- Execution code ---*/")
......@@ -2160,7 +2163,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
if entry.defined_in_pxd:
self.generate_type_import_code(env, entry.type, entry.pos, code)
def specialize_fused_types(self, pxd_env, impl_env):
def specialize_fused_types(self, pxd_env):
"""
If fused c(p)def functions are defined in an imported pxd, but not
used in this implementation file, we still have fused entries and
......
This diff is collapsed.
......@@ -1348,10 +1348,13 @@ if VALUE is not None:
count += 1
""")
fused_function = None
def __call__(self, root):
self.env_stack = [root.scope]
# needed to determine if a cdef var is declared after it's used.
self.seen_vars_stack = []
self.fused_error_funcs = cython.set()
return super(AnalyseDeclarationsTransform, self).__call__(root)
def visit_NameNode(self, node):
......@@ -1399,9 +1402,12 @@ if VALUE is not None:
analyse its children (which are in turn normal functions). If we're a
normal function, just analyse the body of the function.
"""
env = self.env_stack[-1]
self.seen_vars_stack.append(cython.set())
lenv = node.local_scope
node.declare_arguments(lenv)
for var, type_node in node.directive_locals.items():
if not lenv.lookup_here(var): # don't redeclare args
type = type_node.analyse_as_type(lenv)
......@@ -1411,10 +1417,27 @@ if VALUE is not None:
error(type_node.pos, "Not a type")
if node.has_fused_arguments:
node = Nodes.FusedCFuncDefNode(node, self.env_stack[-1])
if self.fused_function:
if self.fused_function not in self.fused_error_funcs:
error(node.pos, "Cannot nest fused functions")
self.fused_error_funcs.add(self.fused_function)
# env.declare_var(node.name, PyrexTypes.py_object_type, node.pos)
node = Nodes.SingleAssignmentNode(
node.pos,
lhs=ExprNodes.NameNode(node.pos, name=node.name),
rhs=ExprNodes.NoneNode(node.pos))
node.analyse_declarations(env)
return node
node = Nodes.FusedCFuncDefNode(node, env)
self.fused_function = node
self.visitchildren(node)
self.fused_function = None
if node.py_func:
node.stats.append(node.py_func)
node.stats.insert(0, node.py_func)
else:
node.body.analyse_declarations(lenv)
......@@ -2082,6 +2105,10 @@ class CreateClosureClasses(CythonTransform):
target_module_scope.check_c_class(func_scope.scope_class)
def visit_LambdaNode(self, node):
if not isinstance(node.def_node, Nodes.DefNode):
# fused function, an error has been previously issued
return node
was_in_lambda = self.in_lambda
self.in_lambda = True
self.create_class_from_scope(node.def_node, self.module_scope, node)
......@@ -2408,7 +2435,7 @@ class ReplaceFusedTypeChecks(VisitorTransform):
error(node.operand2.pos,
"Can only use 'in' or 'not in' on a fused type")
else:
types = PyrexTypes.get_specific_types(type2)
types = PyrexTypes.get_specialized_types(type2)
for specific_type in types:
if type1.same_as(specific_type):
......
......@@ -235,9 +235,16 @@ def public_decl(base_code, dll_linkage):
return base_code
def create_typedef_type(name, base_type, cname, is_external=0):
if base_type.is_complex:
is_fused = base_type.is_fused
if base_type.is_complex or is_fused:
if is_external:
raise ValueError("Complex external typedefs not supported")
if is_fused:
msg = "Fused"
else:
msg = "Complex"
raise ValueError("%s external typedefs not supported" % msg)
return base_type
else:
return CTypedefType(name, base_type, cname, is_external)
......@@ -2123,6 +2130,7 @@ class CFuncType(CType):
result = []
permutations = self.get_all_specific_permutations()
for cname, fused_to_specific in permutations:
new_func_type = self.entry.type.specialize(fused_to_specific)
......@@ -2150,20 +2158,25 @@ class CFuncType(CType):
def specialize_entry(self, entry, cname):
assert not self.is_fused
specialize_entry(entry, cname)
entry.name = get_fused_cname(cname, entry.name)
if entry.is_cmethod:
entry.cname = entry.name
if entry.is_inherited:
entry.cname = StringEncoding.EncodedString(
"%s.%s" % (Naming.obj_base_cname, entry.cname))
else:
entry.cname = get_fused_cname(cname, entry.cname)
def specialize_entry(entry, cname):
"""
Specialize an entry of a copied fused function or method
"""
entry.name = get_fused_cname(cname, entry.name)
if entry.func_cname:
entry.func_cname = get_fused_cname(cname, entry.func_cname)
if entry.is_cmethod:
entry.cname = entry.name
if entry.is_inherited:
entry.cname = StringEncoding.EncodedString(
"%s.%s" % (Naming.obj_base_cname, entry.cname))
else:
entry.cname = get_fused_cname(cname, entry.cname)
if entry.func_cname:
entry.func_cname = get_fused_cname(cname, entry.func_cname)
def get_fused_cname(fused_cname, orig_cname):
"""
......@@ -2177,7 +2190,7 @@ def get_all_specific_permutations(fused_types, id="", f2s=()):
fused_type = fused_types[0]
result = []
for newid, specific_type in enumerate(fused_type.types):
for newid, specific_type in enumerate(sorted(fused_type.types)):
# f2s = dict(f2s, **{ fused_type: specific_type })
f2s = dict(f2s)
f2s.update({ fused_type: specific_type })
......@@ -2195,17 +2208,21 @@ def get_all_specific_permutations(fused_types, id="", f2s=()):
return result
def get_specific_types(type):
def get_specialized_types(type):
"""
Return a list of specialized types sorted in reverse order in accordance
with their preference in runtime fused-type dispatch
"""
assert type.is_fused
if isinstance(type, FusedType):
return type.types
result = []
for cname, f2s in get_all_specific_permutations(type.get_fused_types()):
result.append(type.specialize(f2s))
result = type.types
else:
result = []
for cname, f2s in get_all_specific_permutations(type.get_fused_types()):
result.append(type.specialize(f2s))
return result
return sorted(result)
class CFuncTypeArg(BaseType):
......
......@@ -1509,6 +1509,7 @@ class ClosureScope(LocalScope):
def declare_pyfunction(self, name, pos, allow_redefine=False):
return LocalScope.declare_pyfunction(self, name, pos, allow_redefine, visibility='private')
class StructOrUnionScope(Scope):
# Namespace of a C struct or union.
......
......@@ -272,7 +272,7 @@ try:
except NameError: # Py3
py_long = typedef(int, "long")
py_float = typedef(float, "float")
py_complex = typedef(complex, "complex")
py_complex = typedef(complex, "double complex")
try:
......
# mode: error
cimport cython
def closure(cython.integral i):
def inner(cython.floating f):
pass
def closure2(cython.integral i):
return lambda cython.integral i: i
def closure3(cython.integral i):
def inner():
return lambda cython.floating f: f
_ERRORS = u"""
e_fused_closure.pyx:6:4: Cannot nest fused functions
e_fused_closure.pyx:10:11: Cannot nest fused functions
e_fused_closure.pyx:14:15: Cannot nest fused functions
"""
# mode: run
"""
Test Python def functions without extern types
"""
cy = __import__("cython")
cimport cython
cdef class Base(object):
def __repr__(self):
return type(self).__name__
cdef class ExtClassA(Base):
pass
cdef class ExtClassB(Base):
pass
cdef enum MyEnum:
entry0
entry1
entry2
entry3
entry4
ctypedef fused fused_t:
str
int
long
complex
ExtClassA
ExtClassB
MyEnum
f = 5.6
i = 9
def opt_func(fused_t obj, cython.floating myf = 1.2, cython.integral myi = 7):
"""
Test runtime dispatch, indexing of various kinds and optional arguments
>>> opt_func("spam", f, i)
str object double long
spam 5.60 9 5.60 9
>>> opt_func[str, float, int]("spam", f, i)
str object float int
spam 5.60 9 5.60 9
>>> opt_func["str, double, long"]("spam", f, i)
str object double long
spam 5.60 9 5.60 9
>>> opt_func[str, float, cy.int]("spam", f, i)
str object float int
spam 5.60 9 5.60 9
>>> opt_func(ExtClassA(), f, i)
ExtClassA double long
ExtClassA 5.60 9 5.60 9
>>> opt_func[ExtClassA, float, int](ExtClassA(), f, i)
ExtClassA float int
ExtClassA 5.60 9 5.60 9
>>> opt_func["ExtClassA, double, long"](ExtClassA(), f, i)
ExtClassA double long
ExtClassA 5.60 9 5.60 9
>>> opt_func(ExtClassB(), f, i)
ExtClassB double long
ExtClassB 5.60 9 5.60 9
>>> opt_func[ExtClassB, cy.double, cy.long](ExtClassB(), f, i)
ExtClassB double long
ExtClassB 5.60 9 5.60 9
>>> opt_func(10, f)
long double long
10 5.60 7 5.60 9
>>> opt_func[int, float, int](10, f)
int float int
10 5.60 7 5.60 9
>>> opt_func(10 + 2j, myf = 2.6)
double complex double long
(10+2j) 2.60 7 5.60 9
>>> opt_func[cy.py_complex, float, int](10 + 2j, myf = 2.6)
double complex float int
(10+2j) 2.60 7 5.60 9
>>> opt_func[cy.doublecomplex, cy.float, cy.int](10 + 2j, myf = 2.6)
double complex float int
(10+2j) 2.60 7 5.60 9
>>> opt_func(object(), f)
Traceback (most recent call last):
...
TypeError: Function call with ambiguous argument types
>>> opt_func[ExtClassA, cy.float, long](object(), f)
Traceback (most recent call last):
...
TypeError: Argument 'obj' has incorrect type (expected fused_def.ExtClassA, got object)
"""
print cython.typeof(obj), cython.typeof(myf), cython.typeof(myi)
print obj, "%.2f" % myf, myi, "%.2f" % f, i
def test_opt_func():
"""
>>> test_opt_func()
str object double long
ham 5.60 4 5.60 9
"""
cdef char *s = "ham"
opt_func(s, f, entry4)
def args_kwargs(fused_t obj, cython.floating myf = 1.2, *args, **kwargs):
"""
>>> args_kwargs("foo")
str object double
foo 1.20 5.60 () {}
>>> args_kwargs("eggs", f, 1, 2, [], d={})
str object double
eggs 5.60 5.60 (1, 2, []) {'d': {}}
>>> args_kwargs[str, float]("eggs", f, 1, 2, [], d={})
str object float
eggs 5.60 5.60 (1, 2, []) {'d': {}}
"""
print cython.typeof(obj), cython.typeof(myf)
print obj, "%.2f" % myf, "%.2f" % f, args, kwargs
......@@ -35,6 +35,21 @@ less_simple_t = cython.fused_type(int, float, string_t)
struct_t = cython.fused_type(mystruct_t, myunion_t, MyExt)
builtin_t = cython.fused_type(str, unicode, bytes)
ctypedef fused fusedbunch:
int
long
complex
string_t
ctypedef fused fused1:
short
string_t
cdef fused fused2:
float
double
string_t
cdef struct_t add_simple(struct_t obj, simple_t simple)
cdef less_simple_t add_to_simple(struct_t obj, less_simple_t simple)
cdef public_optional_args(struct_t obj, simple_t simple = *)
......@@ -79,7 +94,17 @@ cdef class TestFusedExtMethods(object):
cpdef cpdef_method(self, cython.integral x, cython.floating y):
return cython.typeof(x), cython.typeof(y)
def def_method(self, fused1 x, fused2 y):
if (fused1 is string_t and fused2 is not string_t or
not fused1 is string_t and fused2 is string_t):
return x, y
else:
return <fused1> x + y
cpdef public_cpdef(cython.integral x, cython.floating y, object_t z):
if cython.integral is int:
pass
return cython.typeof(x), cython.typeof(y), cython.typeof(z)
......@@ -131,7 +156,9 @@ cdef double b = 7.0
cdef double (*func)(TestFusedExtMethods, long, double)
func = obj.method
assert func(obj, a, b) == 15.0
result = func(obj, a, b)
assert result == 15.0, result
func = <double (*)(TestFusedExtMethods, long, double)> obj.method
assert func(obj, x, y) == 11.0
......@@ -200,5 +227,11 @@ ae(myobj.cpdef_method[cy.int, cy.float](10, 10.0), (10, 10.0))
"""
d = {'obj': obj, 'myobj': myobj, 'ae': ae}
# FIXME: uncomment after subclassing CyFunction
#exec s in d
# Test def methods
# ae(obj.def_method(12, 14.9), 26)
# ae(obj.def_method(13, "spam"), (13, "spam"))
# ae(obj.def_method[cy.short, cy.float](13, 16.3), 29)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment