Commit 2848d242 authored by Robert Bradshaw's avatar Robert Bradshaw

decorators for cdef functions, remove strange pxd locals syntax

parent 8469d73b
...@@ -771,11 +771,12 @@ class CVarDefNode(StatNode): ...@@ -771,11 +771,12 @@ class CVarDefNode(StatNode):
# in_pxd boolean # in_pxd boolean
# api boolean # api boolean
# need_properties [entry] # need_properties [entry]
# pxd_locals [CVarDefNode] (used for functions declared in pxd)
# directive_locals { string : NameNode } locals defined by cython.locals(...)
child_attrs = ["base_type", "declarators"] child_attrs = ["base_type", "declarators"]
need_properties = () need_properties = ()
pxd_locals = [] directive_locals = {}
def analyse_declarations(self, env, dest_scope = None): def analyse_declarations(self, env, dest_scope = None):
if not dest_scope: if not dest_scope:
...@@ -812,8 +813,10 @@ class CVarDefNode(StatNode): ...@@ -812,8 +813,10 @@ class CVarDefNode(StatNode):
cname = cname, visibility = self.visibility, in_pxd = self.in_pxd, cname = cname, visibility = self.visibility, in_pxd = self.in_pxd,
api = self.api) api = self.api)
if entry is not None: if entry is not None:
entry.pxd_locals = self.pxd_locals entry.directive_locals = self.directive_locals
else: else:
if self.directive_locals:
s.error("Decorators can only be followed by functions")
if self.in_pxd and self.visibility != 'extern': if self.in_pxd and self.visibility != 'extern':
error(self.pos, error(self.pos,
"Only 'extern' C variable declaration allowed in .pxd file") "Only 'extern' C variable declaration allowed in .pxd file")
...@@ -969,12 +972,11 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -969,12 +972,11 @@ class FuncDefNode(StatNode, BlockNode):
# #filename string C name of filename string const # #filename string C name of filename string const
# entry Symtab.Entry # entry Symtab.Entry
# needs_closure boolean Whether or not this function has inner functions/classes/yield # needs_closure boolean Whether or not this function has inner functions/classes/yield
# pxd_locals [CVarDefNode] locals defined in the pxd # directive_locals { string : NameNode } locals defined by cython.locals(...)
py_func = None py_func = None
assmt = None assmt = None
needs_closure = False needs_closure = False
pxd_locals = []
def analyse_default_values(self, env): def analyse_default_values(self, env):
genv = env.global_scope() genv = env.global_scope()
...@@ -1280,6 +1282,7 @@ class CFuncDefNode(FuncDefNode): ...@@ -1280,6 +1282,7 @@ class CFuncDefNode(FuncDefNode):
# declarator CDeclaratorNode # declarator CDeclaratorNode
# body StatListNode # body StatListNode
# api boolean # api boolean
# decorators [DecoratorNode] list of decorators
# #
# with_gil boolean Acquire GIL around body # with_gil boolean Acquire GIL around body
# type CFuncType # type CFuncType
...@@ -1290,16 +1293,16 @@ class CFuncDefNode(FuncDefNode): ...@@ -1290,16 +1293,16 @@ class CFuncDefNode(FuncDefNode):
child_attrs = ["base_type", "declarator", "body", "py_func"] child_attrs = ["base_type", "declarator", "body", "py_func"]
inline_in_pxd = False inline_in_pxd = False
decorators = None
directive_locals = {}
def unqualified_name(self): def unqualified_name(self):
return self.entry.name return self.entry.name
def analyse_declarations(self, env): def analyse_declarations(self, env):
if 'locals' in env.directives: if 'locals' in env.directives and env.directives['locals']:
directive_locals = env.directives['locals'] self.directive_locals = env.directives['locals']
else: directive_locals = self.directive_locals
directive_locals = {}
self.directive_locals = directive_locals
base_type = self.base_type.analyse(env) base_type = self.base_type.analyse(env)
# The 2 here is because we need both function and argument names. # The 2 here is because we need both function and argument names.
name_declarator, type = self.declarator.analyse(base_type, env, nonempty = 2 * (self.body is not None)) name_declarator, type = self.declarator.analyse(base_type, env, nonempty = 2 * (self.body is not None))
...@@ -1595,7 +1598,7 @@ class DefNode(FuncDefNode): ...@@ -1595,7 +1598,7 @@ class DefNode(FuncDefNode):
nogil = False, nogil = False,
with_gil = False, with_gil = False,
is_overridable = True) is_overridable = True)
cfunc = CVarDefNode(self.pos, type=cfunc_type, pxd_locals=[]) cfunc = CVarDefNode(self.pos, type=cfunc_type)
else: else:
cfunc_type = cfunc.type cfunc_type = cfunc.type
if len(self.args) != len(cfunc_type.args) or cfunc_type.has_varargs: if len(self.args) != len(cfunc_type.args) or cfunc_type.has_varargs:
...@@ -1631,7 +1634,7 @@ class DefNode(FuncDefNode): ...@@ -1631,7 +1634,7 @@ class DefNode(FuncDefNode):
nogil = cfunc_type.nogil, nogil = cfunc_type.nogil,
visibility = 'private', visibility = 'private',
api = False, api = False,
pxd_locals = cfunc.pxd_locals) directive_locals = cfunc.directive_locals)
def analyse_declarations(self, env): def analyse_declarations(self, env):
if 'locals' in env.directives: if 'locals' in env.directives:
......
...@@ -294,19 +294,8 @@ class PxdPostParse(CythonTransform, SkipDeclarations): ...@@ -294,19 +294,8 @@ class PxdPostParse(CythonTransform, SkipDeclarations):
else: else:
err = None # allow inline function err = None # allow inline function
else: else:
err = None
for stat in node.body.stats:
if not isinstance(stat, CVarDefNode):
err = self.ERR_INLINE_ONLY err = self.ERR_INLINE_ONLY
break
node = CVarDefNode(node.pos,
visibility = node.visibility,
base_type = node.base_type,
declarators = [node.declarator],
in_pxd = True,
api = node.api,
overridable = node.overridable,
pxd_locals = node.body.stats)
if err: if err:
self.context.nonfatal_error(PostParseError(node.pos, err)) self.context.nonfatal_error(PostParseError(node.pos, err))
return None return None
...@@ -462,7 +451,7 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -462,7 +451,7 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
return directive return directive
# Handle decorators # Handle decorators
def visit_DefNode(self, node): def visit_FuncDefNode(self, node):
options = [] options = []
if node.decorators: if node.decorators:
...@@ -474,6 +463,9 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -474,6 +463,9 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
options.append(option) options.append(option)
else: else:
realdecs.append(dec) realdecs.append(dec)
if realdecs and isinstance(node, CFuncDefNode):
raise PostParseError(realdecs[0].pos, "Cdef functions cannot take arbitrary decorators.")
else:
node.decorators = realdecs node.decorators = realdecs
if options: if options:
...@@ -487,6 +479,16 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -487,6 +479,16 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
else: else:
return self.visit_Node(node) return self.visit_Node(node)
def visit_CVarDefNode(self, node):
if node.decorators:
for dec in node.decorators:
option = self.try_to_parse_option(dec.decorator)
if option is not None and option[0] == u'locals':
node.directive_locals = option[1]
else:
raise PostParseError(dec.pos, "Cdef functions can only take cython.locals() decorator.")
return node
# Handle with statements # Handle with statements
def visit_WithStatNode(self, node): def visit_WithStatNode(self, node):
option = self.try_to_parse_option(node.manager) option = self.try_to_parse_option(node.manager)
...@@ -686,8 +688,6 @@ property NAME: ...@@ -686,8 +688,6 @@ property NAME:
lenv.declare_var(var, type, type_node.pos) lenv.declare_var(var, type, type_node.pos)
else: else:
error(type_node.pos, "Not a type") error(type_node.pos, "Not a type")
for stat in node.pxd_locals:
stat.analyse_declarations(lenv)
node.body.analyse_declarations(lenv) node.body.analyse_declarations(lenv)
self.env_stack.append(lenv) self.env_stack.append(lenv)
self.visitchildren(node) self.visitchildren(node)
......
...@@ -1479,6 +1479,7 @@ def p_IF_statement(s, ctx): ...@@ -1479,6 +1479,7 @@ def p_IF_statement(s, ctx):
def p_statement(s, ctx, first_statement = 0): def p_statement(s, ctx, first_statement = 0):
cdef_flag = ctx.cdef_flag cdef_flag = ctx.cdef_flag
decorators = []
if s.sy == 'ctypedef': if s.sy == 'ctypedef':
if ctx.level not in ('module', 'module_pxd'): if ctx.level not in ('module', 'module_pxd'):
s.error("ctypedef statement not allowed here") s.error("ctypedef statement not allowed here")
...@@ -1490,14 +1491,13 @@ def p_statement(s, ctx, first_statement = 0): ...@@ -1490,14 +1491,13 @@ def p_statement(s, ctx, first_statement = 0):
elif s.sy == 'IF': elif s.sy == 'IF':
return p_IF_statement(s, ctx) return p_IF_statement(s, ctx)
elif s.sy == 'DECORATOR': elif s.sy == 'DECORATOR':
if ctx.level not in ('module', 'class', 'c_class', 'property'): if ctx.level not in ('module', 'class', 'c_class', 'property', 'module_pxd', 'class_pxd'):
s.error('decorator not allowed here') s.error('decorator not allowed here')
s.level = ctx.level s.level = ctx.level
decorators = p_decorators(s) decorators = p_decorators(s)
if s.sy != 'def': if s.sy not in ('def', 'cdef', 'cpdef'):
s.error("Decorators can only be followed by functions ") s.error("Decorators can only be followed by functions ")
return p_def_statement(s, decorators)
else:
overridable = 0 overridable = 0
if s.sy == 'cdef': if s.sy == 'cdef':
cdef_flag = 1 cdef_flag = 1
...@@ -1510,7 +1510,12 @@ def p_statement(s, ctx, first_statement = 0): ...@@ -1510,7 +1510,12 @@ def p_statement(s, ctx, first_statement = 0):
if ctx.level not in ('module', 'module_pxd', 'function', 'c_class', 'c_class_pxd'): if ctx.level not in ('module', 'module_pxd', 'function', 'c_class', 'c_class_pxd'):
s.error('cdef statement not allowed here') s.error('cdef statement not allowed here')
s.level = ctx.level s.level = ctx.level
return p_cdef_statement(s, ctx(overridable = overridable)) node = p_cdef_statement(s, ctx(overridable = overridable))
if decorators is not None:
if not isinstance(node, (Nodes.CFuncDefNode, Nodes.CVarDefNode)):
s.error("Decorators can only be followed by functions ")
node.decorators = decorators
return node
else: else:
if ctx.api: if ctx.api:
error(s.pos, "'api' not allowed with this statement") error(s.pos, "'api' not allowed with this statement")
...@@ -1518,7 +1523,7 @@ def p_statement(s, ctx, first_statement = 0): ...@@ -1518,7 +1523,7 @@ def p_statement(s, ctx, first_statement = 0):
if ctx.level not in ('module', 'class', 'c_class', 'c_class_pxd', 'property'): if ctx.level not in ('module', 'class', 'c_class', 'c_class_pxd', 'property'):
s.error('def statement not allowed here') s.error('def statement not allowed here')
s.level = ctx.level s.level = ctx.level
return p_def_statement(s) return p_def_statement(s, decorators)
elif s.sy == 'class': elif s.sy == 'class':
if ctx.level != 'module': if ctx.level != 'module':
s.error("class definition not allowed here") s.error("class definition not allowed here")
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment