Commit a67aaf75 authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Introduce TempsBlockNode utility, improve TreeFragment-generated temps

parent a5e1aea2
from Cython.Compiler.Visitor import TreeVisitor, get_temp_name_handle_desc from Cython.Compiler.Visitor import TreeVisitor
from Cython.Compiler.Nodes import * from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import * from Cython.Compiler.ExprNodes import *
...@@ -37,6 +37,7 @@ class CodeWriter(TreeVisitor): ...@@ -37,6 +37,7 @@ class CodeWriter(TreeVisitor):
self.result = result self.result = result
self.numindents = 0 self.numindents = 0
self.tempnames = {} self.tempnames = {}
self.tempblockindex = 0
def write(self, tree): def write(self, tree):
self.visit(tree) self.visit(tree)
...@@ -60,12 +61,6 @@ class CodeWriter(TreeVisitor): ...@@ -60,12 +61,6 @@ class CodeWriter(TreeVisitor):
self.startline(s) self.startline(s)
self.endline() self.endline()
def putname(self, name):
tmpdesc = get_temp_name_handle_desc(name)
if tmpdesc is not None:
name = self.tempnames.setdefault(tmpdesc, u"$" +tmpdesc)
self.put(name)
def comma_seperated_list(self, items, output_rhs=False): def comma_seperated_list(self, items, output_rhs=False):
if len(items) > 0: if len(items) > 0:
for item in items[:-1]: for item in items[:-1]:
...@@ -132,7 +127,7 @@ class CodeWriter(TreeVisitor): ...@@ -132,7 +127,7 @@ class CodeWriter(TreeVisitor):
self.endline() self.endline()
def visit_NameNode(self, node): def visit_NameNode(self, node):
self.putname(node.name) self.put(node.name)
def visit_IntNode(self, node): def visit_IntNode(self, node):
self.put(node.value) self.put(node.value)
...@@ -312,3 +307,18 @@ class CodeWriter(TreeVisitor): ...@@ -312,3 +307,18 @@ class CodeWriter(TreeVisitor):
self.visit(node.operand) self.visit(node.operand)
self.put(u")") self.put(u")")
def visit_TempsBlockNode(self, node):
"""
Temporaries are output like $1_1', where the first number is
an index of the TempsBlockNode and the second number is an index
of the temporary which that block allocates.
"""
idx = 0
for handle in node.handles:
self.tempnames[handle] = "$%d_%d" % (self.tempblockindex, idx)
idx += 1
self.tempblockindex += 1
self.visit(node.body)
def visit_TempRefNode(self, node):
self.put(self.tempnames[node.handle])
from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform from Cython.Compiler.Visitor import VisitorTransform, CythonTransform
from Cython.Compiler.ModuleNode import ModuleNode from Cython.Compiler.ModuleNode import ModuleNode
from Cython.Compiler.Nodes import * from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import * from Cython.Compiler.ExprNodes import *
from Cython.Compiler.TreeFragment import TreeFragment
from Cython.Compiler.StringEncoding import EncodedString from Cython.Compiler.StringEncoding import EncodedString
from Cython.Compiler.Errors import CompileError from Cython.Compiler.Errors import CompileError
import Interpreter import Interpreter
......
from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform from Cython.Compiler.Visitor import VisitorTransform, CythonTransform
from Cython.Compiler.ModuleNode import ModuleNode from Cython.Compiler.ModuleNode import ModuleNode
from Cython.Compiler.Nodes import * from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import * from Cython.Compiler.ExprNodes import *
......
...@@ -4188,6 +4188,7 @@ class FromImportStatNode(StatNode): ...@@ -4188,6 +4188,7 @@ class FromImportStatNode(StatNode):
self.module.generate_disposal_code(code) self.module.generate_disposal_code(code)
#------------------------------------------------------------------------------------ #------------------------------------------------------------------------------------
# #
# Runtime support code # Runtime support code
......
from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform from Cython.Compiler.Visitor import VisitorTransform, CythonTransform
from Cython.Compiler.ModuleNode import ModuleNode from Cython.Compiler.ModuleNode import ModuleNode
from Cython.Compiler.Nodes import * from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import * from Cython.Compiler.ExprNodes import *
from Cython.Compiler.UtilNodes import *
from Cython.Compiler.TreeFragment import TreeFragment from Cython.Compiler.TreeFragment import TreeFragment
from Cython.Compiler.StringEncoding import EncodedString from Cython.Compiler.StringEncoding import EncodedString
from Cython.Compiler.Errors import CompileError from Cython.Compiler.Errors import CompileError
...@@ -409,7 +410,7 @@ class WithTransform(CythonTransform): ...@@ -409,7 +410,7 @@ class WithTransform(CythonTransform):
finally: finally:
if EXC: if EXC:
EXIT(None, None, None) EXIT(None, None, None)
""", temps=[u'MGR', u'EXC', u"EXIT", u"SYS)"], """, temps=[u'MGR', u'EXC', u"EXIT"],
pipeline=[NormalizeTree(None)]) pipeline=[NormalizeTree(None)])
template_with_target = TreeFragment(u""" template_with_target = TreeFragment(u"""
...@@ -428,32 +429,33 @@ class WithTransform(CythonTransform): ...@@ -428,32 +429,33 @@ class WithTransform(CythonTransform):
finally: finally:
if EXC: if EXC:
EXIT(None, None, None) EXIT(None, None, None)
""", temps=[u'MGR', u'EXC', u"EXIT", u"VALUE", u"SYS"], """, temps=[u'MGR', u'EXC', u"EXIT", u"VALUE"],
pipeline=[NormalizeTree(None)]) pipeline=[NormalizeTree(None)])
def visit_WithStatNode(self, node): def visit_WithStatNode(self, node):
excinfo_name = temp_name_handle('EXCINFO') excinfo_tempblock = TempsBlockNode(node.pos, [PyrexTypes.py_object_type], None)
excinfo_namenode = NameNode(pos=node.pos, name=excinfo_name)
excinfo_target = NameNode(pos=node.pos, name=excinfo_name)
if node.target is not None: if node.target is not None:
result = self.template_with_target.substitute({ result = self.template_with_target.substitute({
u'EXPR' : node.manager, u'EXPR' : node.manager,
u'BODY' : node.body, u'BODY' : node.body,
u'TARGET' : node.target, u'TARGET' : node.target,
u'EXCINFO' : excinfo_namenode u'EXCINFO' : excinfo_tempblock.get_ref_node(0, node.pos)
}, pos=node.pos) }, pos=node.pos)
# Set except excinfo target to EXCINFO # Set except excinfo target to EXCINFO
result.stats[4].body.stats[0].except_clauses[0].excinfo_target = excinfo_target result.body.stats[4].body.stats[0].except_clauses[0].excinfo_target = (
excinfo_tempblock.get_ref_node(0, node.pos))
else: else:
result = self.template_without_target.substitute({ result = self.template_without_target.substitute({
u'EXPR' : node.manager, u'EXPR' : node.manager,
u'BODY' : node.body, u'BODY' : node.body,
u'EXCINFO' : excinfo_namenode u'EXCINFO' : excinfo_tempblock.get_ref_node(0, node.pos)
}, pos=node.pos) }, pos=node.pos)
# Set except excinfo target to EXCINFO # Set except excinfo target to EXCINFO
result.stats[4].body.stats[0].except_clauses[0].excinfo_target = excinfo_target result.body.stats[4].body.stats[0].except_clauses[0].excinfo_target = (
excinfo_tempblock.get_ref_node(0, node.pos))
return result.stats excinfo_tempblock.body = result
return excinfo_tempblock
class DecoratorTransform(CythonTransform): class DecoratorTransform(CythonTransform):
......
...@@ -95,20 +95,20 @@ class TestWithTransform(TransformTest): ...@@ -95,20 +95,20 @@ class TestWithTransform(TransformTest):
self.assertCode(u""" self.assertCode(u"""
$MGR = x $1_0 = x
$EXIT = $MGR.__exit__ $1_2 = $1_0.__exit__
$MGR.__enter__() $1_0.__enter__()
$EXC = True $1_1 = True
try: try:
try: try:
y = z ** 3 y = z ** 3
except: except:
$EXC = False $1_1 = False
if (not $EXIT($EXCINFO)): if (not $1_2($0_0)):
raise raise
finally: finally:
if $EXC: if $1_1:
$EXIT(None, None, None) $1_2(None, None, None)
""", t) """, t)
...@@ -119,21 +119,21 @@ class TestWithTransform(TransformTest): ...@@ -119,21 +119,21 @@ class TestWithTransform(TransformTest):
""") """)
self.assertCode(u""" self.assertCode(u"""
$MGR = x $1_0 = x
$EXIT = $MGR.__exit__ $1_2 = $1_0.__exit__
$VALUE = $MGR.__enter__() $1_3 = $1_0.__enter__()
$EXC = True $1_1 = True
try: try:
try: try:
y = $VALUE y = $1_3
y = z ** 3 y = z ** 3
except: except:
$EXC = False $1_1 = False
if (not $EXIT($EXCINFO)): if (not $1_2($0_0)):
raise raise
finally: finally:
if $EXC: if $1_1:
$EXIT(None, None, None) $1_2(None, None, None)
""", t) """, t)
......
from Cython.TestUtils import CythonTest from Cython.TestUtils import CythonTest
from Cython.Compiler.TreeFragment import * from Cython.Compiler.TreeFragment import *
from Cython.Compiler.Nodes import * from Cython.Compiler.Nodes import *
from Cython.Compiler.UtilNodes import *
import Cython.Compiler.Naming as Naming import Cython.Compiler.Naming as Naming
class TestTreeFragments(CythonTest): class TestTreeFragments(CythonTest):
...@@ -54,10 +55,10 @@ class TestTreeFragments(CythonTest): ...@@ -54,10 +55,10 @@ class TestTreeFragments(CythonTest):
x = TMP x = TMP
""") """)
T = F.substitute(temps=[u"TMP"]) T = F.substitute(temps=[u"TMP"])
s = T.stats s = T.body.stats
self.assert_(s[0].expr.name == Naming.temp_prefix + u"1_TMP", s[0].expr.name) self.assert_(isinstance(s[0].expr, TempRefNode))
self.assert_(s[1].rhs.name == Naming.temp_prefix + u"1_TMP") self.assert_(isinstance(s[1].rhs, TempRefNode))
self.assert_(s[0].expr.name != u"TMP") self.assert_(s[0].expr.handle is s[1].rhs.handle)
if __name__ == "__main__": if __name__ == "__main__":
import unittest import unittest
......
...@@ -8,11 +8,12 @@ from Scanning import PyrexScanner, StringSourceDescriptor ...@@ -8,11 +8,12 @@ from Scanning import PyrexScanner, StringSourceDescriptor
from Symtab import BuiltinScope, ModuleScope from Symtab import BuiltinScope, ModuleScope
import Symtab import Symtab
import PyrexTypes import PyrexTypes
from Visitor import VisitorTransform, temp_name_handle from Visitor import VisitorTransform
from Nodes import Node, StatListNode from Nodes import Node, StatListNode
from ExprNodes import NameNode from ExprNodes import NameNode
import Parsing import Parsing
import Main import Main
import UtilNodes
""" """
Support for parsing strings into code trees. Support for parsing strings into code trees.
...@@ -111,11 +112,17 @@ class TemplateTransform(VisitorTransform): ...@@ -111,11 +112,17 @@ class TemplateTransform(VisitorTransform):
def __call__(self, node, substitutions, temps, pos): def __call__(self, node, substitutions, temps, pos):
self.substitutions = substitutions self.substitutions = substitutions
tempdict = {}
for key in temps:
tempdict[key] = temp_name_handle(key) # pending result_code refactor: Symtab.new_temp(PyrexTypes.py_object_type, key)
self.temp_key_to_entries = tempdict
self.pos = pos self.pos = pos
self.temps = temps
if len(temps) > 0:
self.tempblock = UtilNodes.TempsBlockNode(self.get_pos(node),
[PyrexTypes.py_object_type for x in temps],
body=None)
self.tempblock.body = super(TemplateTransform, self).__call__(node)
return self.tempblock
else:
return super(TemplateTransform, self).__call__(node) return super(TemplateTransform, self).__call__(node)
def get_pos(self, node): def get_pos(self, node):
...@@ -145,13 +152,13 @@ class TemplateTransform(VisitorTransform): ...@@ -145,13 +152,13 @@ class TemplateTransform(VisitorTransform):
def visit_NameNode(self, node): def visit_NameNode(self, node):
tempentry = self.temp_key_to_entries.get(node.name) try:
if tempentry is not None: tmpidx = self.temps.index(node.name)
# Replace name with temporary except:
return NameNode(self.get_pos(node), name=tempentry)
# Pending result_code refactor: return NameNode(self.get_pos(node), entry=tempentry)
else:
return self.try_substitution(node, node.name) return self.try_substitution(node, node.name)
else:
# Replace name with temporary
return self.tempblock.get_ref_node(tmpidx, self.get_pos(node))
def visit_ExprStatNode(self, node): def visit_ExprStatNode(self, node):
# If an expression-as-statement consists of only a replaceable # If an expression-as-statement consists of only a replaceable
......
#
# Nodes used as utilities and support for transforms etc.
# These often make up sets including both Nodes and ExprNodes
# so it is convenient to have them in a seperate module.
#
import Nodes
import ExprNodes
from Nodes import Node
from ExprNodes import ExprNode
class TempHandle(object):
temp = None
def __init__(self, type):
self.type = type
class TempRefNode(ExprNode):
# handle TempHandle
subexprs = []
def analyse_types(self, env):
assert self.type == self.handle.type
def analyse_target_types(self, env):
assert self.type == self.handle.type
def analyse_target_declaration(self, env):
pass
def calculate_result_code(self):
result = self.handle.temp
if result is None: result = "<error>" # might be called and overwritten
return result
def generate_result_code(self, code):
pass
def generate_assignment_code(self, rhs, code):
if self.type.is_pyobject:
rhs.make_owned_reference(code)
code.put_xdecref(self.result(), self.ctype())
code.putln('%s = %s;' % (self.result(), rhs.result_as(self.ctype())))
rhs.generate_post_assignment_code(code)
class TempsBlockNode(Node):
"""
Creates a block which allocates temporary variables.
This is used by transforms to output constructs that need
to make use of a temporary variable. Simply pass the types
of the needed temporaries to the constructor.
The variables can be referred to using a TempRefNode
(which can be constructed by calling get_ref_node).
"""
child_attrs = ["body"]
def __init__(self, pos, types, body):
self.handles = [TempHandle(t) for t in types]
Node.__init__(self, pos, body=body)
def get_ref_node(self, index, pos):
handle = self.handles[index]
return TempRefNode(pos, handle=handle, type=handle.type)
def append_temp(self, type, pos):
"""
Appends a new temporary which this block manages, and returns
its index.
"""
self.handle.append(TempHandle(type))
return len(self.handle) - 1
def generate_execution_code(self, code):
for handle in self.handles:
handle.temp = code.funcstate.allocate_temp(handle.type)
self.body.generate_execution_code(code)
for handle in self.handles:
code.funcstate.release_temp(handle.temp)
def analyse_control_flow(self, env):
self.body.analyse_control_flow(env)
def analyse_declarations(self, env):
self.body.analyse_declarations(env)
def analyse_expressions(self, env):
self.body.analyse_expressions(env)
def generate_function_definitions(self, env, code):
self.body.generate_function_definitions(env, code)
def annotate(self, code):
self.body.annotate(code)
...@@ -199,23 +199,6 @@ def replace_node(ptr, value): ...@@ -199,23 +199,6 @@ def replace_node(ptr, value):
else: else:
getattr(parent, attrname)[listidx] = value getattr(parent, attrname)[listidx] = value
tmpnamectr = 0
def temp_name_handle(description=None):
global tmpnamectr
tmpnamectr += 1
if description is not None:
name = u"%d_%s" % (tmpnamectr, description)
else:
name = u"%d" % tmpnamectr
return EncodedString(Naming.temp_prefix + name)
def get_temp_name_handle_desc(handle):
if not handle.startswith(u"__cyt_"):
return None
else:
idx = handle.find(u"_", 6)
return handle[idx+1:]
class PrintTree(TreeVisitor): class PrintTree(TreeVisitor):
"""Prints a representation of the tree to standard output. """Prints a representation of the tree to standard output.
Subclass and override repr_of to provide more information Subclass and override repr_of to provide more information
......
...@@ -47,10 +47,16 @@ class CythonTest(unittest.TestCase): ...@@ -47,10 +47,16 @@ class CythonTest(unittest.TestCase):
self.assertEqual(len(expected), len(result), self.assertEqual(len(expected), len(result),
"Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(expected), u"\n".join(result))) "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(expected), u"\n".join(result)))
def assertCode(self, expected, result_tree): def codeToLines(self, tree):
writer = CodeWriter() writer = CodeWriter()
writer.write(result_tree) writer.write(tree)
result_lines = writer.result.lines return writer.result.lines
def codeToString(self, tree):
return "\n".join(self.codeToLines(tree))
def assertCode(self, expected, result_tree):
result_lines = self.codeToLines(result_tree)
expected_lines = strip_common_indent(expected.split("\n")) expected_lines = strip_common_indent(expected.split("\n"))
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment