TreeFragment.py 9.12 KB
Newer Older
1 2 3 4 5 6 7 8
#
# TreeFragments - parsing of strings to trees
#

"""
Support for parsing strings into code trees.
"""

9 10 11
from __future__ import absolute_import

import re
12
from io import StringIO
13 14 15 16 17 18 19

from .Scanning import PyrexScanner, StringSourceDescriptor
from .Symtab import ModuleScope
from . import PyrexTypes
from .Visitor import VisitorTransform
from .Nodes import Node, StatListNode
from .ExprNodes import NameNode
20
from .StringEncoding import _unicode
21 22 23 24 25
from . import Parsing
from . import Main
from . import UtilNodes


26
class StringParseContext(Main.Context):
27 28 29 30 31 32
    def __init__(self, name, include_directories=None, compiler_directives=None):
        if include_directories is None:
            include_directories = []
        if compiler_directives is None:
            compiler_directives = {}
        Main.Context.__init__(self, include_directories, compiler_directives,
33
                              create_testscope=False)
34
        self.module_name = name
35

36
    def find_module(self, module_name, relative_to=None, pos=None, need_pxd=1, absolute_fallback=True):
37
        if module_name not in (self.module_name, 'cython'):
38
            raise AssertionError("Not yet supporting any cimports/includes from string code snippets")
Stefan Behnel's avatar
Stefan Behnel committed
39
        return ModuleScope(module_name, parent_module=None, context=self)
40

41

42
def parse_from_strings(name, code, pxds=None, level=None, initial_pos=None,
43
                       context=None, allow_struct_enum_decorator=False):
44 45 46 47
    """
    Utility method to parse a (unicode) string of code. This is mostly
    used for internal Cython compiler purposes (creating code snippets
    that transforms should emit, as well as unit testing).
48

49 50 51
    code - a unicode string containing Cython (module-level) code
    name - a descriptive name for the code source (to use in error messages etc.)

52 53 54 55 56 57 58
    RETURNS

    The tree, i.e. a ModuleNode. The ModuleNode's scope attribute is
    set to the scope used when parsing.
    """
    if context is None:
        context = StringParseContext(name)
59 60 61 62
    # Since source files carry an encoding, it makes sense in this context
    # to use a unicode string so that code fragments don't have to bother
    # with encoding. This means that test code passed in should not have an
    # encoding header.
63
    assert isinstance(code, _unicode), "unicode code snippets only please"
64 65 66
    encoding = "UTF-8"

    module_name = name
67 68
    if initial_pos is None:
        initial_pos = (name, 1, 0)
69 70
    code_source = StringSourceDescriptor(name, code)

Stefan Behnel's avatar
Stefan Behnel committed
71
    scope = context.find_module(module_name, pos=initial_pos, need_pxd=False)
72

73
    buf = StringIO(code)
74 75

    scanner = PyrexScanner(buf, code_source, source_encoding = encoding,
76
                     scope = scope, context = context, initial_pos = initial_pos)
77 78
    ctx = Parsing.Ctx(allow_struct_enum_decorator=allow_struct_enum_decorator)

79
    if level is None:
80
        tree = Parsing.p_module(scanner, 0, module_name, ctx=ctx)
81
        tree.scope = scope
82
        tree.is_pxd = False
83
    else:
84 85
        tree = Parsing.p_code(scanner, level=level, ctx=ctx)

86
    tree.scope = scope
87 88
    return tree

Stefan Behnel's avatar
Stefan Behnel committed
89

90 91
class TreeCopier(VisitorTransform):
    def visit_Node(self, node):
92 93 94 95
        if node is None:
            return node
        else:
            c = node.clone_node()
96
            self.visitchildren(c)
97 98
            return c

Stefan Behnel's avatar
Stefan Behnel committed
99

100 101 102 103
class ApplyPositionAndCopy(TreeCopier):
    def __init__(self, pos):
        super(ApplyPositionAndCopy, self).__init__()
        self.pos = pos
104

105 106 107 108 109
    def visit_Node(self, node):
        copy = super(ApplyPositionAndCopy, self).visit_Node(node)
        copy.pos = self.pos
        return copy

Stefan Behnel's avatar
Stefan Behnel committed
110

111 112 113
class TemplateTransform(VisitorTransform):
    """
    Makes a copy of a template tree while doing substitutions.
114

115 116 117 118 119 120 121 122 123 124 125 126
    A dictionary "substitutions" should be passed in when calling
    the transform; mapping names to replacement nodes. Then replacement
    happens like this:
     - If an ExprStatNode contains a single NameNode, whose name is
       a key in the substitutions dictionary, the ExprStatNode is
       replaced with a copy of the tree given in the dictionary.
       It is the responsibility of the caller that the replacement
       node is a valid statement.
     - If a single NameNode is otherwise encountered, it is replaced
       if its name is listed in the substitutions dictionary in the
       same way. It is the responsibility of the caller to make sure
       that the replacement nodes is a valid expression.
127 128 129

    Also a list "temps" should be passed. Any names listed will
    be transformed into anonymous, temporary names.
130

131 132 133
    Currently supported for tempnames is:
    NameNode
    (various function and class definition nodes etc. should be added to this)
134

135 136 137
    Each replacement node gets the position of the substituted node
    recursively applied to every member node.
    """
138

Dag Sverre Seljebotn's avatar
Dag Sverre Seljebotn committed
139 140
    temp_name_counter = 0

141 142 143
    def __call__(self, node, substitutions, temps, pos):
        self.substitutions = substitutions
        self.pos = pos
144 145 146
        tempmap = {}
        temphandles = []
        for temp in temps:
Dag Sverre Seljebotn's avatar
Dag Sverre Seljebotn committed
147
            TemplateTransform.temp_name_counter += 1
148
            handle = UtilNodes.TempHandle(PyrexTypes.py_object_type)
149
            tempmap[temp] = handle
150
            temphandles.append(handle)
151 152
        self.tempmap = tempmap
        result = super(TemplateTransform, self).__call__(node)
153 154 155 156
        if temps:
            result = UtilNodes.TempsBlockNode(self.get_pos(node),
                                              temps=temphandles,
                                              body=result)
157
        return result
158

159 160 161 162 163
    def get_pos(self, node):
        if self.pos:
            return self.pos
        else:
            return node.pos
164

165
    def visit_Node(self, node):
166
        if node is None:
167
            return None
168 169
        else:
            c = node.clone_node()
170 171
            if self.pos is not None:
                c.pos = self.pos
172
            self.visitchildren(c)
173
            return c
174

175 176
    def try_substitution(self, node, key):
        sub = self.substitutions.get(key)
177 178 179 180
        if sub is not None:
            pos = self.pos
            if pos is None: pos = node.pos
            return ApplyPositionAndCopy(pos)(sub)
181
        else:
182
            return self.visit_Node(node) # make copy as usual
183

184
    def visit_NameNode(self, node):
185 186
        temphandle = self.tempmap.get(node.name)
        if temphandle:
187
            # Replace name with temporary
188
            return temphandle.ref(self.get_pos(node))
189 190
        else:
            return self.try_substitution(node, node.name)
191

192 193 194
    def visit_ExprStatNode(self, node):
        # If an expression-as-statement consists of only a replaceable
        # NameNode, we replace the entire statement, not only the NameNode
195 196
        if isinstance(node.expr, NameNode):
            return self.try_substitution(node, node.expr.name)
197 198
        else:
            return self.visit_Node(node)
199

200

201 202 203
def copy_code_tree(node):
    return TreeCopier()(node)

204

205
_match_indent = re.compile(u"^ *").match
206 207


208
def strip_common_indent(lines):
209
    """Strips empty lines and common indentation from the list of strings given in lines"""
210
    # TODO: Facilitate textwrap.indent instead
211
    lines = [x for x in lines if x.strip() != u""]
212 213 214
    if lines:
        minindent = min([len(_match_indent(x).group(0)) for x in lines])
        lines = [x[minindent:] for x in lines]
215
    return lines
216

217

218
class TreeFragment(object):
219 220 221 222 223 224 225
    def __init__(self, code, name=None, pxds=None, temps=None, pipeline=None, level=None, initial_pos=None):
        if pxds is None:
            pxds = {}
        if temps is None:
            temps = []
        if pipeline is None:
            pipeline = []
226 227
        if not name:
            name = "(tree fragment)"
228

229
        if isinstance(code, _unicode):
230 231
            def fmt(x): return u"\n".join(strip_common_indent(x.split(u"\n")))

232 233
            fmt_code = fmt(code)
            fmt_pxds = {}
234
            for key, value in pxds.items():
235
                fmt_pxds[key] = fmt(value)
236
            mod = t = parse_from_strings(name, fmt_code, fmt_pxds, level=level, initial_pos=initial_pos)
237 238
            if level is None:
                t = t.body # Make sure a StatListNode is at the top
239 240 241
            if not isinstance(t, StatListNode):
                t = StatListNode(pos=mod.pos, stats=[t])
            for transform in pipeline:
242 243
                if transform is None:
                    continue
244 245
                t = transform(t)
            self.root = t
246
        elif isinstance(code, Node):
247 248
            if pxds:
                raise NotImplementedError()
249 250 251
            self.root = code
        else:
            raise ValueError("Unrecognized code format (accepts unicode and Node)")
252
        self.temps = temps
253 254 255 256

    def copy(self):
        return copy_code_tree(self.root)

257 258 259 260 261
    def substitute(self, nodes=None, temps=None, pos = None):
        if nodes is None:
            nodes = {}
        if temps is None:
            temps = []
262 263
        return TemplateTransform()(self.root,
                                   substitutions = nodes,
264
                                   temps = self.temps + temps, pos = pos)
265

Stefan Behnel's avatar
Stefan Behnel committed
266

267 268 269 270
class SetPosTransform(VisitorTransform):
    def __init__(self, pos):
        super(SetPosTransform, self).__init__()
        self.pos = pos
271

272 273 274
    def visit_Node(self, node):
        node.pos = self.pos
        self.visitchildren(node)
275
        return node