ParseTreeTransforms.py 119 KB
Newer Older
1 2
from __future__ import absolute_import

3
import cython
4
cython.declare(PyrexTypes=object, Naming=object, ExprNodes=object, Nodes=object,
5 6
               Options=object, UtilNodes=object, LetNode=object,
               LetRefNode=object, TreeFragment=object, EncodedString=object,
7
               error=object, warning=object, copy=object, _unicode=object)
8

9 10
import copy

11 12 13 14 15 16
from . import PyrexTypes
from . import Naming
from . import ExprNodes
from . import Nodes
from . import Options
from . import Builtin
17

18 19
from .Visitor import VisitorTransform, TreeVisitor
from .Visitor import CythonTransform, EnvTransform, ScopeTrackingTransform
20
from .UtilNodes import LetNode, LetRefNode
21
from .TreeFragment import TreeFragment
22
from .StringEncoding import EncodedString, _unicode
23 24
from .Errors import error, warning, CompileError, InternalError
from .Code import UtilityCode
25

26 27 28 29 30 31 32 33 34 35 36 37

class NameNodeCollector(TreeVisitor):
    """Collect all NameNodes of a (sub-)tree in the ``name_nodes``
    attribute.
    """
    def __init__(self):
        super(NameNodeCollector, self).__init__()
        self.name_nodes = []

    def visit_NameNode(self, node):
        self.name_nodes.append(node)

38 39 40
    def visit_Node(self, node):
        self._visitchildren(node, None)

41

42
class SkipDeclarations(object):
43
    """
44 45 46 47 48
    Variable and function declarations can often have a deep tree structure,
    and yet most transformations don't need to descend to this depth.

    Declaration nodes are removed after AnalyseDeclarationsTransform, so there
    is no need to use this for transformations after that point.
49 50 51
    """
    def visit_CTypeDefNode(self, node):
        return node
52

53 54
    def visit_CVarDefNode(self, node):
        return node
55

56 57
    def visit_CDeclaratorNode(self, node):
        return node
58

59 60
    def visit_CBaseTypeNode(self, node):
        return node
61

62 63 64 65 66 67
    def visit_CEnumDefNode(self, node):
        return node

    def visit_CStructOrUnionDefNode(self, node):
        return node

68
class NormalizeTree(CythonTransform):
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
    """
    This transform fixes up a few things after parsing
    in order to make the parse tree more suitable for
    transforms.

    a) After parsing, blocks with only one statement will
    be represented by that statement, not by a StatListNode.
    When doing transforms this is annoying and inconsistent,
    as one cannot in general remove a statement in a consistent
    way and so on. This transform wraps any single statements
    in a StatListNode containing a single statement.

    b) The PassStatNode is a noop and serves no purpose beyond
    plugging such one-statement blocks; i.e., once parsed a
`    "pass" can just as well be represented using an empty
    StatListNode. This means less special cases to worry about
    in subsequent transforms (one always checks to see if a
    StatListNode has no children to see if the block is empty).
    """

89 90
    def __init__(self, context):
        super(NormalizeTree, self).__init__(context)
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
        self.is_in_statlist = False
        self.is_in_expr = False

    def visit_ExprNode(self, node):
        stacktmp = self.is_in_expr
        self.is_in_expr = True
        self.visitchildren(node)
        self.is_in_expr = stacktmp
        return node

    def visit_StatNode(self, node, is_listcontainer=False):
        stacktmp = self.is_in_statlist
        self.is_in_statlist = is_listcontainer
        self.visitchildren(node)
        self.is_in_statlist = stacktmp
        if not self.is_in_statlist and not self.is_in_expr:
107
            return Nodes.StatListNode(pos=node.pos, stats=[node])
108 109 110 111 112 113 114 115 116 117 118
        else:
            return node

    def visit_StatListNode(self, node):
        self.is_in_statlist = True
        self.visitchildren(node)
        self.is_in_statlist = False
        return node

    def visit_ParallelAssignmentNode(self, node):
        return self.visit_StatNode(node, True)
119

120 121 122 123 124 125
    def visit_CEnumDefNode(self, node):
        return self.visit_StatNode(node, True)

    def visit_CStructOrUnionDefNode(self, node):
        return self.visit_StatNode(node, True)

126
    def visit_PassStatNode(self, node):
127
        """Eliminate PassStatNode"""
128
        if not self.is_in_statlist:
129
            return Nodes.StatListNode(pos=node.pos, stats=[])
130 131 132
        else:
            return []

133 134 135
    def visit_ExprStatNode(self, node):
        """Eliminate useless string literals"""
        if node.expr.is_string_literal:
Stefan Behnel's avatar
Stefan Behnel committed
136 137 138
            return self.visit_PassStatNode(node)
        else:
            return self.visit_StatNode(node)
139

140
    def visit_CDeclaratorNode(self, node):
141
        return node
142

143

144 145 146
class PostParseError(CompileError): pass

# error strings checked by unit tests, so define them
147
ERR_CDEF_INCLASS = 'Cannot assign default value to fields in cdef classes, structs or unions'
148 149
ERR_BUF_DEFAULTS = 'Invalid buffer defaults specification (see docs)'
ERR_INVALID_SPECIALATTR_TYPE = 'Special attributes must not have a type declared'
150
class PostParse(ScopeTrackingTransform):
151 152 153 154 155 156 157
    """
    Basic interpretation of the parse tree, as well as validity
    checking that can be done on a very basic level on the parse
    tree (while still not being a problem with the basic syntax,
    as such).

    Specifically:
158
    - Default values to cdef assignments are turned into single
159 160
    assignments following the declaration (everywhere but in class
    bodies, where they raise a compile error)
161

162 163
    - Interpret some node structures into Python runtime values.
    Some nodes take compile-time arguments (currently:
164
    TemplatedTypeNode[args] and __cythonbufferdefaults__ = {args}),
165 166 167 168 169 170 171 172
    which should be interpreted. This happens in a general way
    and other steps should be taken to ensure validity.

    Type arguments cannot be interpreted in this way.

    - For __cythonbufferdefaults__ the arguments are checked for
    validity.

Robert Bradshaw's avatar
Robert Bradshaw committed
173
    TemplatedTypeNode has its directives interpreted:
174 175
    Any first positional argument goes into the "dtype" attribute,
    any "ndim" keyword argument goes into the "ndim" attribute and
176
    so on. Also it is checked that the directive combination is valid.
177 178
    - __cythonbufferdefaults__ attributes are parsed and put into the
    type information.
179 180 181 182 183 184

    Note: Currently Parsing.py does a lot of interpretation and
    reorganization that can be refactored into this transform
    if a more pure Abstract Syntax Tree is wanted.
    """

185 186 187 188 189 190
    def __init__(self, context):
        super(PostParse, self).__init__(context)
        self.specialattribute_handlers = {
            '__cythonbufferdefaults__' : self.handle_bufferdefaults
        }

Stefan Behnel's avatar
Stefan Behnel committed
191 192
    def visit_LambdaNode(self, node):
        # unpack a lambda expression into the corresponding DefNode
Vitja Makarov's avatar
Vitja Makarov committed
193 194
        collector = YieldNodeCollector()
        collector.visitchildren(node.result_expr)
195
        if collector.yields or collector.awaits or isinstance(node.result_expr, ExprNodes.YieldExprNode):
Vitja Makarov's avatar
Vitja Makarov committed
196 197
            body = Nodes.ExprStatNode(
                node.result_expr.pos, expr=node.result_expr)
Vitja Makarov's avatar
Vitja Makarov committed
198 199 200
        else:
            body = Nodes.ReturnStatNode(
                node.result_expr.pos, value=node.result_expr)
Stefan Behnel's avatar
Stefan Behnel committed
201
        node.def_node = Nodes.DefNode(
202
            node.pos, name=node.name,
Stefan Behnel's avatar
Stefan Behnel committed
203 204
            args=node.args, star_arg=node.star_arg,
            starstar_arg=node.starstar_arg,
Vitja Makarov's avatar
Vitja Makarov committed
205
            body=body, doc=None)
Stefan Behnel's avatar
Stefan Behnel committed
206 207
        self.visitchildren(node)
        return node
208 209 210

    def visit_GeneratorExpressionNode(self, node):
        # unpack a generator expression into the corresponding DefNode
Vitja Makarov's avatar
Vitja Makarov committed
211
        node.def_node = Nodes.DefNode(node.pos, name=node.name,
212 213 214 215
                                      doc=None,
                                      args=[], star_arg=None,
                                      starstar_arg=None,
                                      body=node.loop)
Stefan Behnel's avatar
Stefan Behnel committed
216 217 218
        self.visitchildren(node)
        return node

219
    # cdef variables
220
    def handle_bufferdefaults(self, decl):
221
        if not isinstance(decl.default, ExprNodes.DictNode):
222
            raise PostParseError(decl.pos, ERR_BUF_DEFAULTS)
223 224
        self.scope_node.buffer_defaults_node = decl.default
        self.scope_node.buffer_defaults_pos = decl.pos
225

226 227
    def visit_CVarDefNode(self, node):
        # This assumes only plain names and pointers are assignable on
228 229 230
        # declaration. Also, it makes use of the fact that a cdef decl
        # must appear before the first use, so we don't have to deal with
        # "i = 3; cdef int i = i" and can simply move the nodes around.
231 232
        try:
            self.visitchildren(node)
233 234 235 236
            stats = [node]
            newdecls = []
            for decl in node.declarators:
                declbase = decl
237
                while isinstance(declbase, Nodes.CPtrDeclaratorNode):
238
                    declbase = declbase.base
239
                if isinstance(declbase, Nodes.CNameDeclaratorNode):
240
                    if declbase.default is not None:
241
                        if self.scope_type in ('cclass', 'pyclass', 'struct'):
242
                            if isinstance(self.scope_node, Nodes.CClassDefNode):
243 244 245 246 247 248 249
                                handler = self.specialattribute_handlers.get(decl.name)
                                if handler:
                                    if decl is not declbase:
                                        raise PostParseError(decl.pos, ERR_INVALID_SPECIALATTR_TYPE)
                                    handler(decl)
                                    continue # Remove declaration
                            raise PostParseError(decl.pos, ERR_CDEF_INCLASS)
250
                        first_assignment = self.scope_type != 'module'
251 252
                        stats.append(Nodes.SingleAssignmentNode(node.pos,
                            lhs=ExprNodes.NameNode(node.pos, name=declbase.name),
253
                            rhs=declbase.default, first=first_assignment))
254 255 256 257
                        declbase.default = None
                newdecls.append(decl)
            node.declarators = newdecls
            return stats
258
        except PostParseError as e:
259 260 261 262
            # An error in a cdef clause is ok, simply remove the declaration
            # and try to move on to report more errors
            self.context.nonfatal_error(e)
            return None
263

Stefan Behnel's avatar
Stefan Behnel committed
264 265
    # Split parallel assignments (a,b = b,a) into separate partial
    # assignments that are executed rhs-first using temps.  This
Stefan Behnel's avatar
Stefan Behnel committed
266 267 268 269
    # restructuring must be applied before type analysis so that known
    # types on rhs and lhs can be matched directly.  It is required in
    # the case that the types cannot be coerced to a Python type in
    # order to assign from a tuple.
270 271 272 273 274 275 276 277 278 279

    def visit_SingleAssignmentNode(self, node):
        self.visitchildren(node)
        return self._visit_assignment_node(node, [node.lhs, node.rhs])

    def visit_CascadedAssignmentNode(self, node):
        self.visitchildren(node)
        return self._visit_assignment_node(node, node.lhs_list + [node.rhs])

    def _visit_assignment_node(self, node, expr_list):
280 281 282
        """Flatten parallel assignments into separate single
        assignments or cascaded assignments.
        """
283 284
        if sum([ 1 for expr in expr_list
                 if expr.is_sequence_constructor or expr.is_string_literal ]) < 2:
285 286 287
            # no parallel assignments => nothing to do
            return node

288 289
        expr_list_list = []
        flatten_parallel_assignments(expr_list, expr_list_list)
290 291 292
        temp_refs = []
        eliminate_rhs_duplicates(expr_list_list, temp_refs)

293 294 295 296 297
        nodes = []
        for expr_list in expr_list_list:
            lhs_list = expr_list[:-1]
            rhs = expr_list[-1]
            if len(lhs_list) == 1:
298
                node = Nodes.SingleAssignmentNode(rhs.pos,
299 300 301 302 303
                    lhs = lhs_list[0], rhs = rhs)
            else:
                node = Nodes.CascadedAssignmentNode(rhs.pos,
                    lhs_list = lhs_list, rhs = rhs)
            nodes.append(node)
304

305
        if len(nodes) == 1:
306 307 308 309 310 311 312 313 314 315 316 317 318
            assign_node = nodes[0]
        else:
            assign_node = Nodes.ParallelAssignmentNode(nodes[0].pos, stats = nodes)

        if temp_refs:
            duplicates_and_temps = [ (temp.expression, temp)
                                     for temp in temp_refs ]
            sort_common_subsequences(duplicates_and_temps)
            for _, temp_ref in duplicates_and_temps[::-1]:
                assign_node = LetNode(temp_ref, assign_node)

        return assign_node

319 320 321 322 323 324 325 326 327 328 329 330 331
    def _flatten_sequence(self, seq, result):
        for arg in seq.args:
            if arg.is_sequence_constructor:
                self._flatten_sequence(arg, result)
            else:
                result.append(arg)
        return result

    def visit_DelStatNode(self, node):
        self.visitchildren(node)
        node.args = self._flatten_sequence(node, [])
        return node

332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350
    def visit_ExceptClauseNode(self, node):
        if node.is_except_as:
            # except-as must delete NameNode target at the end
            del_target = Nodes.DelStatNode(
                node.pos,
                args=[ExprNodes.NameNode(
                    node.target.pos, name=node.target.name)],
                ignore_nonexisting=True)
            node.body = Nodes.StatListNode(
                node.pos,
                stats=[Nodes.TryFinallyStatNode(
                    node.pos,
                    body=node.body,
                    finally_clause=Nodes.StatListNode(
                        node.pos,
                        stats=[del_target]))])
        self.visitchildren(node)
        return node

351 352 353 354 355 356 357 358
    def visit_JoinedStrNode(self, node):
        if len(node.values) == 1:
            # this is not uncommon because f-string format specs are parsed into JoinedStrNodes
            return node.values[0]
        node.values = ExprNodes.ListNode(node.pos, args=node.values)
        self.visitchildren(node)
        return node

359

360 361 362 363 364 365
def eliminate_rhs_duplicates(expr_list_list, ref_node_sequence):
    """Replace rhs items by LetRefNodes if they appear more than once.
    Creates a sequence of LetRefNodes that set up the required temps
    and appends them to ref_node_sequence.  The input list is modified
    in-place.
    """
Robert Bradshaw's avatar
Robert Bradshaw committed
366
    seen_nodes = set()
367 368 369 370 371 372 373 374 375 376 377
    ref_nodes = {}
    def find_duplicates(node):
        if node.is_literal or node.is_name:
            # no need to replace those; can't include attributes here
            # as their access is not necessarily side-effect free
            return
        if node in seen_nodes:
            if node not in ref_nodes:
                ref_node = LetRefNode(node)
                ref_nodes[node] = ref_node
                ref_node_sequence.append(ref_node)
378
        else:
379 380 381 382 383 384 385 386 387 388 389 390 391 392 393
            seen_nodes.add(node)
            if node.is_sequence_constructor:
                for item in node.args:
                    find_duplicates(item)

    for expr_list in expr_list_list:
        rhs = expr_list[-1]
        find_duplicates(rhs)
    if not ref_nodes:
        return

    def substitute_nodes(node):
        if node in ref_nodes:
            return ref_nodes[node]
        elif node.is_sequence_constructor:
394
            node.args = list(map(substitute_nodes, node.args))
395
        return node
396

397 398 399
    # replace nodes inside of the common subexpressions
    for node in ref_nodes:
        if node.is_sequence_constructor:
400
            node.args = list(map(substitute_nodes, node.args))
401 402 403 404 405 406 407

    # replace common subexpressions on all rhs items
    for expr_list in expr_list_list:
        expr_list[-1] = substitute_nodes(expr_list[-1])

def sort_common_subsequences(items):
    """Sort items/subsequences so that all items and subsequences that
Stefan Behnel's avatar
Stefan Behnel committed
408 409 410 411 412 413 414 415 416
    an item contains appear before the item itself.  This is needed
    because each rhs item must only be evaluated once, so its value
    must be evaluated first and then reused when packing sequences
    that contain it.

    This implies a partial order, and the sort must be stable to
    preserve the original order as much as possible, so we use a
    simple insertion sort (which is very fast for short sequences, the
    normal case in practice).
417 418 419 420 421 422 423 424 425 426 427 428
    """
    def contains(seq, x):
        for item in seq:
            if item is x:
                return True
            elif item.is_sequence_constructor and contains(item.args, x):
                return True
        return False
    def lower_than(a,b):
        return b.is_sequence_constructor and contains(b.args, a)

    for pos, item in enumerate(items):
429
        key = item[1] # the ResultRefNode which has already been injected into the sequences
430
        new_pos = pos
431
        for i in range(pos-1, -1, -1):
432 433 434
            if lower_than(key, items[i][0]):
                new_pos = i
        if new_pos != pos:
435
            for i in range(pos, new_pos, -1):
436 437
                items[i] = items[i-1]
            items[new_pos] = item
438

439 440 441 442 443 444 445 446 447 448 449
def unpack_string_to_character_literals(literal):
    chars = []
    pos = literal.pos
    stype = literal.__class__
    sval = literal.value
    sval_type = sval.__class__
    for char in sval:
        cval = sval_type(char)
        chars.append(stype(pos, value=cval, constant_result=cval))
    return chars

450 451 452 453 454 455 456 457
def flatten_parallel_assignments(input, output):
    #  The input is a list of expression nodes, representing the LHSs
    #  and RHS of one (possibly cascaded) assignment statement.  For
    #  sequence constructors, rearranges the matching parts of both
    #  sides into a list of equivalent assignments between the
    #  individual elements.  This transformation is applied
    #  recursively, so that nested structures get matched as well.
    rhs = input[-1]
458
    if (not (rhs.is_sequence_constructor or isinstance(rhs, ExprNodes.UnicodeNode))
459
        or not sum([lhs.is_sequence_constructor for lhs in input[:-1]])):
460 461 462 463 464
        output.append(input)
        return

    complete_assignments = []

465 466 467 468 469 470
    if rhs.is_sequence_constructor:
        rhs_args = rhs.args
    elif rhs.is_string_literal:
        rhs_args = unpack_string_to_character_literals(rhs)

    rhs_size = len(rhs_args)
471
    lhs_targets = [[] for _ in range(rhs_size)]
472 473 474 475 476 477 478 479 480
    starred_assignments = []
    for lhs in input[:-1]:
        if not lhs.is_sequence_constructor:
            if lhs.is_starred:
                error(lhs.pos, "starred assignment target must be in a list or tuple")
            complete_assignments.append(lhs)
            continue
        lhs_size = len(lhs.args)
        starred_targets = sum([1 for expr in lhs.args if expr.is_starred])
Stefan Behnel's avatar
Stefan Behnel committed
481 482 483 484 485 486 487 488 489
        if starred_targets > 1:
            error(lhs.pos, "more than 1 starred expression in assignment")
            output.append([lhs,rhs])
            continue
        elif lhs_size - starred_targets > rhs_size:
            error(lhs.pos, "need more than %d value%s to unpack"
                  % (rhs_size, (rhs_size != 1) and 's' or ''))
            output.append([lhs,rhs])
            continue
Stefan Behnel's avatar
Stefan Behnel committed
490
        elif starred_targets:
491
            map_starred_assignment(lhs_targets, starred_assignments,
492
                                   lhs.args, rhs_args)
Stefan Behnel's avatar
Stefan Behnel committed
493 494 495 496 497
        elif lhs_size < rhs_size:
            error(lhs.pos, "too many values to unpack (expected %d, got %d)"
                  % (lhs_size, rhs_size))
            output.append([lhs,rhs])
            continue
498
        else:
Stefan Behnel's avatar
Stefan Behnel committed
499 500
            for targets, expr in zip(lhs_targets, lhs.args):
                targets.append(expr)
501 502 503 504 505 506

    if complete_assignments:
        complete_assignments.append(rhs)
        output.append(complete_assignments)

    # recursively flatten partial assignments
507
    for cascade, rhs in zip(lhs_targets, rhs_args):
508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538
        if cascade:
            cascade.append(rhs)
            flatten_parallel_assignments(cascade, output)

    # recursively flatten starred assignments
    for cascade in starred_assignments:
        if cascade[0].is_sequence_constructor:
            flatten_parallel_assignments(cascade, output)
        else:
            output.append(cascade)

def map_starred_assignment(lhs_targets, starred_assignments, lhs_args, rhs_args):
    # Appends the fixed-position LHS targets to the target list that
    # appear left and right of the starred argument.
    #
    # The starred_assignments list receives a new tuple
    # (lhs_target, rhs_values_list) that maps the remaining arguments
    # (those that match the starred target) to a list.

    # left side of the starred target
    for i, (targets, expr) in enumerate(zip(lhs_targets, lhs_args)):
        if expr.is_starred:
            starred = i
            lhs_remaining = len(lhs_args) - i - 1
            break
        targets.append(expr)
    else:
        raise InternalError("no starred arg found when splitting starred assignment")

    # right side of the starred target
    for i, (targets, expr) in enumerate(zip(lhs_targets[-lhs_remaining:],
Vitja Makarov's avatar
Vitja Makarov committed
539
                                            lhs_args[starred + 1:])):
540 541 542 543 544 545 546 547 548 549 550 551 552 553 554
        targets.append(expr)

    # the starred target itself, must be assigned a (potentially empty) list
    target = lhs_args[starred].target # unpack starred node
    starred_rhs = rhs_args[starred:]
    if lhs_remaining:
        starred_rhs = starred_rhs[:-lhs_remaining]
    if starred_rhs:
        pos = starred_rhs[0].pos
    else:
        pos = target.pos
    starred_assignments.append([
        target, ExprNodes.ListNode(pos=pos, args=starred_rhs)])


555
class PxdPostParse(CythonTransform, SkipDeclarations):
556 557 558
    """
    Basic interpretation/validity checking that should only be
    done on pxd trees.
559 560 561 562 563 564

    A lot of this checking currently happens in the parser; but
    what is listed below happens here.

    - "def" functions are let through only if they fill the
    getbuffer/releasebuffer slots
565

566 567
    - cdef functions are let through only if they are on the
    top level and are declared "inline"
568
    """
569 570
    ERR_INLINE_ONLY = "function definition in pxd file must be declared 'cdef inline'"
    ERR_NOGO_WITH_INLINE = "inline function definition in pxd file cannot be '%s'"
571 572 573 574 575 576 577 578 579 580 581 582 583 584 585

    def __call__(self, node):
        self.scope_type = 'pxd'
        return super(PxdPostParse, self).__call__(node)

    def visit_CClassDefNode(self, node):
        old = self.scope_type
        self.scope_type = 'cclass'
        self.visitchildren(node)
        self.scope_type = old
        return node

    def visit_FuncDefNode(self, node):
        # FuncDefNode always come with an implementation (without
        # an imp they are CVarDefNodes..)
586
        err = self.ERR_INLINE_ONLY
587

588
        if (isinstance(node, Nodes.DefNode) and self.scope_type == 'cclass'
589
            and node.name in ('__getbuffer__', '__releasebuffer__')):
590
            err = None # allow these slots
591

592
        if isinstance(node, Nodes.CFuncDefNode):
593 594
            if (u'inline' in node.modifiers and
                self.scope_type in ('pxd', 'cclass')):
595 596 597 598 599 600 601 602
                node.inline_in_pxd = True
                if node.visibility != 'private':
                    err = self.ERR_NOGO_WITH_INLINE % node.visibility
                elif node.api:
                    err = self.ERR_NOGO_WITH_INLINE % 'api'
                else:
                    err = None # allow inline function
            else:
603 604
                err = self.ERR_INLINE_ONLY

605 606
        if err:
            self.context.nonfatal_error(PostParseError(node.pos, err))
607 608 609
            return None
        else:
            return node
610

611
class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
612
    """
613
    After parsing, directives can be stored in a number of places:
614 615
    - #cython-comments at the top of the file (stored in ModuleNode)
    - Command-line arguments overriding these
616 617
    - @cython.directivename decorators
    - with cython.directivename: statements
618

619
    This transform is responsible for interpreting these various sources
620
    and store the directive in two ways:
621 622 623 624 625 626 627 628 629 630 631
    - Set the directives attribute of the ModuleNode for global directives.
    - Use a CompilerDirectivesNode to override directives for a subtree.

    (The first one is primarily to not have to modify with the tree
    structure, so that ModuleNode stay on top.)

    The directives are stored in dictionaries from name to value in effect.
    Each such dictionary is always filled in for all possible directives,
    using default values where no value is given by the user.

    The available directives are controlled in Options.py.
632 633 634

    Note that we have to run this prior to analysis, and so some minor
    duplication of functionality has to occur: We manually track cimports
635
    and which names the "cython" module may have been imported to.
636
    """
637
    unop_method_nodes = {
638
        'typeof': ExprNodes.TypeofNode,
639

640 641 642 643 644 645
        'operator.address': ExprNodes.AmpersandNode,
        'operator.dereference': ExprNodes.DereferenceNode,
        'operator.preincrement' : ExprNodes.inc_dec_constructor(True, '++'),
        'operator.predecrement' : ExprNodes.inc_dec_constructor(True, '--'),
        'operator.postincrement': ExprNodes.inc_dec_constructor(False, '++'),
        'operator.postdecrement': ExprNodes.inc_dec_constructor(False, '--'),
646

Jakub Wilk's avatar
Jakub Wilk committed
647
        # For backwards compatibility.
648
        'address': ExprNodes.AmpersandNode,
649
    }
Robert Bradshaw's avatar
Robert Bradshaw committed
650 651

    binop_method_nodes = {
652
        'operator.comma'        : ExprNodes.c_binop_constructor(','),
Robert Bradshaw's avatar
Robert Bradshaw committed
653
    }
654

655 656 657
    special_methods = set(['declare', 'union', 'struct', 'typedef',
                           'sizeof', 'cast', 'pointer', 'compiled',
                           'NULL', 'fused_type', 'parallel'])
658
    special_methods.update(unop_method_nodes)
659

Robert Bradshaw's avatar
Robert Bradshaw committed
660
    valid_parallel_directives = set([
Mark Florisson's avatar
Mark Florisson committed
661 662 663
        "parallel",
        "prange",
        "threadid",
664
        #"threadsavailable",
Mark Florisson's avatar
Mark Florisson committed
665 666
    ])

667
    def __init__(self, context, compilation_directive_defaults):
668
        super(InterpretCompilerDirectives, self).__init__(context)
Robert Bradshaw's avatar
Robert Bradshaw committed
669
        self.cython_module_names = set()
Robert Bradshaw's avatar
Robert Bradshaw committed
670
        self.directive_names = {'staticmethod': 'staticmethod'}
Mark Florisson's avatar
Mark Florisson committed
671
        self.parallel_directives = {}
672 673
        directives = copy.deepcopy(Options.directive_defaults)
        for key, value in compilation_directive_defaults.items():
674
            directives[_unicode(key)] = copy.deepcopy(value)
675
        self.directives = directives
676

677
    def check_directive_scope(self, pos, directive, scope):
678
        legal_scopes = Options.directive_scopes.get(directive, None)
679 680 681 682 683
        if legal_scopes and scope not in legal_scopes:
            self.context.nonfatal_error(PostParseError(pos, 'The %s compiler directive '
                                        'is not allowed in %s scope' % (directive, scope)))
            return False
        else:
684 685
            if (directive not in Options.directive_defaults
                    and directive not in Options.directive_types):
686
                error(pos, "Invalid directive: '%s'." % (directive,))
687
            return True
688

689
    # Set up processing and handle the cython: comments.
690
    def visit_ModuleNode(self, node):
691
        for key in sorted(node.directive_comments):
692 693
            if not self.check_directive_scope(node.pos, key, 'module'):
                self.wrong_scope_error(node.pos, key, 'module')
694 695
                del node.directive_comments[key]

696 697
        self.module_scope = node.scope

698 699
        self.directives.update(node.directive_comments)
        node.directives = self.directives
Mark Florisson's avatar
Mark Florisson committed
700
        node.parallel_directives = self.parallel_directives
701
        self.visitchildren(node)
702
        node.cython_module_names = self.cython_module_names
703 704
        return node

705 706 707 708 709 710 711
    # The following four functions track imports and cimports that
    # begin with "cython"
    def is_cython_directive(self, name):
        return (name in Options.directive_types or
                name in self.special_methods or
                PyrexTypes.parse_basic_type(name))

Mark Florisson's avatar
Mark Florisson committed
712
    def is_parallel_directive(self, full_name, pos):
Mark Florisson's avatar
Mark Florisson committed
713 714 715 716 717
        """
        Checks to see if fullname (e.g. cython.parallel.prange) is a valid
        parallel directive. If it is a star import it also updates the
        parallel_directives.
        """
Mark Florisson's avatar
Mark Florisson committed
718 719 720
        result = (full_name + ".").startswith("cython.parallel.")

        if result:
Mark Florisson's avatar
Mark Florisson committed
721
            directive = full_name.split('.')
722 723 724
            if full_name == u"cython.parallel":
                self.parallel_directives[u"parallel"] = u"cython.parallel"
            elif full_name == u"cython.parallel.*":
725 726
                for name in self.valid_parallel_directives:
                    self.parallel_directives[name] = u"cython.parallel.%s" % name
Mark Florisson's avatar
Mark Florisson committed
727 728
            elif (len(directive) != 3 or
                  directive[-1] not in self.valid_parallel_directives):
Mark Florisson's avatar
Mark Florisson committed
729 730
                error(pos, "No such directive: %s" % full_name)

731 732
            self.module_scope.use_utility_code(
                UtilityCode.load_cached("InitThreads", "ModuleSetupCode.c"))
733

Mark Florisson's avatar
Mark Florisson committed
734 735
        return result

736 737
    def visit_CImportStatNode(self, node):
        if node.module_name == u"cython":
738
            self.cython_module_names.add(node.as_name or u"cython")
739
        elif node.module_name.startswith(u"cython."):
Mark Florisson's avatar
Mark Florisson committed
740 741 742
            if node.module_name.startswith(u"cython.parallel."):
                error(node.pos, node.module_name + " is not a module")
            if node.module_name == u"cython.parallel":
743
                if node.as_name and node.as_name != u"cython":
Mark Florisson's avatar
Mark Florisson committed
744 745 746 747 748
                    self.parallel_directives[node.as_name] = node.module_name
                else:
                    self.cython_module_names.add(u"cython")
                    self.parallel_directives[
                                    u"cython.parallel"] = node.module_name
749 750
                self.module_scope.use_utility_code(
                    UtilityCode.load_cached("InitThreads", "ModuleSetupCode.c"))
Mark Florisson's avatar
Mark Florisson committed
751
            elif node.as_name:
752
                self.directive_names[node.as_name] = node.module_name[7:]
753
            else:
754
                self.cython_module_names.add(u"cython")
755 756 757
            # if this cimport was a compiler directive, we don't
            # want to leave the cimport node sitting in the tree
            return None
758
        return node
759

760
    def visit_FromCImportStatNode(self, node):
761 762
        if not node.relative_level and (
                node.module_name == u"cython" or node.module_name.startswith(u"cython.")):
763
            submodule = (node.module_name + u".")[7:]
764
            newimp = []
Mark Florisson's avatar
Mark Florisson committed
765

766
            for pos, name, as_name, kind in node.imported_names:
767
                full_name = submodule + name
Mark Florisson's avatar
Mark Florisson committed
768 769 770 771 772 773 774
                qualified_name = u"cython." + full_name

                if self.is_parallel_directive(qualified_name, node.pos):
                    # from cython cimport parallel, or
                    # from cython.parallel cimport parallel, prange, ...
                    self.parallel_directives[as_name or name] = qualified_name
                elif self.is_cython_directive(full_name):
775
                    self.directive_names[as_name or name] = full_name
776 777
                    if kind is not None:
                        self.context.nonfatal_error(PostParseError(pos,
778
                            "Compiler directive imports must be plain imports"))
779 780
                else:
                    newimp.append((pos, name, as_name, kind))
Mark Florisson's avatar
Mark Florisson committed
781

Robert Bradshaw's avatar
Robert Bradshaw committed
782 783
            if not newimp:
                return None
Mark Florisson's avatar
Mark Florisson committed
784

Robert Bradshaw's avatar
Robert Bradshaw committed
785
            node.imported_names = newimp
786
        return node
787

Robert Bradshaw's avatar
Robert Bradshaw committed
788
    def visit_FromImportStatNode(self, node):
789 790
        if (node.module.module_name.value == u"cython") or \
               node.module.module_name.value.startswith(u"cython."):
791
            submodule = (node.module.module_name.value + u".")[7:]
Robert Bradshaw's avatar
Robert Bradshaw committed
792
            newimp = []
793
            for name, name_node in node.items:
794
                full_name = submodule + name
Mark Florisson's avatar
Mark Florisson committed
795 796 797 798
                qualified_name = u"cython." + full_name
                if self.is_parallel_directive(qualified_name, node.pos):
                    self.parallel_directives[name_node.name] = qualified_name
                elif self.is_cython_directive(full_name):
799
                    self.directive_names[name_node.name] = full_name
Robert Bradshaw's avatar
Robert Bradshaw committed
800
                else:
801
                    newimp.append((name, name_node))
Robert Bradshaw's avatar
Robert Bradshaw committed
802 803 804 805 806
            if not newimp:
                return None
            node.items = newimp
        return node

807
    def visit_SingleAssignmentNode(self, node):
808 809 810 811 812 813
        if isinstance(node.rhs, ExprNodes.ImportNode):
            module_name = node.rhs.module_name.value
            is_parallel = (module_name + u".").startswith(u"cython.parallel.")

            if module_name != u"cython" and not is_parallel:
                return node
Mark Florisson's avatar
Mark Florisson committed
814 815 816 817

            module_name = node.rhs.module_name.value
            as_name = node.lhs.name

818
            node = Nodes.CImportStatNode(node.pos,
Mark Florisson's avatar
Mark Florisson committed
819 820
                                         module_name = module_name,
                                         as_name = as_name)
821
            node = self.visit_CImportStatNode(node)
822 823
        else:
            self.visitchildren(node)
824

825
        return node
826

827 828 829
    def visit_NameNode(self, node):
        if node.name in self.cython_module_names:
            node.is_cython_module = True
Robert Bradshaw's avatar
Robert Bradshaw committed
830
        else:
831
            node.cython_attribute = self.directive_names.get(node.name)
832
        return node
833

834
    def try_to_parse_directives(self, node):
835
        # If node is the contents of an directive (in a with statement or
836
        # decorator), returns a list of (directivename, value) pairs.
837
        # Otherwise, returns None
838
        if isinstance(node, ExprNodes.CallNode):
Robert Bradshaw's avatar
Robert Bradshaw committed
839
            self.visit(node.function)
840
            optname = node.function.as_cython_attribute()
841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862
            if optname:
                directivetype = Options.directive_types.get(optname)
                if directivetype:
                    args, kwds = node.explicit_args_kwds()
                    directives = []
                    key_value_pairs = []
                    if kwds is not None and directivetype is not dict:
                        for keyvalue in kwds.key_value_pairs:
                            key, value = keyvalue
                            sub_optname = "%s.%s" % (optname, key.value)
                            if Options.directive_types.get(sub_optname):
                                directives.append(self.try_to_parse_directive(sub_optname, [value], None, keyvalue.pos))
                            else:
                                key_value_pairs.append(keyvalue)
                        if not key_value_pairs:
                            kwds = None
                        else:
                            kwds.key_value_pairs = key_value_pairs
                        if directives and not kwds and not args:
                            return directives
                    directives.append(self.try_to_parse_directive(optname, args, kwds, node.function.pos))
                    return directives
863
        elif isinstance(node, (ExprNodes.AttributeNode, ExprNodes.NameNode)):
864 865 866 867 868 869 870 871 872 873 874
            self.visit(node)
            optname = node.as_cython_attribute()
            if optname:
                directivetype = Options.directive_types.get(optname)
                if directivetype is bool:
                    return [(optname, True)]
                elif directivetype is None:
                    return [(optname, None)]
                else:
                    raise PostParseError(
                        node.pos, "The '%s' directive should be used as a function call." % optname)
875
        return None
876

877 878
    def try_to_parse_directive(self, optname, args, kwds, pos):
        directivetype = Options.directive_types.get(optname)
879
        if len(args) == 1 and isinstance(args[0], ExprNodes.NoneNode):
880
            return optname, Options.directive_defaults[optname]
881
        elif directivetype is bool:
882
            if kwds is not None or len(args) != 1 or not isinstance(args[0], ExprNodes.BoolNode):
883 884 885
                raise PostParseError(pos,
                    'The %s directive takes one compile-time boolean argument' % optname)
            return (optname, args[0].value)
886 887 888 889 890
        elif directivetype is int:
            if kwds is not None or len(args) != 1 or not isinstance(args[0], ExprNodes.IntNode):
                raise PostParseError(pos,
                    'The %s directive takes one compile-time integer argument' % optname)
            return (optname, int(args[0].value))
891
        elif directivetype is str:
892 893
            if kwds is not None or len(args) != 1 or not isinstance(
                    args[0], (ExprNodes.StringNode, ExprNodes.UnicodeNode)):
894 895 896
                raise PostParseError(pos,
                    'The %s directive takes one compile-time string argument' % optname)
            return (optname, str(args[0].value))
897 898 899 900 901
        elif directivetype is type:
            if kwds is not None or len(args) != 1:
                raise PostParseError(pos,
                    'The %s directive takes one type argument' % optname)
            return (optname, args[0])
902 903 904 905 906 907 908 909 910 911
        elif directivetype is dict:
            if len(args) != 0:
                raise PostParseError(pos,
                    'The %s directive takes no prepositional arguments' % optname)
            return optname, dict([(key.value, value) for key, value in kwds.key_value_pairs])
        elif directivetype is list:
            if kwds and len(kwds) != 0:
                raise PostParseError(pos,
                    'The %s directive takes no keyword arguments' % optname)
            return optname, [ str(arg.value) for arg in args ]
912 913 914 915 916
        elif callable(directivetype):
            if kwds is not None or len(args) != 1 or not isinstance(
                    args[0], (ExprNodes.StringNode, ExprNodes.UnicodeNode)):
                raise PostParseError(pos,
                    'The %s directive takes one compile-time string argument' % optname)
Stefan Behnel's avatar
Stefan Behnel committed
917
            return (optname, directivetype(optname, str(args[0].value)))
918 919 920
        else:
            assert False

921 922 923 924 925
    def visit_with_directives(self, body, directives):
        olddirectives = self.directives
        newdirectives = copy.copy(olddirectives)
        newdirectives.update(directives)
        self.directives = newdirectives
926
        assert isinstance(body, Nodes.StatListNode), body
927
        retbody = self.visit_Node(body)
928 929
        directive = Nodes.CompilerDirectivesNode(pos=retbody.pos, body=retbody,
                                                 directives=newdirectives)
930
        self.directives = olddirectives
931
        return directive
932

933
    # Handle decorators
934
    def visit_FuncDefNode(self, node):
935 936 937
        directives = self._extract_directives(node, 'function')
        if not directives:
            return self.visit_Node(node)
938
        body = Nodes.StatListNode(node.pos, stats=[node])
939 940 941
        return self.visit_with_directives(body, directives)

    def visit_CVarDefNode(self, node):
942 943
        directives = self._extract_directives(node, 'function')
        if not directives:
944
            return node
945
        for name, value in directives.items():
946 947
            if name == 'locals':
                node.directive_locals = value
948
            elif name not in ('final', 'staticmethod'):
Stefan Behnel's avatar
Stefan Behnel committed
949 950
                self.context.nonfatal_error(PostParseError(
                    node.pos,
951 952
                    "Cdef functions can only take cython.locals(), "
                    "staticmethod, or final decorators, got %s." % name))
953 954
        body = Nodes.StatListNode(node.pos, stats=[node])
        return self.visit_with_directives(body, directives)
955 956 957 958 959

    def visit_CClassDefNode(self, node):
        directives = self._extract_directives(node, 'cclass')
        if not directives:
            return self.visit_Node(node)
960
        body = Nodes.StatListNode(node.pos, stats=[node])
961 962
        return self.visit_with_directives(body, directives)

963 964 965 966 967 968 969
    def visit_CppClassNode(self, node):
        directives = self._extract_directives(node, 'cppclass')
        if not directives:
            return self.visit_Node(node)
        body = Nodes.StatListNode(node.pos, stats=[node])
        return self.visit_with_directives(body, directives)

970 971 972 973
    def visit_PyClassDefNode(self, node):
        directives = self._extract_directives(node, 'class')
        if not directives:
            return self.visit_Node(node)
974
        body = Nodes.StatListNode(node.pos, stats=[node])
975 976
        return self.visit_with_directives(body, directives)

977 978 979 980 981 982
    def _extract_directives(self, node, scope_name):
        if not node.decorators:
            return {}
        # Split the decorators into two lists -- real decorators and directives
        directives = []
        realdecs = []
Robert Bradshaw's avatar
Robert Bradshaw committed
983
        both = []
984 985 986 987 988
        for dec in node.decorators:
            new_directives = self.try_to_parse_directives(dec.decorator)
            if new_directives is not None:
                for directive in new_directives:
                    if self.check_directive_scope(node.pos, directive[0], scope_name):
Robert Bradshaw's avatar
Robert Bradshaw committed
989 990 991 992 993
                        name, value = directive
                        if self.directives.get(name, object()) != value:
                            directives.append(directive)
                        if directive[0] == 'staticmethod':
                            both.append(dec)
994
            else:
995
                realdecs.append(dec)
996
        if realdecs and isinstance(node, (Nodes.CFuncDefNode, Nodes.CClassDefNode, Nodes.CVarDefNode)):
997 998
            raise PostParseError(realdecs[0].pos, "Cdef functions/classes cannot take arbitrary decorators.")
        else:
Robert Bradshaw's avatar
Robert Bradshaw committed
999
            node.decorators = realdecs + both
1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012
        # merge or override repeated directives
        optdict = {}
        directives.reverse() # Decorators coming first take precedence
        for directive in directives:
            name, value = directive
            if name in optdict:
                old_value = optdict[name]
                # keywords and arg lists can be merged, everything
                # else overrides completely
                if isinstance(old_value, dict):
                    old_value.update(value)
                elif isinstance(old_value, list):
                    old_value.extend(value)
1013 1014
                else:
                    optdict[name] = value
1015 1016 1017 1018
            else:
                optdict[name] = value
        return optdict

1019 1020
    # Handle with statements
    def visit_WithStatNode(self, node):
1021 1022 1023 1024 1025 1026 1027 1028
        directive_dict = {}
        for directive in self.try_to_parse_directives(node.manager) or []:
            if directive is not None:
                if node.target is not None:
                    self.context.nonfatal_error(
                        PostParseError(node.pos, "Compiler directive with statements cannot contain 'as'"))
                else:
                    name, value = directive
1029
                    if name in ('nogil', 'gil'):
1030
                        # special case: in pure mode, "with nogil" spells "with cython.nogil"
1031
                        node = Nodes.GILStatNode(node.pos, state = name, body = node.body)
1032
                        return self.visit_Node(node)
1033 1034 1035 1036
                    if self.check_directive_scope(node.pos, name, 'with statement'):
                        directive_dict[name] = value
        if directive_dict:
            return self.visit_with_directives(node.body, directive_dict)
1037
        return self.visit_Node(node)
1038

1039

Mark Florisson's avatar
Mark Florisson committed
1040 1041 1042 1043 1044 1045
class ParallelRangeTransform(CythonTransform, SkipDeclarations):
    """
    Transform cython.parallel stuff. The parallel_directives come from the
    module node, set there by InterpretCompilerDirectives.

        x = cython.parallel.threadavailable()   -> ParallelThreadAvailableNode
1046
        with nogil, cython.parallel.parallel(): -> ParallelWithBlockNode
Mark Florisson's avatar
Mark Florisson committed
1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061
            print cython.parallel.threadid()    -> ParallelThreadIdNode
            for i in cython.parallel.prange(...):  -> ParallelRangeNode
                ...
    """

    # a list of names, maps 'cython.parallel.prange' in the code to
    # ['cython', 'parallel', 'prange']
    parallel_directive = None

    # Indicates whether a namenode in an expression is the cython module
    namenode_is_cython_module = False

    # Keep track of whether we are the context manager of a 'with' statement
    in_context_manager_section = False

1062 1063 1064 1065
    # One of 'prange' or 'with parallel'. This is used to disallow closely
    # nested 'with parallel:' blocks
    state = None

Mark Florisson's avatar
Mark Florisson committed
1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091
    directive_to_node = {
        u"cython.parallel.parallel": Nodes.ParallelWithBlockNode,
        # u"cython.parallel.threadsavailable": ExprNodes.ParallelThreadsAvailableNode,
        u"cython.parallel.threadid": ExprNodes.ParallelThreadIdNode,
        u"cython.parallel.prange": Nodes.ParallelRangeNode,
    }

    def node_is_parallel_directive(self, node):
        return node.name in self.parallel_directives or node.is_cython_module

    def get_directive_class_node(self, node):
        """
        Figure out which parallel directive was used and return the associated
        Node class.

        E.g. for a cython.parallel.prange() call we return ParallelRangeNode
        """
        if self.namenode_is_cython_module:
            directive = '.'.join(self.parallel_directive)
        else:
            directive = self.parallel_directives[self.parallel_directive[0]]
            directive = '%s.%s' % (directive,
                                   '.'.join(self.parallel_directive[1:]))
            directive = directive.rstrip('.')

        cls = self.directive_to_node.get(directive)
1092 1093
        if cls is None and not (self.namenode_is_cython_module and
                                self.parallel_directive[0] != 'parallel'):
Mark Florisson's avatar
Mark Florisson committed
1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125
            error(node.pos, "Invalid directive: %s" % directive)

        self.namenode_is_cython_module = False
        self.parallel_directive = None

        return cls

    def visit_ModuleNode(self, node):
        """
        If any parallel directives were imported, copy them over and visit
        the AST
        """
        if node.parallel_directives:
            self.parallel_directives = node.parallel_directives
            return self.visit_Node(node)

        # No parallel directives were imported, so they can't be used :)
        return node

    def visit_NameNode(self, node):
        if self.node_is_parallel_directive(node):
            self.parallel_directive = [node.name]
            self.namenode_is_cython_module = node.is_cython_module
        return node

    def visit_AttributeNode(self, node):
        self.visitchildren(node)
        if self.parallel_directive:
            self.parallel_directive.append(node.attribute)
        return node

    def visit_CallNode(self, node):
1126
        self.visit(node.function)
Mark Florisson's avatar
Mark Florisson committed
1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141
        if not self.parallel_directive:
            return node

        # We are a parallel directive, replace this node with the
        # corresponding ParallelSomethingSomething node

        if isinstance(node, ExprNodes.GeneralCallNode):
            args = node.positional_args.args
            kwargs = node.keyword_args
        else:
            args = node.args
            kwargs = {}

        parallel_directive_class = self.get_directive_class_node(node)
        if parallel_directive_class:
1142 1143
            # Note: in case of a parallel() the body is set by
            # visit_WithStatNode
Mark Florisson's avatar
Mark Florisson committed
1144 1145 1146 1147 1148
            node = parallel_directive_class(node.pos, args=args, kwargs=kwargs)

        return node

    def visit_WithStatNode(self, node):
1149 1150
        "Rewrite with cython.parallel.parallel() blocks"
        newnode = self.visit(node.manager)
Mark Florisson's avatar
Mark Florisson committed
1151

1152
        if isinstance(newnode, Nodes.ParallelWithBlockNode):
1153 1154
            if self.state == 'parallel with':
                error(node.manager.pos,
1155
                      "Nested parallel with blocks are disallowed")
1156 1157

            self.state = 'parallel with'
1158
            body = self.visit(node.body)
1159
            self.state = None
Mark Florisson's avatar
Mark Florisson committed
1160

1161 1162 1163 1164
            newnode.body = body
            return newnode
        elif self.parallel_directive:
            parallel_directive_class = self.get_directive_class_node(node)
1165

1166 1167 1168
            if not parallel_directive_class:
                # There was an error, stop here and now
                return None
Mark Florisson's avatar
Mark Florisson committed
1169

1170 1171 1172
            if parallel_directive_class is Nodes.ParallelWithBlockNode:
                error(node.pos, "The parallel directive must be called")
                return None
Mark Florisson's avatar
Mark Florisson committed
1173

1174 1175
        node.body = self.visit(node.body)
        return node
Mark Florisson's avatar
Mark Florisson committed
1176 1177 1178 1179 1180 1181

    def visit_ForInStatNode(self, node):
        "Rewrite 'for i in cython.parallel.prange(...):'"
        self.visit(node.iterator)
        self.visit(node.target)

1182 1183
        in_prange = isinstance(node.iterator.sequence,
                               Nodes.ParallelRangeNode)
1184
        previous_state = self.state
Mark Florisson's avatar
Mark Florisson committed
1185

1186
        if in_prange:
Mark Florisson's avatar
Mark Florisson committed
1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200
            # This will replace the entire ForInStatNode, so copy the
            # attributes
            parallel_range_node = node.iterator.sequence

            parallel_range_node.target = node.target
            parallel_range_node.body = node.body
            parallel_range_node.else_clause = node.else_clause

            node = parallel_range_node

            if not isinstance(node.target, ExprNodes.NameNode):
                error(node.target.pos,
                      "Can only iterate over an iteration variable")

1201
            self.state = 'prange'
Mark Florisson's avatar
Mark Florisson committed
1202

1203 1204 1205
        self.visit(node.body)
        self.state = previous_state
        self.visit(node.else_clause)
Mark Florisson's avatar
Mark Florisson committed
1206 1207 1208 1209 1210
        return node

    def visit(self, node):
        "Visit a node that may be None"
        if node is not None:
1211
            return super(ParallelRangeTransform, self).visit(node)
Mark Florisson's avatar
Mark Florisson committed
1212 1213


1214
class WithTransform(CythonTransform, SkipDeclarations):
1215
    def visit_WithStatNode(self, node):
1216 1217
        self.visitchildren(node, 'body')
        pos = node.pos
1218
        is_async = node.is_async
1219
        body, target, manager = node.body, node.target, node.manager
1220
        node.enter_call = ExprNodes.SimpleCallNode(
1221 1222
            pos, function=ExprNodes.AttributeNode(
                pos, obj=ExprNodes.CloneNode(manager),
1223
                attribute=EncodedString('__aenter__' if is_async else '__enter__'),
1224 1225
                is_special_lookup=True),
            args=[],
1226 1227
            is_temp=True)

1228 1229 1230
        if is_async:
            node.enter_call = ExprNodes.AwaitExprNode(pos, arg=node.enter_call)

1231 1232
        if target is not None:
            body = Nodes.StatListNode(
Stefan Behnel's avatar
Stefan Behnel committed
1233
                pos, stats=[
1234
                    Nodes.WithTargetAssignmentStatNode(
1235
                        pos, lhs=target, with_node=node),
1236
                    body])
1237

1238 1239
        excinfo_target = ExprNodes.TupleNode(pos, slow=True, args=[
            ExprNodes.ExcValueNode(pos) for _ in range(3)])
1240
        except_clause = Nodes.ExceptClauseNode(
1241 1242
            pos, body=Nodes.IfStatNode(
                pos, if_clauses=[
1243
                    Nodes.IfClauseNode(
1244 1245 1246 1247
                        pos, condition=ExprNodes.NotNode(
                            pos, operand=ExprNodes.WithExitCallNode(
                                pos, with_stat=node,
                                test_if_run=False,
1248 1249
                                args=excinfo_target,
                                await=ExprNodes.AwaitExprNode(pos, arg=None) if is_async else None)),
1250
                        body=Nodes.ReraiseStatNode(pos),
Stefan Behnel's avatar
Stefan Behnel committed
1251 1252
                    ),
                ],
1253 1254 1255 1256
                else_clause=None),
            pattern=None,
            target=None,
            excinfo_target=excinfo_target,
Stefan Behnel's avatar
Stefan Behnel committed
1257
        )
1258 1259

        node.body = Nodes.TryFinallyStatNode(
1260 1261 1262 1263
            pos, body=Nodes.TryExceptStatNode(
                pos, body=body,
                except_clauses=[except_clause],
                else_clause=None,
Stefan Behnel's avatar
Stefan Behnel committed
1264
            ),
1265 1266 1267 1268 1269
            finally_clause=Nodes.ExprStatNode(
                pos, expr=ExprNodes.WithExitCallNode(
                    pos, with_stat=node,
                    test_if_run=True,
                    args=ExprNodes.TupleNode(
1270 1271
                        pos, args=[ExprNodes.NoneNode(pos) for _ in range(3)]),
                    await=ExprNodes.AwaitExprNode(pos, arg=None) if is_async else None)),
1272
            handle_error_case=False,
Stefan Behnel's avatar
Stefan Behnel committed
1273
        )
1274
        return node
1275

1276 1277 1278
    def visit_ExprNode(self, node):
        # With statements are never inside expressions.
        return node
1279

1280

1281
class DecoratorTransform(ScopeTrackingTransform, SkipDeclarations):
1282
    """
1283
    Transforms method decorators in cdef classes into nested calls or properties.
1284

1285 1286
    Python-style decorator properties are transformed into a PropertyNode
    with up to the three getter, setter and deleter DefNodes.
1287 1288
    The functional style isn't supported yet.
    """
1289 1290 1291 1292 1293 1294 1295
    _properties = None

    _map_property_attribute = {
        'getter': '__get__',
        'setter': '__set__',
        'deleter': '__del__',
    }.get
1296 1297

    def visit_CClassDefNode(self, node):
1298 1299 1300
        if self._properties is None:
            self._properties = []
        self._properties.append({})
Stefan Behnel's avatar
Stefan Behnel committed
1301
        super(DecoratorTransform, self).visit_CClassDefNode(node)
1302
        self._properties.pop()
1303 1304
        return node

1305
    def visit_PropertyNode(self, node):
1306 1307 1308
        # Low-level warning for other code until we can convert all our uses over.
        level = 2 if isinstance(node.pos[0], str) else 0
        warning(node.pos, "'property %s:' syntax is deprecated, use '@property'" % node.name, level)
1309 1310
        return node

1311
    def visit_DefNode(self, node):
1312 1313 1314
        scope_type = self.scope_type
        node = self.visit_FuncDefNode(node)
        if scope_type != 'cclass' or not node.decorators:
1315
            return node
1316 1317

        # transform @property decorators
1318
        properties = self._properties[-1]
1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334
        for decorator_node in node.decorators[::-1]:
            decorator = decorator_node.decorator
            if decorator.is_name and decorator.name == 'property':
                if len(node.decorators) > 1:
                    return self._reject_decorated_property(node, decorator_node)
                name = node.name
                node.name = '__get__'
                node.decorators.remove(decorator_node)
                stat_list = [node]
                if name in properties:
                    prop = properties[name]
                    prop.pos = node.pos
                    prop.doc = node.doc
                    prop.body.stats = stat_list
                    return []
                prop = Nodes.PropertyNode(node.pos, name=name)
1335
                prop.doc = node.doc
1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346
                prop.body = Nodes.StatListNode(node.pos, stats=stat_list)
                properties[name] = prop
                return [prop]
            elif decorator.is_attribute and decorator.obj.name in properties:
                handler_name = self._map_property_attribute(decorator.attribute)
                if handler_name:
                    assert decorator.obj.name == node.name
                    if len(node.decorators) > 1:
                        return self._reject_decorated_property(node, decorator_node)
                    return self._add_to_property(properties, node, handler_name, decorator_node)

1347
        # transform normal decorators
Stefan Behnel's avatar
Stefan Behnel committed
1348
        return self.chain_decorators(node, node.decorators, node.name)
1349 1350 1351

    @staticmethod
    def _reject_decorated_property(node, decorator_node):
1352 1353 1354 1355
        # restrict transformation to outermost decorator as wrapped properties will probably not work
        for deco in node.decorators:
            if deco != decorator_node:
                error(deco.pos, "Property methods with additional decorators are not supported")
1356 1357
        return node

1358 1359
    @staticmethod
    def _add_to_property(properties, node, name, decorator):
1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371
        prop = properties[node.name]
        node.name = name
        node.decorators.remove(decorator)
        stats = prop.body.stats
        for i, stat in enumerate(stats):
            if stat.name == name:
                stats[i] = node
                break
        else:
            stats.append(node)
        return []

1372
    @staticmethod
Stefan Behnel's avatar
Stefan Behnel committed
1373
    def chain_decorators(node, decorators, name):
1374 1375 1376 1377 1378
        """
        Decorators are applied directly in DefNode and PyClassDefNode to avoid
        reassignments to the function/class name - except for cdef class methods.
        For those, the reassignment is required as methods are originally
        defined in the PyMethodDef struct.
1379

1380 1381 1382
        The IndirectionNode allows DefNode to override the decorator.
        """
        decorator_result = ExprNodes.NameNode(node.pos, name=name)
1383
        for decorator in decorators[::-1]:
1384
            decorator_result = ExprNodes.SimpleCallNode(
1385
                decorator.pos,
1386 1387
                function=decorator.decorator,
                args=[decorator_result])
1388

1389
        name_node = ExprNodes.NameNode(node.pos, name=name)
1390
        reassignment = Nodes.SingleAssignmentNode(
1391
            node.pos,
1392 1393
            lhs=name_node,
            rhs=decorator_result)
1394 1395 1396

        reassignment = Nodes.IndirectionNode([reassignment])
        node.decorator_indirection = reassignment
1397
        return [node, reassignment]
1398

1399

1400 1401 1402 1403 1404 1405 1406 1407 1408
class CnameDirectivesTransform(CythonTransform, SkipDeclarations):
    """
    Only part of the CythonUtilityCode pipeline. Must be run before
    DecoratorTransform in case this is a decorator for a cdef class.
    It filters out @cname('my_cname') decorators and rewrites them to
    CnameDecoratorNodes.
    """

    def handle_function(self, node):
1409
        if not getattr(node, 'decorators', None):
1410 1411
            return self.visit_Node(node)

1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432
        for i, decorator in enumerate(node.decorators):
            decorator = decorator.decorator

            if (isinstance(decorator, ExprNodes.CallNode) and
                    decorator.function.is_name and
                    decorator.function.name == 'cname'):
                args, kwargs = decorator.explicit_args_kwds()

                if kwargs:
                    raise AssertionError(
                            "cname decorator does not take keyword arguments")

                if len(args) != 1:
                    raise AssertionError(
                            "cname decorator takes exactly one argument")

                if not (args[0].is_literal and
                        args[0].type == Builtin.str_type):
                    raise AssertionError(
                            "argument to cname decorator must be a string literal")

1433
                cname = args[0].compile_time_value(None)
1434 1435 1436 1437 1438
                del node.decorators[i]
                node = Nodes.CnameDecoratorNode(pos=node.pos, node=node,
                                                cname=cname)
                break

1439
        return self.visit_Node(node)
1440

1441 1442
    visit_FuncDefNode = handle_function
    visit_CClassDefNode = handle_function
1443 1444
    visit_CEnumDefNode = handle_function
    visit_CStructOrUnionDefNode = handle_function
1445 1446


1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483
class ForwardDeclareTypes(CythonTransform):

    def visit_CompilerDirectivesNode(self, node):
        env = self.module_scope
        old = env.directives
        env.directives = node.directives
        self.visitchildren(node)
        env.directives = old
        return node

    def visit_ModuleNode(self, node):
        self.module_scope = node.scope
        self.module_scope.directives = node.directives
        self.visitchildren(node)
        return node

    def visit_CDefExternNode(self, node):
        old_cinclude_flag = self.module_scope.in_cinclude
        self.module_scope.in_cinclude = 1
        self.visitchildren(node)
        self.module_scope.in_cinclude = old_cinclude_flag
        return node

    def visit_CEnumDefNode(self, node):
        node.declare(self.module_scope)
        return node

    def visit_CStructOrUnionDefNode(self, node):
        if node.name not in self.module_scope.entries:
            node.declare(self.module_scope)
        return node

    def visit_CClassDefNode(self, node):
        if node.class_name not in self.module_scope.entries:
            node.declare(self.module_scope)
        return node

1484

1485
class AnalyseDeclarationsTransform(EnvTransform):
1486

1487 1488 1489 1490 1491 1492
    basic_property = TreeFragment(u"""
property NAME:
    def __get__(self):
        return ATTR
    def __set__(self, value):
        ATTR = value
1493
    """, level='c_class', pipeline=[NormalizeTree(None)])
1494 1495 1496 1497 1498 1499 1500 1501
    basic_pyobject_property = TreeFragment(u"""
property NAME:
    def __get__(self):
        return ATTR
    def __set__(self, value):
        ATTR = value
    def __del__(self):
        ATTR = None
1502
    """, level='c_class', pipeline=[NormalizeTree(None)])
1503 1504 1505 1506
    basic_property_ro = TreeFragment(u"""
property NAME:
    def __get__(self):
        return ATTR
1507
    """, level='c_class', pipeline=[NormalizeTree(None)])
1508

1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521
    struct_or_union_wrapper = TreeFragment(u"""
cdef class NAME:
    cdef TYPE value
    def __init__(self, MEMBER=None):
        cdef int count
        count = 0
        INIT_ASSIGNMENTS
        if IS_UNION and count > 1:
            raise ValueError, "At most one union member should be specified."
    def __str__(self):
        return STR_FORMAT % MEMBER_TUPLE
    def __repr__(self):
        return REPR_FORMAT % MEMBER_TUPLE
1522
    """, pipeline=[NormalizeTree(None)])
1523 1524 1525 1526 1527

    init_assignment = TreeFragment(u"""
if VALUE is not None:
    ATTR = VALUE
    count += 1
1528
    """, pipeline=[NormalizeTree(None)])
1529

1530
    fused_function = None
1531
    in_lambda = 0
1532

1533
    def __call__(self, root):
1534
        # needed to determine if a cdef var is declared after it's used.
1535
        self.seen_vars_stack = []
1536
        self.fused_error_funcs = set()
1537 1538 1539
        super_class = super(AnalyseDeclarationsTransform, self)
        self._super_visit_FuncDefNode = super_class.visit_FuncDefNode
        return super_class.__call__(root)
1540

1541
    def visit_NameNode(self, node):
1542
        self.seen_vars_stack[-1].add(node.name)
1543 1544
        return node

1545
    def visit_ModuleNode(self, node):
Robert Bradshaw's avatar
Robert Bradshaw committed
1546
        self.seen_vars_stack.append(set())
1547
        node.analyse_declarations(self.current_env())
1548
        self.visitchildren(node)
1549
        self.seen_vars_stack.pop()
1550
        return node
Stefan Behnel's avatar
Stefan Behnel committed
1551 1552

    def visit_LambdaNode(self, node):
1553
        self.in_lambda += 1
1554
        node.analyse_declarations(self.current_env())
Stefan Behnel's avatar
Stefan Behnel committed
1555
        self.visitchildren(node)
1556
        self.in_lambda -= 1
Stefan Behnel's avatar
Stefan Behnel committed
1557 1558
        return node

1559 1560
    def visit_CClassDefNode(self, node):
        node = self.visit_ClassDefNode(node)
1561
        if node.scope and node.scope.implemented and node.body:
1562 1563 1564 1565 1566 1567 1568 1569 1570 1571
            stats = []
            for entry in node.scope.var_entries:
                if entry.needs_property:
                    property = self.create_Property(entry)
                    property.analyse_declarations(node.scope)
                    self.visit(property)
                    stats.append(property)
            if stats:
                node.body.stats += stats
        return node
1572

1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591
    def _handle_fused_def_decorators(self, old_decorators, env, node):
        """
        Create function calls to the decorators and reassignments to
        the function.
        """
        # Delete staticmethod and classmethod decorators, this is
        # handled directly by the fused function object.
        decorators = []
        for decorator in old_decorators:
            func = decorator.decorator
            if (not func.is_name or
                func.name not in ('staticmethod', 'classmethod') or
                env.lookup_here(func.name)):
                # not a static or classmethod
                decorators.append(decorator)

        if decorators:
            transform = DecoratorTransform(self.context)
            def_node = node.node
Stefan Behnel's avatar
Stefan Behnel committed
1592
            _, reassignments = transform.chain_decorators(
1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619
                def_node, decorators, def_node.name)
            reassignments.analyse_declarations(env)
            node = [node, reassignments]

        return node

    def _handle_def(self, decorators, env, node):
        "Handle def or cpdef fused functions"
        # Create PyCFunction nodes for each specialization
        node.stats.insert(0, node.py_func)
        node.py_func = self.visit(node.py_func)
        node.update_fused_defnode_entry(env)
        pycfunc = ExprNodes.PyCFunctionNode.from_defnode(node.py_func,
                                                         True)
        pycfunc = ExprNodes.ProxyNode(pycfunc.coerce_to_temp(env))
        node.resulting_fused_function = pycfunc
        # Create assignment node for our def function
        node.fused_func_assignment = self._create_assignment(
            node.py_func, ExprNodes.CloneNode(pycfunc), env)

        if decorators:
            node = self._handle_fused_def_decorators(decorators, env, node)

        return node

    def _create_fused_function(self, env, node):
        "Create a fused function for a DefNode with fused arguments"
1620
        from . import FusedNode
1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657

        if self.fused_function or self.in_lambda:
            if self.fused_function not in self.fused_error_funcs:
                if self.in_lambda:
                    error(node.pos, "Fused lambdas not allowed")
                else:
                    error(node.pos, "Cannot nest fused functions")

            self.fused_error_funcs.add(self.fused_function)

            node.body = Nodes.PassStatNode(node.pos)
            for arg in node.args:
                if arg.type.is_fused:
                    arg.type = arg.type.get_fused_types()[0]

            return node

        decorators = getattr(node, 'decorators', None)
        node = FusedNode.FusedCFuncDefNode(node, env)
        self.fused_function = node
        self.visitchildren(node)
        self.fused_function = None
        if node.py_func:
            node = self._handle_def(decorators, env, node)

        return node

    def _handle_nogil_cleanup(self, lenv, node):
        "Handle cleanup for 'with gil' blocks in nogil functions."
        if lenv.nogil and lenv.has_with_gil_block:
            # Acquire the GIL for cleanup in 'nogil' functions, by wrapping
            # the entire function body in try/finally.
            # The corresponding release will be taken care of by
            # Nodes.FuncDefNode.generate_function_definitions()
            node.body = Nodes.NogilTryFinallyStatNode(
                node.body.pos,
                body=node.body,
1658 1659
                finally_clause=Nodes.EnsureGILNode(node.body.pos),
                finally_except_clause=Nodes.EnsureGILNode(node.body.pos))
1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670

    def _handle_fused(self, node):
        if node.is_generator and node.has_fused_arguments:
            node.has_fused_arguments = False
            error(node.pos, "Fused generators not supported")
            node.gbody = Nodes.StatListNode(node.pos,
                                            stats=[],
                                            body=Nodes.PassStatNode(node.pos))

        return node.has_fused_arguments

1671
    def visit_FuncDefNode(self, node):
1672
        """
Stefan Behnel's avatar
Stefan Behnel committed
1673 1674 1675 1676 1677 1678 1679 1680
        Analyse a function and its body, as that hasn't happend yet.  Also
        analyse the directive_locals set by @cython.locals().

        Then, if we are a function with fused arguments, replace the function
        (after it has declared itself in the symbol table!) with a
        FusedCFuncDefNode, and analyse its children (which are in turn normal
        functions). If we're a normal function, just analyse the body of the
        function.
1681
        """
1682
        env = self.current_env()
1683

Robert Bradshaw's avatar
Robert Bradshaw committed
1684
        self.seen_vars_stack.append(set())
1685
        lenv = node.local_scope
1686
        node.declare_arguments(lenv)
1687

Stefan Behnel's avatar
Stefan Behnel committed
1688
        # @cython.locals(...)
1689 1690 1691 1692 1693 1694 1695
        for var, type_node in node.directive_locals.items():
            if not lenv.lookup_here(var):   # don't redeclare args
                type = type_node.analyse_as_type(lenv)
                if type:
                    lenv.declare_var(var, type, type_node.pos)
                else:
                    error(type_node.pos, "Not a type")
1696

1697 1698
        if self._handle_fused(node):
            node = self._create_fused_function(env, node)
1699 1700
        else:
            node.body.analyse_declarations(lenv)
1701
            self._handle_nogil_cleanup(lenv, node)
1702
            self._super_visit_FuncDefNode(node)
1703

1704
        self.seen_vars_stack.pop()
1705
        return node
1706

1707 1708
    def visit_DefNode(self, node):
        node = self.visit_FuncDefNode(node)
1709
        env = self.current_env()
1710
        if (not isinstance(node, Nodes.DefNode) or
Stefan Behnel's avatar
Stefan Behnel committed
1711 1712
                node.fused_py_func or node.is_generator_body or
                not node.needs_assignment_synthesis(env)):
1713 1714 1715
            return node
        return [node, self._synthesize_assignment(node, env)]

1716 1717 1718
    def visit_GeneratorBodyDefNode(self, node):
        return self.visit_FuncDefNode(node)

1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730
    def _synthesize_assignment(self, node, env):
        # Synthesize assignment node and put it right after defnode
        genv = env
        while genv.is_py_class_scope or genv.is_c_class_scope:
            genv = genv.outer_scope

        if genv.is_closure_scope:
            rhs = node.py_cfunc_node = ExprNodes.InnerFunctionNode(
                node.pos, def_node=node,
                pymethdef_cname=node.entry.pymethdef_cname,
                code_object=ExprNodes.CodeObjectNode(node))
        else:
1731 1732
            binding = self.current_directives.get('binding')
            rhs = ExprNodes.PyCFunctionNode.from_defnode(node, binding)
1733
            node.code_object = rhs.code_object
1734 1735 1736 1737 1738

        if env.is_py_class_scope:
            rhs.binding = True

        node.is_cyfunction = rhs.binding
1739
        return self._create_assignment(node, rhs, env)
1740

1741 1742 1743
    def _create_assignment(self, def_node, rhs, env):
        if def_node.decorators:
            for decorator in def_node.decorators[::-1]:
1744 1745 1746 1747
                rhs = ExprNodes.SimpleCallNode(
                    decorator.pos,
                    function = decorator.decorator,
                    args = [rhs])
1748
            def_node.decorators = None
1749 1750

        assmt = Nodes.SingleAssignmentNode(
1751 1752
            def_node.pos,
            lhs=ExprNodes.NameNode(def_node.pos, name=def_node.name),
1753 1754 1755 1756
            rhs=rhs)
        assmt.analyse_declarations(env)
        return assmt

1757
    def visit_ScopedExprNode(self, node):
1758
        env = self.current_env()
1759
        node.analyse_declarations(env)
Stefan Behnel's avatar
Stefan Behnel committed
1760
        # the node may or may not have a local scope
1761
        if node.has_local_scope:
Robert Bradshaw's avatar
Robert Bradshaw committed
1762
            self.seen_vars_stack.append(set(self.seen_vars_stack[-1]))
1763
            self.enter_scope(node, node.expr_scope)
1764
            node.analyse_scoped_declarations(node.expr_scope)
Stefan Behnel's avatar
Stefan Behnel committed
1765
            self.visitchildren(node)
1766
            self.exit_scope()
Stefan Behnel's avatar
Stefan Behnel committed
1767
            self.seen_vars_stack.pop()
1768
        else:
1769
            node.analyse_scoped_declarations(env)
Stefan Behnel's avatar
Stefan Behnel committed
1770
            self.visitchildren(node)
1771 1772
        return node

1773 1774
    def visit_TempResultFromStatNode(self, node):
        self.visitchildren(node)
1775
        node.analyse_declarations(self.current_env())
1776 1777
        return node

1778 1779 1780 1781 1782
    def visit_CppClassNode(self, node):
        if node.visibility == 'extern':
            return None
        else:
            return self.visit_ClassDefNode(node)
1783

1784
    def visit_CStructOrUnionDefNode(self, node):
1785
        # Create a wrapper node if needed.
1786 1787 1788
        # We want to use the struct type information (so it can't happen
        # before this phase) but also create new objects to be declared
        # (so it can't happen later).
1789
        # Note that we don't return the original node, as it is
1790 1791 1792
        # never used after this phase.
        if True: # private (default)
            return None
1793

1794 1795 1796 1797 1798 1799 1800 1801 1802 1803
        self_value = ExprNodes.AttributeNode(
            pos = node.pos,
            obj = ExprNodes.NameNode(pos=node.pos, name=u"self"),
            attribute = EncodedString(u"value"))
        var_entries = node.entry.type.scope.var_entries
        attributes = []
        for entry in var_entries:
            attributes.append(ExprNodes.AttributeNode(pos = entry.pos,
                                                      obj = self_value,
                                                      attribute = entry.name))
1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840
        # __init__ assignments
        init_assignments = []
        for entry, attr in zip(var_entries, attributes):
            # TODO: branch on visibility
            init_assignments.append(self.init_assignment.substitute({
                    u"VALUE": ExprNodes.NameNode(entry.pos, name = entry.name),
                    u"ATTR": attr,
                }, pos = entry.pos))

        # create the class
        str_format = u"%s(%s)" % (node.entry.type.name, ("%s, " * len(attributes))[:-2])
        wrapper_class = self.struct_or_union_wrapper.substitute({
            u"INIT_ASSIGNMENTS": Nodes.StatListNode(node.pos, stats = init_assignments),
            u"IS_UNION": ExprNodes.BoolNode(node.pos, value = not node.entry.type.is_struct),
            u"MEMBER_TUPLE": ExprNodes.TupleNode(node.pos, args=attributes),
            u"STR_FORMAT": ExprNodes.StringNode(node.pos, value = EncodedString(str_format)),
            u"REPR_FORMAT": ExprNodes.StringNode(node.pos, value = EncodedString(str_format.replace("%s", "%r"))),
        }, pos = node.pos).stats[0]
        wrapper_class.class_name = node.name
        wrapper_class.shadow = True
        class_body = wrapper_class.body.stats

        # fix value type
        assert isinstance(class_body[0].base_type, Nodes.CSimpleBaseTypeNode)
        class_body[0].base_type.name = node.name

        # fix __init__ arguments
        init_method = class_body[1]
        assert isinstance(init_method, Nodes.DefNode) and init_method.name == '__init__'
        arg_template = init_method.args[1]
        if not node.entry.type.is_struct:
            arg_template.kw_only = True
        del init_method.args[1]
        for entry, attr in zip(var_entries, attributes):
            arg = copy.deepcopy(arg_template)
            arg.declarator.name = entry.name
            init_method.args.append(arg)
Robert Bradshaw's avatar
Robert Bradshaw committed
1841

1842
        # setters/getters
1843 1844 1845 1846 1847 1848 1849 1850 1851 1852
        for entry, attr in zip(var_entries, attributes):
            # TODO: branch on visibility
            if entry.type.is_pyobject:
                template = self.basic_pyobject_property
            else:
                template = self.basic_property
            property = template.substitute({
                    u"ATTR": attr,
                }, pos = entry.pos).stats[0]
            property.name = entry.name
1853
            wrapper_class.body.stats.append(property)
Robert Bradshaw's avatar
Robert Bradshaw committed
1854

1855
        wrapper_class.analyse_declarations(self.current_env())
1856
        return self.visit_CClassDefNode(wrapper_class)
1857

1858 1859 1860 1861
    # Some nodes are no longer needed after declaration
    # analysis and can be dropped. The analysis was performed
    # on these nodes in a seperate recursive process from the
    # enclosing function or module, so we can simply drop them.
1862
    def visit_CDeclaratorNode(self, node):
1863 1864
        # necessary to ensure that all CNameDeclaratorNodes are visited.
        self.visitchildren(node)
1865
        return node
1866

1867 1868 1869 1870 1871
    def visit_CTypeDefNode(self, node):
        return node

    def visit_CBaseTypeNode(self, node):
        return None
1872

1873
    def visit_CEnumDefNode(self, node):
1874 1875 1876 1877
        if node.visibility == 'public':
            return node
        else:
            return None
1878

1879
    def visit_CNameDeclaratorNode(self, node):
1880
        if node.name in self.seen_vars_stack[-1]:
1881
            entry = self.current_env().lookup(node.name)
Dag Sverre Seljebotn's avatar
Dag Sverre Seljebotn committed
1882 1883
            if (entry is None or entry.visibility != 'extern'
                and not entry.scope.is_c_class_scope):
1884
                warning(node.pos, "cdef variable '%s' declared after it is used" % node.name, 2)
1885 1886 1887
        self.visitchildren(node)
        return node

1888
    def visit_CVarDefNode(self, node):
1889 1890
        # to ensure all CNameDeclaratorNodes are visited.
        self.visitchildren(node)
1891
        return None
1892

1893
    def visit_CnameDecoratorNode(self, node):
1894 1895
        child_node = self.visit(node.node)
        if not child_node:
1896
            return None
1897 1898 1899 1900
        if type(child_node) is list: # Assignment synthesized
            node.child_node = child_node[0]
            return [node] + child_node[1:]
        node.node = child_node
1901 1902
        return node

1903
    def create_Property(self, entry):
1904
        if entry.visibility == 'public':
1905 1906 1907 1908
            if entry.type.is_pyobject:
                template = self.basic_pyobject_property
            else:
                template = self.basic_property
1909 1910
        elif entry.visibility == 'readonly':
            template = self.basic_property_ro
1911
        property = template.substitute({
1912
                u"ATTR": ExprNodes.AttributeNode(pos=entry.pos,
1913
                                                 obj=ExprNodes.NameNode(pos=entry.pos, name="self"),
1914
                                                 attribute=entry.name),
1915 1916
            }, pos=entry.pos).stats[0]
        property.name = entry.name
1917
        property.doc = entry.doc
1918
        return property
1919

1920

1921 1922 1923 1924 1925 1926 1927 1928
class CalculateQualifiedNamesTransform(EnvTransform):
    """
    Calculate and store the '__qualname__' and the global
    module name on some nodes.
    """
    def visit_ModuleNode(self, node):
        self.module_name = self.global_scope().qualified_name
        self.qualified_name = []
1929 1930 1931
        _super = super(CalculateQualifiedNamesTransform, self)
        self._super_visit_FuncDefNode = _super.visit_FuncDefNode
        self._super_visit_ClassDefNode = _super.visit_ClassDefNode
1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943
        self.visitchildren(node)
        return node

    def _set_qualname(self, node, name=None):
        if name:
            qualname = self.qualified_name[:]
            qualname.append(name)
        else:
            qualname = self.qualified_name
        node.qualname = EncodedString('.'.join(qualname))
        node.module_name = self.module_name

1944 1945 1946
    def _append_entry(self, entry):
        if entry.is_pyglobal and not entry.is_pyclass_attr:
            self.qualified_name = [entry.name]
1947
        else:
1948
            self.qualified_name.append(entry.name)
1949 1950

    def visit_ClassNode(self, node):
1951 1952 1953
        self._set_qualname(node, node.name)
        self.visitchildren(node)
        return node
1954 1955

    def visit_PyClassNamespaceNode(self, node):
1956
        # class name was already added by parent node
1957 1958 1959
        self._set_qualname(node)
        self.visitchildren(node)
        return node
1960 1961

    def visit_PyCFunctionNode(self, node):
1962 1963 1964
        self._set_qualname(node, node.def_node.name)
        self.visitchildren(node)
        return node
1965

1966
    def visit_DefNode(self, node):
1967
        self._set_qualname(node, node.name)
1968 1969 1970
        return self.visit_FuncDefNode(node)

    def visit_FuncDefNode(self, node):
1971
        orig_qualified_name = self.qualified_name[:]
1972 1973 1974 1975
        if getattr(node, 'name', None) == '<lambda>':
            self.qualified_name.append('<lambda>')
        else:
            self._append_entry(node.entry)
1976
        self.qualified_name.append('<locals>')
1977
        self._super_visit_FuncDefNode(node)
1978 1979 1980 1981 1982
        self.qualified_name = orig_qualified_name
        return node

    def visit_ClassDefNode(self, node):
        orig_qualified_name = self.qualified_name[:]
1983 1984 1985 1986
        entry = (getattr(node, 'entry', None) or             # PyClass
                 self.current_env().lookup_here(node.name))  # CClass
        self._append_entry(entry)
        self._super_visit_ClassDefNode(node)
1987 1988 1989 1990
        self.qualified_name = orig_qualified_name
        return node


1991
class AnalyseExpressionsTransform(CythonTransform):
1992

1993
    def visit_ModuleNode(self, node):
1994
        node.scope.infer_types()
1995
        node.body = node.body.analyse_expressions(node.scope)
1996 1997
        self.visitchildren(node)
        return node
1998

1999
    def visit_FuncDefNode(self, node):
2000
        node.local_scope.infer_types()
2001
        node.body = node.body.analyse_expressions(node.local_scope)
2002 2003
        self.visitchildren(node)
        return node
2004 2005

    def visit_ScopedExprNode(self, node):
2006
        if node.has_local_scope:
2007
            node.expr_scope.infer_types()
2008
            node = node.analyse_scoped_expressions(node.expr_scope)
2009 2010
        self.visitchildren(node)
        return node
2011

2012 2013 2014
    def visit_IndexNode(self, node):
        """
        Replace index nodes used to specialize cdef functions with fused
Mark Florisson's avatar
Mark Florisson committed
2015 2016 2017
        argument types with the Attribute- or NameNode referring to the
        function. We then need to copy over the specialization properties to
        the attribute or name node.
2018 2019 2020 2021

        Because the indexing might be a Python indexing operation on a fused
        function, or (usually) a Cython indexing operation, we need to
        re-analyse the types.
2022 2023
        """
        self.visit_Node(node)
2024
        if node.is_fused_index and not node.type.is_error:
2025
            node = node.base
2026
        return node
2027

2028

2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 2050
class FindInvalidUseOfFusedTypes(CythonTransform):

    def visit_FuncDefNode(self, node):
        # Errors related to use in functions with fused args will already
        # have been detected
        if not node.has_fused_arguments:
            if not node.is_generator_body and node.return_type.is_fused:
                error(node.pos, "Return type is not specified as argument type")
            else:
                self.visitchildren(node)

        return node

    def visit_ExprNode(self, node):
        if node.type and node.type.is_fused:
            error(node.pos, "Invalid use of fused types, type cannot be specialized")
        else:
            self.visitchildren(node)

        return node


2051
class ExpandInplaceOperators(EnvTransform):
2052

2053 2054 2055 2056 2057 2058
    def visit_InPlaceAssignmentNode(self, node):
        lhs = node.lhs
        rhs = node.rhs
        if lhs.type.is_cpp_class:
            # No getting around this exact operator here.
            return node
2059 2060
        if isinstance(lhs, ExprNodes.BufferIndexNode):
            # There is code to handle this case in InPlaceAssignmentNode
2061 2062
            return node

Robert Bradshaw's avatar
Robert Bradshaw committed
2063
        env = self.current_env()
2064
        def side_effect_free_reference(node, setting=False):
2065
            if node.is_name:
Robert Bradshaw's avatar
Robert Bradshaw committed
2066 2067
                return node, []
            elif node.type.is_pyobject and not setting:
2068 2069
                node = LetRefNode(node)
                return node, [node]
2070
            elif node.is_subscript:
2071 2072
                base, temps = side_effect_free_reference(node.base)
                index = LetRefNode(node.index)
2073
                return ExprNodes.IndexNode(node.pos, base=base, index=index), temps + [index]
2074
            elif node.is_attribute:
2075
                obj, temps = side_effect_free_reference(node.obj)
2076
                return ExprNodes.AttributeNode(node.pos, obj=obj, attribute=node.attribute), temps
2077 2078
            elif isinstance(node, ExprNodes.BufferIndexNode):
                raise ValueError("Don't allow things like attributes of buffer indexing operations")
2079 2080 2081 2082 2083 2084 2085 2086
            else:
                node = LetRefNode(node)
                return node, [node]
        try:
            lhs, let_ref_nodes = side_effect_free_reference(lhs, setting=True)
        except ValueError:
            return node
        dup = lhs.__class__(**lhs.__dict__)
2087
        binop = ExprNodes.binop_node(node.pos,
2088 2089 2090 2091
                                     operator = node.operator,
                                     operand1 = dup,
                                     operand2 = rhs,
                                     inplace=True)
Robert Bradshaw's avatar
Robert Bradshaw committed
2092 2093 2094
        # Manually analyse types for new node.
        lhs.analyse_target_types(env)
        dup.analyse_types(env)
Robert Bradshaw's avatar
Robert Bradshaw committed
2095
        binop.analyse_operation(env)
2096
        node = Nodes.SingleAssignmentNode(
2097
            node.pos,
2098 2099
            lhs = lhs,
            rhs=binop.coerce_to(lhs.type, env))
2100 2101 2102 2103 2104 2105 2106 2107 2108 2109
        # Use LetRefNode to avoid side effects.
        let_ref_nodes.reverse()
        for t in let_ref_nodes:
            node = LetNode(t, node)
        return node

    def visit_ExprNode(self, node):
        # In-place assignments can't happen within an expression.
        return node

2110

Haoyu Bai's avatar
Haoyu Bai committed
2111 2112 2113 2114 2115 2116 2117
class AdjustDefByDirectives(CythonTransform, SkipDeclarations):
    """
    Adjust function and class definitions by the decorator directives:

    @cython.cfunc
    @cython.cclass
    @cython.ccall
2118
    @cython.inline
Haoyu Bai's avatar
Haoyu Bai committed
2119 2120 2121 2122 2123 2124 2125 2126 2127 2128 2129 2130 2131 2132 2133 2134
    """

    def visit_ModuleNode(self, node):
        self.directives = node.directives
        self.in_py_class = False
        self.visitchildren(node)
        return node

    def visit_CompilerDirectivesNode(self, node):
        old_directives = self.directives
        self.directives = node.directives
        self.visitchildren(node)
        self.directives = old_directives
        return node

    def visit_DefNode(self, node):
2135 2136 2137
        modifiers = []
        if 'inline' in self.directives:
            modifiers.append('inline')
Haoyu Bai's avatar
Haoyu Bai committed
2138
        if 'ccall' in self.directives:
2139 2140
            node = node.as_cfunction(
                overridable=True, returns=self.directives.get('returns'), modifiers=modifiers)
Haoyu Bai's avatar
Haoyu Bai committed
2141
            return self.visit(node)
Haoyu Bai's avatar
Haoyu Bai committed
2142 2143 2144 2145
        if 'cfunc' in self.directives:
            if self.in_py_class:
                error(node.pos, "cfunc directive is not allowed here")
            else:
2146 2147
                node = node.as_cfunction(
                    overridable=False, returns=self.directives.get('returns'), modifiers=modifiers)
Haoyu Bai's avatar
Haoyu Bai committed
2148
                return self.visit(node)
2149 2150
        if 'inline' in modifiers:
            error(node.pos, "Python functions cannot be declared 'inline'")
Haoyu Bai's avatar
Haoyu Bai committed
2151 2152 2153 2154
        self.visitchildren(node)
        return node

    def visit_PyClassDefNode(self, node):
2155 2156 2157 2158 2159 2160 2161 2162 2163
        if 'cclass' in self.directives:
            node = node.as_cclass()
            return self.visit(node)
        else:
            old_in_pyclass = self.in_py_class
            self.in_py_class = True
            self.visitchildren(node)
            self.in_py_class = old_in_pyclass
            return node
Haoyu Bai's avatar
Haoyu Bai committed
2164 2165 2166 2167 2168 2169 2170

    def visit_CClassDefNode(self, node):
        old_in_pyclass = self.in_py_class
        self.in_py_class = False
        self.visitchildren(node)
        self.in_py_class = old_in_pyclass
        return node
2171

2172

2173 2174
class AlignFunctionDefinitions(CythonTransform):
    """
2175 2176
    This class takes the signatures from a .pxd file and applies them to
    the def methods in a .py file.
2177
    """
2178

2179 2180
    def visit_ModuleNode(self, node):
        self.scope = node.scope
2181
        self.directives = node.directives
2182
        self.imported_names = set()  # hack, see visit_FromImportStatNode()
2183 2184
        self.visitchildren(node)
        return node
2185

2186 2187 2188 2189 2190
    def visit_PyClassDefNode(self, node):
        pxd_def = self.scope.lookup(node.name)
        if pxd_def:
            if pxd_def.is_cclass:
                return self.visit_CClassDefNode(node.as_cclass(), pxd_def)
2191
            elif not pxd_def.scope or not pxd_def.scope.is_builtin_scope:
2192
                error(node.pos, "'%s' redeclared" % node.name)
2193 2194
                if pxd_def.pos:
                    error(pxd_def.pos, "previous declaration here")
2195
                return None
2196
        return node
2197

2198 2199 2200 2201
    def visit_CClassDefNode(self, node, pxd_def=None):
        if pxd_def is None:
            pxd_def = self.scope.lookup(node.class_name)
        if pxd_def:
2202 2203
            if not pxd_def.defined_in_pxd:
                return node
2204 2205 2206 2207 2208 2209
            outer_scope = self.scope
            self.scope = pxd_def.type.scope
        self.visitchildren(node)
        if pxd_def:
            self.scope = outer_scope
        return node
2210

2211 2212
    def visit_DefNode(self, node):
        pxd_def = self.scope.lookup(node.name)
2213
        if pxd_def and (not pxd_def.scope or not pxd_def.scope.is_builtin_scope):
2214
            if not pxd_def.is_cfunction:
2215
                error(node.pos, "'%s' redeclared" % node.name)
2216 2217
                if pxd_def.pos:
                    error(pxd_def.pos, "previous declaration here")
2218
                return None
2219
            node = node.as_cfunction(pxd_def)
2220
        elif (self.scope.is_module_scope and self.directives['auto_cpdef']
2221
              and not node.name in self.imported_names
2222
              and node.is_cdef_func_compatible()):
2223
            # FIXME: cpdef-ing should be done in analyse_declarations()
2224
            node = node.as_cfunction(scope=self.scope)
2225
        # Enable this when nested cdef functions are allowed.
2226 2227
        # self.visitchildren(node)
        return node
2228

2229 2230 2231 2232 2233 2234 2235 2236 2237
    def visit_FromImportStatNode(self, node):
        # hack to prevent conditional import fallback functions from
        # being cdpef-ed (global Python variables currently conflict
        # with imports)
        if self.scope.is_module_scope:
            for name, _ in node.items:
                self.imported_names.add(name)
        return node

2238 2239 2240 2241
    def visit_ExprNode(self, node):
        # ignore lambdas and everything else that appears in expressions
        return node

2242

2243 2244 2245 2246 2247 2248 2249 2250 2251 2252 2253 2254 2255 2256 2257 2258 2259 2260 2261 2262 2263 2264 2265 2266 2267 2268 2269 2270 2271 2272 2273 2274
class RemoveUnreachableCode(CythonTransform):
    def visit_StatListNode(self, node):
        if not self.current_directives['remove_unreachable']:
            return node
        self.visitchildren(node)
        for idx, stat in enumerate(node.stats):
            idx += 1
            if stat.is_terminator:
                if idx < len(node.stats):
                    if self.current_directives['warn.unreachable']:
                        warning(node.stats[idx].pos, "Unreachable code", 2)
                    node.stats = node.stats[:idx]
                node.is_terminator = True
                break
        return node

    def visit_IfClauseNode(self, node):
        self.visitchildren(node)
        if node.body.is_terminator:
            node.is_terminator = True
        return node

    def visit_IfStatNode(self, node):
        self.visitchildren(node)
        if node.else_clause and node.else_clause.is_terminator:
            for clause in node.if_clauses:
                if not clause.is_terminator:
                    break
            else:
                node.is_terminator = True
        return node

2275 2276
    def visit_TryExceptStatNode(self, node):
        self.visitchildren(node)
2277 2278 2279
        if node.body.is_terminator and node.else_clause:
            if self.current_directives['warn.unreachable']:
                warning(node.else_clause.pos, "Unreachable code", 2)
2280 2281 2282
            node.else_clause = None
        return node

2283

2284 2285 2286 2287
class YieldNodeCollector(TreeVisitor):

    def __init__(self):
        super(YieldNodeCollector, self).__init__()
2288
        self.yields = []
2289
        self.awaits = []
2290 2291
        self.returns = []
        self.has_return_value = False
2292

2293
    def visit_Node(self, node):
2294
        self.visitchildren(node)
2295 2296

    def visit_YieldExprNode(self, node):
2297
        self.yields.append(node)
Vitja Makarov's avatar
Vitja Makarov committed
2298
        self.visitchildren(node)
2299

2300 2301 2302 2303
    def visit_AwaitExprNode(self, node):
        self.awaits.append(node)
        self.visitchildren(node)

2304
    def visit_ReturnStatNode(self, node):
2305
        self.visitchildren(node)
2306 2307 2308
        if node.value:
            self.has_return_value = True
        self.returns.append(node)
2309 2310 2311 2312

    def visit_ClassDefNode(self, node):
        pass

2313
    def visit_FuncDefNode(self, node):
2314
        pass
2315

Vitja Makarov's avatar
Vitja Makarov committed
2316 2317 2318
    def visit_LambdaNode(self, node):
        pass

Vitja Makarov's avatar
Vitja Makarov committed
2319 2320 2321
    def visit_GeneratorExpressionNode(self, node):
        pass

2322 2323 2324 2325 2326
    def visit_CArgDeclNode(self, node):
        # do not look into annotations
        # FIXME: support (yield) in default arguments (currently crashes)
        pass

2327

2328
class MarkClosureVisitor(CythonTransform):
2329 2330 2331 2332 2333 2334

    def visit_ModuleNode(self, node):
        self.needs_closure = False
        self.visitchildren(node)
        return node

Robert Bradshaw's avatar
Robert Bradshaw committed
2335 2336 2337 2338 2339
    def visit_FuncDefNode(self, node):
        self.needs_closure = False
        self.visitchildren(node)
        node.needs_closure = self.needs_closure
        self.needs_closure = True
2340

2341 2342 2343
        collector = YieldNodeCollector()
        collector.visitchildren(node)

2344 2345 2346 2347 2348 2349 2350 2351 2352 2353
        if node.is_async_def:
            if collector.yields:
                error(collector.yields[0].pos, "'yield' not allowed in async coroutines (use 'await')")
            yields = collector.awaits
        elif collector.yields:
            if collector.awaits:
                error(collector.yields[0].pos, "'await' not allowed in generators (use 'yield')")
            yields = collector.yields
        else:
            return node
2354

2355 2356 2357 2358 2359 2360 2361 2362 2363 2364 2365 2366 2367
        for i, yield_expr in enumerate(yields, 1):
            yield_expr.label_num = i
        for retnode in collector.returns:
            retnode.in_generator = True

        gbody = Nodes.GeneratorBodyDefNode(
            pos=node.pos, name=node.name, body=node.body)
        coroutine = (Nodes.AsyncDefNode if node.is_async_def else Nodes.GeneratorDefNode)(
            pos=node.pos, name=node.name, args=node.args,
            star_arg=node.star_arg, starstar_arg=node.starstar_arg,
            doc=node.doc, decorators=node.decorators,
            gbody=gbody, lambda_name=node.lambda_name)
        return coroutine
2368

2369
    def visit_CFuncDefNode(self, node):
2370 2371 2372 2373
        self.needs_closure = False
        self.visitchildren(node)
        node.needs_closure = self.needs_closure
        self.needs_closure = True
2374 2375
        if node.needs_closure and node.overridable:
            error(node.pos, "closures inside cpdef functions not yet supported")
2376
        return node
Stefan Behnel's avatar
Stefan Behnel committed
2377 2378 2379 2380 2381 2382 2383 2384

    def visit_LambdaNode(self, node):
        self.needs_closure = False
        self.visitchildren(node)
        node.needs_closure = self.needs_closure
        self.needs_closure = True
        return node

Robert Bradshaw's avatar
Robert Bradshaw committed
2385 2386 2387 2388
    def visit_ClassDefNode(self, node):
        self.visitchildren(node)
        self.needs_closure = True
        return node
Stefan Behnel's avatar
Stefan Behnel committed
2389

2390

2391
class CreateClosureClasses(CythonTransform):
2392
    # Output closure classes in module scope for all functions
Vitja Makarov's avatar
Vitja Makarov committed
2393 2394 2395 2396 2397 2398 2399
    # that really need it.

    def __init__(self, context):
        super(CreateClosureClasses, self).__init__(context)
        self.path = []
        self.in_lambda = False

2400 2401 2402 2403 2404
    def visit_ModuleNode(self, node):
        self.module_scope = node.scope
        self.visitchildren(node)
        return node

Stefan Behnel's avatar
Stefan Behnel committed
2405
    def find_entries_used_in_closures(self, node):
Vitja Makarov's avatar
Vitja Makarov committed
2406 2407 2408 2409 2410
        from_closure = []
        in_closure = []
        for name, entry in node.local_scope.entries.items():
            if entry.from_closure:
                from_closure.append((name, entry))
Stefan Behnel's avatar
Stefan Behnel committed
2411
            elif entry.in_closure:
Vitja Makarov's avatar
Vitja Makarov committed
2412 2413 2414 2415
                in_closure.append((name, entry))
        return from_closure, in_closure

    def create_class_from_scope(self, node, target_module_scope, inner_node=None):
2416 2417 2418 2419 2420 2421
        # move local variables into closure
        if node.is_generator:
            for entry in node.local_scope.entries.values():
                if not entry.from_closure:
                    entry.in_closure = True

Stefan Behnel's avatar
Stefan Behnel committed
2422
        from_closure, in_closure = self.find_entries_used_in_closures(node)
Vitja Makarov's avatar
Vitja Makarov committed
2423 2424 2425 2426 2427 2428
        in_closure.sort()

        # Now from the begining
        node.needs_closure = False
        node.needs_outer_scope = False

2429
        func_scope = node.local_scope
Vitja Makarov's avatar
Vitja Makarov committed
2430 2431 2432 2433
        cscope = node.entry.scope
        while cscope.is_py_class_scope or cscope.is_c_class_scope:
            cscope = cscope.outer_scope

2434
        if not from_closure and (self.path or inner_node):
Vitja Makarov's avatar
Vitja Makarov committed
2435
            if not inner_node:
2436
                if not node.py_cfunc_node:
2437
                    raise InternalError("DefNode does not have assignment node")
2438
                inner_node = node.py_cfunc_node
Vitja Makarov's avatar
Vitja Makarov committed
2439 2440
            inner_node.needs_self_code = False
            node.needs_outer_scope = False
2441 2442

        if node.is_generator:
2443
            pass
2444
        elif not in_closure and not from_closure:
Vitja Makarov's avatar
Vitja Makarov committed
2445 2446 2447 2448 2449 2450 2451
            return
        elif not in_closure:
            func_scope.is_passthrough = True
            func_scope.scope_class = cscope.scope_class
            node.needs_outer_scope = True
            return

2452 2453 2454
        as_name = '%s_%s' % (
            target_module_scope.next_id(Naming.closure_class_prefix),
            node.entry.cname)
2455

Stefan Behnel's avatar
Stefan Behnel committed
2456 2457
        entry = target_module_scope.declare_c_class(
            name=as_name, pos=node.pos, defining=True,
2458
            implementing=True)
2459
        entry.type.is_final_type = True
Stefan Behnel's avatar
Stefan Behnel committed
2460

Robert Bradshaw's avatar
Robert Bradshaw committed
2461
        func_scope.scope_class = entry
2462
        class_scope = entry.type.scope
2463
        class_scope.is_internal = True
2464 2465
        if Options.closure_freelist_size:
            class_scope.directives['freelist'] = Options.closure_freelist_size
2466

Vitja Makarov's avatar
Vitja Makarov committed
2467 2468
        if from_closure:
            assert cscope.is_closure_scope
2469
            class_scope.declare_var(pos=node.pos,
Vitja Makarov's avatar
Vitja Makarov committed
2470
                                    name=Naming.outer_scope_cname,
2471
                                    cname=Naming.outer_scope_cname,
2472
                                    type=cscope.scope_class.type,
2473
                                    is_cdef=True)
Vitja Makarov's avatar
Vitja Makarov committed
2474 2475
            node.needs_outer_scope = True
        for name, entry in in_closure:
2476
            closure_entry = class_scope.declare_var(pos=entry.pos,
2477
                                    name=entry.name,
2478
                                    cname=entry.cname,
2479 2480
                                    type=entry.type,
                                    is_cdef=True)
2481 2482
            if entry.is_declared_generic:
                closure_entry.is_declared_generic = 1
Vitja Makarov's avatar
Vitja Makarov committed
2483 2484 2485 2486 2487
        node.needs_closure = True
        # Do it here because other classes are already checked
        target_module_scope.check_c_class(func_scope.scope_class)

    def visit_LambdaNode(self, node):
2488 2489 2490 2491
        if not isinstance(node.def_node, Nodes.DefNode):
            # fused function, an error has been previously issued
            return node

Vitja Makarov's avatar
Vitja Makarov committed
2492 2493 2494 2495 2496 2497 2498
        was_in_lambda = self.in_lambda
        self.in_lambda = True
        self.create_class_from_scope(node.def_node, self.module_scope, node)
        self.visitchildren(node)
        self.in_lambda = was_in_lambda
        return node

2499
    def visit_FuncDefNode(self, node):
Vitja Makarov's avatar
Vitja Makarov committed
2500 2501 2502 2503
        if self.in_lambda:
            self.visitchildren(node)
            return node
        if node.needs_closure or self.path:
Robert Bradshaw's avatar
Robert Bradshaw committed
2504
            self.create_class_from_scope(node, self.module_scope)
Vitja Makarov's avatar
Vitja Makarov committed
2505
            self.path.append(node)
2506
            self.visitchildren(node)
Vitja Makarov's avatar
Vitja Makarov committed
2507
            self.path.pop()
2508
        return node
2509

2510 2511 2512 2513
    def visit_GeneratorBodyDefNode(self, node):
        self.visitchildren(node)
        return node

2514
    def visit_CFuncDefNode(self, node):
Robert Bradshaw's avatar
Robert Bradshaw committed
2515 2516 2517 2518 2519
        if not node.overridable:
            return self.visit_FuncDefNode(node)
        else:
            self.visitchildren(node)
            return node
2520

2521 2522 2523 2524 2525 2526

class GilCheck(VisitorTransform):
    """
    Call `node.gil_check(env)` on each node to make sure we hold the
    GIL when we need it.  Raise an error when on Python operations
    inside a `nogil` environment.
2527 2528 2529

    Additionally, raise exceptions for closely nested with gil or with nogil
    statements. The latter would abort Python.
2530
    """
2531

2532 2533
    def __call__(self, root):
        self.env_stack = [root.scope]
2534
        self.nogil = False
2535 2536 2537 2538

        # True for 'cdef func() nogil:' functions, as the GIL may be held while
        # calling this function (thus contained 'nogil' blocks may be valid).
        self.nogil_declarator_only = False
2539 2540 2541 2542
        return super(GilCheck, self).__call__(root)

    def visit_FuncDefNode(self, node):
        self.env_stack.append(node.local_scope)
2543 2544
        was_nogil = self.nogil
        self.nogil = node.local_scope.nogil
Mark Florisson's avatar
Mark Florisson committed
2545

2546 2547 2548
        if self.nogil:
            self.nogil_declarator_only = True

2549 2550
        if self.nogil and node.nogil_check:
            node.nogil_check(node.local_scope)
Mark Florisson's avatar
Mark Florisson committed
2551

2552
        self.visitchildren(node)
2553 2554 2555 2556

        # This cannot be nested, so it doesn't need backup/restore
        self.nogil_declarator_only = False

2557
        self.env_stack.pop()
2558
        self.nogil = was_nogil
2559 2560 2561
        return node

    def visit_GILStatNode(self, node):
Mark Florisson's avatar
Mark Florisson committed
2562 2563 2564
        if self.nogil and node.nogil_check:
            node.nogil_check()

2565 2566
        was_nogil = self.nogil
        self.nogil = (node.state == 'nogil')
2567 2568 2569 2570 2571 2572 2573 2574 2575

        if was_nogil == self.nogil and not self.nogil_declarator_only:
            if not was_nogil:
                error(node.pos, "Trying to acquire the GIL while it is "
                                "already held.")
            else:
                error(node.pos, "Trying to release the GIL while it was "
                                "previously released.")

2576 2577 2578 2579 2580
        if isinstance(node.finally_clause, Nodes.StatListNode):
            # The finally clause of the GILStatNode is a GILExitNode,
            # which is wrapped in a StatListNode. Just unpack that.
            node.finally_clause, = node.finally_clause.stats

2581
        self.visitchildren(node)
2582
        self.nogil = was_nogil
2583 2584
        return node

Mark Florisson's avatar
Mark Florisson committed
2585
    def visit_ParallelRangeNode(self, node):
2586 2587
        if node.nogil:
            node.nogil = False
Mark Florisson's avatar
Mark Florisson committed
2588 2589 2590 2591 2592 2593 2594 2595 2596 2597 2598 2599 2600 2601 2602 2603 2604 2605 2606 2607 2608 2609 2610 2611 2612
            node = Nodes.GILStatNode(node.pos, state='nogil', body=node)
            return self.visit_GILStatNode(node)

        if not self.nogil:
            error(node.pos, "prange() can only be used without the GIL")
            # Forget about any GIL-related errors that may occur in the body
            return None

        node.nogil_check(self.env_stack[-1])
        self.visitchildren(node)
        return node

    def visit_ParallelWithBlockNode(self, node):
        if not self.nogil:
            error(node.pos, "The parallel section may only be used without "
                            "the GIL")
            return None

        if node.nogil_check:
            # It does not currently implement this, but test for it anyway to
            # avoid potential future surprises
            node.nogil_check(self.env_stack[-1])

        self.visitchildren(node)
        return node
2613 2614 2615

    def visit_TryFinallyStatNode(self, node):
        """
2616
        Take care of try/finally statements in nogil code sections.
2617 2618 2619 2620 2621 2622
        """
        if not self.nogil or isinstance(node, Nodes.GILStatNode):
            return self.visit_Node(node)

        node.nogil_check = None
        node.is_try_finally_in_nogil = True
2623
        self.visitchildren(node)
Mark Florisson's avatar
Mark Florisson committed
2624
        return node
Mark Florisson's avatar
Mark Florisson committed
2625

2626
    def visit_Node(self, node):
2627 2628
        if self.env_stack and self.nogil and node.nogil_check:
            node.nogil_check(self.env_stack[-1])
2629
        self.visitchildren(node)
2630
        node.in_nogil_context = self.nogil
2631 2632
        return node

2633

Robert Bradshaw's avatar
Robert Bradshaw committed
2634
class TransformBuiltinMethods(EnvTransform):
2635 2636 2637
    """
    Replace Cython's own cython.* builtins by the corresponding tree nodes.
    """
Robert Bradshaw's avatar
Robert Bradshaw committed
2638

2639 2640 2641 2642 2643 2644
    def visit_SingleAssignmentNode(self, node):
        if node.declaration_only:
            return None
        else:
            self.visitchildren(node)
            return node
2645

2646
    def visit_AttributeNode(self, node):
2647
        self.visitchildren(node)
2648 2649 2650 2651
        return self.visit_cython_attribute(node)

    def visit_NameNode(self, node):
        return self.visit_cython_attribute(node)
2652

2653 2654
    def visit_cython_attribute(self, node):
        attribute = node.as_cython_attribute()
2655 2656
        if attribute:
            if attribute == u'compiled':
2657
                node = ExprNodes.BoolNode(node.pos, value=True)
Stefan Behnel's avatar
Stefan Behnel committed
2658
            elif attribute == u'__version__':
2659 2660
                from .. import __version__ as version
                node = ExprNodes.StringNode(node.pos, value=EncodedString(version))
2661
            elif attribute == u'NULL':
2662
                node = ExprNodes.NullNode(node.pos)
2663
            elif attribute in (u'set', u'frozenset', u'staticmethod'):
2664 2665
                node = ExprNodes.NameNode(node.pos, name=EncodedString(attribute),
                                          entry=self.current_env().builtin_scope().lookup_here(attribute))
2666 2667
            elif PyrexTypes.parse_basic_type(attribute):
                pass
2668
            elif self.context.cython_scope.lookup_qualified_name(attribute):
2669 2670
                pass
            else:
Robert Bradshaw's avatar
Robert Bradshaw committed
2671
                error(node.pos, u"'%s' not a valid cython attribute or is being used incorrectly" % attribute)
2672 2673
        return node

Vitja Makarov's avatar
Vitja Makarov committed
2674 2675 2676 2677 2678 2679 2680 2681 2682 2683 2684
    def visit_ExecStatNode(self, node):
        lenv = self.current_env()
        self.visitchildren(node)
        if len(node.args) == 1:
            node.args.append(ExprNodes.GlobalsExprNode(node.pos))
            if not lenv.is_module_scope:
                node.args.append(
                    ExprNodes.LocalsExprNode(
                        node.pos, self.current_scope_node(), lenv))
        return node

2685
    def _inject_locals(self, node, func_name):
2686
        # locals()/dir()/vars() builtins
2687 2688 2689 2690 2691 2692
        lenv = self.current_env()
        entry = lenv.lookup_here(func_name)
        if entry:
            # not the builtin
            return node
        pos = node.pos
2693 2694
        if func_name in ('locals', 'vars'):
            if func_name == 'locals' and len(node.args) > 0:
2695 2696 2697
                error(self.pos, "Builtin 'locals()' called with wrong number of args, expected 0, got %d"
                      % len(node.args))
                return node
2698 2699 2700 2701 2702 2703
            elif func_name == 'vars':
                if len(node.args) > 1:
                    error(self.pos, "Builtin 'vars()' called with wrong number of args, expected 0-1, got %d"
                          % len(node.args))
                if len(node.args) > 0:
                    return node # nothing to do
2704
            return ExprNodes.LocalsExprNode(pos, self.current_scope_node(), lenv)
2705
        else: # dir()
2706 2707 2708
            if len(node.args) > 1:
                error(self.pos, "Builtin 'dir()' called with wrong number of args, expected 0-1, got %d"
                      % len(node.args))
2709
            if len(node.args) > 0:
2710 2711
                # optimised in Builtin.py
                return node
2712 2713 2714 2715 2716 2717
            if lenv.is_py_class_scope or lenv.is_module_scope:
                if lenv.is_py_class_scope:
                    pyclass = self.current_scope_node()
                    locals_dict = ExprNodes.CloneNode(pyclass.dict)
                else:
                    locals_dict = ExprNodes.GlobalsExprNode(pos)
2718
                return ExprNodes.SortedDictKeysNode(locals_dict)
2719 2720 2721
            local_names = sorted(var.name for var in lenv.entries.values() if var.name)
            items = [ExprNodes.IdentifierStringNode(pos, value=var)
                     for var in local_names]
2722
            return ExprNodes.ListNode(pos, args=items)
2723

2724 2725 2726 2727 2728 2729 2730 2731 2732 2733 2734 2735 2736 2737
    def visit_PrimaryCmpNode(self, node):
        # special case: for in/not-in test, we do not need to sort locals()
        self.visitchildren(node)
        if node.operator in 'not_in':  # in/not_in
            if isinstance(node.operand2, ExprNodes.SortedDictKeysNode):
                arg = node.operand2.arg
                if isinstance(arg, ExprNodes.NoneCheckNode):
                    arg = arg.arg
                node.operand2 = arg
        return node

    def visit_CascadedCmpNode(self, node):
        return self.visit_PrimaryCmpNode(node)

2738 2739 2740 2741 2742 2743 2744 2745 2746 2747 2748 2749
    def _inject_eval(self, node, func_name):
        lenv = self.current_env()
        entry = lenv.lookup_here(func_name)
        if entry or len(node.args) != 1:
            return node
        # Inject globals and locals
        node.args.append(ExprNodes.GlobalsExprNode(node.pos))
        if not lenv.is_module_scope:
            node.args.append(
                ExprNodes.LocalsExprNode(
                    node.pos, self.current_scope_node(), lenv))
        return node
2750

2751 2752 2753 2754 2755 2756 2757
    def _inject_super(self, node, func_name):
        lenv = self.current_env()
        entry = lenv.lookup_here(func_name)
        if entry or node.args:
            return node
        # Inject no-args super
        def_node = self.current_scope_node()
2758
        if (not isinstance(def_node, Nodes.DefNode) or not def_node.args or
2759 2760 2761
            len(self.env_stack) < 2):
            return node
        class_node, class_scope = self.env_stack[-2]
2762
        if class_scope.is_py_class_scope:
2763 2764 2765 2766 2767 2768 2769
            def_node.requires_classobj = True
            class_node.class_cell.is_active = True
            node.args = [
                ExprNodes.ClassCellNode(
                    node.pos, is_generator=def_node.is_generator),
                ExprNodes.NameNode(node.pos, name=def_node.args[0].name)
                ]
2770 2771 2772 2773 2774 2775 2776
        elif class_scope.is_c_class_scope:
            node.args = [
                ExprNodes.NameNode(
                    node.pos, name=class_node.scope.name,
                    entry=class_node.entry),
                ExprNodes.NameNode(node.pos, name=def_node.args[0].name)
                ]
2777 2778
        return node

2779
    def visit_SimpleCallNode(self, node):
2780
        # cython.foo
2781
        function = node.function.as_cython_attribute()
2782
        if function:
2783 2784 2785 2786
            if function in InterpretCompilerDirectives.unop_method_nodes:
                if len(node.args) != 1:
                    error(node.function.pos, u"%s() takes exactly one argument" % function)
                else:
Stefan Behnel's avatar
Stefan Behnel committed
2787 2788
                    node = InterpretCompilerDirectives.unop_method_nodes[function](
                        node.function.pos, operand=node.args[0])
Robert Bradshaw's avatar
Robert Bradshaw committed
2789 2790 2791 2792
            elif function in InterpretCompilerDirectives.binop_method_nodes:
                if len(node.args) != 2:
                    error(node.function.pos, u"%s() takes exactly two arguments" % function)
                else:
Stefan Behnel's avatar
Stefan Behnel committed
2793 2794
                    node = InterpretCompilerDirectives.binop_method_nodes[function](
                        node.function.pos, operand1=node.args[0], operand2=node.args[1])
2795
            elif function == u'cast':
2796
                if len(node.args) != 2:
memeplex's avatar
memeplex committed
2797 2798
                    error(node.function.pos,
                          u"cast() takes exactly two arguments and an optional typecheck keyword")
2799
                else:
Stefan Behnel's avatar
Stefan Behnel committed
2800
                    type = node.args[0].analyse_as_type(self.current_env())
2801
                    if type:
2802 2803
                        node = ExprNodes.TypecastNode(
                            node.function.pos, type=type, operand=node.args[1], typecheck=False)
2804 2805 2806 2807
                    else:
                        error(node.args[0].pos, "Not a type")
            elif function == u'sizeof':
                if len(node.args) != 1:
Robert Bradshaw's avatar
Robert Bradshaw committed
2808
                    error(node.function.pos, u"sizeof() takes exactly one argument")
2809
                else:
Stefan Behnel's avatar
Stefan Behnel committed
2810
                    type = node.args[0].analyse_as_type(self.current_env())
2811
                    if type:
2812
                        node = ExprNodes.SizeofTypeNode(node.function.pos, arg_type=type)
2813
                    else:
2814
                        node = ExprNodes.SizeofVarNode(node.function.pos, operand=node.args[0])
2815 2816
            elif function == 'cmod':
                if len(node.args) != 2:
Robert Bradshaw's avatar
Robert Bradshaw committed
2817
                    error(node.function.pos, u"cmod() takes exactly two arguments")
2818
                else:
2819
                    node = ExprNodes.binop_node(node.function.pos, '%', node.args[0], node.args[1])
2820 2821 2822
                    node.cdivision = True
            elif function == 'cdiv':
                if len(node.args) != 2:
Robert Bradshaw's avatar
Robert Bradshaw committed
2823
                    error(node.function.pos, u"cdiv() takes exactly two arguments")
2824
                else:
2825
                    node = ExprNodes.binop_node(node.function.pos, '/', node.args[0], node.args[1])
2826
                    node.cdivision = True
2827
            elif function == u'set':
2828
                node.function = ExprNodes.NameNode(node.pos, name=EncodedString('set'))
Robert Bradshaw's avatar
Robert Bradshaw committed
2829 2830
            elif function == u'staticmethod':
                node.function = ExprNodes.NameNode(node.pos, name=EncodedString('staticmethod'))
2831 2832
            elif self.context.cython_scope.lookup_qualified_name(function):
                pass
2833
            else:
2834 2835
                error(node.function.pos,
                      u"'%s' not a valid cython language construct" % function)
2836

2837
        self.visitchildren(node)
2838 2839 2840 2841 2842

        if isinstance(node, ExprNodes.SimpleCallNode) and node.function.is_name:
            func_name = node.function.name
            if func_name in ('dir', 'locals', 'vars'):
                return self._inject_locals(node, func_name)
2843 2844
            if func_name == 'eval':
                return self._inject_eval(node, func_name)
2845 2846
            if func_name == 'super':
                return self._inject_super(node, func_name)
Robert Bradshaw's avatar
Robert Bradshaw committed
2847
        return node
2848

memeplex's avatar
memeplex committed
2849 2850 2851 2852 2853 2854 2855 2856 2857 2858 2859 2860 2861 2862 2863 2864 2865 2866
    def visit_GeneralCallNode(self, node):
        function = node.function.as_cython_attribute()
        if function:
            args = node.positional_args.args
            kwargs = node.keyword_args.compile_time_value(None)
            if function == u'cast':
                if (len(args) != 2 or len(kwargs) > 1 or
                        (len(kwargs) == 1 and 'typecheck' not in kwargs)):
                    error(node.function.pos,
                          u"cast() takes exactly two arguments and an optional typecheck keyword")
                else:
                    type = args[0].analyse_as_type(self.current_env())
                    if type:
                        typecheck = kwargs.get('typecheck', False)
                        node = ExprNodes.TypecastNode(
                            node.function.pos, type=type, operand=args[1], typecheck=typecheck)
                    else:
                        error(args[0].pos, "Not a type")
2867 2868

        self.visitchildren(node)
memeplex's avatar
memeplex committed
2869 2870
        return node

2871

2872 2873 2874 2875 2876 2877 2878 2879 2880 2881 2882 2883 2884 2885
class ReplaceFusedTypeChecks(VisitorTransform):
    """
    This is not a transform in the pipeline. It is invoked on the specific
    versions of a cdef function with fused argument types. It filters out any
    type branches that don't match. e.g.

        if fused_t is mytype:
            ...
        elif fused_t in other_fused_type:
            ...
    """
    def __init__(self, local_scope):
        super(ReplaceFusedTypeChecks, self).__init__()
        self.local_scope = local_scope
Stefan Behnel's avatar
Stefan Behnel committed
2886
        # defer the import until now to avoid circular import time dependencies
2887 2888
        from .Optimize import ConstantFolding
        self.transform = ConstantFolding(reevaluate=True)
2889 2890

    def visit_IfStatNode(self, node):
2891 2892 2893 2894
        """
        Filters out any if clauses with false compile time type check
        expression.
        """
2895
        self.visitchildren(node)
2896
        return self.transform(node)
2897

2898 2899 2900 2901 2902
    def visit_PrimaryCmpNode(self, node):
        type1 = node.operand1.analyse_as_type(self.local_scope)
        type2 = node.operand2.analyse_as_type(self.local_scope)

        if type1 and type2:
Mark Florisson's avatar
Mark Florisson committed
2903 2904
            false_node = ExprNodes.BoolNode(node.pos, value=False)
            true_node = ExprNodes.BoolNode(node.pos, value=True)
2905 2906 2907 2908

            type1 = self.specialize_type(type1, node.operand1.pos)
            op = node.operator

2909
            if op in ('is', 'is_not', '==', '!='):
2910 2911 2912 2913 2914 2915
                type2 = self.specialize_type(type2, node.operand2.pos)

                is_same = type1.same_as(type2)
                eq = op in ('is', '==')

                if (is_same and eq) or (not is_same and not eq):
Mark Florisson's avatar
Mark Florisson committed
2916
                    return true_node
2917 2918 2919 2920 2921 2922 2923 2924 2925 2926 2927 2928 2929 2930

            elif op in ('in', 'not_in'):
                # We have to do an instance check directly, as operand2
                # needs to be a fused type and not a type with a subtype
                # that is fused. First unpack the typedef
                if isinstance(type2, PyrexTypes.CTypedefType):
                    type2 = type2.typedef_base_type

                if type1.is_fused:
                    error(node.operand1.pos, "Type is fused")
                elif not type2.is_fused:
                    error(node.operand2.pos,
                          "Can only use 'in' or 'not in' on a fused type")
                else:
2931
                    types = PyrexTypes.get_specialized_types(type2)
2932

2933 2934
                    for specialized_type in types:
                        if type1.same_as(specialized_type):
2935
                            if op == 'in':
Mark Florisson's avatar
Mark Florisson committed
2936
                                return true_node
2937
                            else:
Mark Florisson's avatar
Mark Florisson committed
2938
                                return false_node
2939 2940

                    if op == 'not_in':
Mark Florisson's avatar
Mark Florisson committed
2941
                        return true_node
2942

Mark Florisson's avatar
Mark Florisson committed
2943
            return false_node
2944 2945 2946 2947 2948 2949 2950 2951 2952 2953 2954 2955 2956 2957 2958

        return node

    def specialize_type(self, type, pos):
        try:
            return type.specialize(self.local_scope.fused_to_specific)
        except KeyError:
            error(pos, "Type is not specific")
            return type

    def visit_Node(self, node):
        self.visitchildren(node)
        return node


Mark Florisson's avatar
Mark Florisson committed
2959
class DebugTransform(CythonTransform):
2960
    """
Mark Florisson's avatar
Mark Florisson committed
2961
    Write debug information for this Cython module.
2962
    """
2963

2964
    def __init__(self, context, options, result):
Mark Florisson's avatar
Mark Florisson committed
2965
        super(DebugTransform, self).__init__(context)
Robert Bradshaw's avatar
Robert Bradshaw committed
2966
        self.visited = set()
2967
        # our treebuilder and debug output writer
Mark Florisson's avatar
Mark Florisson committed
2968
        # (see Cython.Debugger.debug_output.CythonDebugWriter)
2969
        self.tb = self.context.gdb_debug_outputwriter
2970
        #self.c_output_file = options.output_file
2971
        self.c_output_file = result.c_file
2972

2973 2974 2975
        # Closure support, basically treat nested functions as if the AST were
        # never nested
        self.nested_funcdefs = []
2976

Mark Florisson's avatar
Mark Florisson committed
2977 2978
        # tells visit_NameNode whether it should register step-into functions
        self.register_stepinto = False
2979

2980
    def visit_ModuleNode(self, node):
Mark Florisson's avatar
Mark Florisson committed
2981
        self.tb.module_name = node.full_module_name
2982
        attrs = dict(
Mark Florisson's avatar
Mark Florisson committed
2983
            module_name=node.full_module_name,
Mark Florisson's avatar
Mark Florisson committed
2984 2985
            filename=node.pos[0].filename,
            c_filename=self.c_output_file)
2986

2987
        self.tb.start('Module', attrs)
2988

2989
        # serialize functions
Mark Florisson's avatar
Mark Florisson committed
2990
        self.tb.start('Functions')
2991
        # First, serialize functions normally...
2992
        self.visitchildren(node)
2993

2994 2995 2996
        # ... then, serialize nested functions
        for nested_funcdef in self.nested_funcdefs:
            self.visit_FuncDefNode(nested_funcdef)
2997

2998 2999 3000
        self.register_stepinto = True
        self.serialize_modulenode_as_function(node)
        self.register_stepinto = False
3001
        self.tb.end('Functions')
3002

3003
        # 2.3 compatibility. Serialize global variables
Mark Florisson's avatar
Mark Florisson committed
3004
        self.tb.start('Globals')
3005
        entries = {}
Mark Florisson's avatar
Mark Florisson committed
3006

3007
        for k, v in node.scope.entries.items():
Mark Florisson's avatar
Mark Florisson committed
3008
            if (v.qualified_name not in self.visited and not
3009 3010 3011
                    v.name.startswith('__pyx_') and not
                    v.type.is_cfunction and not
                    v.type.is_extension_type):
3012
                entries[k]= v
3013

3014 3015
        self.serialize_local_variables(entries)
        self.tb.end('Globals')
Mark Florisson's avatar
Mark Florisson committed
3016 3017
        # self.tb.end('Module') # end Module after the line number mapping in
        # Cython.Compiler.ModuleNode.ModuleNode._serialize_lineno_map
3018
        return node
3019 3020

    def visit_FuncDefNode(self, node):
3021
        self.visited.add(node.local_scope.qualified_name)
3022 3023 3024 3025 3026 3027 3028 3029

        if getattr(node, 'is_wrapper', False):
            return node

        if self.register_stepinto:
            self.nested_funcdefs.append(node)
            return node

3030
        # node.entry.visibility = 'extern'
3031 3032 3033 3034
        if node.py_func is None:
            pf_cname = ''
        else:
            pf_cname = node.py_func.entry.func_cname
3035

3036
        attrs = dict(
3037
            name=node.entry.name or getattr(node, 'name', '<unknown>'),
3038 3039 3040 3041
            cname=node.entry.func_cname,
            pf_cname=pf_cname,
            qualified_name=node.local_scope.qualified_name,
            lineno=str(node.pos[1]))
3042

3043
        self.tb.start('Function', attrs=attrs)
3044

Mark Florisson's avatar
Mark Florisson committed
3045
        self.tb.start('Locals')
3046 3047
        self.serialize_local_variables(node.local_scope.entries)
        self.tb.end('Locals')
Mark Florisson's avatar
Mark Florisson committed
3048 3049

        self.tb.start('Arguments')
3050
        for arg in node.local_scope.arg_entries:
Mark Florisson's avatar
Mark Florisson committed
3051 3052
            self.tb.start(arg.name)
            self.tb.end(arg.name)
3053
        self.tb.end('Arguments')
Mark Florisson's avatar
Mark Florisson committed
3054 3055

        self.tb.start('StepIntoFunctions')
Mark Florisson's avatar
Mark Florisson committed
3056
        self.register_stepinto = True
Mark Florisson's avatar
Mark Florisson committed
3057
        self.visitchildren(node)
Mark Florisson's avatar
Mark Florisson committed
3058
        self.register_stepinto = False
Mark Florisson's avatar
Mark Florisson committed
3059
        self.tb.end('StepIntoFunctions')
3060
        self.tb.end('Function')
Mark Florisson's avatar
Mark Florisson committed
3061 3062 3063 3064

        return node

    def visit_NameNode(self, node):
3065
        if (self.register_stepinto and
3066
            node.type is not None and
3067
            node.type.is_cfunction and
3068 3069
            getattr(node, 'is_called', False) and
            node.entry.func_cname is not None):
3070 3071 3072 3073
            # don't check node.entry.in_cinclude, as 'cdef extern: ...'
            # declared functions are not 'in_cinclude'.
            # This means we will list called 'cdef' functions as
            # "step into functions", but this is not an issue as they will be
Mark Florisson's avatar
Mark Florisson committed
3074
            # recognized as Cython functions anyway.
Mark Florisson's avatar
Mark Florisson committed
3075 3076 3077
            attrs = dict(name=node.entry.func_cname)
            self.tb.start('StepIntoFunction', attrs=attrs)
            self.tb.end('StepIntoFunction')
3078

Mark Florisson's avatar
Mark Florisson committed
3079
        self.visitchildren(node)
3080
        return node
3081

3082 3083 3084 3085 3086 3087 3088
    def serialize_modulenode_as_function(self, node):
        """
        Serialize the module-level code as a function so the debugger will know
        it's a "relevant frame" and it will know where to set the breakpoint
        for 'break modulename'.
        """
        name = node.full_module_name.rpartition('.')[-1]
3089

3090 3091
        cname_py2 = 'init' + name
        cname_py3 = 'PyInit_' + name
3092

3093 3094 3095 3096
        py2_attrs = dict(
            name=name,
            cname=cname_py2,
            pf_cname='',
3097
            # Ignore the qualified_name, breakpoints should be set using
3098 3099 3100 3101 3102
            # `cy break modulename:lineno` for module-level breakpoints.
            qualified_name='',
            lineno='1',
            is_initmodule_function="True",
        )
3103

3104
        py3_attrs = dict(py2_attrs, cname=cname_py3)
3105

3106 3107
        self._serialize_modulenode_as_function(node, py2_attrs)
        self._serialize_modulenode_as_function(node, py3_attrs)
3108

3109 3110
    def _serialize_modulenode_as_function(self, node, attrs):
        self.tb.start('Function', attrs=attrs)
3111

3112 3113 3114 3115 3116 3117 3118 3119 3120 3121 3122 3123
        self.tb.start('Locals')
        self.serialize_local_variables(node.scope.entries)
        self.tb.end('Locals')

        self.tb.start('Arguments')
        self.tb.end('Arguments')

        self.tb.start('StepIntoFunctions')
        self.register_stepinto = True
        self.visitchildren(node)
        self.register_stepinto = False
        self.tb.end('StepIntoFunctions')
3124

3125
        self.tb.end('Function')
3126

3127 3128
    def serialize_local_variables(self, entries):
        for entry in entries.values():
3129 3130 3131
            if not entry.cname:
                # not a local variable
                continue
3132
            if entry.type.is_pyobject:
Mark Florisson's avatar
Mark Florisson committed
3133
                vartype = 'PythonObject'
3134 3135
            else:
                vartype = 'CObject'
3136

3137 3138 3139
            if entry.from_closure:
                # We're dealing with a closure where a variable from an outer
                # scope is accessed, get it from the scope object.
3140
                cname = '%s->%s' % (Naming.cur_scope_cname,
3141
                                    entry.outer_entry.cname)
3142

3143
                qname = '%s.%s.%s' % (entry.scope.outer_scope.qualified_name,
3144
                                      entry.scope.name,
3145
                                      entry.name)
3146
            elif entry.in_closure:
3147
                cname = '%s->%s' % (Naming.cur_scope_cname,
3148 3149
                                    entry.cname)
                qname = entry.qualified_name
3150 3151 3152
            else:
                cname = entry.cname
                qname = entry.qualified_name
3153

3154 3155 3156 3157 3158 3159 3160
            if not entry.pos:
                # this happens for variables that are not in the user's code,
                # e.g. for the global __builtins__, __doc__, etc. We can just
                # set the lineno to 0 for those.
                lineno = '0'
            else:
                lineno = str(entry.pos[1])
3161

3162 3163 3164
            attrs = dict(
                name=entry.name,
                cname=cname,
3165
                qualified_name=qname,
3166 3167
                type=vartype,
                lineno=lineno)
3168

Mark Florisson's avatar
Mark Florisson committed
3169 3170
            self.tb.start('LocalVar', attrs)
            self.tb.end('LocalVar')