Commit 23b8ea6d authored by Mark Florisson's avatar Mark Florisson

Support decorators for fused functions

parent 5a0effd0
......@@ -40,6 +40,7 @@ class FusedCFuncDefNode(StatListNode):
resulting_fused_function = None
fused_func_assignment = None
defaults_tuple = None
decorators = None
def __init__(self, node, env):
super(FusedCFuncDefNode, self).__init__(node.pos)
......@@ -49,6 +50,7 @@ class FusedCFuncDefNode(StatListNode):
is_def = isinstance(self.node, DefNode)
if is_def:
# self.node.decorators = []
self.copy_def(env)
else:
self.copy_cdef(env)
......@@ -91,6 +93,8 @@ class FusedCFuncDefNode(StatListNode):
fused_to_specific)
copied_node.analyse_declarations(env)
# copied_node.is_staticmethod = self.node.is_staticmethod
# copied_node.is_classmethod = self.node.is_classmethod
self.create_new_local_scope(copied_node, env, fused_to_specific)
self.specialize_copied_def(copied_node, cname, self.node.entry,
fused_to_specific, fused_compound_types)
......
......@@ -2524,7 +2524,8 @@ class DefNode(FuncDefNode):
sig.is_staticmethod = True
sig.has_generic_args = True
if self.is_classmethod and self.has_fused_arguments and env.is_c_class_scope:
if ((self.is_classmethod or self.is_staticmethod) and
self.has_fused_arguments and env.is_c_class_scope):
del self.decorator_indirection.stats[:]
for i in range(min(nfixed, len(self.args))):
......
......@@ -1242,12 +1242,12 @@ class DecoratorTransform(ScopeTrackingTransform, SkipDeclarations):
func_node = self.visit_FuncDefNode(func_node)
if scope_type != 'cclass' or not func_node.decorators:
return func_node
return self._handle_decorators(
func_node, func_node.name)
return self.handle_decorators(func_node, func_node.decorators,
func_node.name)
def _handle_decorators(self, node, name):
def handle_decorators(self, node, decorators, name):
decorator_result = ExprNodes.NameNode(node.pos, name = name)
for decorator in node.decorators[::-1]:
for decorator in decorators[::-1]:
decorator_result = ExprNodes.SimpleCallNode(
decorator.pos,
function = decorator.decorator,
......@@ -1441,37 +1441,55 @@ if VALUE is not None:
node.body.stats += stats
return node
def visit_FuncDefNode(self, node):
def _handle_fused_def_decorators(self, old_decorators, env, node):
"""
Analyse a function and its body, as that hasn't happend yet. Also
analyse the directive_locals set by @cython.locals(). Then, if we are
a function with fused arguments, replace the function (after it has
declared itself in the symbol table!) with a FusedCFuncDefNode, and
analyse its children (which are in turn normal functions). If we're a
normal function, just analyse the body of the function.
Create function calls to the decorators and reassignments to
the function.
"""
env = self.env_stack[-1]
# Delete staticmethod and classmethod decorators, this is
# handled directly by the fused function object.
decorators = []
for decorator in old_decorators:
func = decorator.decorator
if (not func.is_name or
func.name not in ('staticmethod', 'classmethod') or
env.lookup_here(func.name)):
# not a static or classmethod
decorators.append(decorator)
if decorators:
transform = DecoratorTransform(self.context)
def_node = node.node
_, reassignments = transform.handle_decorators(
def_node, decorators, def_node.name)
reassignments.analyse_declarations(env)
node = [node, reassignments]
return node
def _handle_def(self, decorators, env, node):
"Handle def or cpdef fused functions"
# Create PyCFunction nodes for each specialization
node.stats.insert(0, node.py_func)
node.py_func = self.visit(node.py_func)
node.update_fused_defnode_entry(env)
pycfunc = ExprNodes.PyCFunctionNode.from_defnode(node.py_func,
True)
pycfunc = ExprNodes.ProxyNode(pycfunc.coerce_to_temp(env))
node.resulting_fused_function = pycfunc
# Create assignment node for our def function
node.fused_func_assignment = self._create_assignment(
node.py_func, ExprNodes.CloneNode(pycfunc), env)
self.seen_vars_stack.append(set())
lenv = node.local_scope
node.declare_arguments(lenv)
if decorators:
node = self._handle_fused_def_decorators(decorators, env, node)
for var, type_node in node.directive_locals.items():
if not lenv.lookup_here(var): # don't redeclare args
type = type_node.analyse_as_type(lenv)
if type:
lenv.declare_var(var, type, type_node.pos)
else:
error(type_node.pos, "Not a type")
return node
if node.is_generator and node.has_fused_arguments:
node.has_fused_arguments = False
error(node.pos, "Fused generators not supported")
node.gbody = Nodes.StatListNode(node.pos,
stats=[],
body=Nodes.PassStatNode(node.pos))
def _create_fused_function(self, env, node):
"Create a fused function for a DefNode with fused arguments"
from Cython.Compiler import FusedNode
if node.has_fused_arguments:
if self.fused_function or self.in_lambda:
if self.fused_function not in self.fused_error_funcs:
if self.in_lambda:
......@@ -1488,29 +1506,18 @@ if VALUE is not None:
return node
from Cython.Compiler import FusedNode
decorators = getattr(node, 'decorators', None)
node = FusedNode.FusedCFuncDefNode(node, env)
self.fused_function = node
self.visitchildren(node)
self.fused_function = None
if node.py_func:
# Create PyCFunction nodes for each specialization
node.stats.insert(0, node.py_func)
node.py_func = self.visit(node.py_func)
node.update_fused_defnode_entry(env)
pycfunc = ExprNodes.PyCFunctionNode.from_defnode(node.py_func,
True)
pycfunc = ExprNodes.ProxyNode(pycfunc.coerce_to_temp(env))
node.resulting_fused_function = pycfunc
node = self._handle_def(decorators, env, node)
# Create assignment node for our def function
node.fused_func_assignment = self._create_assignment(
node.py_func, ExprNodes.CloneNode(pycfunc), env)
else:
node.body.analyse_declarations(lenv)
return node
def _handle_nogil_cleanup(self, lenv, node):
"Handle cleanup for 'with gil' blocks in nogil functions."
if lenv.nogil and lenv.has_with_gil_block:
# Acquire the GIL for cleanup in 'nogil' functions, by wrapping
# the entire function body in try/finally.
......@@ -1518,9 +1525,47 @@ if VALUE is not None:
# Nodes.FuncDefNode.generate_function_definitions()
node.body = Nodes.NogilTryFinallyStatNode(
node.body.pos,
body = node.body,
finally_clause = Nodes.EnsureGILNode(node.body.pos),
)
body=node.body,
finally_clause=Nodes.EnsureGILNode(node.body.pos))
def _handle_fused(self, node):
if node.is_generator and node.has_fused_arguments:
node.has_fused_arguments = False
error(node.pos, "Fused generators not supported")
node.gbody = Nodes.StatListNode(node.pos,
stats=[],
body=Nodes.PassStatNode(node.pos))
return node.has_fused_arguments
def visit_FuncDefNode(self, node):
"""
Analyse a function and its body, as that hasn't happend yet. Also
analyse the directive_locals set by @cython.locals(). Then, if we are
a function with fused arguments, replace the function (after it has
declared itself in the symbol table!) with a FusedCFuncDefNode, and
analyse its children (which are in turn normal functions). If we're a
normal function, just analyse the body of the function.
"""
env = self.env_stack[-1]
self.seen_vars_stack.append(set())
lenv = node.local_scope
node.declare_arguments(lenv)
for var, type_node in node.directive_locals.items():
if not lenv.lookup_here(var): # don't redeclare args
type = type_node.analyse_as_type(lenv)
if type:
lenv.declare_var(var, type, type_node.pos)
else:
error(type_node.pos, "Not a type")
if self._handle_fused(node):
node = self._create_fused_function(env, node)
else:
node.body.analyse_declarations(lenv)
self._handle_nogil_cleanup(lenv, node)
self.env_stack.append(lenv)
self.visitchildren(node)
......
......@@ -304,3 +304,20 @@ def test_code_object(cython.floating dummy = 2.0):
>>> getcode(test_code_object) is getcode(test_code_object[float])
True
"""
def create_dec(value):
def dec(f):
if not hasattr(f, 'order'):
f.order = []
f.order.append(value)
return f
return dec
@create_dec(1)
@create_dec(2)
@create_dec(3)
def test_decorators(cython.floating arg):
"""
>>> test_decorators.order
[3, 2, 1]
"""
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment