ParseTreeTransforms.py 97 KB
Newer Older
1
import cython
2 3 4 5 6 7 8 9 10 11 12
cython.declare(PyrexTypes=object, Naming=object, ExprNodes=object, Nodes=object,
               Options=object, UtilNodes=object, ModuleNode=object,
               LetNode=object, LetRefNode=object, TreeFragment=object,
               TemplateTransform=object, EncodedString=object,
               error=object, warning=object, copy=object)

import PyrexTypes
import Naming
import ExprNodes
import Nodes
import Options
13
import Builtin
14

15
from Cython.Compiler.Visitor import VisitorTransform, TreeVisitor
16
from Cython.Compiler.Visitor import CythonTransform, EnvTransform, ScopeTrackingTransform
17
from Cython.Compiler.ModuleNode import ModuleNode
18
from Cython.Compiler.UtilNodes import LetNode, LetRefNode, ResultRefNode
19
from Cython.Compiler.TreeFragment import TreeFragment, TemplateTransform
20
from Cython.Compiler.StringEncoding import EncodedString
21
from Cython.Compiler.Errors import error, warning, CompileError, InternalError
22

23
import copy
24

25 26 27 28 29 30 31 32 33 34 35 36

class NameNodeCollector(TreeVisitor):
    """Collect all NameNodes of a (sub-)tree in the ``name_nodes``
    attribute.
    """
    def __init__(self):
        super(NameNodeCollector, self).__init__()
        self.name_nodes = []

    def visit_NameNode(self, node):
        self.name_nodes.append(node)

37 38 39
    def visit_Node(self, node):
        self._visitchildren(node, None)

40

41
class SkipDeclarations(object):
42
    """
43 44 45 46 47
    Variable and function declarations can often have a deep tree structure,
    and yet most transformations don't need to descend to this depth.

    Declaration nodes are removed after AnalyseDeclarationsTransform, so there
    is no need to use this for transformations after that point.
48 49 50
    """
    def visit_CTypeDefNode(self, node):
        return node
51

52 53
    def visit_CVarDefNode(self, node):
        return node
54

55 56
    def visit_CDeclaratorNode(self, node):
        return node
57

58 59
    def visit_CBaseTypeNode(self, node):
        return node
60

61 62 63 64 65 66
    def visit_CEnumDefNode(self, node):
        return node

    def visit_CStructOrUnionDefNode(self, node):
        return node

67
class NormalizeTree(CythonTransform):
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
    """
    This transform fixes up a few things after parsing
    in order to make the parse tree more suitable for
    transforms.

    a) After parsing, blocks with only one statement will
    be represented by that statement, not by a StatListNode.
    When doing transforms this is annoying and inconsistent,
    as one cannot in general remove a statement in a consistent
    way and so on. This transform wraps any single statements
    in a StatListNode containing a single statement.

    b) The PassStatNode is a noop and serves no purpose beyond
    plugging such one-statement blocks; i.e., once parsed a
`    "pass" can just as well be represented using an empty
    StatListNode. This means less special cases to worry about
    in subsequent transforms (one always checks to see if a
    StatListNode has no children to see if the block is empty).
    """

88 89
    def __init__(self, context):
        super(NormalizeTree, self).__init__(context)
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
        self.is_in_statlist = False
        self.is_in_expr = False

    def visit_ExprNode(self, node):
        stacktmp = self.is_in_expr
        self.is_in_expr = True
        self.visitchildren(node)
        self.is_in_expr = stacktmp
        return node

    def visit_StatNode(self, node, is_listcontainer=False):
        stacktmp = self.is_in_statlist
        self.is_in_statlist = is_listcontainer
        self.visitchildren(node)
        self.is_in_statlist = stacktmp
        if not self.is_in_statlist and not self.is_in_expr:
106
            return Nodes.StatListNode(pos=node.pos, stats=[node])
107 108 109 110 111 112 113 114 115 116 117
        else:
            return node

    def visit_StatListNode(self, node):
        self.is_in_statlist = True
        self.visitchildren(node)
        self.is_in_statlist = False
        return node

    def visit_ParallelAssignmentNode(self, node):
        return self.visit_StatNode(node, True)
118

119 120 121 122 123 124
    def visit_CEnumDefNode(self, node):
        return self.visit_StatNode(node, True)

    def visit_CStructOrUnionDefNode(self, node):
        return self.visit_StatNode(node, True)

125 126 127
    # Eliminate PassStatNode
    def visit_PassStatNode(self, node):
        if not self.is_in_statlist:
128
            return Nodes.StatListNode(pos=node.pos, stats=[])
129 130 131
        else:
            return []

132
    def visit_CDeclaratorNode(self, node):
133
        return node
134

135

136 137 138
class PostParseError(CompileError): pass

# error strings checked by unit tests, so define them
139
ERR_CDEF_INCLASS = 'Cannot assign default value to fields in cdef classes, structs or unions'
140 141
ERR_BUF_DEFAULTS = 'Invalid buffer defaults specification (see docs)'
ERR_INVALID_SPECIALATTR_TYPE = 'Special attributes must not have a type declared'
142
class PostParse(ScopeTrackingTransform):
143 144 145 146 147 148 149
    """
    Basic interpretation of the parse tree, as well as validity
    checking that can be done on a very basic level on the parse
    tree (while still not being a problem with the basic syntax,
    as such).

    Specifically:
150
    - Default values to cdef assignments are turned into single
151 152
    assignments following the declaration (everywhere but in class
    bodies, where they raise a compile error)
153

154 155
    - Interpret some node structures into Python runtime values.
    Some nodes take compile-time arguments (currently:
156
    TemplatedTypeNode[args] and __cythonbufferdefaults__ = {args}),
157 158 159 160 161 162 163 164
    which should be interpreted. This happens in a general way
    and other steps should be taken to ensure validity.

    Type arguments cannot be interpreted in this way.

    - For __cythonbufferdefaults__ the arguments are checked for
    validity.

Robert Bradshaw's avatar
Robert Bradshaw committed
165
    TemplatedTypeNode has its directives interpreted:
166 167
    Any first positional argument goes into the "dtype" attribute,
    any "ndim" keyword argument goes into the "ndim" attribute and
168
    so on. Also it is checked that the directive combination is valid.
169 170
    - __cythonbufferdefaults__ attributes are parsed and put into the
    type information.
171 172 173 174 175 176

    Note: Currently Parsing.py does a lot of interpretation and
    reorganization that can be refactored into this transform
    if a more pure Abstract Syntax Tree is wanted.
    """

177 178 179 180 181 182
    def __init__(self, context):
        super(PostParse, self).__init__(context)
        self.specialattribute_handlers = {
            '__cythonbufferdefaults__' : self.handle_bufferdefaults
        }

183
    def visit_ModuleNode(self, node):
Stefan Behnel's avatar
Stefan Behnel committed
184
        self.lambda_counter = 1
185
        self.genexpr_counter = 1
186
        return super(PostParse, self).visit_ModuleNode(node)
187

Stefan Behnel's avatar
Stefan Behnel committed
188 189 190 191 192
    def visit_LambdaNode(self, node):
        # unpack a lambda expression into the corresponding DefNode
        lambda_id = self.lambda_counter
        self.lambda_counter += 1
        node.lambda_name = EncodedString(u'lambda%d' % lambda_id)
Vitja Makarov's avatar
Vitja Makarov committed
193 194 195
        collector = YieldNodeCollector()
        collector.visitchildren(node.result_expr)
        if collector.yields or isinstance(node.result_expr, ExprNodes.YieldExprNode):
Vitja Makarov's avatar
Vitja Makarov committed
196 197
            body = Nodes.ExprStatNode(
                node.result_expr.pos, expr=node.result_expr)
Vitja Makarov's avatar
Vitja Makarov committed
198 199 200
        else:
            body = Nodes.ReturnStatNode(
                node.result_expr.pos, value=node.result_expr)
Stefan Behnel's avatar
Stefan Behnel committed
201 202 203 204
        node.def_node = Nodes.DefNode(
            node.pos, name=node.name, lambda_name=node.lambda_name,
            args=node.args, star_arg=node.star_arg,
            starstar_arg=node.starstar_arg,
Vitja Makarov's avatar
Vitja Makarov committed
205
            body=body, doc=None)
Stefan Behnel's avatar
Stefan Behnel committed
206 207
        self.visitchildren(node)
        return node
208 209 210 211 212 213 214

    def visit_GeneratorExpressionNode(self, node):
        # unpack a generator expression into the corresponding DefNode
        genexpr_id = self.genexpr_counter
        self.genexpr_counter += 1
        node.genexpr_name = EncodedString(u'genexpr%d' % genexpr_id)

Vitja Makarov's avatar
Vitja Makarov committed
215
        node.def_node = Nodes.DefNode(node.pos, name=node.name,
216 217 218 219
                                      doc=None,
                                      args=[], star_arg=None,
                                      starstar_arg=None,
                                      body=node.loop)
Stefan Behnel's avatar
Stefan Behnel committed
220 221 222
        self.visitchildren(node)
        return node

223
    # cdef variables
224
    def handle_bufferdefaults(self, decl):
225
        if not isinstance(decl.default, ExprNodes.DictNode):
226
            raise PostParseError(decl.pos, ERR_BUF_DEFAULTS)
227 228
        self.scope_node.buffer_defaults_node = decl.default
        self.scope_node.buffer_defaults_pos = decl.pos
229

230 231
    def visit_CVarDefNode(self, node):
        # This assumes only plain names and pointers are assignable on
232 233 234
        # declaration. Also, it makes use of the fact that a cdef decl
        # must appear before the first use, so we don't have to deal with
        # "i = 3; cdef int i = i" and can simply move the nodes around.
235 236
        try:
            self.visitchildren(node)
237 238 239 240
            stats = [node]
            newdecls = []
            for decl in node.declarators:
                declbase = decl
241
                while isinstance(declbase, Nodes.CPtrDeclaratorNode):
242
                    declbase = declbase.base
243
                if isinstance(declbase, Nodes.CNameDeclaratorNode):
244
                    if declbase.default is not None:
245
                        if self.scope_type in ('cclass', 'pyclass', 'struct'):
246
                            if isinstance(self.scope_node, Nodes.CClassDefNode):
247 248 249 250 251 252 253
                                handler = self.specialattribute_handlers.get(decl.name)
                                if handler:
                                    if decl is not declbase:
                                        raise PostParseError(decl.pos, ERR_INVALID_SPECIALATTR_TYPE)
                                    handler(decl)
                                    continue # Remove declaration
                            raise PostParseError(decl.pos, ERR_CDEF_INCLASS)
254
                        first_assignment = self.scope_type != 'module'
255 256
                        stats.append(Nodes.SingleAssignmentNode(node.pos,
                            lhs=ExprNodes.NameNode(node.pos, name=declbase.name),
257
                            rhs=declbase.default, first=first_assignment))
258 259 260 261
                        declbase.default = None
                newdecls.append(decl)
            node.declarators = newdecls
            return stats
262 263 264 265 266
        except PostParseError, e:
            # An error in a cdef clause is ok, simply remove the declaration
            # and try to move on to report more errors
            self.context.nonfatal_error(e)
            return None
267

Stefan Behnel's avatar
Stefan Behnel committed
268 269
    # Split parallel assignments (a,b = b,a) into separate partial
    # assignments that are executed rhs-first using temps.  This
Stefan Behnel's avatar
Stefan Behnel committed
270 271 272 273
    # restructuring must be applied before type analysis so that known
    # types on rhs and lhs can be matched directly.  It is required in
    # the case that the types cannot be coerced to a Python type in
    # order to assign from a tuple.
274 275 276 277 278 279 280 281 282 283

    def visit_SingleAssignmentNode(self, node):
        self.visitchildren(node)
        return self._visit_assignment_node(node, [node.lhs, node.rhs])

    def visit_CascadedAssignmentNode(self, node):
        self.visitchildren(node)
        return self._visit_assignment_node(node, node.lhs_list + [node.rhs])

    def _visit_assignment_node(self, node, expr_list):
284 285 286
        """Flatten parallel assignments into separate single
        assignments or cascaded assignments.
        """
287 288 289 290
        if sum([ 1 for expr in expr_list if expr.is_sequence_constructor ]) < 2:
            # no parallel assignments => nothing to do
            return node

291 292
        expr_list_list = []
        flatten_parallel_assignments(expr_list, expr_list_list)
293 294 295
        temp_refs = []
        eliminate_rhs_duplicates(expr_list_list, temp_refs)

296 297 298 299 300
        nodes = []
        for expr_list in expr_list_list:
            lhs_list = expr_list[:-1]
            rhs = expr_list[-1]
            if len(lhs_list) == 1:
301
                node = Nodes.SingleAssignmentNode(rhs.pos,
302 303 304 305 306
                    lhs = lhs_list[0], rhs = rhs)
            else:
                node = Nodes.CascadedAssignmentNode(rhs.pos,
                    lhs_list = lhs_list, rhs = rhs)
            nodes.append(node)
307

308
        if len(nodes) == 1:
309 310 311 312 313 314 315 316 317 318 319 320 321
            assign_node = nodes[0]
        else:
            assign_node = Nodes.ParallelAssignmentNode(nodes[0].pos, stats = nodes)

        if temp_refs:
            duplicates_and_temps = [ (temp.expression, temp)
                                     for temp in temp_refs ]
            sort_common_subsequences(duplicates_and_temps)
            for _, temp_ref in duplicates_and_temps[::-1]:
                assign_node = LetNode(temp_ref, assign_node)

        return assign_node

322 323 324 325 326 327 328 329 330 331 332 333 334 335
    def _flatten_sequence(self, seq, result):
        for arg in seq.args:
            if arg.is_sequence_constructor:
                self._flatten_sequence(arg, result)
            else:
                result.append(arg)
        return result

    def visit_DelStatNode(self, node):
        self.visitchildren(node)
        node.args = self._flatten_sequence(node, [])
        return node


336 337 338 339 340 341
def eliminate_rhs_duplicates(expr_list_list, ref_node_sequence):
    """Replace rhs items by LetRefNodes if they appear more than once.
    Creates a sequence of LetRefNodes that set up the required temps
    and appends them to ref_node_sequence.  The input list is modified
    in-place.
    """
Stefan Behnel's avatar
Stefan Behnel committed
342
    seen_nodes = cython.set()
343 344 345 346 347 348 349 350 351 352 353
    ref_nodes = {}
    def find_duplicates(node):
        if node.is_literal or node.is_name:
            # no need to replace those; can't include attributes here
            # as their access is not necessarily side-effect free
            return
        if node in seen_nodes:
            if node not in ref_nodes:
                ref_node = LetRefNode(node)
                ref_nodes[node] = ref_node
                ref_node_sequence.append(ref_node)
354
        else:
355 356 357 358 359 360 361 362 363 364 365 366 367 368 369
            seen_nodes.add(node)
            if node.is_sequence_constructor:
                for item in node.args:
                    find_duplicates(item)

    for expr_list in expr_list_list:
        rhs = expr_list[-1]
        find_duplicates(rhs)
    if not ref_nodes:
        return

    def substitute_nodes(node):
        if node in ref_nodes:
            return ref_nodes[node]
        elif node.is_sequence_constructor:
370
            node.args = list(map(substitute_nodes, node.args))
371
        return node
372

373 374 375
    # replace nodes inside of the common subexpressions
    for node in ref_nodes:
        if node.is_sequence_constructor:
376
            node.args = list(map(substitute_nodes, node.args))
377 378 379 380 381 382 383

    # replace common subexpressions on all rhs items
    for expr_list in expr_list_list:
        expr_list[-1] = substitute_nodes(expr_list[-1])

def sort_common_subsequences(items):
    """Sort items/subsequences so that all items and subsequences that
Stefan Behnel's avatar
Stefan Behnel committed
384 385 386 387 388 389 390 391 392
    an item contains appear before the item itself.  This is needed
    because each rhs item must only be evaluated once, so its value
    must be evaluated first and then reused when packing sequences
    that contain it.

    This implies a partial order, and the sort must be stable to
    preserve the original order as much as possible, so we use a
    simple insertion sort (which is very fast for short sequences, the
    normal case in practice).
393 394 395 396 397 398 399 400 401 402 403 404
    """
    def contains(seq, x):
        for item in seq:
            if item is x:
                return True
            elif item.is_sequence_constructor and contains(item.args, x):
                return True
        return False
    def lower_than(a,b):
        return b.is_sequence_constructor and contains(b.args, a)

    for pos, item in enumerate(items):
405
        key = item[1] # the ResultRefNode which has already been injected into the sequences
406 407 408 409 410 411 412 413
        new_pos = pos
        for i in xrange(pos-1, -1, -1):
            if lower_than(key, items[i][0]):
                new_pos = i
        if new_pos != pos:
            for i in xrange(pos, new_pos, -1):
                items[i] = items[i-1]
            items[new_pos] = item
414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429

def flatten_parallel_assignments(input, output):
    #  The input is a list of expression nodes, representing the LHSs
    #  and RHS of one (possibly cascaded) assignment statement.  For
    #  sequence constructors, rearranges the matching parts of both
    #  sides into a list of equivalent assignments between the
    #  individual elements.  This transformation is applied
    #  recursively, so that nested structures get matched as well.
    rhs = input[-1]
    if not rhs.is_sequence_constructor or not sum([lhs.is_sequence_constructor for lhs in input[:-1]]):
        output.append(input)
        return

    complete_assignments = []

    rhs_size = len(rhs.args)
Stefan Behnel's avatar
Stefan Behnel committed
430
    lhs_targets = [ [] for _ in xrange(rhs_size) ]
431 432 433 434 435 436 437 438 439
    starred_assignments = []
    for lhs in input[:-1]:
        if not lhs.is_sequence_constructor:
            if lhs.is_starred:
                error(lhs.pos, "starred assignment target must be in a list or tuple")
            complete_assignments.append(lhs)
            continue
        lhs_size = len(lhs.args)
        starred_targets = sum([1 for expr in lhs.args if expr.is_starred])
Stefan Behnel's avatar
Stefan Behnel committed
440 441 442 443 444 445 446 447 448
        if starred_targets > 1:
            error(lhs.pos, "more than 1 starred expression in assignment")
            output.append([lhs,rhs])
            continue
        elif lhs_size - starred_targets > rhs_size:
            error(lhs.pos, "need more than %d value%s to unpack"
                  % (rhs_size, (rhs_size != 1) and 's' or ''))
            output.append([lhs,rhs])
            continue
Stefan Behnel's avatar
Stefan Behnel committed
449
        elif starred_targets:
450 451
            map_starred_assignment(lhs_targets, starred_assignments,
                                   lhs.args, rhs.args)
Stefan Behnel's avatar
Stefan Behnel committed
452 453 454 455 456
        elif lhs_size < rhs_size:
            error(lhs.pos, "too many values to unpack (expected %d, got %d)"
                  % (lhs_size, rhs_size))
            output.append([lhs,rhs])
            continue
457
        else:
Stefan Behnel's avatar
Stefan Behnel committed
458 459
            for targets, expr in zip(lhs_targets, lhs.args):
                targets.append(expr)
460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497

    if complete_assignments:
        complete_assignments.append(rhs)
        output.append(complete_assignments)

    # recursively flatten partial assignments
    for cascade, rhs in zip(lhs_targets, rhs.args):
        if cascade:
            cascade.append(rhs)
            flatten_parallel_assignments(cascade, output)

    # recursively flatten starred assignments
    for cascade in starred_assignments:
        if cascade[0].is_sequence_constructor:
            flatten_parallel_assignments(cascade, output)
        else:
            output.append(cascade)

def map_starred_assignment(lhs_targets, starred_assignments, lhs_args, rhs_args):
    # Appends the fixed-position LHS targets to the target list that
    # appear left and right of the starred argument.
    #
    # The starred_assignments list receives a new tuple
    # (lhs_target, rhs_values_list) that maps the remaining arguments
    # (those that match the starred target) to a list.

    # left side of the starred target
    for i, (targets, expr) in enumerate(zip(lhs_targets, lhs_args)):
        if expr.is_starred:
            starred = i
            lhs_remaining = len(lhs_args) - i - 1
            break
        targets.append(expr)
    else:
        raise InternalError("no starred arg found when splitting starred assignment")

    # right side of the starred target
    for i, (targets, expr) in enumerate(zip(lhs_targets[-lhs_remaining:],
Vitja Makarov's avatar
Vitja Makarov committed
498
                                            lhs_args[starred + 1:])):
499 500 501 502 503 504 505 506 507 508 509 510 511 512 513
        targets.append(expr)

    # the starred target itself, must be assigned a (potentially empty) list
    target = lhs_args[starred].target # unpack starred node
    starred_rhs = rhs_args[starred:]
    if lhs_remaining:
        starred_rhs = starred_rhs[:-lhs_remaining]
    if starred_rhs:
        pos = starred_rhs[0].pos
    else:
        pos = target.pos
    starred_assignments.append([
        target, ExprNodes.ListNode(pos=pos, args=starred_rhs)])


514
class PxdPostParse(CythonTransform, SkipDeclarations):
515 516 517
    """
    Basic interpretation/validity checking that should only be
    done on pxd trees.
518 519 520 521 522 523

    A lot of this checking currently happens in the parser; but
    what is listed below happens here.

    - "def" functions are let through only if they fill the
    getbuffer/releasebuffer slots
524

525 526
    - cdef functions are let through only if they are on the
    top level and are declared "inline"
527
    """
528 529
    ERR_INLINE_ONLY = "function definition in pxd file must be declared 'cdef inline'"
    ERR_NOGO_WITH_INLINE = "inline function definition in pxd file cannot be '%s'"
530 531 532 533 534 535 536 537 538 539 540 541 542 543 544

    def __call__(self, node):
        self.scope_type = 'pxd'
        return super(PxdPostParse, self).__call__(node)

    def visit_CClassDefNode(self, node):
        old = self.scope_type
        self.scope_type = 'cclass'
        self.visitchildren(node)
        self.scope_type = old
        return node

    def visit_FuncDefNode(self, node):
        # FuncDefNode always come with an implementation (without
        # an imp they are CVarDefNodes..)
545
        err = self.ERR_INLINE_ONLY
546

547
        if (isinstance(node, Nodes.DefNode) and self.scope_type == 'cclass'
548
            and node.name in ('__getbuffer__', '__releasebuffer__')):
549
            err = None # allow these slots
550

551
        if isinstance(node, Nodes.CFuncDefNode):
552 553 554 555 556 557 558 559 560
            if u'inline' in node.modifiers and self.scope_type == 'pxd':
                node.inline_in_pxd = True
                if node.visibility != 'private':
                    err = self.ERR_NOGO_WITH_INLINE % node.visibility
                elif node.api:
                    err = self.ERR_NOGO_WITH_INLINE % 'api'
                else:
                    err = None # allow inline function
            else:
561 562
                err = self.ERR_INLINE_ONLY

563 564
        if err:
            self.context.nonfatal_error(PostParseError(node.pos, err))
565 566 567
            return None
        else:
            return node
568

569
class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
570
    """
571
    After parsing, directives can be stored in a number of places:
572 573
    - #cython-comments at the top of the file (stored in ModuleNode)
    - Command-line arguments overriding these
574 575
    - @cython.directivename decorators
    - with cython.directivename: statements
576

577
    This transform is responsible for interpreting these various sources
578
    and store the directive in two ways:
579 580 581 582 583 584 585 586 587 588 589
    - Set the directives attribute of the ModuleNode for global directives.
    - Use a CompilerDirectivesNode to override directives for a subtree.

    (The first one is primarily to not have to modify with the tree
    structure, so that ModuleNode stay on top.)

    The directives are stored in dictionaries from name to value in effect.
    Each such dictionary is always filled in for all possible directives,
    using default values where no value is given by the user.

    The available directives are controlled in Options.py.
590 591 592

    Note that we have to run this prior to analysis, and so some minor
    duplication of functionality has to occur: We manually track cimports
593
    and which names the "cython" module may have been imported to.
594
    """
595
    unop_method_nodes = {
596
        'typeof': ExprNodes.TypeofNode,
597

598 599 600 601 602 603
        'operator.address': ExprNodes.AmpersandNode,
        'operator.dereference': ExprNodes.DereferenceNode,
        'operator.preincrement' : ExprNodes.inc_dec_constructor(True, '++'),
        'operator.predecrement' : ExprNodes.inc_dec_constructor(True, '--'),
        'operator.postincrement': ExprNodes.inc_dec_constructor(False, '++'),
        'operator.postdecrement': ExprNodes.inc_dec_constructor(False, '--'),
604

605
        # For backwards compatability.
606
        'address': ExprNodes.AmpersandNode,
607
    }
Robert Bradshaw's avatar
Robert Bradshaw committed
608 609

    binop_method_nodes = {
610
        'operator.comma'        : ExprNodes.c_binop_constructor(','),
Robert Bradshaw's avatar
Robert Bradshaw committed
611
    }
612

Stefan Behnel's avatar
Stefan Behnel committed
613
    special_methods = cython.set(['declare', 'union', 'struct', 'typedef', 'sizeof',
Mark Florisson's avatar
Mark Florisson committed
614
                                  'cast', 'pointer', 'compiled', 'NULL', 'parallel'])
Stefan Behnel's avatar
Stefan Behnel committed
615
    special_methods.update(unop_method_nodes.keys())
616

Mark Florisson's avatar
Mark Florisson committed
617 618 619 620 621 622 623
    valid_parallel_directives = cython.set([
        "parallel",
        "prange",
        "threadid",
#        "threadsavailable",
    ])

624
    def __init__(self, context, compilation_directive_defaults):
625
        super(InterpretCompilerDirectives, self).__init__(context)
626
        self.compilation_directive_defaults = {}
627
        for key, value in compilation_directive_defaults.items():
628
            self.compilation_directive_defaults[unicode(key)] = copy.deepcopy(value)
Stefan Behnel's avatar
Stefan Behnel committed
629
        self.cython_module_names = cython.set()
630
        self.directive_names = {}
Mark Florisson's avatar
Mark Florisson committed
631
        self.parallel_directives = {}
632

633
    def check_directive_scope(self, pos, directive, scope):
634
        legal_scopes = Options.directive_scopes.get(directive, None)
635 636 637 638 639
        if legal_scopes and scope not in legal_scopes:
            self.context.nonfatal_error(PostParseError(pos, 'The %s compiler directive '
                                        'is not allowed in %s scope' % (directive, scope)))
            return False
        else:
640 641
            if (directive not in Options.directive_defaults
                    and directive not in Options.directive_types):
642
                error(pos, "Invalid directive: '%s'." % (directive,))
643
            return True
644

645
    # Set up processing and handle the cython: comments.
646
    def visit_ModuleNode(self, node):
647
        for key, value in node.directive_comments.items():
648 649
            if not self.check_directive_scope(node.pos, key, 'module'):
                self.wrong_scope_error(node.pos, key, 'module')
650 651
                del node.directive_comments[key]

652 653
        self.module_scope = node.scope

654 655
        directives = copy.deepcopy(Options.directive_defaults)
        directives.update(copy.deepcopy(self.compilation_directive_defaults))
656 657 658
        directives.update(node.directive_comments)
        self.directives = directives
        node.directives = directives
Mark Florisson's avatar
Mark Florisson committed
659
        node.parallel_directives = self.parallel_directives
660
        self.visitchildren(node)
661
        node.cython_module_names = self.cython_module_names
662 663
        return node

664 665 666 667 668 669 670
    # The following four functions track imports and cimports that
    # begin with "cython"
    def is_cython_directive(self, name):
        return (name in Options.directive_types or
                name in self.special_methods or
                PyrexTypes.parse_basic_type(name))

Mark Florisson's avatar
Mark Florisson committed
671
    def is_parallel_directive(self, full_name, pos):
Mark Florisson's avatar
Mark Florisson committed
672 673 674 675 676
        """
        Checks to see if fullname (e.g. cython.parallel.prange) is a valid
        parallel directive. If it is a star import it also updates the
        parallel_directives.
        """
Mark Florisson's avatar
Mark Florisson committed
677 678 679
        result = (full_name + ".").startswith("cython.parallel.")

        if result:
Mark Florisson's avatar
Mark Florisson committed
680 681
            directive = full_name.split('.')
            if full_name == u"cython.parallel.*":
682 683
                for name in self.valid_parallel_directives:
                    self.parallel_directives[name] = u"cython.parallel.%s" % name
Mark Florisson's avatar
Mark Florisson committed
684 685
            elif (len(directive) != 3 or
                  directive[-1] not in self.valid_parallel_directives):
Mark Florisson's avatar
Mark Florisson committed
686 687
                error(pos, "No such directive: %s" % full_name)

688 689
            self.module_scope.use_utility_code(Nodes.init_threads)

Mark Florisson's avatar
Mark Florisson committed
690 691
        return result

692 693
    def visit_CImportStatNode(self, node):
        if node.module_name == u"cython":
694
            self.cython_module_names.add(node.as_name or u"cython")
695
        elif node.module_name.startswith(u"cython."):
Mark Florisson's avatar
Mark Florisson committed
696 697 698
            if node.module_name.startswith(u"cython.parallel."):
                error(node.pos, node.module_name + " is not a module")
            if node.module_name == u"cython.parallel":
699
                if node.as_name and node.as_name != u"cython":
Mark Florisson's avatar
Mark Florisson committed
700 701 702 703 704
                    self.parallel_directives[node.as_name] = node.module_name
                else:
                    self.cython_module_names.add(u"cython")
                    self.parallel_directives[
                                    u"cython.parallel"] = node.module_name
705
                self.module_scope.use_utility_code(Nodes.init_threads)
Mark Florisson's avatar
Mark Florisson committed
706
            elif node.as_name:
707
                self.directive_names[node.as_name] = node.module_name[7:]
708
            else:
709
                self.cython_module_names.add(u"cython")
710 711 712
            # if this cimport was a compiler directive, we don't
            # want to leave the cimport node sitting in the tree
            return None
713
        return node
714

715
    def visit_FromCImportStatNode(self, node):
716 717
        if (node.module_name == u"cython") or \
               node.module_name.startswith(u"cython."):
718
            submodule = (node.module_name + u".")[7:]
719
            newimp = []
Mark Florisson's avatar
Mark Florisson committed
720

721
            for pos, name, as_name, kind in node.imported_names:
722
                full_name = submodule + name
Mark Florisson's avatar
Mark Florisson committed
723 724 725 726 727 728 729
                qualified_name = u"cython." + full_name

                if self.is_parallel_directive(qualified_name, node.pos):
                    # from cython cimport parallel, or
                    # from cython.parallel cimport parallel, prange, ...
                    self.parallel_directives[as_name or name] = qualified_name
                elif self.is_cython_directive(full_name):
Robert Bradshaw's avatar
Robert Bradshaw committed
730
                    if as_name is None:
731
                        as_name = full_name
Mark Florisson's avatar
Mark Florisson committed
732

733
                    self.directive_names[as_name] = full_name
734 735
                    if kind is not None:
                        self.context.nonfatal_error(PostParseError(pos,
736
                            "Compiler directive imports must be plain imports"))
737 738
                else:
                    newimp.append((pos, name, as_name, kind))
Mark Florisson's avatar
Mark Florisson committed
739

Robert Bradshaw's avatar
Robert Bradshaw committed
740 741
            if not newimp:
                return None
Mark Florisson's avatar
Mark Florisson committed
742

Robert Bradshaw's avatar
Robert Bradshaw committed
743
            node.imported_names = newimp
744
        return node
745

Robert Bradshaw's avatar
Robert Bradshaw committed
746
    def visit_FromImportStatNode(self, node):
747 748
        if (node.module.module_name.value == u"cython") or \
               node.module.module_name.value.startswith(u"cython."):
749
            submodule = (node.module.module_name.value + u".")[7:]
Robert Bradshaw's avatar
Robert Bradshaw committed
750
            newimp = []
751
            for name, name_node in node.items:
752
                full_name = submodule + name
Mark Florisson's avatar
Mark Florisson committed
753 754 755 756
                qualified_name = u"cython." + full_name
                if self.is_parallel_directive(qualified_name, node.pos):
                    self.parallel_directives[name_node.name] = qualified_name
                elif self.is_cython_directive(full_name):
757
                    self.directive_names[name_node.name] = full_name
Robert Bradshaw's avatar
Robert Bradshaw committed
758
                else:
759
                    newimp.append((name, name_node))
Robert Bradshaw's avatar
Robert Bradshaw committed
760 761 762 763 764
            if not newimp:
                return None
            node.items = newimp
        return node

765
    def visit_SingleAssignmentNode(self, node):
766 767 768 769 770 771
        if isinstance(node.rhs, ExprNodes.ImportNode):
            module_name = node.rhs.module_name.value
            is_parallel = (module_name + u".").startswith(u"cython.parallel.")

            if module_name != u"cython" and not is_parallel:
                return node
Mark Florisson's avatar
Mark Florisson committed
772 773 774 775

            module_name = node.rhs.module_name.value
            as_name = node.lhs.name

776
            node = Nodes.CImportStatNode(node.pos,
Mark Florisson's avatar
Mark Florisson committed
777 778
                                         module_name = module_name,
                                         as_name = as_name)
779
            node = self.visit_CImportStatNode(node)
780 781
        else:
            self.visitchildren(node)
782

783
        return node
784

785 786 787
    def visit_NameNode(self, node):
        if node.name in self.cython_module_names:
            node.is_cython_module = True
Robert Bradshaw's avatar
Robert Bradshaw committed
788
        else:
789
            node.cython_attribute = self.directive_names.get(node.name)
790
        return node
791

792
    def try_to_parse_directives(self, node):
793
        # If node is the contents of an directive (in a with statement or
794
        # decorator), returns a list of (directivename, value) pairs.
795
        # Otherwise, returns None
796
        if isinstance(node, ExprNodes.CallNode):
Robert Bradshaw's avatar
Robert Bradshaw committed
797
            self.visit(node.function)
798
            optname = node.function.as_cython_attribute()
799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820
            if optname:
                directivetype = Options.directive_types.get(optname)
                if directivetype:
                    args, kwds = node.explicit_args_kwds()
                    directives = []
                    key_value_pairs = []
                    if kwds is not None and directivetype is not dict:
                        for keyvalue in kwds.key_value_pairs:
                            key, value = keyvalue
                            sub_optname = "%s.%s" % (optname, key.value)
                            if Options.directive_types.get(sub_optname):
                                directives.append(self.try_to_parse_directive(sub_optname, [value], None, keyvalue.pos))
                            else:
                                key_value_pairs.append(keyvalue)
                        if not key_value_pairs:
                            kwds = None
                        else:
                            kwds.key_value_pairs = key_value_pairs
                        if directives and not kwds and not args:
                            return directives
                    directives.append(self.try_to_parse_directive(optname, args, kwds, node.function.pos))
                    return directives
821
        elif isinstance(node, (ExprNodes.AttributeNode, ExprNodes.NameNode)):
822 823 824 825 826 827 828 829 830 831 832
            self.visit(node)
            optname = node.as_cython_attribute()
            if optname:
                directivetype = Options.directive_types.get(optname)
                if directivetype is bool:
                    return [(optname, True)]
                elif directivetype is None:
                    return [(optname, None)]
                else:
                    raise PostParseError(
                        node.pos, "The '%s' directive should be used as a function call." % optname)
833
        return None
834

835 836
    def try_to_parse_directive(self, optname, args, kwds, pos):
        directivetype = Options.directive_types.get(optname)
837
        if len(args) == 1 and isinstance(args[0], ExprNodes.NoneNode):
838
            return optname, Options.directive_defaults[optname]
839
        elif directivetype is bool:
840
            if kwds is not None or len(args) != 1 or not isinstance(args[0], ExprNodes.BoolNode):
841 842 843 844
                raise PostParseError(pos,
                    'The %s directive takes one compile-time boolean argument' % optname)
            return (optname, args[0].value)
        elif directivetype is str:
845 846
            if kwds is not None or len(args) != 1 or not isinstance(args[0], (ExprNodes.StringNode,
                                                                              ExprNodes.UnicodeNode)):
847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862
                raise PostParseError(pos,
                    'The %s directive takes one compile-time string argument' % optname)
            return (optname, str(args[0].value))
        elif directivetype is dict:
            if len(args) != 0:
                raise PostParseError(pos,
                    'The %s directive takes no prepositional arguments' % optname)
            return optname, dict([(key.value, value) for key, value in kwds.key_value_pairs])
        elif directivetype is list:
            if kwds and len(kwds) != 0:
                raise PostParseError(pos,
                    'The %s directive takes no keyword arguments' % optname)
            return optname, [ str(arg.value) for arg in args ]
        else:
            assert False

863 864 865 866 867
    def visit_with_directives(self, body, directives):
        olddirectives = self.directives
        newdirectives = copy.copy(olddirectives)
        newdirectives.update(directives)
        self.directives = newdirectives
868
        assert isinstance(body, Nodes.StatListNode), body
869
        retbody = self.visit_Node(body)
870 871
        directive = Nodes.CompilerDirectivesNode(pos=retbody.pos, body=retbody,
                                                 directives=newdirectives)
872
        self.directives = olddirectives
873
        return directive
874

875
    # Handle decorators
876
    def visit_FuncDefNode(self, node):
877 878 879
        directives = self._extract_directives(node, 'function')
        if not directives:
            return self.visit_Node(node)
880
        body = Nodes.StatListNode(node.pos, stats=[node])
881 882 883
        return self.visit_with_directives(body, directives)

    def visit_CVarDefNode(self, node):
884 885
        if not node.decorators:
            return node
886 887 888 889
        for dec in node.decorators:
            for directive in self.try_to_parse_directives(dec.decorator) or ():
                if directive is not None and directive[0] == u'locals':
                    node.directive_locals = directive[1]
890
                else:
891 892 893 894 895 896 897 898
                    self.context.nonfatal_error(PostParseError(dec.pos,
                        "Cdef functions can only take cython.locals() decorator."))
        return node

    def visit_CClassDefNode(self, node):
        directives = self._extract_directives(node, 'cclass')
        if not directives:
            return self.visit_Node(node)
899
        body = Nodes.StatListNode(node.pos, stats=[node])
900 901
        return self.visit_with_directives(body, directives)

902 903 904 905
    def visit_PyClassDefNode(self, node):
        directives = self._extract_directives(node, 'class')
        if not directives:
            return self.visit_Node(node)
906
        body = Nodes.StatListNode(node.pos, stats=[node])
907 908
        return self.visit_with_directives(body, directives)

909 910 911 912 913 914 915 916 917 918 919 920
    def _extract_directives(self, node, scope_name):
        if not node.decorators:
            return {}
        # Split the decorators into two lists -- real decorators and directives
        directives = []
        realdecs = []
        for dec in node.decorators:
            new_directives = self.try_to_parse_directives(dec.decorator)
            if new_directives is not None:
                for directive in new_directives:
                    if self.check_directive_scope(node.pos, directive[0], scope_name):
                        directives.append(directive)
921
            else:
922
                realdecs.append(dec)
923
        if realdecs and isinstance(node, (Nodes.CFuncDefNode, Nodes.CClassDefNode)):
924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939
            raise PostParseError(realdecs[0].pos, "Cdef functions/classes cannot take arbitrary decorators.")
        else:
            node.decorators = realdecs
        # merge or override repeated directives
        optdict = {}
        directives.reverse() # Decorators coming first take precedence
        for directive in directives:
            name, value = directive
            if name in optdict:
                old_value = optdict[name]
                # keywords and arg lists can be merged, everything
                # else overrides completely
                if isinstance(old_value, dict):
                    old_value.update(value)
                elif isinstance(old_value, list):
                    old_value.extend(value)
940 941
                else:
                    optdict[name] = value
942 943 944 945
            else:
                optdict[name] = value
        return optdict

946 947
    # Handle with statements
    def visit_WithStatNode(self, node):
948 949 950 951 952 953 954 955
        directive_dict = {}
        for directive in self.try_to_parse_directives(node.manager) or []:
            if directive is not None:
                if node.target is not None:
                    self.context.nonfatal_error(
                        PostParseError(node.pos, "Compiler directive with statements cannot contain 'as'"))
                else:
                    name, value = directive
956
                    if name in ('nogil', 'gil'):
957
                        # special case: in pure mode, "with nogil" spells "with cython.nogil"
958
                        node = Nodes.GILStatNode(node.pos, state = name, body = node.body)
959
                        return self.visit_Node(node)
960 961 962 963
                    if self.check_directive_scope(node.pos, name, 'with statement'):
                        directive_dict[name] = value
        if directive_dict:
            return self.visit_with_directives(node.body, directive_dict)
964
        return self.visit_Node(node)
965

Mark Florisson's avatar
Mark Florisson committed
966 967 968 969 970 971 972

class ParallelRangeTransform(CythonTransform, SkipDeclarations):
    """
    Transform cython.parallel stuff. The parallel_directives come from the
    module node, set there by InterpretCompilerDirectives.

        x = cython.parallel.threadavailable()   -> ParallelThreadAvailableNode
973
        with nogil, cython.parallel.parallel(): -> ParallelWithBlockNode
Mark Florisson's avatar
Mark Florisson committed
974 975 976 977 978 979 980 981 982 983 984 985 986 987 988
            print cython.parallel.threadid()    -> ParallelThreadIdNode
            for i in cython.parallel.prange(...):  -> ParallelRangeNode
                ...
    """

    # a list of names, maps 'cython.parallel.prange' in the code to
    # ['cython', 'parallel', 'prange']
    parallel_directive = None

    # Indicates whether a namenode in an expression is the cython module
    namenode_is_cython_module = False

    # Keep track of whether we are the context manager of a 'with' statement
    in_context_manager_section = False

989 990 991 992
    # One of 'prange' or 'with parallel'. This is used to disallow closely
    # nested 'with parallel:' blocks
    state = None

Mark Florisson's avatar
Mark Florisson committed
993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051
    directive_to_node = {
        u"cython.parallel.parallel": Nodes.ParallelWithBlockNode,
        # u"cython.parallel.threadsavailable": ExprNodes.ParallelThreadsAvailableNode,
        u"cython.parallel.threadid": ExprNodes.ParallelThreadIdNode,
        u"cython.parallel.prange": Nodes.ParallelRangeNode,
    }

    def node_is_parallel_directive(self, node):
        return node.name in self.parallel_directives or node.is_cython_module

    def get_directive_class_node(self, node):
        """
        Figure out which parallel directive was used and return the associated
        Node class.

        E.g. for a cython.parallel.prange() call we return ParallelRangeNode
        """
        if self.namenode_is_cython_module:
            directive = '.'.join(self.parallel_directive)
        else:
            directive = self.parallel_directives[self.parallel_directive[0]]
            directive = '%s.%s' % (directive,
                                   '.'.join(self.parallel_directive[1:]))
            directive = directive.rstrip('.')

        cls = self.directive_to_node.get(directive)
        if cls is None:
            error(node.pos, "Invalid directive: %s" % directive)

        self.namenode_is_cython_module = False
        self.parallel_directive = None

        return cls

    def visit_ModuleNode(self, node):
        """
        If any parallel directives were imported, copy them over and visit
        the AST
        """
        if node.parallel_directives:
            self.parallel_directives = node.parallel_directives
            return self.visit_Node(node)

        # No parallel directives were imported, so they can't be used :)
        return node

    def visit_NameNode(self, node):
        if self.node_is_parallel_directive(node):
            self.parallel_directive = [node.name]
            self.namenode_is_cython_module = node.is_cython_module
        return node

    def visit_AttributeNode(self, node):
        self.visitchildren(node)
        if self.parallel_directive:
            self.parallel_directive.append(node.attribute)
        return node

    def visit_CallNode(self, node):
1052
        self.visit(node.function)
Mark Florisson's avatar
Mark Florisson committed
1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067
        if not self.parallel_directive:
            return node

        # We are a parallel directive, replace this node with the
        # corresponding ParallelSomethingSomething node

        if isinstance(node, ExprNodes.GeneralCallNode):
            args = node.positional_args.args
            kwargs = node.keyword_args
        else:
            args = node.args
            kwargs = {}

        parallel_directive_class = self.get_directive_class_node(node)
        if parallel_directive_class:
1068 1069
            # Note: in case of a parallel() the body is set by
            # visit_WithStatNode
Mark Florisson's avatar
Mark Florisson committed
1070 1071 1072 1073 1074
            node = parallel_directive_class(node.pos, args=args, kwargs=kwargs)

        return node

    def visit_WithStatNode(self, node):
1075 1076
        "Rewrite with cython.parallel.parallel() blocks"
        newnode = self.visit(node.manager)
Mark Florisson's avatar
Mark Florisson committed
1077

1078
        if isinstance(newnode, Nodes.ParallelWithBlockNode):
1079 1080
            if self.state == 'parallel with':
                error(node.manager.pos,
1081
                      "Closely nested parallel with blocks are disallowed")
1082 1083

            self.state = 'parallel with'
1084
            body = self.visit(node.body)
1085
            self.state = None
Mark Florisson's avatar
Mark Florisson committed
1086

1087 1088 1089 1090
            newnode.body = body
            return newnode
        elif self.parallel_directive:
            parallel_directive_class = self.get_directive_class_node(node)
1091

1092 1093 1094
            if not parallel_directive_class:
                # There was an error, stop here and now
                return None
Mark Florisson's avatar
Mark Florisson committed
1095

1096 1097 1098
            if parallel_directive_class is Nodes.ParallelWithBlockNode:
                error(node.pos, "The parallel directive must be called")
                return None
Mark Florisson's avatar
Mark Florisson committed
1099

1100 1101
        node.body = self.visit(node.body)
        return node
Mark Florisson's avatar
Mark Florisson committed
1102 1103 1104 1105 1106 1107

    def visit_ForInStatNode(self, node):
        "Rewrite 'for i in cython.parallel.prange(...):'"
        self.visit(node.iterator)
        self.visit(node.target)

1108 1109
        in_prange = isinstance(node.iterator.sequence,
                               Nodes.ParallelRangeNode)
1110
        previous_state = self.state
Mark Florisson's avatar
Mark Florisson committed
1111

1112
        if in_prange:
Mark Florisson's avatar
Mark Florisson committed
1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126
            # This will replace the entire ForInStatNode, so copy the
            # attributes
            parallel_range_node = node.iterator.sequence

            parallel_range_node.target = node.target
            parallel_range_node.body = node.body
            parallel_range_node.else_clause = node.else_clause

            node = parallel_range_node

            if not isinstance(node.target, ExprNodes.NameNode):
                error(node.target.pos,
                      "Can only iterate over an iteration variable")

1127
            self.state = 'prange'
Mark Florisson's avatar
Mark Florisson committed
1128

1129 1130 1131
        self.visit(node.body)
        self.state = previous_state
        self.visit(node.else_clause)
Mark Florisson's avatar
Mark Florisson committed
1132 1133 1134 1135 1136
        return node

    def visit(self, node):
        "Visit a node that may be None"
        if node is not None:
1137
            return super(ParallelRangeTransform, self).visit(node)
Mark Florisson's avatar
Mark Florisson committed
1138 1139


1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194
class WithTransform(CythonTransform, SkipDeclarations):
    def visit_WithStatNode(self, node):
        self.visitchildren(node, 'body')
        pos = node.pos
        body, target, manager = node.body, node.target, node.manager
        node.target_temp = ExprNodes.TempNode(pos, type=PyrexTypes.py_object_type)
        if target is not None:
            node.has_target = True
            body = Nodes.StatListNode(
                pos, stats = [
                    Nodes.WithTargetAssignmentStatNode(
                        pos, lhs = target, rhs = node.target_temp),
                    body
                    ])
            node.target = None

        excinfo_target = ResultRefNode(
            pos=pos, type=Builtin.tuple_type, may_hold_none=False)
        except_clause = Nodes.ExceptClauseNode(
            pos, body = Nodes.IfStatNode(
                pos, if_clauses = [
                    Nodes.IfClauseNode(
                        pos, condition = ExprNodes.NotNode(
                            pos, operand = ExprNodes.WithExitCallNode(
                                pos, with_stat = node,
                                args = excinfo_target)),
                        body = Nodes.ReraiseStatNode(pos),
                        ),
                    ],
                else_clause = None),
            pattern = None,
            target = None,
            excinfo_target = excinfo_target,
            )

        node.body = Nodes.TryFinallyStatNode(
            pos, body = Nodes.TryExceptStatNode(
                pos, body = body,
                except_clauses = [except_clause],
                else_clause = None,
                ),
            finally_clause = Nodes.ExprStatNode(
                pos, expr = ExprNodes.WithExitCallNode(
                    pos, with_stat = node,
                    args = ExprNodes.TupleNode(
                        pos, args = [ExprNodes.NoneNode(pos) for _ in range(3)]
                        ))),
            handle_error_case = False,
            )
        return node

    def visit_ExprNode(self, node):
        # With statements are never inside expressions.
        return node

1195

1196
class DecoratorTransform(CythonTransform, SkipDeclarations):
1197

1198
    def visit_DefNode(self, func_node):
1199
        self.visitchildren(func_node)
1200 1201
        if not func_node.decorators:
            return func_node
1202 1203 1204
        return self._handle_decorators(
            func_node, func_node.name)

1205 1206
    def visit_CClassDefNode(self, class_node):
        # This doesn't currently work, so it's disabled.
1207 1208 1209 1210 1211 1212 1213 1214 1215
        #
        # Problem: assignments to cdef class names do not work.  They
        # would require an additional check anyway, as the extension
        # type must not change its C type, so decorators cannot
        # replace an extension type, just alter it and return it.

        self.visitchildren(class_node)
        if not class_node.decorators:
            return class_node
1216 1217 1218 1219 1220
        error(class_node.pos,
              "Decorators not allowed on cdef classes (used on type '%s')" % class_node.class_name)
        return class_node
        #return self._handle_decorators(
        #    class_node, class_node.class_name)
1221 1222 1223 1224 1225 1226 1227 1228 1229

    def visit_ClassDefNode(self, class_node):
        self.visitchildren(class_node)
        if not class_node.decorators:
            return class_node
        return self._handle_decorators(
            class_node, class_node.name)

    def _handle_decorators(self, node, name):
1230
        decorator_result = ExprNodes.NameNode(node.pos, name = name)
1231
        for decorator in node.decorators[::-1]:
1232
            decorator_result = ExprNodes.SimpleCallNode(
1233 1234 1235 1236
                decorator.pos,
                function = decorator.decorator,
                args = [decorator_result])

1237 1238
        name_node = ExprNodes.NameNode(node.pos, name = name)
        reassignment = Nodes.SingleAssignmentNode(
1239 1240
            node.pos,
            lhs = name_node,
1241
            rhs = decorator_result)
1242
        return [node, reassignment]
1243

1244 1245 1246 1247 1248 1249 1250 1251 1252
class CnameDirectivesTransform(CythonTransform, SkipDeclarations):
    """
    Only part of the CythonUtilityCode pipeline. Must be run before
    DecoratorTransform in case this is a decorator for a cdef class.
    It filters out @cname('my_cname') decorators and rewrites them to
    CnameDecoratorNodes.
    """

    def handle_function(self, node):
1253 1254 1255
        if not node.decorators:
            return self.visit_Node(node)

1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276
        for i, decorator in enumerate(node.decorators):
            decorator = decorator.decorator

            if (isinstance(decorator, ExprNodes.CallNode) and
                    decorator.function.is_name and
                    decorator.function.name == 'cname'):
                args, kwargs = decorator.explicit_args_kwds()

                if kwargs:
                    raise AssertionError(
                            "cname decorator does not take keyword arguments")

                if len(args) != 1:
                    raise AssertionError(
                            "cname decorator takes exactly one argument")

                if not (args[0].is_literal and
                        args[0].type == Builtin.str_type):
                    raise AssertionError(
                            "argument to cname decorator must be a string literal")

1277
                cname = args[0].compile_time_value(None).decode('UTF-8')
1278 1279 1280 1281 1282
                del node.decorators[i]
                node = Nodes.CnameDecoratorNode(pos=node.pos, node=node,
                                                cname=cname)
                break

1283
        return self.visit_Node(node)
1284

1285 1286
    visit_FuncDefNode = handle_function
    visit_CClassDefNode = handle_function
1287 1288


1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325
class ForwardDeclareTypes(CythonTransform):

    def visit_CompilerDirectivesNode(self, node):
        env = self.module_scope
        old = env.directives
        env.directives = node.directives
        self.visitchildren(node)
        env.directives = old
        return node

    def visit_ModuleNode(self, node):
        self.module_scope = node.scope
        self.module_scope.directives = node.directives
        self.visitchildren(node)
        return node

    def visit_CDefExternNode(self, node):
        old_cinclude_flag = self.module_scope.in_cinclude
        self.module_scope.in_cinclude = 1
        self.visitchildren(node)
        self.module_scope.in_cinclude = old_cinclude_flag
        return node

    def visit_CEnumDefNode(self, node):
        node.declare(self.module_scope)
        return node

    def visit_CStructOrUnionDefNode(self, node):
        if node.name not in self.module_scope.entries:
            node.declare(self.module_scope)
        return node

    def visit_CClassDefNode(self, node):
        if node.class_name not in self.module_scope.entries:
            node.declare(self.module_scope)
        return node

1326

1327
class AnalyseDeclarationsTransform(CythonTransform):
1328

1329 1330 1331 1332 1333 1334 1335
    basic_property = TreeFragment(u"""
property NAME:
    def __get__(self):
        return ATTR
    def __set__(self, value):
        ATTR = value
    """, level='c_class')
1336 1337 1338 1339 1340 1341 1342 1343 1344
    basic_pyobject_property = TreeFragment(u"""
property NAME:
    def __get__(self):
        return ATTR
    def __set__(self, value):
        ATTR = value
    def __del__(self):
        ATTR = None
    """, level='c_class')
1345 1346 1347 1348 1349
    basic_property_ro = TreeFragment(u"""
property NAME:
    def __get__(self):
        return ATTR
    """, level='c_class')
1350

1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370
    struct_or_union_wrapper = TreeFragment(u"""
cdef class NAME:
    cdef TYPE value
    def __init__(self, MEMBER=None):
        cdef int count
        count = 0
        INIT_ASSIGNMENTS
        if IS_UNION and count > 1:
            raise ValueError, "At most one union member should be specified."
    def __str__(self):
        return STR_FORMAT % MEMBER_TUPLE
    def __repr__(self):
        return REPR_FORMAT % MEMBER_TUPLE
    """)

    init_assignment = TreeFragment(u"""
if VALUE is not None:
    ATTR = VALUE
    count += 1
    """)
1371

1372 1373
    def __call__(self, root):
        self.env_stack = [root.scope]
1374
        # needed to determine if a cdef var is declared after it's used.
1375
        self.seen_vars_stack = []
1376 1377
        return super(AnalyseDeclarationsTransform, self).__call__(root)

1378
    def visit_NameNode(self, node):
1379
        self.seen_vars_stack[-1].add(node.name)
1380 1381
        return node

1382
    def visit_ModuleNode(self, node):
Stefan Behnel's avatar
Stefan Behnel committed
1383
        self.seen_vars_stack.append(cython.set())
1384 1385
        node.analyse_declarations(self.env_stack[-1])
        self.visitchildren(node)
1386
        self.seen_vars_stack.pop()
1387
        return node
Stefan Behnel's avatar
Stefan Behnel committed
1388 1389 1390 1391 1392 1393

    def visit_LambdaNode(self, node):
        node.analyse_declarations(self.env_stack[-1])
        self.visitchildren(node)
        return node

1394 1395 1396 1397 1398
    def visit_ClassDefNode(self, node):
        self.env_stack.append(node.scope)
        self.visitchildren(node)
        self.env_stack.pop()
        return node
1399

1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412
    def visit_CClassDefNode(self, node):
        node = self.visit_ClassDefNode(node)
        if node.scope and node.scope.implemented:
            stats = []
            for entry in node.scope.var_entries:
                if entry.needs_property:
                    property = self.create_Property(entry)
                    property.analyse_declarations(node.scope)
                    self.visit(property)
                    stats.append(property)
            if stats:
                node.body.stats += stats
        return node
1413

1414
    def visit_FuncDefNode(self, node):
Stefan Behnel's avatar
Stefan Behnel committed
1415
        self.seen_vars_stack.append(cython.set())
1416
        lenv = node.local_scope
1417
        node.declare_arguments(lenv)
1418 1419 1420 1421 1422 1423 1424
        for var, type_node in node.directive_locals.items():
            if not lenv.lookup_here(var):   # don't redeclare args
                type = type_node.analyse_as_type(lenv)
                if type:
                    lenv.declare_var(var, type, type_node.pos)
                else:
                    error(type_node.pos, "Not a type")
1425
        node.body.analyse_declarations(lenv)
1426 1427

        if lenv.nogil and lenv.has_with_gil_block:
1428 1429 1430
            # Acquire the GIL for cleanup in 'nogil' functions, by wrapping
            # the entire function body in try/finally.
            # The corresponding release will be taken care of by
1431
            # Nodes.FuncDefNode.generate_function_definitions()
1432
            node.body = Nodes.NogilTryFinallyStatNode(
1433 1434 1435 1436 1437
                node.body.pos,
                body = node.body,
                finally_clause = Nodes.EnsureGILNode(node.body.pos),
            )

1438 1439 1440
        self.env_stack.append(lenv)
        self.visitchildren(node)
        self.env_stack.pop()
1441
        self.seen_vars_stack.pop()
1442
        return node
1443

1444
    def visit_ScopedExprNode(self, node):
1445 1446
        env = self.env_stack[-1]
        node.analyse_declarations(env)
Stefan Behnel's avatar
Stefan Behnel committed
1447
        # the node may or may not have a local scope
1448
        if node.has_local_scope:
Stefan Behnel's avatar
Stefan Behnel committed
1449
            self.seen_vars_stack.append(cython.set(self.seen_vars_stack[-1]))
Stefan Behnel's avatar
Stefan Behnel committed
1450
            self.env_stack.append(node.expr_scope)
1451
            node.analyse_scoped_declarations(node.expr_scope)
Stefan Behnel's avatar
Stefan Behnel committed
1452 1453 1454
            self.visitchildren(node)
            self.env_stack.pop()
            self.seen_vars_stack.pop()
1455
        else:
1456
            node.analyse_scoped_declarations(env)
Stefan Behnel's avatar
Stefan Behnel committed
1457
            self.visitchildren(node)
1458 1459
        return node

1460 1461 1462 1463 1464
    def visit_TempResultFromStatNode(self, node):
        self.visitchildren(node)
        node.analyse_declarations(self.env_stack[-1])
        return node

1465
    def visit_CStructOrUnionDefNode(self, node):
1466
        # Create a wrapper node if needed.
1467 1468 1469
        # We want to use the struct type information (so it can't happen
        # before this phase) but also create new objects to be declared
        # (so it can't happen later).
1470
        # Note that we don't return the original node, as it is
1471 1472 1473
        # never used after this phase.
        if True: # private (default)
            return None
1474

1475 1476 1477 1478 1479 1480 1481 1482 1483 1484
        self_value = ExprNodes.AttributeNode(
            pos = node.pos,
            obj = ExprNodes.NameNode(pos=node.pos, name=u"self"),
            attribute = EncodedString(u"value"))
        var_entries = node.entry.type.scope.var_entries
        attributes = []
        for entry in var_entries:
            attributes.append(ExprNodes.AttributeNode(pos = entry.pos,
                                                      obj = self_value,
                                                      attribute = entry.name))
1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521
        # __init__ assignments
        init_assignments = []
        for entry, attr in zip(var_entries, attributes):
            # TODO: branch on visibility
            init_assignments.append(self.init_assignment.substitute({
                    u"VALUE": ExprNodes.NameNode(entry.pos, name = entry.name),
                    u"ATTR": attr,
                }, pos = entry.pos))

        # create the class
        str_format = u"%s(%s)" % (node.entry.type.name, ("%s, " * len(attributes))[:-2])
        wrapper_class = self.struct_or_union_wrapper.substitute({
            u"INIT_ASSIGNMENTS": Nodes.StatListNode(node.pos, stats = init_assignments),
            u"IS_UNION": ExprNodes.BoolNode(node.pos, value = not node.entry.type.is_struct),
            u"MEMBER_TUPLE": ExprNodes.TupleNode(node.pos, args=attributes),
            u"STR_FORMAT": ExprNodes.StringNode(node.pos, value = EncodedString(str_format)),
            u"REPR_FORMAT": ExprNodes.StringNode(node.pos, value = EncodedString(str_format.replace("%s", "%r"))),
        }, pos = node.pos).stats[0]
        wrapper_class.class_name = node.name
        wrapper_class.shadow = True
        class_body = wrapper_class.body.stats

        # fix value type
        assert isinstance(class_body[0].base_type, Nodes.CSimpleBaseTypeNode)
        class_body[0].base_type.name = node.name

        # fix __init__ arguments
        init_method = class_body[1]
        assert isinstance(init_method, Nodes.DefNode) and init_method.name == '__init__'
        arg_template = init_method.args[1]
        if not node.entry.type.is_struct:
            arg_template.kw_only = True
        del init_method.args[1]
        for entry, attr in zip(var_entries, attributes):
            arg = copy.deepcopy(arg_template)
            arg.declarator.name = entry.name
            init_method.args.append(arg)
Robert Bradshaw's avatar
Robert Bradshaw committed
1522

1523
        # setters/getters
1524 1525 1526 1527 1528 1529 1530 1531 1532 1533
        for entry, attr in zip(var_entries, attributes):
            # TODO: branch on visibility
            if entry.type.is_pyobject:
                template = self.basic_pyobject_property
            else:
                template = self.basic_property
            property = template.substitute({
                    u"ATTR": attr,
                }, pos = entry.pos).stats[0]
            property.name = entry.name
1534
            wrapper_class.body.stats.append(property)
Robert Bradshaw's avatar
Robert Bradshaw committed
1535

1536 1537
        wrapper_class.analyse_declarations(self.env_stack[-1])
        return self.visit_CClassDefNode(wrapper_class)
1538

1539 1540 1541 1542
    # Some nodes are no longer needed after declaration
    # analysis and can be dropped. The analysis was performed
    # on these nodes in a seperate recursive process from the
    # enclosing function or module, so we can simply drop them.
1543
    def visit_CDeclaratorNode(self, node):
1544 1545
        # necessary to ensure that all CNameDeclaratorNodes are visited.
        self.visitchildren(node)
1546
        return node
1547

1548 1549 1550 1551 1552
    def visit_CTypeDefNode(self, node):
        return node

    def visit_CBaseTypeNode(self, node):
        return None
1553

1554
    def visit_CEnumDefNode(self, node):
1555 1556 1557 1558
        if node.visibility == 'public':
            return node
        else:
            return None
1559

1560
    def visit_CNameDeclaratorNode(self, node):
1561 1562
        if node.name in self.seen_vars_stack[-1]:
            entry = self.env_stack[-1].lookup(node.name)
Dag Sverre Seljebotn's avatar
Dag Sverre Seljebotn committed
1563 1564
            if (entry is None or entry.visibility != 'extern'
                and not entry.scope.is_c_class_scope):
1565
                warning(node.pos, "cdef variable '%s' declared after it is used" % node.name, 2)
1566 1567 1568
        self.visitchildren(node)
        return node

1569
    def visit_CVarDefNode(self, node):
1570 1571
        # to ensure all CNameDeclaratorNodes are visited.
        self.visitchildren(node)
1572
        return None
1573

1574
    def create_Property(self, entry):
1575
        if entry.visibility == 'public':
1576 1577 1578 1579
            if entry.type.is_pyobject:
                template = self.basic_pyobject_property
            else:
                template = self.basic_property
1580 1581
        elif entry.visibility == 'readonly':
            template = self.basic_property_ro
1582
        property = template.substitute({
1583
                u"ATTR": ExprNodes.AttributeNode(pos=entry.pos,
1584
                                                 obj=ExprNodes.NameNode(pos=entry.pos, name="self"),
1585
                                                 attribute=entry.name),
1586 1587
            }, pos=entry.pos).stats[0]
        property.name = entry.name
1588 1589 1590
        # ---------------------------------------
        # XXX This should go to AutoDocTransforms
        # ---------------------------------------
1591
        if (Options.docstrings and
1592
            self.current_directives['embedsignature']):
1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604
            attr_name = entry.name
            type_name = entry.type.declaration_code("", for_display=1)
            default_value = ''
            if not entry.type.is_pyobject:
                type_name = "'%s'" % type_name
            elif entry.type.is_extension_type:
                type_name = entry.type.module_name + '.' + type_name
            if entry.init is not None:
                default_value = ' = ' + entry.init
            docstring = attr_name + ': ' + type_name + default_value
            property.doc = EncodedString(docstring)
        # ---------------------------------------
1605
        return property
1606

1607

1608
class AnalyseExpressionsTransform(CythonTransform):
Robert Bradshaw's avatar
Robert Bradshaw committed
1609

1610
    def visit_ModuleNode(self, node):
1611
        node.scope.infer_types()
1612 1613 1614
        node.body.analyse_expressions(node.scope)
        self.visitchildren(node)
        return node
1615

1616
    def visit_FuncDefNode(self, node):
1617
        node.local_scope.infer_types()
1618 1619 1620
        node.body.analyse_expressions(node.local_scope)
        self.visitchildren(node)
        return node
1621 1622

    def visit_ScopedExprNode(self, node):
1623
        if node.has_local_scope:
1624 1625
            node.expr_scope.infer_types()
            node.analyse_scoped_expressions(node.expr_scope)
1626 1627
        self.visitchildren(node)
        return node
1628

1629

1630
class ExpandInplaceOperators(EnvTransform):
1631

1632 1633 1634 1635 1636 1637
    def visit_InPlaceAssignmentNode(self, node):
        lhs = node.lhs
        rhs = node.rhs
        if lhs.type.is_cpp_class:
            # No getting around this exact operator here.
            return node
1638
        if isinstance(lhs, ExprNodes.IndexNode) and lhs.is_buffer_access:
1639 1640 1641
            # There is code to handle this case.
            return node

Robert Bradshaw's avatar
Robert Bradshaw committed
1642
        env = self.current_env()
1643
        def side_effect_free_reference(node, setting=False):
1644
            if isinstance(node, ExprNodes.NameNode):
Robert Bradshaw's avatar
Robert Bradshaw committed
1645 1646
                return node, []
            elif node.type.is_pyobject and not setting:
1647 1648
                node = LetRefNode(node)
                return node, [node]
1649
            elif isinstance(node, ExprNodes.IndexNode):
1650 1651 1652 1653
                if node.is_buffer_access:
                    raise ValueError, "Buffer access"
                base, temps = side_effect_free_reference(node.base)
                index = LetRefNode(node.index)
1654 1655
                return ExprNodes.IndexNode(node.pos, base=base, index=index), temps + [index]
            elif isinstance(node, ExprNodes.AttributeNode):
1656
                obj, temps = side_effect_free_reference(node.obj)
1657
                return ExprNodes.AttributeNode(node.pos, obj=obj, attribute=node.attribute), temps
1658 1659 1660 1661 1662 1663 1664 1665
            else:
                node = LetRefNode(node)
                return node, [node]
        try:
            lhs, let_ref_nodes = side_effect_free_reference(lhs, setting=True)
        except ValueError:
            return node
        dup = lhs.__class__(**lhs.__dict__)
1666
        binop = ExprNodes.binop_node(node.pos,
1667 1668 1669 1670
                                     operator = node.operator,
                                     operand1 = dup,
                                     operand2 = rhs,
                                     inplace=True)
Robert Bradshaw's avatar
Robert Bradshaw committed
1671 1672 1673
        # Manually analyse types for new node.
        lhs.analyse_target_types(env)
        dup.analyse_types(env)
Robert Bradshaw's avatar
Robert Bradshaw committed
1674
        binop.analyse_operation(env)
1675
        node = Nodes.SingleAssignmentNode(
1676
            node.pos,
1677 1678
            lhs = lhs,
            rhs=binop.coerce_to(lhs.type, env))
1679 1680 1681 1682 1683 1684 1685 1686 1687 1688
        # Use LetRefNode to avoid side effects.
        let_ref_nodes.reverse()
        for t in let_ref_nodes:
            node = LetNode(t, node)
        return node

    def visit_ExprNode(self, node):
        # In-place assignments can't happen within an expression.
        return node

Haoyu Bai's avatar
Haoyu Bai committed
1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711
class AdjustDefByDirectives(CythonTransform, SkipDeclarations):
    """
    Adjust function and class definitions by the decorator directives:

    @cython.cfunc
    @cython.cclass
    @cython.ccall
    """

    def visit_ModuleNode(self, node):
        self.directives = node.directives
        self.in_py_class = False
        self.visitchildren(node)
        return node

    def visit_CompilerDirectivesNode(self, node):
        old_directives = self.directives
        self.directives = node.directives
        self.visitchildren(node)
        self.directives = old_directives
        return node

    def visit_DefNode(self, node):
Haoyu Bai's avatar
Haoyu Bai committed
1712 1713 1714
        if 'ccall' in self.directives:
            node = node.as_cfunction(overridable=True)
            return self.visit(node)
Haoyu Bai's avatar
Haoyu Bai committed
1715 1716 1717 1718 1719 1720 1721 1722 1723 1724
        if 'cfunc' in self.directives:
            if self.in_py_class:
                error(node.pos, "cfunc directive is not allowed here")
            else:
                node = node.as_cfunction(overridable=False)
                return self.visit(node)
        self.visitchildren(node)
        return node

    def visit_PyClassDefNode(self, node):
1725 1726 1727 1728 1729 1730 1731 1732 1733
        if 'cclass' in self.directives:
            node = node.as_cclass()
            return self.visit(node)
        else:
            old_in_pyclass = self.in_py_class
            self.in_py_class = True
            self.visitchildren(node)
            self.in_py_class = old_in_pyclass
            return node
Haoyu Bai's avatar
Haoyu Bai committed
1734 1735 1736 1737 1738 1739 1740

    def visit_CClassDefNode(self, node):
        old_in_pyclass = self.in_py_class
        self.in_py_class = False
        self.visitchildren(node)
        self.in_py_class = old_in_pyclass
        return node
1741

1742 1743
class AlignFunctionDefinitions(CythonTransform):
    """
1744 1745
    This class takes the signatures from a .pxd file and applies them to
    the def methods in a .py file.
1746
    """
1747

1748 1749
    def visit_ModuleNode(self, node):
        self.scope = node.scope
1750
        self.directives = node.directives
1751 1752
        self.visitchildren(node)
        return node
1753

1754 1755 1756 1757 1758 1759 1760
    def visit_PyClassDefNode(self, node):
        pxd_def = self.scope.lookup(node.name)
        if pxd_def:
            if pxd_def.is_cclass:
                return self.visit_CClassDefNode(node.as_cclass(), pxd_def)
            else:
                error(node.pos, "'%s' redeclared" % node.name)
1761 1762
                if pxd_def.pos:
                    error(pxd_def.pos, "previous declaration here")
1763
                return None
1764 1765
        else:
            return node
1766

1767 1768 1769 1770 1771 1772 1773 1774 1775 1776
    def visit_CClassDefNode(self, node, pxd_def=None):
        if pxd_def is None:
            pxd_def = self.scope.lookup(node.class_name)
        if pxd_def:
            outer_scope = self.scope
            self.scope = pxd_def.type.scope
        self.visitchildren(node)
        if pxd_def:
            self.scope = outer_scope
        return node
1777

1778 1779
    def visit_DefNode(self, node):
        pxd_def = self.scope.lookup(node.name)
1780
        if pxd_def and (not pxd_def.scope or not pxd_def.scope.is_builtin_scope):
1781
            if not pxd_def.is_cfunction:
1782
                error(node.pos, "'%s' redeclared" % node.name)
1783 1784
                if pxd_def.pos:
                    error(pxd_def.pos, "previous declaration here")
1785
                return None
1786
            node = node.as_cfunction(pxd_def)
1787 1788
        elif (self.scope.is_module_scope and self.directives['auto_cpdef']
              and node.is_cdef_func_compatible()):
1789
            node = node.as_cfunction(scope=self.scope)
1790
        # Enable this when nested cdef functions are allowed.
1791 1792
        # self.visitchildren(node)
        return node
1793

1794

1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831
class RemoveUnreachableCode(CythonTransform):
    def visit_Node(self, node):
        self.visitchildren(node)
        return node

    def visit_StatListNode(self, node):
        if not self.current_directives['remove_unreachable']:
            return node
        self.visitchildren(node)
        for idx, stat in enumerate(node.stats):
            idx += 1
            if stat.is_terminator:
                if idx < len(node.stats):
                    if self.current_directives['warn.unreachable']:
                        warning(node.stats[idx].pos, "Unreachable code", 2)
                    node.stats = node.stats[:idx]
                node.is_terminator = True
                break
        return node

    def visit_IfClauseNode(self, node):
        self.visitchildren(node)
        if node.body.is_terminator:
            node.is_terminator = True
        return node

    def visit_IfStatNode(self, node):
        self.visitchildren(node)
        if node.else_clause and node.else_clause.is_terminator:
            for clause in node.if_clauses:
                if not clause.is_terminator:
                    break
            else:
                node.is_terminator = True
        return node


1832 1833 1834 1835
class YieldNodeCollector(TreeVisitor):

    def __init__(self):
        super(YieldNodeCollector, self).__init__()
1836
        self.yields = []
1837 1838
        self.returns = []
        self.has_return_value = False
1839

1840 1841
    def visit_Node(self, node):
        return self.visitchildren(node)
1842 1843

    def visit_YieldExprNode(self, node):
1844
        if self.has_return_value:
1845
            error(node.pos, "'yield' outside function")
1846
        self.yields.append(node)
Vitja Makarov's avatar
Vitja Makarov committed
1847
        self.visitchildren(node)
1848 1849

    def visit_ReturnStatNode(self, node):
1850 1851 1852 1853 1854
        if node.value:
            self.has_return_value = True
            if self.yields:
                error(node.pos, "'return' with argument inside generator")
        self.returns.append(node)
1855 1856 1857 1858

    def visit_ClassDefNode(self, node):
        pass

1859
    def visit_FuncDefNode(self, node):
1860
        pass
1861

Vitja Makarov's avatar
Vitja Makarov committed
1862 1863 1864
    def visit_LambdaNode(self, node):
        pass

Vitja Makarov's avatar
Vitja Makarov committed
1865 1866 1867
    def visit_GeneratorExpressionNode(self, node):
        pass

1868
class MarkClosureVisitor(CythonTransform):
1869 1870 1871 1872 1873 1874

    def visit_ModuleNode(self, node):
        self.needs_closure = False
        self.visitchildren(node)
        return node

Robert Bradshaw's avatar
Robert Bradshaw committed
1875 1876 1877 1878 1879
    def visit_FuncDefNode(self, node):
        self.needs_closure = False
        self.visitchildren(node)
        node.needs_closure = self.needs_closure
        self.needs_closure = True
1880

1881 1882 1883 1884
        collector = YieldNodeCollector()
        collector.visitchildren(node)

        if collector.yields:
Vitja Makarov's avatar
Vitja Makarov committed
1885 1886
            for i, yield_expr in enumerate(collector.yields):
                yield_expr.label_num = i + 1
1887

1888 1889
            gbody = Nodes.GeneratorBodyDefNode(pos=node.pos,
                                               name=node.name,
1890
                                               body=node.body)
1891 1892 1893 1894 1895 1896 1897
            generator = Nodes.GeneratorDefNode(pos=node.pos,
                                               name=node.name,
                                               args=node.args,
                                               star_arg=node.star_arg,
                                               starstar_arg=node.starstar_arg,
                                               doc=node.doc,
                                               decorators=node.decorators,
Vitja Makarov's avatar
Vitja Makarov committed
1898 1899
                                               gbody=gbody,
                                               lambda_name=node.lambda_name)
1900
            return generator
Robert Bradshaw's avatar
Robert Bradshaw committed
1901
        return node
1902

1903 1904 1905 1906 1907
    def visit_CFuncDefNode(self, node):
        self.visit_FuncDefNode(node)
        if node.needs_closure:
            error(node.pos, "closures inside cdef functions not yet supported")
        return node
Stefan Behnel's avatar
Stefan Behnel committed
1908 1909 1910 1911 1912 1913 1914 1915

    def visit_LambdaNode(self, node):
        self.needs_closure = False
        self.visitchildren(node)
        node.needs_closure = self.needs_closure
        self.needs_closure = True
        return node

Robert Bradshaw's avatar
Robert Bradshaw committed
1916 1917 1918 1919
    def visit_ClassDefNode(self, node):
        self.visitchildren(node)
        self.needs_closure = True
        return node
Stefan Behnel's avatar
Stefan Behnel committed
1920

1921
class CreateClosureClasses(CythonTransform):
1922
    # Output closure classes in module scope for all functions
Vitja Makarov's avatar
Vitja Makarov committed
1923 1924 1925 1926 1927 1928
    # that really need it.

    def __init__(self, context):
        super(CreateClosureClasses, self).__init__(context)
        self.path = []
        self.in_lambda = False
1929
        self.generator_class = None
Vitja Makarov's avatar
Vitja Makarov committed
1930

1931 1932 1933 1934 1935
    def visit_ModuleNode(self, node):
        self.module_scope = node.scope
        self.visitchildren(node)
        return node

1936
    def create_generator_class(self, target_module_scope, pos):
1937 1938 1939
        if self.generator_class:
            return self.generator_class
        # XXX: make generator class creation cleaner
1940 1941 1942
        entry = target_module_scope.declare_c_class(name='__pyx_Generator',
                    objstruct_cname='__pyx_Generator_object',
                    typeobj_cname='__pyx_Generator_type',
1943 1944 1945 1946 1947 1948 1949
                    pos=pos, defining=True, implementing=True)
        klass = entry.type.scope
        klass.is_internal = True
        klass.directives = {'final': True}

        body_type = PyrexTypes.create_typedef_type('generator_body',
                                                   PyrexTypes.c_void_ptr_type,
1950
                                                   '__pyx_generator_body_t')
1951 1952 1953 1954 1955 1956
        klass.declare_var(pos=pos, name='body', cname='body',
                          type=body_type, is_cdef=True)
        klass.declare_var(pos=pos, name='is_running', cname='is_running', type=PyrexTypes.c_int_type,
                          is_cdef=True)
        klass.declare_var(pos=pos, name='resume_label', cname='resume_label', type=PyrexTypes.c_int_type,
                          is_cdef=True)
1957 1958 1959 1960 1961 1962
        klass.declare_var(pos=pos, name='exc_type', cname='exc_type',
                          type=PyrexTypes.py_object_type, is_cdef=True)
        klass.declare_var(pos=pos, name='exc_value', cname='exc_value',
                          type=PyrexTypes.py_object_type, is_cdef=True)
        klass.declare_var(pos=pos, name='exc_traceback', cname='exc_traceback',
                          type=PyrexTypes.py_object_type, is_cdef=True)
1963 1964 1965

        import TypeSlots
        e = klass.declare_pyfunction('send', pos)
1966
        e.func_cname = '__Pyx_Generator_Send'
1967 1968
        e.signature = TypeSlots.binaryfunc

Vitja Makarov's avatar
Vitja Makarov committed
1969
        e = klass.declare_pyfunction('close', pos)
1970
        e.func_cname = '__Pyx_Generator_Close'
Vitja Makarov's avatar
Vitja Makarov committed
1971
        e.signature = TypeSlots.unaryfunc
1972

1973
        e = klass.declare_pyfunction('throw', pos)
1974
        e.func_cname = '__Pyx_Generator_Throw'
1975
        e.signature = TypeSlots.pyfunction_signature
1976 1977 1978 1979 1980

        e = klass.declare_var('__iter__', PyrexTypes.py_object_type, pos, visibility='public')
        e.func_cname = 'PyObject_SelfIter'

        e = klass.declare_var('__next__', PyrexTypes.py_object_type, pos, visibility='public')
1981
        e.func_cname = '__Pyx_Generator_Next'
1982 1983 1984 1985

        self.generator_class = entry.type
        return self.generator_class

Stefan Behnel's avatar
Stefan Behnel committed
1986
    def find_entries_used_in_closures(self, node):
Vitja Makarov's avatar
Vitja Makarov committed
1987 1988 1989 1990 1991
        from_closure = []
        in_closure = []
        for name, entry in node.local_scope.entries.items():
            if entry.from_closure:
                from_closure.append((name, entry))
Stefan Behnel's avatar
Stefan Behnel committed
1992
            elif entry.in_closure:
Vitja Makarov's avatar
Vitja Makarov committed
1993 1994 1995 1996
                in_closure.append((name, entry))
        return from_closure, in_closure

    def create_class_from_scope(self, node, target_module_scope, inner_node=None):
1997 1998 1999
        # skip generator body
        if node.is_generator_body:
            return
2000 2001 2002 2003 2004 2005
        # move local variables into closure
        if node.is_generator:
            for entry in node.local_scope.entries.values():
                if not entry.from_closure:
                    entry.in_closure = True

Stefan Behnel's avatar
Stefan Behnel committed
2006
        from_closure, in_closure = self.find_entries_used_in_closures(node)
Vitja Makarov's avatar
Vitja Makarov committed
2007 2008 2009 2010 2011 2012
        in_closure.sort()

        # Now from the begining
        node.needs_closure = False
        node.needs_outer_scope = False

2013
        func_scope = node.local_scope
Vitja Makarov's avatar
Vitja Makarov committed
2014 2015 2016 2017
        cscope = node.entry.scope
        while cscope.is_py_class_scope or cscope.is_c_class_scope:
            cscope = cscope.outer_scope

2018
        if not from_closure and (self.path or inner_node):
Vitja Makarov's avatar
Vitja Makarov committed
2019 2020 2021 2022 2023 2024
            if not inner_node:
                if not node.assmt:
                    raise InternalError, "DefNode does not have assignment node"
                inner_node = node.assmt.rhs
            inner_node.needs_self_code = False
            node.needs_outer_scope = False
2025

Stefan Behnel's avatar
Stefan Behnel committed
2026
        base_type = None
2027
        if node.is_generator:
Stefan Behnel's avatar
Stefan Behnel committed
2028
            base_type = self.create_generator_class(target_module_scope, node.pos)
2029
        elif not in_closure and not from_closure:
Vitja Makarov's avatar
Vitja Makarov committed
2030 2031 2032 2033 2034 2035 2036 2037
            return
        elif not in_closure:
            func_scope.is_passthrough = True
            func_scope.scope_class = cscope.scope_class
            node.needs_outer_scope = True
            return

        as_name = '%s_%s' % (target_module_scope.next_id(Naming.closure_class_prefix), node.entry.cname)
2038

Stefan Behnel's avatar
Stefan Behnel committed
2039 2040 2041 2042
        entry = target_module_scope.declare_c_class(
            name=as_name, pos=node.pos, defining=True,
            implementing=True, base_type=base_type)

Robert Bradshaw's avatar
Robert Bradshaw committed
2043
        func_scope.scope_class = entry
2044
        class_scope = entry.type.scope
2045
        class_scope.is_internal = True
2046
        class_scope.directives = {'final': True}
2047

Vitja Makarov's avatar
Vitja Makarov committed
2048 2049
        if from_closure:
            assert cscope.is_closure_scope
2050
            class_scope.declare_var(pos=node.pos,
Vitja Makarov's avatar
Vitja Makarov committed
2051
                                    name=Naming.outer_scope_cname,
2052
                                    cname=Naming.outer_scope_cname,
2053
                                    type=cscope.scope_class.type,
2054
                                    is_cdef=True)
Vitja Makarov's avatar
Vitja Makarov committed
2055 2056
            node.needs_outer_scope = True
        for name, entry in in_closure:
2057
            closure_entry = class_scope.declare_var(pos=entry.pos,
2058
                                    name=entry.name,
2059
                                    cname=entry.cname,
2060 2061
                                    type=entry.type,
                                    is_cdef=True)
2062 2063
            if entry.is_declared_generic:
                closure_entry.is_declared_generic = 1
Vitja Makarov's avatar
Vitja Makarov committed
2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075
        node.needs_closure = True
        # Do it here because other classes are already checked
        target_module_scope.check_c_class(func_scope.scope_class)

    def visit_LambdaNode(self, node):
        was_in_lambda = self.in_lambda
        self.in_lambda = True
        self.create_class_from_scope(node.def_node, self.module_scope, node)
        self.visitchildren(node)
        self.in_lambda = was_in_lambda
        return node

2076
    def visit_FuncDefNode(self, node):
Vitja Makarov's avatar
Vitja Makarov committed
2077 2078 2079 2080
        if self.in_lambda:
            self.visitchildren(node)
            return node
        if node.needs_closure or self.path:
Robert Bradshaw's avatar
Robert Bradshaw committed
2081
            self.create_class_from_scope(node, self.module_scope)
Vitja Makarov's avatar
Vitja Makarov committed
2082
            self.path.append(node)
2083
            self.visitchildren(node)
Vitja Makarov's avatar
Vitja Makarov committed
2084
            self.path.pop()
2085
        return node
2086 2087 2088 2089 2090 2091 2092


class GilCheck(VisitorTransform):
    """
    Call `node.gil_check(env)` on each node to make sure we hold the
    GIL when we need it.  Raise an error when on Python operations
    inside a `nogil` environment.
2093 2094 2095

    Additionally, raise exceptions for closely nested with gil or with nogil
    statements. The latter would abort Python.
2096
    """
2097

2098 2099
    def __call__(self, root):
        self.env_stack = [root.scope]
2100
        self.nogil = False
2101 2102 2103 2104

        # True for 'cdef func() nogil:' functions, as the GIL may be held while
        # calling this function (thus contained 'nogil' blocks may be valid).
        self.nogil_declarator_only = False
2105 2106 2107 2108
        return super(GilCheck, self).__call__(root)

    def visit_FuncDefNode(self, node):
        self.env_stack.append(node.local_scope)
2109 2110
        was_nogil = self.nogil
        self.nogil = node.local_scope.nogil
Mark Florisson's avatar
Mark Florisson committed
2111

2112 2113 2114
        if self.nogil:
            self.nogil_declarator_only = True

2115 2116
        if self.nogil and node.nogil_check:
            node.nogil_check(node.local_scope)
Mark Florisson's avatar
Mark Florisson committed
2117

2118
        self.visitchildren(node)
2119 2120 2121 2122

        # This cannot be nested, so it doesn't need backup/restore
        self.nogil_declarator_only = False

2123
        self.env_stack.pop()
2124
        self.nogil = was_nogil
2125 2126 2127
        return node

    def visit_GILStatNode(self, node):
Mark Florisson's avatar
Mark Florisson committed
2128 2129 2130
        if self.nogil and node.nogil_check:
            node.nogil_check()

2131 2132
        was_nogil = self.nogil
        self.nogil = (node.state == 'nogil')
2133 2134 2135 2136 2137 2138 2139 2140 2141

        if was_nogil == self.nogil and not self.nogil_declarator_only:
            if not was_nogil:
                error(node.pos, "Trying to acquire the GIL while it is "
                                "already held.")
            else:
                error(node.pos, "Trying to release the GIL while it was "
                                "previously released.")

2142 2143 2144 2145 2146 2147 2148 2149
        if isinstance(node.finally_clause, Nodes.StatListNode):
            # The finally clause of the GILStatNode is a GILExitNode,
            # which is wrapped in a StatListNode. Just unpack that.
            node.finally_clause, = node.finally_clause.stats

        if node.state == 'gil':
            self.seen_with_gil_statement = True

2150
        self.visitchildren(node)
2151
        self.nogil = was_nogil
2152 2153
        return node

Mark Florisson's avatar
Mark Florisson committed
2154
    def visit_ParallelRangeNode(self, node):
2155 2156
        if node.nogil:
            node.nogil = False
Mark Florisson's avatar
Mark Florisson committed
2157 2158 2159 2160 2161 2162 2163 2164 2165 2166 2167 2168 2169 2170 2171 2172 2173 2174 2175 2176 2177 2178 2179 2180 2181
            node = Nodes.GILStatNode(node.pos, state='nogil', body=node)
            return self.visit_GILStatNode(node)

        if not self.nogil:
            error(node.pos, "prange() can only be used without the GIL")
            # Forget about any GIL-related errors that may occur in the body
            return None

        node.nogil_check(self.env_stack[-1])
        self.visitchildren(node)
        return node

    def visit_ParallelWithBlockNode(self, node):
        if not self.nogil:
            error(node.pos, "The parallel section may only be used without "
                            "the GIL")
            return None

        if node.nogil_check:
            # It does not currently implement this, but test for it anyway to
            # avoid potential future surprises
            node.nogil_check(self.env_stack[-1])

        self.visitchildren(node)
        return node
2182 2183 2184 2185 2186 2187 2188 2189 2190 2191 2192 2193 2194 2195 2196 2197 2198 2199 2200 2201 2202

    def visit_TryFinallyStatNode(self, node):
        """
        Take care of try/finally statements in nogil code sections. The
        'try' must contain a 'with gil:' statement somewhere.
        """
        if not self.nogil or isinstance(node, Nodes.GILStatNode):
            return self.visit_Node(node)

        node.nogil_check = None
        node.is_try_finally_in_nogil = True

        # First, visit the body and check for errors
        self.seen_with_gil_statement = False
        self.visitchildren(node.body)

        if not self.seen_with_gil_statement:
            error(node.pos, "Cannot use try/finally in nogil sections unless "
                            "it contains a 'with gil' statement.")

        self.visitchildren(node.finally_clause)
Mark Florisson's avatar
Mark Florisson committed
2203
        return node
Mark Florisson's avatar
Mark Florisson committed
2204

2205
    def visit_Node(self, node):
2206 2207
        if self.env_stack and self.nogil and node.nogil_check:
            node.nogil_check(self.env_stack[-1])
2208 2209 2210
        self.visitchildren(node)
        return node

2211

Robert Bradshaw's avatar
Robert Bradshaw committed
2212 2213
class TransformBuiltinMethods(EnvTransform):

2214 2215 2216 2217 2218 2219
    def visit_SingleAssignmentNode(self, node):
        if node.declaration_only:
            return None
        else:
            self.visitchildren(node)
            return node
2220

2221
    def visit_AttributeNode(self, node):
2222
        self.visitchildren(node)
2223 2224 2225 2226
        return self.visit_cython_attribute(node)

    def visit_NameNode(self, node):
        return self.visit_cython_attribute(node)
2227

2228 2229
    def visit_cython_attribute(self, node):
        attribute = node.as_cython_attribute()
2230 2231
        if attribute:
            if attribute == u'compiled':
2232
                node = ExprNodes.BoolNode(node.pos, value=True)
2233
            elif attribute == u'NULL':
2234
                node = ExprNodes.NullNode(node.pos)
2235
            elif attribute in (u'set', u'frozenset'):
2236 2237
                node = ExprNodes.NameNode(node.pos, name=EncodedString(attribute),
                                          entry=self.current_env().builtin_scope().lookup_here(attribute))
2238 2239
            elif PyrexTypes.parse_basic_type(attribute):
                pass
2240
            elif self.context.cython_scope.lookup_qualified_name(attribute):
2241 2242
                pass
            else:
Robert Bradshaw's avatar
Robert Bradshaw committed
2243
                error(node.pos, u"'%s' not a valid cython attribute or is being used incorrectly" % attribute)
2244 2245
        return node

2246
    def _inject_locals(self, node, func_name):
2247
        # locals()/dir()/vars() builtins
2248 2249 2250 2251 2252 2253
        lenv = self.current_env()
        entry = lenv.lookup_here(func_name)
        if entry:
            # not the builtin
            return node
        pos = node.pos
2254 2255
        if func_name in ('locals', 'vars'):
            if func_name == 'locals' and len(node.args) > 0:
2256 2257 2258
                error(self.pos, "Builtin 'locals()' called with wrong number of args, expected 0, got %d"
                      % len(node.args))
                return node
2259 2260 2261 2262 2263 2264
            elif func_name == 'vars':
                if len(node.args) > 1:
                    error(self.pos, "Builtin 'vars()' called with wrong number of args, expected 0-1, got %d"
                          % len(node.args))
                if len(node.args) > 0:
                    return node # nothing to do
2265 2266
            items = [ ExprNodes.DictItemNode(pos,
                                             key=ExprNodes.StringNode(pos, value=var),
Vitja Makarov's avatar
Vitja Makarov committed
2267
                                             value=ExprNodes.NameNode(pos, name=var, allow_null=True))
2268
                      for var in lenv.entries ]
Vitja Makarov's avatar
Vitja Makarov committed
2269
            return ExprNodes.DictNode(pos, key_value_pairs=items, exclude_null_values=True)
2270
        else: # dir()
2271 2272 2273
            if len(node.args) > 1:
                error(self.pos, "Builtin 'dir()' called with wrong number of args, expected 0-1, got %d"
                      % len(node.args))
2274
            if len(node.args) > 0:
2275 2276 2277
                # optimised in Builtin.py
                return node
            items = [ ExprNodes.StringNode(pos, value=var) for var in lenv.entries ]
2278
            return ExprNodes.ListNode(pos, args=items)
2279

2280
    def visit_SimpleCallNode(self, node):
Robert Bradshaw's avatar
Robert Bradshaw committed
2281
        if isinstance(node.function, ExprNodes.NameNode):
2282
            func_name = node.function.name
2283
            if func_name in ('dir', 'locals', 'vars'):
2284
                return self._inject_locals(node, func_name)
2285 2286

        # cython.foo
2287
        function = node.function.as_cython_attribute()
2288
        if function:
2289 2290 2291 2292 2293
            if function in InterpretCompilerDirectives.unop_method_nodes:
                if len(node.args) != 1:
                    error(node.function.pos, u"%s() takes exactly one argument" % function)
                else:
                    node = InterpretCompilerDirectives.unop_method_nodes[function](node.function.pos, operand=node.args[0])
Robert Bradshaw's avatar
Robert Bradshaw committed
2294 2295 2296 2297 2298
            elif function in InterpretCompilerDirectives.binop_method_nodes:
                if len(node.args) != 2:
                    error(node.function.pos, u"%s() takes exactly two arguments" % function)
                else:
                    node = InterpretCompilerDirectives.binop_method_nodes[function](node.function.pos, operand1=node.args[0], operand2=node.args[1])
2299
            elif function == u'cast':
2300
                if len(node.args) != 2:
2301
                    error(node.function.pos, u"cast() takes exactly two arguments")
2302
                else:
Stefan Behnel's avatar
Stefan Behnel committed
2303
                    type = node.args[0].analyse_as_type(self.current_env())
2304
                    if type:
2305
                        node = ExprNodes.TypecastNode(node.function.pos, type=type, operand=node.args[1])
2306 2307 2308 2309
                    else:
                        error(node.args[0].pos, "Not a type")
            elif function == u'sizeof':
                if len(node.args) != 1:
Robert Bradshaw's avatar
Robert Bradshaw committed
2310
                    error(node.function.pos, u"sizeof() takes exactly one argument")
2311
                else:
Stefan Behnel's avatar
Stefan Behnel committed
2312
                    type = node.args[0].analyse_as_type(self.current_env())
2313
                    if type:
2314
                        node = ExprNodes.SizeofTypeNode(node.function.pos, arg_type=type)
2315
                    else:
2316
                        node = ExprNodes.SizeofVarNode(node.function.pos, operand=node.args[0])
2317 2318
            elif function == 'cmod':
                if len(node.args) != 2:
Robert Bradshaw's avatar
Robert Bradshaw committed
2319
                    error(node.function.pos, u"cmod() takes exactly two arguments")
2320
                else:
2321
                    node = ExprNodes.binop_node(node.function.pos, '%', node.args[0], node.args[1])
2322 2323 2324
                    node.cdivision = True
            elif function == 'cdiv':
                if len(node.args) != 2:
Robert Bradshaw's avatar
Robert Bradshaw committed
2325
                    error(node.function.pos, u"cdiv() takes exactly two arguments")
2326
                else:
2327
                    node = ExprNodes.binop_node(node.function.pos, '/', node.args[0], node.args[1])
2328
                    node.cdivision = True
2329
            elif function == u'set':
2330
                node.function = ExprNodes.NameNode(node.pos, name=EncodedString('set'))
2331 2332
            elif self.context.cython_scope.lookup_qualified_name(function):
                pass
2333
            else:
2334 2335
                error(node.function.pos,
                      u"'%s' not a valid cython language construct" % function)
2336

2337
        self.visitchildren(node)
Robert Bradshaw's avatar
Robert Bradshaw committed
2338
        return node
2339 2340


2341 2342 2343 2344 2345 2346 2347 2348 2349 2350 2351 2352 2353 2354 2355 2356 2357 2358
class FindUninitializedParallelVars(CythonTransform, SkipDeclarations):
    """
    This transform isn't part of the pipeline, it simply finds all references
    to variables in parallel blocks.
    """

    def __init__(self):
        CythonTransform.__init__(self, None)
        self.used_vars = []

    def visit_ParallelStatNode(self, node):
        return node

    def visit_NameNode(self, node):
        self.used_vars.append((node.entry, node.pos))
        return node


Mark Florisson's avatar
Mark Florisson committed
2359
class DebugTransform(CythonTransform):
2360
    """
Mark Florisson's avatar
Mark Florisson committed
2361
    Write debug information for this Cython module.
2362
    """
2363

2364
    def __init__(self, context, options, result):
Mark Florisson's avatar
Mark Florisson committed
2365
        super(DebugTransform, self).__init__(context)
2366
        self.visited = cython.set()
2367
        # our treebuilder and debug output writer
Mark Florisson's avatar
Mark Florisson committed
2368
        # (see Cython.Debugger.debug_output.CythonDebugWriter)
2369
        self.tb = self.context.gdb_debug_outputwriter
2370
        #self.c_output_file = options.output_file
2371
        self.c_output_file = result.c_file
2372

2373 2374 2375
        # Closure support, basically treat nested functions as if the AST were
        # never nested
        self.nested_funcdefs = []
2376

Mark Florisson's avatar
Mark Florisson committed
2377 2378
        # tells visit_NameNode whether it should register step-into functions
        self.register_stepinto = False
2379

2380
    def visit_ModuleNode(self, node):
Mark Florisson's avatar
Mark Florisson committed
2381
        self.tb.module_name = node.full_module_name
2382
        attrs = dict(
Mark Florisson's avatar
Mark Florisson committed
2383
            module_name=node.full_module_name,
Mark Florisson's avatar
Mark Florisson committed
2384 2385
            filename=node.pos[0].filename,
            c_filename=self.c_output_file)
2386

2387
        self.tb.start('Module', attrs)
2388

2389
        # serialize functions
Mark Florisson's avatar
Mark Florisson committed
2390
        self.tb.start('Functions')
2391
        # First, serialize functions normally...
2392
        self.visitchildren(node)
2393

2394 2395 2396
        # ... then, serialize nested functions
        for nested_funcdef in self.nested_funcdefs:
            self.visit_FuncDefNode(nested_funcdef)
2397

2398 2399 2400
        self.register_stepinto = True
        self.serialize_modulenode_as_function(node)
        self.register_stepinto = False
2401
        self.tb.end('Functions')
2402

2403
        # 2.3 compatibility. Serialize global variables
Mark Florisson's avatar
Mark Florisson committed
2404
        self.tb.start('Globals')
2405
        entries = {}
Mark Florisson's avatar
Mark Florisson committed
2406

2407
        for k, v in node.scope.entries.iteritems():
Mark Florisson's avatar
Mark Florisson committed
2408
            if (v.qualified_name not in self.visited and not
2409
                v.name.startswith('__pyx_') and not
Mark Florisson's avatar
Mark Florisson committed
2410 2411
                v.type.is_cfunction and not
                v.type.is_extension_type):
2412
                entries[k]= v
2413

2414 2415
        self.serialize_local_variables(entries)
        self.tb.end('Globals')
Mark Florisson's avatar
Mark Florisson committed
2416 2417
        # self.tb.end('Module') # end Module after the line number mapping in
        # Cython.Compiler.ModuleNode.ModuleNode._serialize_lineno_map
2418
        return node
2419 2420

    def visit_FuncDefNode(self, node):
2421
        self.visited.add(node.local_scope.qualified_name)
2422 2423 2424 2425 2426 2427 2428 2429

        if getattr(node, 'is_wrapper', False):
            return node

        if self.register_stepinto:
            self.nested_funcdefs.append(node)
            return node

2430
        # node.entry.visibility = 'extern'
2431 2432 2433 2434
        if node.py_func is None:
            pf_cname = ''
        else:
            pf_cname = node.py_func.entry.func_cname
2435

2436 2437 2438 2439 2440 2441
        attrs = dict(
            name=node.entry.name,
            cname=node.entry.func_cname,
            pf_cname=pf_cname,
            qualified_name=node.local_scope.qualified_name,
            lineno=str(node.pos[1]))
2442

2443
        self.tb.start('Function', attrs=attrs)
2444

Mark Florisson's avatar
Mark Florisson committed
2445
        self.tb.start('Locals')
2446 2447
        self.serialize_local_variables(node.local_scope.entries)
        self.tb.end('Locals')
Mark Florisson's avatar
Mark Florisson committed
2448 2449

        self.tb.start('Arguments')
2450
        for arg in node.local_scope.arg_entries:
Mark Florisson's avatar
Mark Florisson committed
2451 2452
            self.tb.start(arg.name)
            self.tb.end(arg.name)
2453
        self.tb.end('Arguments')
Mark Florisson's avatar
Mark Florisson committed
2454 2455

        self.tb.start('StepIntoFunctions')
Mark Florisson's avatar
Mark Florisson committed
2456
        self.register_stepinto = True
Mark Florisson's avatar
Mark Florisson committed
2457
        self.visitchildren(node)
Mark Florisson's avatar
Mark Florisson committed
2458
        self.register_stepinto = False
Mark Florisson's avatar
Mark Florisson committed
2459
        self.tb.end('StepIntoFunctions')
2460
        self.tb.end('Function')
Mark Florisson's avatar
Mark Florisson committed
2461 2462 2463 2464

        return node

    def visit_NameNode(self, node):
2465 2466
        if (self.register_stepinto and
            node.type.is_cfunction and
2467 2468
            getattr(node, 'is_called', False) and
            node.entry.func_cname is not None):
2469 2470 2471 2472
            # don't check node.entry.in_cinclude, as 'cdef extern: ...'
            # declared functions are not 'in_cinclude'.
            # This means we will list called 'cdef' functions as
            # "step into functions", but this is not an issue as they will be
Mark Florisson's avatar
Mark Florisson committed
2473
            # recognized as Cython functions anyway.
Mark Florisson's avatar
Mark Florisson committed
2474 2475 2476
            attrs = dict(name=node.entry.func_cname)
            self.tb.start('StepIntoFunction', attrs=attrs)
            self.tb.end('StepIntoFunction')
2477

Mark Florisson's avatar
Mark Florisson committed
2478
        self.visitchildren(node)
2479
        return node
2480

2481 2482 2483 2484 2485 2486 2487
    def serialize_modulenode_as_function(self, node):
        """
        Serialize the module-level code as a function so the debugger will know
        it's a "relevant frame" and it will know where to set the breakpoint
        for 'break modulename'.
        """
        name = node.full_module_name.rpartition('.')[-1]
2488

2489 2490
        cname_py2 = 'init' + name
        cname_py3 = 'PyInit_' + name
2491

2492 2493 2494 2495
        py2_attrs = dict(
            name=name,
            cname=cname_py2,
            pf_cname='',
2496
            # Ignore the qualified_name, breakpoints should be set using
2497 2498 2499 2500 2501
            # `cy break modulename:lineno` for module-level breakpoints.
            qualified_name='',
            lineno='1',
            is_initmodule_function="True",
        )
2502

2503
        py3_attrs = dict(py2_attrs, cname=cname_py3)
2504

2505 2506
        self._serialize_modulenode_as_function(node, py2_attrs)
        self._serialize_modulenode_as_function(node, py3_attrs)
2507

2508 2509
    def _serialize_modulenode_as_function(self, node, attrs):
        self.tb.start('Function', attrs=attrs)
2510

2511 2512 2513 2514 2515 2516 2517 2518 2519 2520 2521 2522
        self.tb.start('Locals')
        self.serialize_local_variables(node.scope.entries)
        self.tb.end('Locals')

        self.tb.start('Arguments')
        self.tb.end('Arguments')

        self.tb.start('StepIntoFunctions')
        self.register_stepinto = True
        self.visitchildren(node)
        self.register_stepinto = False
        self.tb.end('StepIntoFunctions')
2523

2524
        self.tb.end('Function')
2525

2526 2527 2528
    def serialize_local_variables(self, entries):
        for entry in entries.values():
            if entry.type.is_pyobject:
Mark Florisson's avatar
Mark Florisson committed
2529
                vartype = 'PythonObject'
2530 2531
            else:
                vartype = 'CObject'
2532

2533 2534 2535
            if entry.from_closure:
                # We're dealing with a closure where a variable from an outer
                # scope is accessed, get it from the scope object.
2536
                cname = '%s->%s' % (Naming.cur_scope_cname,
2537
                                    entry.outer_entry.cname)
2538

2539
                qname = '%s.%s.%s' % (entry.scope.outer_scope.qualified_name,
2540
                                      entry.scope.name,
2541
                                      entry.name)
2542
            elif entry.in_closure:
2543
                cname = '%s->%s' % (Naming.cur_scope_cname,
2544 2545
                                    entry.cname)
                qname = entry.qualified_name
2546 2547 2548
            else:
                cname = entry.cname
                qname = entry.qualified_name
2549

2550 2551 2552 2553 2554 2555 2556
            if not entry.pos:
                # this happens for variables that are not in the user's code,
                # e.g. for the global __builtins__, __doc__, etc. We can just
                # set the lineno to 0 for those.
                lineno = '0'
            else:
                lineno = str(entry.pos[1])
2557

2558 2559 2560
            attrs = dict(
                name=entry.name,
                cname=cname,
2561
                qualified_name=qname,
2562 2563
                type=vartype,
                lineno=lineno)
2564

Mark Florisson's avatar
Mark Florisson committed
2565 2566
            self.tb.start('LocalVar', attrs)
            self.tb.end('LocalVar')