Commit 7eca82ec authored by Craig Citro's avatar Craig Citro

Add changes to allow def statements anywhere they're legal.

parent ad9a2205
...@@ -516,6 +516,9 @@ class ExprNode(Node): ...@@ -516,6 +516,9 @@ class ExprNode(Node):
for sub in self.subexpr_nodes(): for sub in self.subexpr_nodes():
sub.free_temps(code) sub.free_temps(code)
def generate_function_definitions(self, env, code):
pass
# ---------------- Annotation --------------------- # ---------------- Annotation ---------------------
def annotate(self, code): def annotate(self, code):
......
...@@ -1607,6 +1607,8 @@ class PyArgDeclNode(Node): ...@@ -1607,6 +1607,8 @@ class PyArgDeclNode(Node):
# entry Symtab.Entry # entry Symtab.Entry
child_attrs = [] child_attrs = []
def generate_function_definitions(self, env, code):
self.entry.generate_function_definitions(env, code)
class DecoratorNode(Node): class DecoratorNode(Node):
# A decorator # A decorator
...@@ -2918,6 +2920,9 @@ class ExprStatNode(StatNode): ...@@ -2918,6 +2920,9 @@ class ExprStatNode(StatNode):
self.expr.generate_disposal_code(code) self.expr.generate_disposal_code(code)
self.expr.free_temps(code) self.expr.free_temps(code)
def generate_function_definitions(self, env, code):
self.expr.generate_function_definitions(env, code)
def annotate(self, code): def annotate(self, code):
self.expr.annotate(code) self.expr.annotate(code)
...@@ -3036,6 +3041,9 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -3036,6 +3041,9 @@ class SingleAssignmentNode(AssignmentNode):
def generate_assignment_code(self, code): def generate_assignment_code(self, code):
self.lhs.generate_assignment_code(self.rhs, code) self.lhs.generate_assignment_code(self.rhs, code)
def generate_function_definitions(self, env, code):
self.rhs.generate_function_definitions(env, code)
def annotate(self, code): def annotate(self, code):
self.lhs.annotate(code) self.lhs.annotate(code)
self.rhs.annotate(code) self.rhs.annotate(code)
...@@ -3088,6 +3096,9 @@ class CascadedAssignmentNode(AssignmentNode): ...@@ -3088,6 +3096,9 @@ class CascadedAssignmentNode(AssignmentNode):
self.rhs.generate_disposal_code(code) self.rhs.generate_disposal_code(code)
self.rhs.free_temps(code) self.rhs.free_temps(code)
def generate_function_definitions(self, env, code):
self.rhs.generate_function_definitions(env, code)
def annotate(self, code): def annotate(self, code):
for i in range(len(self.lhs_list)): for i in range(len(self.lhs_list)):
lhs = self.lhs_list[i].annotate(code) lhs = self.lhs_list[i].annotate(code)
...@@ -3131,13 +3142,17 @@ class ParallelAssignmentNode(AssignmentNode): ...@@ -3131,13 +3142,17 @@ class ParallelAssignmentNode(AssignmentNode):
for stat in self.stats: for stat in self.stats:
stat.generate_assignment_code(code) stat.generate_assignment_code(code)
def generate_function_definitions(self, env, code):
for stat in self.stats:
stat.generate_function_definitions(env, code)
def annotate(self, code): def annotate(self, code):
for stat in self.stats: for stat in self.stats:
stat.annotate(code) stat.annotate(code)
class InPlaceAssignmentNode(AssignmentNode): class InPlaceAssignmentNode(AssignmentNode):
# An in place arithmatic operand: # An in place arithmetic operand:
# #
# a += b # a += b
# a -= b # a -= b
...@@ -3327,6 +3342,10 @@ class PrintStatNode(StatNode): ...@@ -3327,6 +3342,10 @@ class PrintStatNode(StatNode):
self.arg_tuple.generate_disposal_code(code) self.arg_tuple.generate_disposal_code(code)
self.arg_tuple.free_temps(code) self.arg_tuple.free_temps(code)
def generate_function_definitions(self, env, code):
for item in self.arg_tuple:
item.generate_function_definitions(env, code)
def annotate(self, code): def annotate(self, code):
self.arg_tuple.annotate(code) self.arg_tuple.annotate(code)
...@@ -3511,6 +3530,10 @@ class ReturnStatNode(StatNode): ...@@ -3511,6 +3530,10 @@ class ReturnStatNode(StatNode):
code.put_decref_clear(cname, type) code.put_decref_clear(cname, type)
code.put_goto(code.return_label) code.put_goto(code.return_label)
def generate_function_definitions(self, env, code):
if self.value is not None:
self.value.generate_function_definitions(env, code)
def annotate(self, code): def annotate(self, code):
if self.value: if self.value:
self.value.annotate(code) self.value.annotate(code)
...@@ -3568,6 +3591,14 @@ class RaiseStatNode(StatNode): ...@@ -3568,6 +3591,14 @@ class RaiseStatNode(StatNode):
code.putln( code.putln(
code.error_goto(self.pos)) code.error_goto(self.pos))
def generate_function_definitions(self, env, code):
if self.exc_type is not None:
self.exc_type.generate_function_definitions(env, code)
if self.exc_value is not None:
self.exc_value.generate_function_definitions(env, code)
if self.exc_tb is not None:
self.exc_tb.generate_function_definitions(env, code)
def annotate(self, code): def annotate(self, code):
if self.exc_type: if self.exc_type:
self.exc_type.annotate(code) self.exc_type.annotate(code)
...@@ -3642,6 +3673,11 @@ class AssertStatNode(StatNode): ...@@ -3642,6 +3673,11 @@ class AssertStatNode(StatNode):
self.cond.free_temps(code) self.cond.free_temps(code)
code.putln("#endif") code.putln("#endif")
def generate_function_definitions(self, env, code):
self.cond.generate_function_definitions(env, code)
if self.value is not None:
self.value.generate_function_definitions(env, code)
def annotate(self, code): def annotate(self, code):
self.cond.annotate(code) self.cond.annotate(code)
if self.value: if self.value:
...@@ -3688,6 +3724,12 @@ class IfStatNode(StatNode): ...@@ -3688,6 +3724,12 @@ class IfStatNode(StatNode):
code.putln("}") code.putln("}")
code.put_label(end_label) code.put_label(end_label)
def generate_function_definitions(self, env, code):
for clause in self.if_clauses:
clause.generate_function_definitions(env, code)
if self.else_clause is not None:
self.else_clause.generate_function_definitions(env, code)
def annotate(self, code): def annotate(self, code):
for if_clause in self.if_clauses: for if_clause in self.if_clauses:
if_clause.annotate(code) if_clause.annotate(code)
...@@ -3729,6 +3771,10 @@ class IfClauseNode(Node): ...@@ -3729,6 +3771,10 @@ class IfClauseNode(Node):
code.put_goto(end_label) code.put_goto(end_label)
code.putln("}") code.putln("}")
def generate_function_definitions(self, env, code):
self.condition.generate_function_definitions(env, code)
self.body.generate_function_definitions(env, code)
def annotate(self, code): def annotate(self, code):
self.condition.annotate(code) self.condition.annotate(code)
self.body.annotate(code) self.body.annotate(code)
...@@ -3750,6 +3796,11 @@ class SwitchCaseNode(StatNode): ...@@ -3750,6 +3796,11 @@ class SwitchCaseNode(StatNode):
self.body.generate_execution_code(code) self.body.generate_execution_code(code)
code.putln("break;") code.putln("break;")
def generate_function_definitions(self, env, code):
for cond in self.conditions:
cond.generate_function_definitions(env, code)
self.body.generate_function_definitions(env, code)
def annotate(self, code): def annotate(self, code):
for cond in self.conditions: for cond in self.conditions:
cond.annotate(code) cond.annotate(code)
...@@ -3774,6 +3825,13 @@ class SwitchStatNode(StatNode): ...@@ -3774,6 +3825,13 @@ class SwitchStatNode(StatNode):
code.putln("break;") code.putln("break;")
code.putln("}") code.putln("}")
def generate_function_definitions(self, env, code):
self.test.generate_function_definitions(env, code)
for case in self.cases:
case.generate_function_definitions(env, code)
if self.else_clause is not None:
self.else_clause.generate_function_definitions(env, code)
def annotate(self, code): def annotate(self, code):
self.test.annotate(code) self.test.annotate(code)
for case in self.cases: for case in self.cases:
...@@ -3834,6 +3892,12 @@ class WhileStatNode(LoopNode, StatNode): ...@@ -3834,6 +3892,12 @@ class WhileStatNode(LoopNode, StatNode):
code.putln("}") code.putln("}")
code.put_label(break_label) code.put_label(break_label)
def generate_function_definitions(self, env, code):
self.condition.generate_function_definitions(env, code)
self.body.generate_function_definitions(env, code)
if self.else_clause is not None:
self.else_clause.generate_function_definitions(env, code)
def annotate(self, code): def annotate(self, code):
self.condition.annotate(code) self.condition.annotate(code)
self.body.annotate(code) self.body.annotate(code)
...@@ -3898,6 +3962,13 @@ class ForInStatNode(LoopNode, StatNode): ...@@ -3898,6 +3962,13 @@ class ForInStatNode(LoopNode, StatNode):
self.iterator.generate_disposal_code(code) self.iterator.generate_disposal_code(code)
self.iterator.free_temps(code) self.iterator.free_temps(code)
def generate_function_definitions(self, env, code):
self.target.generate_function_definitions(env, code)
self.iterator.generate_function_definitions(env, code)
self.body.generate_function_definitions(env, code)
if self.else_clause is not None:
self.else_clause.generate_function_definitions(env, code)
def annotate(self, code): def annotate(self, code):
self.target.annotate(code) self.target.annotate(code)
self.iterator.annotate(code) self.iterator.annotate(code)
...@@ -4088,12 +4159,22 @@ class ForFromStatNode(LoopNode, StatNode): ...@@ -4088,12 +4159,22 @@ class ForFromStatNode(LoopNode, StatNode):
'>' : ("-1", "--") '>' : ("-1", "--")
} }
def generate_function_definitions(self, env, code):
self.target.generate_function_definitions(env, code)
self.bound1.generate_function_definitions(env, code)
self.bound2.generate_function_definitions(env, code)
if self.step is not None:
self.step.generate_function_definitions(env, code)
self.body.generate_function_definitions(env, code)
if self.else_clause is not None:
self.else_clause.generate_function_definitions(env, code)
def annotate(self, code): def annotate(self, code):
self.target.annotate(code) self.target.annotate(code)
self.bound1.annotate(code) self.bound1.annotate(code)
self.bound2.annotate(code) self.bound2.annotate(code)
if self.step: if self.step:
self.bound2.annotate(code) self.step.annotate(code)
self.body.annotate(code) self.body.annotate(code)
if self.else_clause: if self.else_clause:
self.else_clause.annotate(code) self.else_clause.annotate(code)
...@@ -4248,6 +4329,13 @@ class TryExceptStatNode(StatNode): ...@@ -4248,6 +4329,13 @@ class TryExceptStatNode(StatNode):
code.continue_label = old_continue_label code.continue_label = old_continue_label
code.error_label = old_error_label code.error_label = old_error_label
def generate_function_definitions(self, env, code):
self.body.generate_function_definitions(env, code)
for except_clause in self.except_clauses:
except_clause.generate_function_definitions(env, code)
if self.else_clause is not None:
self.else_clause.generate_function_definitions(env, code)
def annotate(self, code): def annotate(self, code):
self.body.annotate(code) self.body.annotate(code)
for except_node in self.except_clauses: for except_node in self.except_clauses:
...@@ -4386,6 +4474,11 @@ class ExceptClauseNode(Node): ...@@ -4386,6 +4474,11 @@ class ExceptClauseNode(Node):
code.putln( code.putln(
"}") "}")
def generate_function_definitions(self, env, code):
if self.target is not None:
self.target.generate_function_definitions(env, code)
self.body.generate_function_definitions(env, code)
def annotate(self, code): def annotate(self, code):
if self.pattern: if self.pattern:
self.pattern.annotate(code) self.pattern.annotate(code)
...@@ -4533,6 +4626,10 @@ class TryFinallyStatNode(StatNode): ...@@ -4533,6 +4626,10 @@ class TryFinallyStatNode(StatNode):
code.putln( code.putln(
"}") "}")
def generate_function_definitions(self, env, code):
self.body.generate_function_definitions(env, code)
self.finally_clause.generate_function_definitions(env, code)
def put_error_catcher(self, code, error_label, i, catch_label, temps_to_clean_up): def put_error_catcher(self, code, error_label, i, catch_label, temps_to_clean_up):
code.globalstate.use_utility_code(restore_exception_utility_code) code.globalstate.use_utility_code(restore_exception_utility_code)
code.putln( code.putln(
......
...@@ -915,7 +915,6 @@ class CreateClosureClasses(CythonTransform): ...@@ -915,7 +915,6 @@ class CreateClosureClasses(CythonTransform):
return node return node
def create_class_from_scope(self, node, target_module_scope): def create_class_from_scope(self, node, target_module_scope):
as_name = "%s%s" % (Naming.closure_class_prefix, node.entry.cname) as_name = "%s%s" % (Naming.closure_class_prefix, node.entry.cname)
func_scope = node.local_scope func_scope = node.local_scope
...@@ -931,7 +930,8 @@ class CreateClosureClasses(CythonTransform): ...@@ -931,7 +930,8 @@ class CreateClosureClasses(CythonTransform):
type=node.entry.scope.scope_class.type, type=node.entry.scope.scope_class.type,
is_cdef=True) is_cdef=True)
for entry in func_scope.entries.values(): for entry in func_scope.entries.values():
# This is wasteful--we should do this later when we know which vars are actually being used inside... # This is wasteful--we should do this later when we know
# which vars are actually being used inside...
cname = entry.cname cname = entry.cname
class_scope.declare_var(pos=entry.pos, class_scope.declare_var(pos=entry.pos,
name=entry.name, name=entry.name,
......
...@@ -1658,8 +1658,8 @@ def p_statement(s, ctx, first_statement = 0): ...@@ -1658,8 +1658,8 @@ def p_statement(s, ctx, first_statement = 0):
if ctx.api: if ctx.api:
error(s.pos, "'api' not allowed with this statement") error(s.pos, "'api' not allowed with this statement")
elif s.sy == 'def': elif s.sy == 'def':
if ctx.level not in ('module', 'class', 'c_class', 'c_class_pxd', 'property', 'function'): #if ctx.level not in ('module', 'class', 'c_class', 'c_class_pxd', 'property', 'function'):
s.error('def statement not allowed here') # s.error('def statement not allowed here')
s.level = ctx.level s.level = ctx.level
return p_def_statement(s, decorators) return p_def_statement(s, decorators)
elif s.sy == 'class': elif s.sy == 'class':
......
...@@ -249,7 +249,7 @@ class VisitorTransform(TreeVisitor): ...@@ -249,7 +249,7 @@ class VisitorTransform(TreeVisitor):
class CythonTransform(VisitorTransform): class CythonTransform(VisitorTransform):
""" """
Certain common conventions and utilitues for Cython transforms. Certain common conventions and utilities for Cython transforms.
- Sets up the context of the pipeline in self.context - Sets up the context of the pipeline in self.context
- Tracks directives in effect in self.current_directives - Tracks directives in effect in self.current_directives
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment