Commit d2a99890 authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Better exception info reading for with statement

parent 21c0a027
from Cython.Compiler.Visitor import TreeVisitor from Cython.Compiler.Visitor import TreeVisitor, get_temp_name_handle_desc
from Cython.Compiler.Nodes import * from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import * from Cython.Compiler.ExprNodes import *
from Cython.Compiler.Symtab import TempName
""" """
Serializes a Cython code tree to Cython code. This is primarily useful for Serializes a Cython code tree to Cython code. This is primarily useful for
...@@ -62,8 +61,9 @@ class CodeWriter(TreeVisitor): ...@@ -62,8 +61,9 @@ class CodeWriter(TreeVisitor):
self.endline() self.endline()
def putname(self, name): def putname(self, name):
if isinstance(name, TempName): tmpdesc = get_temp_name_handle_desc(name)
name = self.tempnames.setdefault(name, u"$" + name.description) if tmpdesc is not None:
name = self.tempnames.setdefault(tmpdesc, u"$" +tmpdesc)
self.put(name) self.put(name)
def comma_seperated_list(self, items, output_rhs=False): def comma_seperated_list(self, items, output_rhs=False):
......
...@@ -1060,7 +1060,6 @@ class NameNode(AtomicExprNode): ...@@ -1060,7 +1060,6 @@ class NameNode(AtomicExprNode):
else: else:
code.annotate(pos, AnnotationItem('c_call', 'c function', size=len(self.name))) code.annotate(pos, AnnotationItem('c_call', 'c function', size=len(self.name)))
class BackquoteNode(ExprNode): class BackquoteNode(ExprNode):
# `expr` # `expr`
# #
...@@ -1212,6 +1211,9 @@ class ExcValueNode(AtomicExprNode): ...@@ -1212,6 +1211,9 @@ class ExcValueNode(AtomicExprNode):
def generate_result_code(self, code): def generate_result_code(self, code):
pass pass
def analyse_types(self, env):
pass
class TempNode(AtomicExprNode): class TempNode(AtomicExprNode):
# Node created during analyse_types phase # Node created during analyse_types phase
......
...@@ -3329,18 +3329,24 @@ class ExceptClauseNode(Node): ...@@ -3329,18 +3329,24 @@ class ExceptClauseNode(Node):
# pattern ExprNode # pattern ExprNode
# target ExprNode or None # target ExprNode or None
# body StatNode # body StatNode
# excinfo_target NameNode or None optional target for exception info
# excinfo_target NameNode or None used internally
# match_flag string result of exception match # match_flag string result of exception match
# exc_value ExcValueNode used internally # exc_value ExcValueNode used internally
# function_name string qualified name of enclosing function # function_name string qualified name of enclosing function
# exc_vars (string * 3) local exception variables # exc_vars (string * 3) local exception variables
child_attrs = ["pattern", "target", "body", "exc_value"] child_attrs = ["pattern", "target", "body", "exc_value", "excinfo_target"]
exc_value = None exc_value = None
excinfo_target = None
excinfo_assignment = None
def analyse_declarations(self, env): def analyse_declarations(self, env):
if self.target: if self.target:
self.target.analyse_target_declaration(env) self.target.analyse_target_declaration(env)
if self.excinfo_target is not None:
self.excinfo_target.analyse_target_declaration(env)
self.body.analyse_declarations(env) self.body.analyse_declarations(env)
def analyse_expressions(self, env): def analyse_expressions(self, env):
...@@ -3358,6 +3364,17 @@ class ExceptClauseNode(Node): ...@@ -3358,6 +3364,17 @@ class ExceptClauseNode(Node):
self.exc_value = ExprNodes.ExcValueNode(self.pos, env, self.exc_vars[1]) self.exc_value = ExprNodes.ExcValueNode(self.pos, env, self.exc_vars[1])
self.exc_value.allocate_temps(env) self.exc_value.allocate_temps(env)
self.target.analyse_target_expression(env, self.exc_value) self.target.analyse_target_expression(env, self.exc_value)
if self.excinfo_target is not None:
import ExprNodes
self.excinfo_tuple = ExprNodes.TupleNode(pos=self.pos, args=[
ExprNodes.ExcValueNode(pos=self.pos, env=env, var=self.exc_vars[0]),
ExprNodes.ExcValueNode(pos=self.pos, env=env, var=self.exc_vars[1]),
ExprNodes.ExcValueNode(pos=self.pos, env=env, var=self.exc_vars[2])
])
self.excinfo_tuple.analyse_expressions(env)
self.excinfo_tuple.allocate_temps(env)
self.excinfo_target.analyse_target_expression(env, self.excinfo_tuple)
self.body.analyse_expressions(env) self.body.analyse_expressions(env)
for var in self.exc_vars: for var in self.exc_vars:
env.release_temp(var) env.release_temp(var)
...@@ -3387,6 +3404,10 @@ class ExceptClauseNode(Node): ...@@ -3387,6 +3404,10 @@ class ExceptClauseNode(Node):
if self.target: if self.target:
self.exc_value.generate_evaluation_code(code) self.exc_value.generate_evaluation_code(code)
self.target.generate_assignment_code(self.exc_value, code) self.target.generate_assignment_code(self.exc_value, code)
if self.excinfo_target is not None:
self.excinfo_tuple.generate_evaluation_code(code)
self.excinfo_target.generate_assignment_code(self.excinfo_tuple, code)
old_exc_vars = code.exc_vars old_exc_vars = code.exc_vars
code.exc_vars = self.exc_vars code.exc_vars = self.exc_vars
self.body.generate_execution_code(code) self.body.generate_execution_code(code)
...@@ -4497,6 +4518,7 @@ bad: ...@@ -4497,6 +4518,7 @@ bad:
Py_XDECREF(*tb); Py_XDECREF(*tb);
return -1; return -1;
} }
"""] """]
#------------------------------------------------------------------------------------ #------------------------------------------------------------------------------------
from Cython.Compiler.Visitor import VisitorTransform from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle
from Cython.Compiler.Nodes import * from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import *
from Cython.Compiler.TreeFragment import TreeFragment from Cython.Compiler.TreeFragment import TreeFragment
...@@ -71,10 +72,13 @@ class PostParse(VisitorTransform): ...@@ -71,10 +72,13 @@ class PostParse(VisitorTransform):
def visit_CStructOrUnionDefNode(self, node): def visit_CStructOrUnionDefNode(self, node):
return self.visit_StatNode(node, True) return self.visit_StatNode(node, True)
class WithTransform(VisitorTransform): class WithTransform(VisitorTransform):
# EXCINFO is manually set to a variable that contains
# the exc_info() tuple that can be generated by the enclosing except
# statement.
template_without_target = TreeFragment(u""" template_without_target = TreeFragment(u"""
import sys as SYS
MGR = EXPR MGR = EXPR
EXIT = MGR.__exit__ EXIT = MGR.__exit__
MGR.__enter__() MGR.__enter__()
...@@ -84,15 +88,15 @@ class WithTransform(VisitorTransform): ...@@ -84,15 +88,15 @@ class WithTransform(VisitorTransform):
BODY BODY
except: except:
EXC = False EXC = False
if not EXIT(*SYS.exc_info()): if not EXIT(*EXCINFO):
raise raise
finally: finally:
if EXC: if EXC:
EXIT(None, None, None) EXIT(None, None, None)
""", u"WithTransformFragment") """, temps=[u'MGR', u'EXC', u"EXIT", u"SYS)"],
pipeline=[PostParse()])
template_with_target = TreeFragment(u""" template_with_target = TreeFragment(u"""
import sys as SYS
MGR = EXPR MGR = EXPR
EXIT = MGR.__exit__ EXIT = MGR.__exit__
VALUE = MGR.__enter__() VALUE = MGR.__enter__()
...@@ -103,47 +107,38 @@ class WithTransform(VisitorTransform): ...@@ -103,47 +107,38 @@ class WithTransform(VisitorTransform):
BODY BODY
except: except:
EXC = False EXC = False
if not EXIT(*SYS.exc_info()): if not EXIT(*EXCINFO):
raise raise
finally: finally:
if EXC: if EXC:
EXIT(None, None, None) EXIT(None, None, None)
""", u"WithTransformFragment") """, temps=[u'MGR', u'EXC', u"EXIT", u"VALUE", u"SYS"],
pipeline=[PostParse()])
def visit_Node(self, node): def visit_Node(self, node):
self.visitchildren(node) self.visitchildren(node)
return node return node
def visit_WithStatNode(self, node): def visit_WithStatNode(self, node):
excinfo_name = temp_name_handle('EXCINFO')
excinfo_namenode = NameNode(pos=node.pos, name=excinfo_name)
excinfo_target = NameNode(pos=node.pos, name=excinfo_name)
if node.target is not None: if node.target is not None:
result = self.template_with_target.substitute({ result = self.template_with_target.substitute({
u'EXPR' : node.manager, u'EXPR' : node.manager,
u'BODY' : node.body, u'BODY' : node.body,
u'TARGET' : node.target u'TARGET' : node.target,
}, temps=(u'MGR', u'EXC', u"EXIT", u"VALUE", u"SYS"), u'EXCINFO' : excinfo_namenode
pos = node.pos) }, pos = node.pos)
# Set except excinfo target to EXCINFO
result.stats[4].body.stats[0].except_clauses[0].excinfo_target = excinfo_target
else: else:
result = self.template_without_target.substitute({ result = self.template_without_target.substitute({
u'EXPR' : node.manager, u'EXPR' : node.manager,
u'BODY' : node.body, u'BODY' : node.body,
}, temps=(u'MGR', u'EXC', u"EXIT", u"SYS"), u'EXCINFO' : excinfo_namenode
pos = node.pos) }, pos = node.pos)
# Set except excinfo target to EXCINFO
return result.body.stats result.stats[4].body.stats[0].except_clauses[0].excinfo_target = excinfo_target
class CallExitFuncNode(Node):
def analyse_types(self, env):
pass
def analyse_expressions(self, env):
self.exc_vars = [
env.allocate_temp(PyrexTypes.py_object_type)
for x in xrange(3)
]
def generate_result(self, code): return result.stats
code.putln("""{
PyObject* type; PyObject* value; PyObject* tb;
__Pyx_GetException(
}""")
...@@ -16,29 +16,6 @@ from TypeSlots import \ ...@@ -16,29 +16,6 @@ from TypeSlots import \
import ControlFlow import ControlFlow
import __builtin__ import __builtin__
class TempName(object):
"""
Use instances of this class in order to provide a name for
anonymous, temporary functions. Each instance is considered
a seperate name, which are guaranteed not to clash with one
another or with names explicitly given as strings.
The argument to the constructor is simply a describing string
for debugging purposes and does not affect name clashes at all.
NOTE: Support for these TempNames are introduced on an as-needed
basis and will not "just work" everywhere. Places where they work:
- (none)
"""
def __init__(self, description):
self.description = description
# Spoon-feed operators for documentation purposes
def __hash__(self):
return id(self)
def __cmp__(self, other):
return cmp(id(self), id(other))
possible_identifier = re.compile(ur"(?![0-9])\w+$", re.U).match possible_identifier = re.compile(ur"(?![0-9])\w+$", re.U).match
nice_identifier = re.compile('^[a-zA-Z0-0_]+$').match nice_identifier = re.compile('^[a-zA-Z0-0_]+$').match
...@@ -1098,20 +1075,13 @@ class ModuleScope(Scope): ...@@ -1098,20 +1075,13 @@ class ModuleScope(Scope):
var_entry.is_readonly = 1 var_entry.is_readonly = 1
entry.as_variable = var_entry entry.as_variable = var_entry
tempctr = 0
class LocalScope(Scope): class LocalScope(Scope):
def __init__(self, name, outer_scope): def __init__(self, name, outer_scope):
Scope.__init__(self, name, outer_scope, outer_scope) Scope.__init__(self, name, outer_scope, outer_scope)
def mangle(self, prefix, name): def mangle(self, prefix, name):
if isinstance(name, TempName): return prefix + name
global tempctr
tempctr += 1
return u"%s%s%d" % (Naming.temp_prefix, name.description, tempctr)
else:
return prefix + name
def declare_arg(self, name, type, pos): def declare_arg(self, name, type, pos):
# Add an entry for an argument of a function. # Add an entry for an argument of a function.
......
...@@ -6,8 +6,8 @@ class TestPostParse(TransformTest): ...@@ -6,8 +6,8 @@ class TestPostParse(TransformTest):
def test_parserbehaviour_is_what_we_coded_for(self): def test_parserbehaviour_is_what_we_coded_for(self):
t = self.fragment(u"if x: y").root t = self.fragment(u"if x: y").root
self.assertLines(u""" self.assertLines(u"""
(root): ModuleNode (root): StatListNode
body: IfStatNode stats[0]: IfStatNode
if_clauses[0]: IfClauseNode if_clauses[0]: IfClauseNode
condition: NameNode condition: NameNode
body: ExprStatNode body: ExprStatNode
...@@ -17,14 +17,13 @@ class TestPostParse(TransformTest): ...@@ -17,14 +17,13 @@ class TestPostParse(TransformTest):
def test_wrap_singlestat(self): def test_wrap_singlestat(self):
t = self.run_pipeline([PostParse()], u"if x: y") t = self.run_pipeline([PostParse()], u"if x: y")
self.assertLines(u""" self.assertLines(u"""
(root): ModuleNode (root): StatListNode
body: StatListNode stats[0]: IfStatNode
stats[0]: IfStatNode if_clauses[0]: IfClauseNode
if_clauses[0]: IfClauseNode condition: NameNode
condition: NameNode body: StatListNode
body: StatListNode stats[0]: ExprStatNode
stats[0]: ExprStatNode expr: NameNode
expr: NameNode
""", self.treetypes(t)) """, self.treetypes(t))
def test_wrap_multistat(self): def test_wrap_multistat(self):
...@@ -34,16 +33,15 @@ class TestPostParse(TransformTest): ...@@ -34,16 +33,15 @@ class TestPostParse(TransformTest):
y y
""") """)
self.assertLines(u""" self.assertLines(u"""
(root): ModuleNode (root): StatListNode
body: StatListNode stats[0]: IfStatNode
stats[0]: IfStatNode if_clauses[0]: IfClauseNode
if_clauses[0]: IfClauseNode condition: NameNode
condition: NameNode body: StatListNode
body: StatListNode stats[0]: ExprStatNode
stats[0]: ExprStatNode expr: NameNode
expr: NameNode stats[1]: ExprStatNode
stats[1]: ExprStatNode expr: NameNode
expr: NameNode
""", self.treetypes(t)) """, self.treetypes(t))
def test_statinexpr(self): def test_statinexpr(self):
...@@ -51,15 +49,14 @@ class TestPostParse(TransformTest): ...@@ -51,15 +49,14 @@ class TestPostParse(TransformTest):
a, b = x, y a, b = x, y
""") """)
self.assertLines(u""" self.assertLines(u"""
(root): ModuleNode (root): StatListNode
body: StatListNode stats[0]: ParallelAssignmentNode
stats[0]: ParallelAssignmentNode stats[0]: SingleAssignmentNode
stats[0]: SingleAssignmentNode lhs: NameNode
lhs: NameNode rhs: NameNode
rhs: NameNode stats[1]: SingleAssignmentNode
stats[1]: SingleAssignmentNode lhs: NameNode
lhs: NameNode rhs: NameNode
rhs: NameNode
""", self.treetypes(t)) """, self.treetypes(t))
def test_wrap_offagain(self): def test_wrap_offagain(self):
...@@ -70,24 +67,23 @@ class TestPostParse(TransformTest): ...@@ -70,24 +67,23 @@ class TestPostParse(TransformTest):
x x
""") """)
self.assertLines(u""" self.assertLines(u"""
(root): ModuleNode (root): StatListNode
body: StatListNode stats[0]: ExprStatNode
stats[0]: ExprStatNode expr: NameNode
expr: NameNode stats[1]: ExprStatNode
stats[1]: ExprStatNode expr: NameNode
expr: NameNode stats[2]: IfStatNode
stats[2]: IfStatNode if_clauses[0]: IfClauseNode
if_clauses[0]: IfClauseNode condition: NameNode
condition: NameNode body: StatListNode
body: StatListNode stats[0]: ExprStatNode
stats[0]: ExprStatNode expr: NameNode
expr: NameNode
""", self.treetypes(t)) """, self.treetypes(t))
def test_pass_eliminated(self): def test_pass_eliminated(self):
t = self.run_pipeline([PostParse()], u"pass") t = self.run_pipeline([PostParse()], u"pass")
self.assert_(len(t.body.stats) == 0) self.assert_(len(t.stats) == 0)
class TestWithTransform(TransformTest): class TestWithTransform(TransformTest):
...@@ -99,7 +95,6 @@ class TestWithTransform(TransformTest): ...@@ -99,7 +95,6 @@ class TestWithTransform(TransformTest):
self.assertCode(u""" self.assertCode(u"""
$SYS = (import sys)
$MGR = x $MGR = x
$EXIT = $MGR.__exit__ $EXIT = $MGR.__exit__
$MGR.__enter__() $MGR.__enter__()
...@@ -109,7 +104,7 @@ class TestWithTransform(TransformTest): ...@@ -109,7 +104,7 @@ class TestWithTransform(TransformTest):
y = z ** 3 y = z ** 3
except: except:
$EXC = False $EXC = False
if (not $EXIT($SYS.exc_info())): if (not $EXIT($EXCINFO)):
raise raise
finally: finally:
if $EXC: if $EXC:
...@@ -124,7 +119,6 @@ class TestWithTransform(TransformTest): ...@@ -124,7 +119,6 @@ class TestWithTransform(TransformTest):
""") """)
self.assertCode(u""" self.assertCode(u"""
$SYS = (import sys)
$MGR = x $MGR = x
$EXIT = $MGR.__exit__ $EXIT = $MGR.__exit__
$VALUE = $MGR.__enter__() $VALUE = $MGR.__enter__()
...@@ -135,7 +129,7 @@ class TestWithTransform(TransformTest): ...@@ -135,7 +129,7 @@ class TestWithTransform(TransformTest):
y = z ** 3 y = z ** 3
except: except:
$EXC = False $EXC = False
if (not $EXIT($SYS.exc_info())): if (not $EXIT($EXCINFO)):
raise raise
finally: finally:
if $EXC: if $EXC:
......
...@@ -6,9 +6,8 @@ import re ...@@ -6,9 +6,8 @@ import re
from cStringIO import StringIO from cStringIO import StringIO
from Scanning import PyrexScanner, StringSourceDescriptor from Scanning import PyrexScanner, StringSourceDescriptor
from Symtab import BuiltinScope, ModuleScope from Symtab import BuiltinScope, ModuleScope
from Visitor import VisitorTransform from Visitor import VisitorTransform, temp_name_handle
from Nodes import Node from Nodes import Node, StatListNode
from Symtab import TempName
from ExprNodes import NameNode from ExprNodes import NameNode
import Parsing import Parsing
import Main import Main
...@@ -109,7 +108,7 @@ class TemplateTransform(VisitorTransform): ...@@ -109,7 +108,7 @@ class TemplateTransform(VisitorTransform):
self.substitutions = substitutions self.substitutions = substitutions
tempdict = {} tempdict = {}
for key in temps: for key in temps:
tempdict[key] = TempName(key) tempdict[key] = temp_name_handle(key)
self.temps = tempdict self.temps = tempdict
self.pos = pos self.pos = pos
return super(TemplateTransform, self).__call__(node) return super(TemplateTransform, self).__call__(node)
...@@ -164,7 +163,7 @@ def strip_common_indent(lines): ...@@ -164,7 +163,7 @@ def strip_common_indent(lines):
return lines return lines
class TreeFragment(object): class TreeFragment(object):
def __init__(self, code, name, pxds={}): def __init__(self, code, name="(tree fragment)", pxds={}, temps=[], pipeline=[]):
if isinstance(code, unicode): if isinstance(code, unicode):
def fmt(x): return u"\n".join(strip_common_indent(x.split(u"\n"))) def fmt(x): return u"\n".join(strip_common_indent(x.split(u"\n")))
...@@ -173,12 +172,20 @@ class TreeFragment(object): ...@@ -173,12 +172,20 @@ class TreeFragment(object):
for key, value in pxds.iteritems(): for key, value in pxds.iteritems():
fmt_pxds[key] = fmt(value) fmt_pxds[key] = fmt(value)
self.root = parse_from_strings(name, fmt_code, fmt_pxds) t = parse_from_strings(name, fmt_code, fmt_pxds)
mod = t
t = t.body # Make sure a StatListNode is at the top
if not isinstance(t, StatListNode):
t = StatListNode(pos=mod.pos, stats=[t])
for transform in pipeline:
t = transform(t)
self.root = t
elif isinstance(code, Node): elif isinstance(code, Node):
if pxds != {}: raise NotImplementedError() if pxds != {}: raise NotImplementedError()
self.root = code self.root = code
else: else:
raise ValueError("Unrecognized code format (accepts unicode and Node)") raise ValueError("Unrecognized code format (accepts unicode and Node)")
self.temps = temps
def copy(self): def copy(self):
return copy_code_tree(self.root) return copy_code_tree(self.root)
...@@ -186,7 +193,7 @@ class TreeFragment(object): ...@@ -186,7 +193,7 @@ class TreeFragment(object):
def substitute(self, nodes={}, temps=[], pos = None): def substitute(self, nodes={}, temps=[], pos = None):
return TemplateTransform()(self.root, return TemplateTransform()(self.root,
substitutions = nodes, substitutions = nodes,
temps = temps, pos = pos) temps = self.temps + temps, pos = pos)
......
...@@ -166,6 +166,19 @@ def replace_node(ptr, value): ...@@ -166,6 +166,19 @@ def replace_node(ptr, value):
else: else:
getattr(parent, attrname)[listidx] = value getattr(parent, attrname)[listidx] = value
tmpnamectr = 0
def temp_name_handle(description):
global tmpnamectr
tmpnamectr += 1
return u"__cyt_%d_%s" % (tmpnamectr, description)
def get_temp_name_handle_desc(handle):
if not handle.startswith(u"__cyt_"):
return None
else:
idx = handle.find(u"_", 6)
return handle[idx+1:]
class PrintTree(TreeVisitor): class PrintTree(TreeVisitor):
"""Prints a representation of the tree to standard output. """Prints a representation of the tree to standard output.
Subclass and override repr_of to provide more information Subclass and override repr_of to provide more information
......
...@@ -77,8 +77,8 @@ class TransformTest(CythonTest): ...@@ -77,8 +77,8 @@ class TransformTest(CythonTest):
To create a test case: To create a test case:
- Call run_pipeline. The pipeline should at least contain the transform you - Call run_pipeline. The pipeline should at least contain the transform you
are testing; pyx should be either a string (passed to the parser to are testing; pyx should be either a string (passed to the parser to
create a post-parse tree) or a ModuleNode representing input to pipeline. create a post-parse tree) or a node representing input to pipeline.
The result will be a transformed result (usually a ModuleNode). The result will be a transformed result.
- Check that the tree is correct. If wanted, assertCode can be used, which - Check that the tree is correct. If wanted, assertCode can be used, which
takes a code string as expected, and a ModuleNode in result_tree takes a code string as expected, and a ModuleNode in result_tree
...@@ -93,7 +93,6 @@ class TransformTest(CythonTest): ...@@ -93,7 +93,6 @@ class TransformTest(CythonTest):
def run_pipeline(self, pipeline, pyx, pxds={}): def run_pipeline(self, pipeline, pyx, pxds={}):
tree = self.fragment(pyx, pxds).root tree = self.fragment(pyx, pxds).root
assert isinstance(tree, ModuleNode)
# Run pipeline # Run pipeline
for T in pipeline: for T in pipeline:
tree = T(tree) tree = T(tree)
......
from __future__ import with_statement from __future__ import with_statement
__doc__ = u""" __doc__ = u"""
>>> no_as()
enter
hello
exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
>>> basic() >>> basic()
enter enter
value value
...@@ -8,12 +12,12 @@ exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'> ...@@ -8,12 +12,12 @@ exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
>>> with_exception(None) >>> with_exception(None)
enter enter
value value
exit <type 'type'> <type 'exceptions.Exception'> <type 'traceback'> exit <type 'type'> <class 'withstat.MyException'> <type 'traceback'>
outer except outer except
>>> with_exception(True) >>> with_exception(True)
enter enter
value value
exit <type 'type'> <type 'exceptions.Exception'> <type 'traceback'> exit <type 'type'> <class 'withstat.MyException'> <type 'traceback'>
>>> multitarget() >>> multitarget()
enter enter
1 2 3 4 5 1 2 3 4 5
...@@ -24,18 +28,25 @@ enter ...@@ -24,18 +28,25 @@ enter
exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'> exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
""" """
class MyException(Exception):
pass
class ContextManager: class ContextManager:
def __init__(self, value, exit_ret = None): def __init__(self, value, exit_ret = None):
self.value = value self.value = value
self.exit_ret = exit_ret self.exit_ret = exit_ret
def __exit__(self, a, b, c): def __exit__(self, a, b, tb):
print "exit", type(a), type(b), type(c) print "exit", type(a), type(b), type(tb)
return self.exit_ret return self.exit_ret
def __enter__(self): def __enter__(self):
print "enter" print "enter"
return self.value return self.value
def no_as():
with ContextManager("value"):
print "hello"
def basic(): def basic():
with ContextManager("value") as x: with ContextManager("value") as x:
...@@ -45,7 +56,7 @@ def with_exception(exit_ret): ...@@ -45,7 +56,7 @@ def with_exception(exit_ret):
try: try:
with ContextManager("value", exit_ret=exit_ret) as value: with ContextManager("value", exit_ret=exit_ret) as value:
print value print value
raise Exception() raise MyException()
except: except:
print "outer except" print "outer except"
...@@ -56,3 +67,4 @@ def multitarget(): ...@@ -56,3 +67,4 @@ def multitarget():
def tupletarget(): def tupletarget():
with ContextManager((1, 2, (3, (4, 5)))) as t: with ContextManager((1, 2, (3, (4, 5)))) as t:
print t print t
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment