Commit 9acf70e2 authored by Stefan Behnel's avatar Stefan Behnel

deep-copy finally clauses of try-finally statements earlier to properly...

deep-copy finally clauses of try-finally statements earlier to properly support arbitrary statements in them (genexprs, lambdas, etc.)
parent 84ba7336
...@@ -8611,6 +8611,7 @@ class LambdaNode(InnerFunctionNode): ...@@ -8611,6 +8611,7 @@ class LambdaNode(InnerFunctionNode):
name = StringEncoding.EncodedString('<lambda>') name = StringEncoding.EncodedString('<lambda>')
def analyse_declarations(self, env): def analyse_declarations(self, env):
self.lambda_name = self.def_node.lambda_name = env.next_id('lambda')
self.def_node.no_assignment_synthesis = True self.def_node.no_assignment_synthesis = True
self.def_node.pymethdef_required = True self.def_node.pymethdef_required = True
self.def_node.analyse_declarations(env) self.def_node.analyse_declarations(env)
...@@ -8639,6 +8640,7 @@ class GeneratorExpressionNode(LambdaNode): ...@@ -8639,6 +8640,7 @@ class GeneratorExpressionNode(LambdaNode):
binding = False binding = False
def analyse_declarations(self, env): def analyse_declarations(self, env):
self.genexpr_name = env.next_id('genexpr')
super(GeneratorExpressionNode, self).analyse_declarations(env) super(GeneratorExpressionNode, self).analyse_declarations(env)
# No pymethdef required # No pymethdef required
self.def_node.pymethdef_required = False self.def_node.pymethdef_required = False
......
...@@ -89,6 +89,11 @@ cdef class Uninitialized: ...@@ -89,6 +89,11 @@ cdef class Uninitialized:
cdef class Unknown: cdef class Unknown:
pass pass
cdef class MessageCollection:
cdef set messages
@cython.locals(dirty=bint, block=ControlBlock, parent=ControlBlock, @cython.locals(dirty=bint, block=ControlBlock, parent=ControlBlock,
assmt=NameAssignment) assmt=NameAssignment)
cdef check_definitions(ControlFlow flow, dict compiler_directives) cdef check_definitions(ControlFlow flow, dict compiler_directives)
......
...@@ -3,10 +3,12 @@ from __future__ import absolute_import ...@@ -3,10 +3,12 @@ from __future__ import absolute_import
import cython import cython
cython.declare(PyrexTypes=object, ExprNodes=object, Nodes=object, cython.declare(PyrexTypes=object, ExprNodes=object, Nodes=object,
Builtin=object, InternalError=object, Builtin=object, InternalError=object,
error=object, warning=object, error=object, warning=object, deepcopy=object,
py_object_type=object, unspecified_type=object, py_object_type=object, unspecified_type=object,
object_expr=object, fake_rhs_expr=object, TypedExprNode=object) object_expr=object, fake_rhs_expr=object, TypedExprNode=object)
from copy import deepcopy
from . import Builtin from . import Builtin
from . import ExprNodes from . import ExprNodes
from . import Nodes from . import Nodes
...@@ -326,6 +328,17 @@ class NameAssignment(object): ...@@ -326,6 +328,17 @@ class NameAssignment(object):
self.is_deletion = False self.is_deletion = False
self.inferred_type = None self.inferred_type = None
def __deepcopy__(self, memo):
ass = NameAssignment.__new__(type(self))
ass.lhs = deepcopy(self.lhs, memo)
ass.rhs = deepcopy(self.rhs, memo)
ass.entry = self.entry
ass.refs = deepcopy(self.refs, memo)
ass.is_arg = self.is_arg
ass.is_deletion = self.is_deletion
ass.inferred_type = self.inferred_type
return ass
def __repr__(self): def __repr__(self):
return '%s(entry=%r)' % (self.__class__.__name__, self.entry) return '%s(entry=%r)' % (self.__class__.__name__, self.entry)
...@@ -400,6 +413,12 @@ class NameReference(object): ...@@ -400,6 +413,12 @@ class NameReference(object):
def __repr__(self): def __repr__(self):
return '%s(entry=%r)' % (self.__class__.__name__, self.entry) return '%s(entry=%r)' % (self.__class__.__name__, self.entry)
def __deepcopy__(self, memo):
ref = NameReference.__new__(type(self))
ref.node = deepcopy(self.node, memo)
ref.entry = self.entry
ref.pos = self.node.pos
class ControlFlowState(list): class ControlFlowState(list):
# Keeps track of Node's entry assignments # Keeps track of Node's entry assignments
...@@ -487,10 +506,11 @@ class GV(object): ...@@ -487,10 +506,11 @@ class GV(object):
if annotate_defs: if annotate_defs:
for stat in block.stats: for stat in block.stats:
if isinstance(stat, NameAssignment): if isinstance(stat, NameAssignment):
label += '\n %s [definition]' % stat.entry.name label += '\n %s [%s %s]' % (
stat.entry.name, 'deletion' if stat.is_deletion else 'definition', stat.pos[1])
elif isinstance(stat, NameReference): elif isinstance(stat, NameReference):
if stat.entry: if stat.entry:
label += '\n %s [reference]' % stat.entry.name label += '\n %s [reference %s]' % (stat.entry.name, stat.pos[1])
if not label: if not label:
label = 'empty' label = 'empty'
pid = ctx.nodeid(block) pid = ctx.nodeid(block)
...@@ -505,17 +525,16 @@ class GV(object): ...@@ -505,17 +525,16 @@ class GV(object):
class MessageCollection(object): class MessageCollection(object):
"""Collect error/warnings messages first then sort""" """Collect error/warnings messages first then sort"""
def __init__(self): def __init__(self):
self.messages = [] self.messages = set()
def error(self, pos, message): def error(self, pos, message):
self.messages.append((pos, True, message)) self.messages.add((pos, True, message))
def warning(self, pos, message): def warning(self, pos, message):
self.messages.append((pos, False, message)) self.messages.add((pos, False, message))
def report(self): def report(self):
self.messages.sort() for pos, is_error, message in sorted(self.messages):
for pos, is_error, message in self.messages:
if is_error: if is_error:
error(pos, message) error(pos, message)
else: else:
...@@ -589,8 +608,8 @@ def check_definitions(flow, compiler_directives): ...@@ -589,8 +608,8 @@ def check_definitions(flow, compiler_directives):
if not entry.from_closure and len(node.cf_state) == 1: if not entry.from_closure and len(node.cf_state) == 1:
node.cf_is_null = True node.cf_is_null = True
if (node.allow_null or entry.from_closure if (node.allow_null or entry.from_closure
or entry.is_pyclass_attr or entry.type.is_error): or entry.is_pyclass_attr or entry.type.is_error):
pass # Can be uninitialized here pass # Can be uninitialized here
elif node.cf_is_null: elif node.cf_is_null:
if entry.error_on_uninitialized or ( if entry.error_on_uninitialized or (
Options.error_on_uninitialized and ( Options.error_on_uninitialized and (
...@@ -844,8 +863,8 @@ class ControlFlowAnalysis(CythonTransform): ...@@ -844,8 +863,8 @@ class ControlFlowAnalysis(CythonTransform):
error(arg.pos, error(arg.pos,
"can not delete variable '%s' " "can not delete variable '%s' "
"referenced in nested scope" % entry.name) "referenced in nested scope" % entry.name)
# Mark reference if not node.ignore_nonexisting:
self._visit(arg) self._visit(arg) # mark reference
self.flow.mark_deletion(arg, entry) self.flow.mark_deletion(arg, entry)
else: else:
self._visit(arg) self._visit(arg)
...@@ -1177,7 +1196,7 @@ class ControlFlowAnalysis(CythonTransform): ...@@ -1177,7 +1196,7 @@ class ControlFlowAnalysis(CythonTransform):
# Exception entry point # Exception entry point
entry_point = self.flow.newblock() entry_point = self.flow.newblock()
self.flow.block = entry_point self.flow.block = entry_point
self._visit(node.finally_clause) self._visit(node.finally_except_clause)
if self.flow.block and self.flow.exceptions: if self.flow.block and self.flow.exceptions:
self.flow.block.add_child(self.flow.exceptions[-1].entry_point) self.flow.block.add_child(self.flow.exceptions[-1].entry_point)
......
...@@ -405,10 +405,10 @@ class StatListNode(Node): ...@@ -405,10 +405,10 @@ class StatListNode(Node):
child_attrs = ["stats"] child_attrs = ["stats"]
@staticmethod
def create_analysed(pos, env, *args, **kw): def create_analysed(pos, env, *args, **kw):
node = StatListNode(pos, *args, **kw) node = StatListNode(pos, *args, **kw)
return node # No node-specific analysis necesarry return node # No node-specific analysis needed
create_analysed = staticmethod(create_analysed)
def analyse_declarations(self, env): def analyse_declarations(self, env):
#print "StatListNode.analyse_declarations" ### #print "StatListNode.analyse_declarations" ###
...@@ -6896,6 +6896,7 @@ class TryFinallyStatNode(StatNode): ...@@ -6896,6 +6896,7 @@ class TryFinallyStatNode(StatNode):
# #
# body StatNode # body StatNode
# finally_clause StatNode # finally_clause StatNode
# finally_except_clause deep-copy of finally_clause for exception case
# #
# The plan is that we funnel all continue, break # The plan is that we funnel all continue, break
# return and error gotos into the beginning of the # return and error gotos into the beginning of the
...@@ -6906,13 +6907,14 @@ class TryFinallyStatNode(StatNode): ...@@ -6906,13 +6907,14 @@ class TryFinallyStatNode(StatNode):
# exception on entry to the finally block and restore # exception on entry to the finally block and restore
# it on exit. # it on exit.
child_attrs = ["body", "finally_clause"] child_attrs = ["body", "finally_clause", "finally_except_clause"]
preserve_exception = 1 preserve_exception = 1
# handle exception case, in addition to return/break/continue # handle exception case, in addition to return/break/continue
handle_error_case = True handle_error_case = True
func_return_type = None func_return_type = None
finally_except_clause = None
disallow_continue_in_try_finally = 0 disallow_continue_in_try_finally = 0
# There doesn't seem to be any point in disallowing # There doesn't seem to be any point in disallowing
...@@ -6921,18 +6923,21 @@ class TryFinallyStatNode(StatNode): ...@@ -6921,18 +6923,21 @@ class TryFinallyStatNode(StatNode):
is_try_finally_in_nogil = False is_try_finally_in_nogil = False
@staticmethod
def create_analysed(pos, env, body, finally_clause): def create_analysed(pos, env, body, finally_clause):
node = TryFinallyStatNode(pos, body=body, finally_clause=finally_clause) node = TryFinallyStatNode(pos, body=body, finally_clause=finally_clause)
return node return node
create_analysed = staticmethod(create_analysed)
def analyse_declarations(self, env): def analyse_declarations(self, env):
self.body.analyse_declarations(env) self.body.analyse_declarations(env)
self.finally_except_clause = copy.deepcopy(self.finally_clause)
self.finally_except_clause.analyse_declarations(env)
self.finally_clause.analyse_declarations(env) self.finally_clause.analyse_declarations(env)
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.body = self.body.analyse_expressions(env) self.body = self.body.analyse_expressions(env)
self.finally_clause = self.finally_clause.analyse_expressions(env) self.finally_clause = self.finally_clause.analyse_expressions(env)
self.finally_except_clause = self.finally_except_clause.analyse_expressions(env)
if env.return_type and not env.return_type.is_void: if env.return_type and not env.return_type.is_void:
self.func_return_type = env.return_type self.func_return_type = env.return_type
return self return self
...@@ -7012,7 +7017,7 @@ class TryFinallyStatNode(StatNode): ...@@ -7012,7 +7017,7 @@ class TryFinallyStatNode(StatNode):
code.putln('{') code.putln('{')
old_exc_vars = code.funcstate.exc_vars old_exc_vars = code.funcstate.exc_vars
code.funcstate.exc_vars = exc_vars[:3] code.funcstate.exc_vars = exc_vars[:3]
fresh_finally_clause().generate_execution_code(code) self.finally_except_clause.generate_execution_code(code)
code.funcstate.exc_vars = old_exc_vars code.funcstate.exc_vars = old_exc_vars
code.putln('}') code.putln('}')
......
...@@ -188,16 +188,8 @@ class PostParse(ScopeTrackingTransform): ...@@ -188,16 +188,8 @@ class PostParse(ScopeTrackingTransform):
'__cythonbufferdefaults__' : self.handle_bufferdefaults '__cythonbufferdefaults__' : self.handle_bufferdefaults
} }
def visit_ModuleNode(self, node):
self.lambda_counter = 1
self.genexpr_counter = 1
return super(PostParse, self).visit_ModuleNode(node)
def visit_LambdaNode(self, node): def visit_LambdaNode(self, node):
# unpack a lambda expression into the corresponding DefNode # unpack a lambda expression into the corresponding DefNode
lambda_id = self.lambda_counter
self.lambda_counter += 1
node.lambda_name = EncodedString(u'lambda%d' % lambda_id)
collector = YieldNodeCollector() collector = YieldNodeCollector()
collector.visitchildren(node.result_expr) collector.visitchildren(node.result_expr)
if collector.yields or collector.awaits or isinstance(node.result_expr, ExprNodes.YieldExprNode): if collector.yields or collector.awaits or isinstance(node.result_expr, ExprNodes.YieldExprNode):
...@@ -207,7 +199,7 @@ class PostParse(ScopeTrackingTransform): ...@@ -207,7 +199,7 @@ class PostParse(ScopeTrackingTransform):
body = Nodes.ReturnStatNode( body = Nodes.ReturnStatNode(
node.result_expr.pos, value=node.result_expr) node.result_expr.pos, value=node.result_expr)
node.def_node = Nodes.DefNode( node.def_node = Nodes.DefNode(
node.pos, name=node.name, lambda_name=node.lambda_name, node.pos, name=node.name,
args=node.args, star_arg=node.star_arg, args=node.args, star_arg=node.star_arg,
starstar_arg=node.starstar_arg, starstar_arg=node.starstar_arg,
body=body, doc=None) body=body, doc=None)
...@@ -216,10 +208,6 @@ class PostParse(ScopeTrackingTransform): ...@@ -216,10 +208,6 @@ class PostParse(ScopeTrackingTransform):
def visit_GeneratorExpressionNode(self, node): def visit_GeneratorExpressionNode(self, node):
# unpack a generator expression into the corresponding DefNode # unpack a generator expression into the corresponding DefNode
genexpr_id = self.genexpr_counter
self.genexpr_counter += 1
node.genexpr_name = EncodedString(u'genexpr%d' % genexpr_id)
node.def_node = Nodes.DefNode(node.pos, name=node.name, node.def_node = Nodes.DefNode(node.pos, name=node.name,
doc=None, doc=None,
args=[], star_arg=None, args=[], star_arg=None,
...@@ -1577,7 +1565,8 @@ if VALUE is not None: ...@@ -1577,7 +1565,8 @@ if VALUE is not None:
node.body = Nodes.NogilTryFinallyStatNode( node.body = Nodes.NogilTryFinallyStatNode(
node.body.pos, node.body.pos,
body=node.body, body=node.body,
finally_clause=Nodes.EnsureGILNode(node.body.pos)) finally_clause=Nodes.EnsureGILNode(node.body.pos),
finally_except_clause=Nodes.EnsureGILNode(node.body.pos))
def _handle_fused(self, node): def _handle_fused(self, node):
if node.is_generator and node.has_fused_arguments: if node.is_generator and node.has_fused_arguments:
......
...@@ -217,6 +217,12 @@ class Entry(object): ...@@ -217,6 +217,12 @@ class Entry(object):
def all_entries(self): def all_entries(self):
return [self] + self.inner_entries return [self] + self.inner_entries
def __deepcopy__(self, memo):
return self
def __copy__(self):
return self
class InnerEntry(Entry): class InnerEntry(Entry):
""" """
......
# mode: run # mode: run
# tag: tryfinally # tag: tryfinally
import string
import sys import sys
IS_PY3 = sys.version_info[0] >= 3 IS_PY3 = sys.version_info[0] >= 3
...@@ -480,3 +481,59 @@ def finally_yield(x): ...@@ -480,3 +481,59 @@ def finally_yield(x):
return return
finally: finally:
yield 1 yield 1
def complex_finally_clause(x, obj):
"""
>>> class T(object):
... def method(self, value):
... print(value)
>>> complex_finally_clause('finish', T())
module.py
module.py
module.py
99
>>> complex_finally_clause('tryreturn', T())
module.py
module.py
module.py
2
>>> complex_finally_clause('trybreak', T())
module.py
module.py
module.py
99
>>> complex_finally_clause('tryraise', T())
Traceback (most recent call last):
TypeError
"""
name = 'module'
l = []
cdef object lobj = l
for i in range(3):
l[:] = [1, 2, 3]
try:
if i == 0:
pass
elif i == 1:
continue
elif x == 'trybreak':
break
elif x == 'tryraise':
raise TypeError()
elif x == 'tryreturn':
return 2
else:
pass
finally:
obj.method(name + '.py')
from contextlib import contextmanager
with contextmanager(lambda: (yield 1))() as y:
assert y == 1
assert name[0] in string.ascii_letters
string.Template("-- huhu $name --").substitute(**{'name': '(%s)' % name})
del l[0], lobj[0]
assert all(i == 3 for i in l), l
return 99
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment