ParseTreeTransforms.py 136 KB
Newer Older
1 2
from __future__ import absolute_import

3
import cython
4
cython.declare(PyrexTypes=object, Naming=object, ExprNodes=object, Nodes=object,
5 6
               Options=object, UtilNodes=object, LetNode=object,
               LetRefNode=object, TreeFragment=object, EncodedString=object,
7
               error=object, warning=object, copy=object, _unicode=object)
8

9
import copy
10
import hashlib
11

12 13 14 15 16 17
from . import PyrexTypes
from . import Naming
from . import ExprNodes
from . import Nodes
from . import Options
from . import Builtin
18
from . import Errors
19

20 21
from .Visitor import VisitorTransform, TreeVisitor
from .Visitor import CythonTransform, EnvTransform, ScopeTrackingTransform
22
from .UtilNodes import LetNode, LetRefNode
23
from .TreeFragment import TreeFragment
24
from .StringEncoding import EncodedString, _unicode
25 26
from .Errors import error, warning, CompileError, InternalError
from .Code import UtilityCode
27

28

29
class SkipDeclarations(object):
30
    """
31 32 33 34 35
    Variable and function declarations can often have a deep tree structure,
    and yet most transformations don't need to descend to this depth.

    Declaration nodes are removed after AnalyseDeclarationsTransform, so there
    is no need to use this for transformations after that point.
36 37 38
    """
    def visit_CTypeDefNode(self, node):
        return node
39

40 41
    def visit_CVarDefNode(self, node):
        return node
42

43 44
    def visit_CDeclaratorNode(self, node):
        return node
45

46 47
    def visit_CBaseTypeNode(self, node):
        return node
48

49 50 51 52 53 54
    def visit_CEnumDefNode(self, node):
        return node

    def visit_CStructOrUnionDefNode(self, node):
        return node

Stefan Behnel's avatar
Stefan Behnel committed
55

56
class NormalizeTree(CythonTransform):
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
    """
    This transform fixes up a few things after parsing
    in order to make the parse tree more suitable for
    transforms.

    a) After parsing, blocks with only one statement will
    be represented by that statement, not by a StatListNode.
    When doing transforms this is annoying and inconsistent,
    as one cannot in general remove a statement in a consistent
    way and so on. This transform wraps any single statements
    in a StatListNode containing a single statement.

    b) The PassStatNode is a noop and serves no purpose beyond
    plugging such one-statement blocks; i.e., once parsed a
`    "pass" can just as well be represented using an empty
    StatListNode. This means less special cases to worry about
    in subsequent transforms (one always checks to see if a
    StatListNode has no children to see if the block is empty).
    """

77 78
    def __init__(self, context):
        super(NormalizeTree, self).__init__(context)
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
        self.is_in_statlist = False
        self.is_in_expr = False

    def visit_ExprNode(self, node):
        stacktmp = self.is_in_expr
        self.is_in_expr = True
        self.visitchildren(node)
        self.is_in_expr = stacktmp
        return node

    def visit_StatNode(self, node, is_listcontainer=False):
        stacktmp = self.is_in_statlist
        self.is_in_statlist = is_listcontainer
        self.visitchildren(node)
        self.is_in_statlist = stacktmp
        if not self.is_in_statlist and not self.is_in_expr:
95
            return Nodes.StatListNode(pos=node.pos, stats=[node])
96 97 98 99 100 101 102 103 104 105 106
        else:
            return node

    def visit_StatListNode(self, node):
        self.is_in_statlist = True
        self.visitchildren(node)
        self.is_in_statlist = False
        return node

    def visit_ParallelAssignmentNode(self, node):
        return self.visit_StatNode(node, True)
107

108 109 110 111 112 113
    def visit_CEnumDefNode(self, node):
        return self.visit_StatNode(node, True)

    def visit_CStructOrUnionDefNode(self, node):
        return self.visit_StatNode(node, True)

114
    def visit_PassStatNode(self, node):
115
        """Eliminate PassStatNode"""
116
        if not self.is_in_statlist:
117
            return Nodes.StatListNode(pos=node.pos, stats=[])
118 119 120
        else:
            return []

121 122 123
    def visit_ExprStatNode(self, node):
        """Eliminate useless string literals"""
        if node.expr.is_string_literal:
Stefan Behnel's avatar
Stefan Behnel committed
124 125 126
            return self.visit_PassStatNode(node)
        else:
            return self.visit_StatNode(node)
127

128
    def visit_CDeclaratorNode(self, node):
129
        return node
130

131

132 133 134
class PostParseError(CompileError): pass

# error strings checked by unit tests, so define them
135
ERR_CDEF_INCLASS = 'Cannot assign default value to fields in cdef classes, structs or unions'
136 137
ERR_BUF_DEFAULTS = 'Invalid buffer defaults specification (see docs)'
ERR_INVALID_SPECIALATTR_TYPE = 'Special attributes must not have a type declared'
138
class PostParse(ScopeTrackingTransform):
139 140 141 142 143 144 145
    """
    Basic interpretation of the parse tree, as well as validity
    checking that can be done on a very basic level on the parse
    tree (while still not being a problem with the basic syntax,
    as such).

    Specifically:
146
    - Default values to cdef assignments are turned into single
147 148
    assignments following the declaration (everywhere but in class
    bodies, where they raise a compile error)
149

150 151
    - Interpret some node structures into Python runtime values.
    Some nodes take compile-time arguments (currently:
152
    TemplatedTypeNode[args] and __cythonbufferdefaults__ = {args}),
153 154 155 156 157 158 159 160
    which should be interpreted. This happens in a general way
    and other steps should be taken to ensure validity.

    Type arguments cannot be interpreted in this way.

    - For __cythonbufferdefaults__ the arguments are checked for
    validity.

Robert Bradshaw's avatar
Robert Bradshaw committed
161
    TemplatedTypeNode has its directives interpreted:
162 163
    Any first positional argument goes into the "dtype" attribute,
    any "ndim" keyword argument goes into the "ndim" attribute and
164
    so on. Also it is checked that the directive combination is valid.
165 166
    - __cythonbufferdefaults__ attributes are parsed and put into the
    type information.
167 168 169 170 171 172

    Note: Currently Parsing.py does a lot of interpretation and
    reorganization that can be refactored into this transform
    if a more pure Abstract Syntax Tree is wanted.
    """

173 174 175 176 177 178
    def __init__(self, context):
        super(PostParse, self).__init__(context)
        self.specialattribute_handlers = {
            '__cythonbufferdefaults__' : self.handle_bufferdefaults
        }

Stefan Behnel's avatar
Stefan Behnel committed
179 180
    def visit_LambdaNode(self, node):
        # unpack a lambda expression into the corresponding DefNode
Vitja Makarov's avatar
Vitja Makarov committed
181 182
        collector = YieldNodeCollector()
        collector.visitchildren(node.result_expr)
183
        if collector.has_yield or collector.has_await or isinstance(node.result_expr, ExprNodes.YieldExprNode):
Vitja Makarov's avatar
Vitja Makarov committed
184 185
            body = Nodes.ExprStatNode(
                node.result_expr.pos, expr=node.result_expr)
Vitja Makarov's avatar
Vitja Makarov committed
186 187 188
        else:
            body = Nodes.ReturnStatNode(
                node.result_expr.pos, value=node.result_expr)
Stefan Behnel's avatar
Stefan Behnel committed
189
        node.def_node = Nodes.DefNode(
190
            node.pos, name=node.name,
Stefan Behnel's avatar
Stefan Behnel committed
191 192
            args=node.args, star_arg=node.star_arg,
            starstar_arg=node.starstar_arg,
Vitja Makarov's avatar
Vitja Makarov committed
193
            body=body, doc=None)
Stefan Behnel's avatar
Stefan Behnel committed
194 195
        self.visitchildren(node)
        return node
196 197 198

    def visit_GeneratorExpressionNode(self, node):
        # unpack a generator expression into the corresponding DefNode
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
        collector = YieldNodeCollector()
        collector.visitchildren(node.loop)
        node.def_node = Nodes.DefNode(
            node.pos, name=node.name, doc=None,
            args=[], star_arg=None, starstar_arg=None,
            body=node.loop, is_async_def=collector.has_await)
        self.visitchildren(node)
        return node

    def visit_ComprehensionNode(self, node):
        # enforce local scope also in Py2 for async generators (seriously, that's a Py3.6 feature...)
        if not node.has_local_scope:
            collector = YieldNodeCollector()
            collector.visitchildren(node.loop)
            if collector.has_await:
                node.has_local_scope = True
Stefan Behnel's avatar
Stefan Behnel committed
215 216 217
        self.visitchildren(node)
        return node

218
    # cdef variables
219
    def handle_bufferdefaults(self, decl):
220
        if not isinstance(decl.default, ExprNodes.DictNode):
221
            raise PostParseError(decl.pos, ERR_BUF_DEFAULTS)
222 223
        self.scope_node.buffer_defaults_node = decl.default
        self.scope_node.buffer_defaults_pos = decl.pos
224

225 226
    def visit_CVarDefNode(self, node):
        # This assumes only plain names and pointers are assignable on
227 228 229
        # declaration. Also, it makes use of the fact that a cdef decl
        # must appear before the first use, so we don't have to deal with
        # "i = 3; cdef int i = i" and can simply move the nodes around.
230 231
        try:
            self.visitchildren(node)
232 233 234 235
            stats = [node]
            newdecls = []
            for decl in node.declarators:
                declbase = decl
236
                while isinstance(declbase, Nodes.CPtrDeclaratorNode):
237
                    declbase = declbase.base
238
                if isinstance(declbase, Nodes.CNameDeclaratorNode):
239
                    if declbase.default is not None:
240
                        if self.scope_type in ('cclass', 'pyclass', 'struct'):
241
                            if isinstance(self.scope_node, Nodes.CClassDefNode):
242 243 244 245 246 247 248
                                handler = self.specialattribute_handlers.get(decl.name)
                                if handler:
                                    if decl is not declbase:
                                        raise PostParseError(decl.pos, ERR_INVALID_SPECIALATTR_TYPE)
                                    handler(decl)
                                    continue # Remove declaration
                            raise PostParseError(decl.pos, ERR_CDEF_INCLASS)
249
                        first_assignment = self.scope_type != 'module'
250 251
                        stats.append(Nodes.SingleAssignmentNode(node.pos,
                            lhs=ExprNodes.NameNode(node.pos, name=declbase.name),
252
                            rhs=declbase.default, first=first_assignment))
253 254 255 256
                        declbase.default = None
                newdecls.append(decl)
            node.declarators = newdecls
            return stats
257
        except PostParseError as e:
258 259 260 261
            # An error in a cdef clause is ok, simply remove the declaration
            # and try to move on to report more errors
            self.context.nonfatal_error(e)
            return None
262

Stefan Behnel's avatar
Stefan Behnel committed
263 264
    # Split parallel assignments (a,b = b,a) into separate partial
    # assignments that are executed rhs-first using temps.  This
Stefan Behnel's avatar
Stefan Behnel committed
265 266 267 268
    # restructuring must be applied before type analysis so that known
    # types on rhs and lhs can be matched directly.  It is required in
    # the case that the types cannot be coerced to a Python type in
    # order to assign from a tuple.
269 270 271 272 273 274 275 276 277 278

    def visit_SingleAssignmentNode(self, node):
        self.visitchildren(node)
        return self._visit_assignment_node(node, [node.lhs, node.rhs])

    def visit_CascadedAssignmentNode(self, node):
        self.visitchildren(node)
        return self._visit_assignment_node(node, node.lhs_list + [node.rhs])

    def _visit_assignment_node(self, node, expr_list):
279 280 281
        """Flatten parallel assignments into separate single
        assignments or cascaded assignments.
        """
282 283
        if sum([ 1 for expr in expr_list
                 if expr.is_sequence_constructor or expr.is_string_literal ]) < 2:
284 285 286
            # no parallel assignments => nothing to do
            return node

287 288
        expr_list_list = []
        flatten_parallel_assignments(expr_list, expr_list_list)
289 290 291
        temp_refs = []
        eliminate_rhs_duplicates(expr_list_list, temp_refs)

292 293 294 295 296
        nodes = []
        for expr_list in expr_list_list:
            lhs_list = expr_list[:-1]
            rhs = expr_list[-1]
            if len(lhs_list) == 1:
297
                node = Nodes.SingleAssignmentNode(rhs.pos,
298 299 300 301 302
                    lhs = lhs_list[0], rhs = rhs)
            else:
                node = Nodes.CascadedAssignmentNode(rhs.pos,
                    lhs_list = lhs_list, rhs = rhs)
            nodes.append(node)
303

304
        if len(nodes) == 1:
305 306 307 308 309 310 311 312 313 314 315 316 317
            assign_node = nodes[0]
        else:
            assign_node = Nodes.ParallelAssignmentNode(nodes[0].pos, stats = nodes)

        if temp_refs:
            duplicates_and_temps = [ (temp.expression, temp)
                                     for temp in temp_refs ]
            sort_common_subsequences(duplicates_and_temps)
            for _, temp_ref in duplicates_and_temps[::-1]:
                assign_node = LetNode(temp_ref, assign_node)

        return assign_node

318 319 320 321 322 323 324 325 326 327 328 329 330
    def _flatten_sequence(self, seq, result):
        for arg in seq.args:
            if arg.is_sequence_constructor:
                self._flatten_sequence(arg, result)
            else:
                result.append(arg)
        return result

    def visit_DelStatNode(self, node):
        self.visitchildren(node)
        node.args = self._flatten_sequence(node, [])
        return node

331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349
    def visit_ExceptClauseNode(self, node):
        if node.is_except_as:
            # except-as must delete NameNode target at the end
            del_target = Nodes.DelStatNode(
                node.pos,
                args=[ExprNodes.NameNode(
                    node.target.pos, name=node.target.name)],
                ignore_nonexisting=True)
            node.body = Nodes.StatListNode(
                node.pos,
                stats=[Nodes.TryFinallyStatNode(
                    node.pos,
                    body=node.body,
                    finally_clause=Nodes.StatListNode(
                        node.pos,
                        stats=[del_target]))])
        self.visitchildren(node)
        return node

350

351 352 353 354 355 356
def eliminate_rhs_duplicates(expr_list_list, ref_node_sequence):
    """Replace rhs items by LetRefNodes if they appear more than once.
    Creates a sequence of LetRefNodes that set up the required temps
    and appends them to ref_node_sequence.  The input list is modified
    in-place.
    """
Robert Bradshaw's avatar
Robert Bradshaw committed
357
    seen_nodes = set()
358 359 360 361 362 363 364 365 366 367 368
    ref_nodes = {}
    def find_duplicates(node):
        if node.is_literal or node.is_name:
            # no need to replace those; can't include attributes here
            # as their access is not necessarily side-effect free
            return
        if node in seen_nodes:
            if node not in ref_nodes:
                ref_node = LetRefNode(node)
                ref_nodes[node] = ref_node
                ref_node_sequence.append(ref_node)
369
        else:
370 371 372 373 374 375 376 377 378 379 380 381 382 383 384
            seen_nodes.add(node)
            if node.is_sequence_constructor:
                for item in node.args:
                    find_duplicates(item)

    for expr_list in expr_list_list:
        rhs = expr_list[-1]
        find_duplicates(rhs)
    if not ref_nodes:
        return

    def substitute_nodes(node):
        if node in ref_nodes:
            return ref_nodes[node]
        elif node.is_sequence_constructor:
385
            node.args = list(map(substitute_nodes, node.args))
386
        return node
387

388 389 390
    # replace nodes inside of the common subexpressions
    for node in ref_nodes:
        if node.is_sequence_constructor:
391
            node.args = list(map(substitute_nodes, node.args))
392 393 394 395 396 397 398

    # replace common subexpressions on all rhs items
    for expr_list in expr_list_list:
        expr_list[-1] = substitute_nodes(expr_list[-1])

def sort_common_subsequences(items):
    """Sort items/subsequences so that all items and subsequences that
Stefan Behnel's avatar
Stefan Behnel committed
399 400 401 402 403 404 405 406 407
    an item contains appear before the item itself.  This is needed
    because each rhs item must only be evaluated once, so its value
    must be evaluated first and then reused when packing sequences
    that contain it.

    This implies a partial order, and the sort must be stable to
    preserve the original order as much as possible, so we use a
    simple insertion sort (which is very fast for short sequences, the
    normal case in practice).
408 409 410 411 412 413 414 415 416 417 418 419
    """
    def contains(seq, x):
        for item in seq:
            if item is x:
                return True
            elif item.is_sequence_constructor and contains(item.args, x):
                return True
        return False
    def lower_than(a,b):
        return b.is_sequence_constructor and contains(b.args, a)

    for pos, item in enumerate(items):
420
        key = item[1] # the ResultRefNode which has already been injected into the sequences
421
        new_pos = pos
422
        for i in range(pos-1, -1, -1):
423 424 425
            if lower_than(key, items[i][0]):
                new_pos = i
        if new_pos != pos:
426
            for i in range(pos, new_pos, -1):
427 428
                items[i] = items[i-1]
            items[new_pos] = item
429

430 431 432 433 434 435 436 437 438 439 440
def unpack_string_to_character_literals(literal):
    chars = []
    pos = literal.pos
    stype = literal.__class__
    sval = literal.value
    sval_type = sval.__class__
    for char in sval:
        cval = sval_type(char)
        chars.append(stype(pos, value=cval, constant_result=cval))
    return chars

441 442 443 444 445 446 447 448
def flatten_parallel_assignments(input, output):
    #  The input is a list of expression nodes, representing the LHSs
    #  and RHS of one (possibly cascaded) assignment statement.  For
    #  sequence constructors, rearranges the matching parts of both
    #  sides into a list of equivalent assignments between the
    #  individual elements.  This transformation is applied
    #  recursively, so that nested structures get matched as well.
    rhs = input[-1]
449
    if (not (rhs.is_sequence_constructor or isinstance(rhs, ExprNodes.UnicodeNode))
450
        or not sum([lhs.is_sequence_constructor for lhs in input[:-1]])):
451 452 453 454 455
        output.append(input)
        return

    complete_assignments = []

456 457 458 459 460 461
    if rhs.is_sequence_constructor:
        rhs_args = rhs.args
    elif rhs.is_string_literal:
        rhs_args = unpack_string_to_character_literals(rhs)

    rhs_size = len(rhs_args)
462
    lhs_targets = [[] for _ in range(rhs_size)]
463 464 465 466 467 468 469 470 471
    starred_assignments = []
    for lhs in input[:-1]:
        if not lhs.is_sequence_constructor:
            if lhs.is_starred:
                error(lhs.pos, "starred assignment target must be in a list or tuple")
            complete_assignments.append(lhs)
            continue
        lhs_size = len(lhs.args)
        starred_targets = sum([1 for expr in lhs.args if expr.is_starred])
Stefan Behnel's avatar
Stefan Behnel committed
472 473 474 475 476 477 478 479 480
        if starred_targets > 1:
            error(lhs.pos, "more than 1 starred expression in assignment")
            output.append([lhs,rhs])
            continue
        elif lhs_size - starred_targets > rhs_size:
            error(lhs.pos, "need more than %d value%s to unpack"
                  % (rhs_size, (rhs_size != 1) and 's' or ''))
            output.append([lhs,rhs])
            continue
Stefan Behnel's avatar
Stefan Behnel committed
481
        elif starred_targets:
482
            map_starred_assignment(lhs_targets, starred_assignments,
483
                                   lhs.args, rhs_args)
Stefan Behnel's avatar
Stefan Behnel committed
484 485 486 487 488
        elif lhs_size < rhs_size:
            error(lhs.pos, "too many values to unpack (expected %d, got %d)"
                  % (lhs_size, rhs_size))
            output.append([lhs,rhs])
            continue
489
        else:
Stefan Behnel's avatar
Stefan Behnel committed
490 491
            for targets, expr in zip(lhs_targets, lhs.args):
                targets.append(expr)
492 493 494 495 496 497

    if complete_assignments:
        complete_assignments.append(rhs)
        output.append(complete_assignments)

    # recursively flatten partial assignments
498
    for cascade, rhs in zip(lhs_targets, rhs_args):
499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529
        if cascade:
            cascade.append(rhs)
            flatten_parallel_assignments(cascade, output)

    # recursively flatten starred assignments
    for cascade in starred_assignments:
        if cascade[0].is_sequence_constructor:
            flatten_parallel_assignments(cascade, output)
        else:
            output.append(cascade)

def map_starred_assignment(lhs_targets, starred_assignments, lhs_args, rhs_args):
    # Appends the fixed-position LHS targets to the target list that
    # appear left and right of the starred argument.
    #
    # The starred_assignments list receives a new tuple
    # (lhs_target, rhs_values_list) that maps the remaining arguments
    # (those that match the starred target) to a list.

    # left side of the starred target
    for i, (targets, expr) in enumerate(zip(lhs_targets, lhs_args)):
        if expr.is_starred:
            starred = i
            lhs_remaining = len(lhs_args) - i - 1
            break
        targets.append(expr)
    else:
        raise InternalError("no starred arg found when splitting starred assignment")

    # right side of the starred target
    for i, (targets, expr) in enumerate(zip(lhs_targets[-lhs_remaining:],
Vitja Makarov's avatar
Vitja Makarov committed
530
                                            lhs_args[starred + 1:])):
531 532 533 534 535 536 537 538 539 540 541 542 543 544 545
        targets.append(expr)

    # the starred target itself, must be assigned a (potentially empty) list
    target = lhs_args[starred].target # unpack starred node
    starred_rhs = rhs_args[starred:]
    if lhs_remaining:
        starred_rhs = starred_rhs[:-lhs_remaining]
    if starred_rhs:
        pos = starred_rhs[0].pos
    else:
        pos = target.pos
    starred_assignments.append([
        target, ExprNodes.ListNode(pos=pos, args=starred_rhs)])


546
class PxdPostParse(CythonTransform, SkipDeclarations):
547 548 549
    """
    Basic interpretation/validity checking that should only be
    done on pxd trees.
550 551 552 553 554 555

    A lot of this checking currently happens in the parser; but
    what is listed below happens here.

    - "def" functions are let through only if they fill the
    getbuffer/releasebuffer slots
556

557 558
    - cdef functions are let through only if they are on the
    top level and are declared "inline"
559
    """
560 561
    ERR_INLINE_ONLY = "function definition in pxd file must be declared 'cdef inline'"
    ERR_NOGO_WITH_INLINE = "inline function definition in pxd file cannot be '%s'"
562 563 564 565 566 567 568 569 570 571 572 573 574 575 576

    def __call__(self, node):
        self.scope_type = 'pxd'
        return super(PxdPostParse, self).__call__(node)

    def visit_CClassDefNode(self, node):
        old = self.scope_type
        self.scope_type = 'cclass'
        self.visitchildren(node)
        self.scope_type = old
        return node

    def visit_FuncDefNode(self, node):
        # FuncDefNode always come with an implementation (without
        # an imp they are CVarDefNodes..)
577
        err = self.ERR_INLINE_ONLY
578

579
        if (isinstance(node, Nodes.DefNode) and self.scope_type == 'cclass'
580
            and node.name in ('__getbuffer__', '__releasebuffer__')):
581
            err = None # allow these slots
582

583
        if isinstance(node, Nodes.CFuncDefNode):
584 585
            if (u'inline' in node.modifiers and
                self.scope_type in ('pxd', 'cclass')):
586 587 588 589 590 591 592 593
                node.inline_in_pxd = True
                if node.visibility != 'private':
                    err = self.ERR_NOGO_WITH_INLINE % node.visibility
                elif node.api:
                    err = self.ERR_NOGO_WITH_INLINE % 'api'
                else:
                    err = None # allow inline function
            else:
594 595
                err = self.ERR_INLINE_ONLY

596 597
        if err:
            self.context.nonfatal_error(PostParseError(node.pos, err))
598 599 600
            return None
        else:
            return node
601

602 603 604 605 606

class TrackNumpyAttributes(VisitorTransform, SkipDeclarations):
    # TODO: Make name handling as good as in InterpretCompilerDirectives() below - probably best to merge the two.
    def __init__(self):
        super(TrackNumpyAttributes, self).__init__()
607 608 609 610 611 612 613 614 615
        self.numpy_module_names = set()

    def visit_CImportStatNode(self, node):
        if node.module_name == u"numpy":
            self.numpy_module_names.add(node.as_name or u"numpy")
        return node

    def visit_AttributeNode(self, node):
        self.visitchildren(node)
616 617
        obj = node.obj
        if (obj.is_name and obj.name in self.numpy_module_names) or obj.is_numpy_attribute:
618 619 620
            node.is_numpy_attribute = True
        return node

621 622 623
    visit_Node = VisitorTransform.recurse_to_children


624
class InterpretCompilerDirectives(CythonTransform):
625
    """
626
    After parsing, directives can be stored in a number of places:
627 628
    - #cython-comments at the top of the file (stored in ModuleNode)
    - Command-line arguments overriding these
629 630
    - @cython.directivename decorators
    - with cython.directivename: statements
631

632
    This transform is responsible for interpreting these various sources
633
    and store the directive in two ways:
634 635 636 637 638 639 640 641 642 643 644
    - Set the directives attribute of the ModuleNode for global directives.
    - Use a CompilerDirectivesNode to override directives for a subtree.

    (The first one is primarily to not have to modify with the tree
    structure, so that ModuleNode stay on top.)

    The directives are stored in dictionaries from name to value in effect.
    Each such dictionary is always filled in for all possible directives,
    using default values where no value is given by the user.

    The available directives are controlled in Options.py.
645 646 647

    Note that we have to run this prior to analysis, and so some minor
    duplication of functionality has to occur: We manually track cimports
648
    and which names the "cython" module may have been imported to.
649
    """
650
    unop_method_nodes = {
651
        'typeof': ExprNodes.TypeofNode,
652

653 654 655 656 657 658
        'operator.address': ExprNodes.AmpersandNode,
        'operator.dereference': ExprNodes.DereferenceNode,
        'operator.preincrement' : ExprNodes.inc_dec_constructor(True, '++'),
        'operator.predecrement' : ExprNodes.inc_dec_constructor(True, '--'),
        'operator.postincrement': ExprNodes.inc_dec_constructor(False, '++'),
        'operator.postdecrement': ExprNodes.inc_dec_constructor(False, '--'),
659
        'operator.typeid'       : ExprNodes.TypeidNode,
660

Jakub Wilk's avatar
Jakub Wilk committed
661
        # For backwards compatibility.
662
        'address': ExprNodes.AmpersandNode,
663
    }
Robert Bradshaw's avatar
Robert Bradshaw committed
664 665

    binop_method_nodes = {
666
        'operator.comma'        : ExprNodes.c_binop_constructor(','),
Robert Bradshaw's avatar
Robert Bradshaw committed
667
    }
668

669 670 671
    special_methods = set(['declare', 'union', 'struct', 'typedef',
                           'sizeof', 'cast', 'pointer', 'compiled',
                           'NULL', 'fused_type', 'parallel'])
672
    special_methods.update(unop_method_nodes)
673

Robert Bradshaw's avatar
Robert Bradshaw committed
674
    valid_parallel_directives = set([
Mark Florisson's avatar
Mark Florisson committed
675 676 677
        "parallel",
        "prange",
        "threadid",
678
        #"threadsavailable",
Mark Florisson's avatar
Mark Florisson committed
679 680
    ])

681
    def __init__(self, context, compilation_directive_defaults):
682
        super(InterpretCompilerDirectives, self).__init__(context)
Robert Bradshaw's avatar
Robert Bradshaw committed
683
        self.cython_module_names = set()
Robert Bradshaw's avatar
Robert Bradshaw committed
684
        self.directive_names = {'staticmethod': 'staticmethod'}
Mark Florisson's avatar
Mark Florisson committed
685
        self.parallel_directives = {}
686
        directives = copy.deepcopy(Options.get_directive_defaults())
687
        for key, value in compilation_directive_defaults.items():
688
            directives[_unicode(key)] = copy.deepcopy(value)
689
        self.directives = directives
690

691
    def check_directive_scope(self, pos, directive, scope):
692
        legal_scopes = Options.directive_scopes.get(directive, None)
693 694 695 696 697
        if legal_scopes and scope not in legal_scopes:
            self.context.nonfatal_error(PostParseError(pos, 'The %s compiler directive '
                                        'is not allowed in %s scope' % (directive, scope)))
            return False
        else:
698
            if directive not in Options.directive_types:
699
                error(pos, "Invalid directive: '%s'." % (directive,))
700
            return True
701

702
    # Set up processing and handle the cython: comments.
703
    def visit_ModuleNode(self, node):
704
        for key in sorted(node.directive_comments):
705 706
            if not self.check_directive_scope(node.pos, key, 'module'):
                self.wrong_scope_error(node.pos, key, 'module')
707 708
                del node.directive_comments[key]

709 710
        self.module_scope = node.scope

711 712
        self.directives.update(node.directive_comments)
        node.directives = self.directives
Mark Florisson's avatar
Mark Florisson committed
713
        node.parallel_directives = self.parallel_directives
714
        self.visitchildren(node)
715
        node.cython_module_names = self.cython_module_names
716 717
        return node

718 719 720 721 722 723 724
    # The following four functions track imports and cimports that
    # begin with "cython"
    def is_cython_directive(self, name):
        return (name in Options.directive_types or
                name in self.special_methods or
                PyrexTypes.parse_basic_type(name))

Mark Florisson's avatar
Mark Florisson committed
725
    def is_parallel_directive(self, full_name, pos):
Mark Florisson's avatar
Mark Florisson committed
726 727 728 729 730
        """
        Checks to see if fullname (e.g. cython.parallel.prange) is a valid
        parallel directive. If it is a star import it also updates the
        parallel_directives.
        """
Mark Florisson's avatar
Mark Florisson committed
731 732 733
        result = (full_name + ".").startswith("cython.parallel.")

        if result:
Mark Florisson's avatar
Mark Florisson committed
734
            directive = full_name.split('.')
735 736 737
            if full_name == u"cython.parallel":
                self.parallel_directives[u"parallel"] = u"cython.parallel"
            elif full_name == u"cython.parallel.*":
738 739
                for name in self.valid_parallel_directives:
                    self.parallel_directives[name] = u"cython.parallel.%s" % name
Mark Florisson's avatar
Mark Florisson committed
740 741
            elif (len(directive) != 3 or
                  directive[-1] not in self.valid_parallel_directives):
Mark Florisson's avatar
Mark Florisson committed
742 743
                error(pos, "No such directive: %s" % full_name)

744 745
            self.module_scope.use_utility_code(
                UtilityCode.load_cached("InitThreads", "ModuleSetupCode.c"))
746

Mark Florisson's avatar
Mark Florisson committed
747 748
        return result

749 750
    def visit_CImportStatNode(self, node):
        if node.module_name == u"cython":
751
            self.cython_module_names.add(node.as_name or u"cython")
752
        elif node.module_name.startswith(u"cython."):
Mark Florisson's avatar
Mark Florisson committed
753 754 755
            if node.module_name.startswith(u"cython.parallel."):
                error(node.pos, node.module_name + " is not a module")
            if node.module_name == u"cython.parallel":
756
                if node.as_name and node.as_name != u"cython":
Mark Florisson's avatar
Mark Florisson committed
757 758 759 760 761
                    self.parallel_directives[node.as_name] = node.module_name
                else:
                    self.cython_module_names.add(u"cython")
                    self.parallel_directives[
                                    u"cython.parallel"] = node.module_name
762 763
                self.module_scope.use_utility_code(
                    UtilityCode.load_cached("InitThreads", "ModuleSetupCode.c"))
Mark Florisson's avatar
Mark Florisson committed
764
            elif node.as_name:
765
                self.directive_names[node.as_name] = node.module_name[7:]
766
            else:
767
                self.cython_module_names.add(u"cython")
768 769 770
            # if this cimport was a compiler directive, we don't
            # want to leave the cimport node sitting in the tree
            return None
771
        return node
772

773
    def visit_FromCImportStatNode(self, node):
774 775
        if not node.relative_level and (
                node.module_name == u"cython" or node.module_name.startswith(u"cython.")):
776
            submodule = (node.module_name + u".")[7:]
777
            newimp = []
Mark Florisson's avatar
Mark Florisson committed
778

779
            for pos, name, as_name, kind in node.imported_names:
780
                full_name = submodule + name
Mark Florisson's avatar
Mark Florisson committed
781 782 783 784 785 786 787
                qualified_name = u"cython." + full_name

                if self.is_parallel_directive(qualified_name, node.pos):
                    # from cython cimport parallel, or
                    # from cython.parallel cimport parallel, prange, ...
                    self.parallel_directives[as_name or name] = qualified_name
                elif self.is_cython_directive(full_name):
788
                    self.directive_names[as_name or name] = full_name
789 790
                    if kind is not None:
                        self.context.nonfatal_error(PostParseError(pos,
791
                            "Compiler directive imports must be plain imports"))
792 793
                else:
                    newimp.append((pos, name, as_name, kind))
Mark Florisson's avatar
Mark Florisson committed
794

Robert Bradshaw's avatar
Robert Bradshaw committed
795 796
            if not newimp:
                return None
Mark Florisson's avatar
Mark Florisson committed
797

Robert Bradshaw's avatar
Robert Bradshaw committed
798
            node.imported_names = newimp
799
        return node
800

Robert Bradshaw's avatar
Robert Bradshaw committed
801
    def visit_FromImportStatNode(self, node):
802 803
        if (node.module.module_name.value == u"cython") or \
               node.module.module_name.value.startswith(u"cython."):
804
            submodule = (node.module.module_name.value + u".")[7:]
Robert Bradshaw's avatar
Robert Bradshaw committed
805
            newimp = []
806
            for name, name_node in node.items:
807
                full_name = submodule + name
Mark Florisson's avatar
Mark Florisson committed
808 809 810 811
                qualified_name = u"cython." + full_name
                if self.is_parallel_directive(qualified_name, node.pos):
                    self.parallel_directives[name_node.name] = qualified_name
                elif self.is_cython_directive(full_name):
812
                    self.directive_names[name_node.name] = full_name
Robert Bradshaw's avatar
Robert Bradshaw committed
813
                else:
814
                    newimp.append((name, name_node))
Robert Bradshaw's avatar
Robert Bradshaw committed
815 816 817 818 819
            if not newimp:
                return None
            node.items = newimp
        return node

820
    def visit_SingleAssignmentNode(self, node):
821 822 823 824 825 826
        if isinstance(node.rhs, ExprNodes.ImportNode):
            module_name = node.rhs.module_name.value
            is_parallel = (module_name + u".").startswith(u"cython.parallel.")

            if module_name != u"cython" and not is_parallel:
                return node
Mark Florisson's avatar
Mark Florisson committed
827 828 829 830

            module_name = node.rhs.module_name.value
            as_name = node.lhs.name

831
            node = Nodes.CImportStatNode(node.pos,
Mark Florisson's avatar
Mark Florisson committed
832 833
                                         module_name = module_name,
                                         as_name = as_name)
834
            node = self.visit_CImportStatNode(node)
835 836
        else:
            self.visitchildren(node)
837

838
        return node
839

840 841 842
    def visit_NameNode(self, node):
        if node.name in self.cython_module_names:
            node.is_cython_module = True
Robert Bradshaw's avatar
Robert Bradshaw committed
843
        else:
844 845 846
            directive = self.directive_names.get(node.name)
            if directive is not None:
                node.cython_attribute = directive
847
        return node
848

849 850 851 852 853
    def visit_NewExprNode(self, node):
        self.visit(node.cppclass)
        self.visitchildren(node)
        return node

854
    def try_to_parse_directives(self, node):
855
        # If node is the contents of an directive (in a with statement or
856
        # decorator), returns a list of (directivename, value) pairs.
857
        # Otherwise, returns None
858
        if isinstance(node, ExprNodes.CallNode):
Robert Bradshaw's avatar
Robert Bradshaw committed
859
            self.visit(node.function)
860
            optname = node.function.as_cython_attribute()
861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882
            if optname:
                directivetype = Options.directive_types.get(optname)
                if directivetype:
                    args, kwds = node.explicit_args_kwds()
                    directives = []
                    key_value_pairs = []
                    if kwds is not None and directivetype is not dict:
                        for keyvalue in kwds.key_value_pairs:
                            key, value = keyvalue
                            sub_optname = "%s.%s" % (optname, key.value)
                            if Options.directive_types.get(sub_optname):
                                directives.append(self.try_to_parse_directive(sub_optname, [value], None, keyvalue.pos))
                            else:
                                key_value_pairs.append(keyvalue)
                        if not key_value_pairs:
                            kwds = None
                        else:
                            kwds.key_value_pairs = key_value_pairs
                        if directives and not kwds and not args:
                            return directives
                    directives.append(self.try_to_parse_directive(optname, args, kwds, node.function.pos))
                    return directives
883
        elif isinstance(node, (ExprNodes.AttributeNode, ExprNodes.NameNode)):
884 885 886 887 888
            self.visit(node)
            optname = node.as_cython_attribute()
            if optname:
                directivetype = Options.directive_types.get(optname)
                if directivetype is bool:
889 890
                    arg = ExprNodes.BoolNode(node.pos, value=True)
                    return [self.try_to_parse_directive(optname, [arg], None, node.pos)]
891 892 893 894 895
                elif directivetype is None:
                    return [(optname, None)]
                else:
                    raise PostParseError(
                        node.pos, "The '%s' directive should be used as a function call." % optname)
896
        return None
897

898
    def try_to_parse_directive(self, optname, args, kwds, pos):
899 900
        if optname == 'np_pythran' and not self.context.cpp:
            raise PostParseError(pos, 'The %s directive can only be used in C++ mode.' % optname)
901
        elif optname == 'exceptval':
902
            # default: exceptval(None, check=True)
903
            arg_error = len(args) > 1
904
            check = True
905 906 907 908 909 910 911 912 913 914 915 916 917 918
            if kwds and kwds.key_value_pairs:
                kw = kwds.key_value_pairs[0]
                if (len(kwds.key_value_pairs) == 1 and
                        kw.key.is_string_literal and kw.key.value == 'check' and
                        isinstance(kw.value, ExprNodes.BoolNode)):
                    check = kw.value.value
                else:
                    arg_error = True
            if arg_error:
                raise PostParseError(
                    pos, 'The exceptval directive takes 0 or 1 positional arguments and the boolean keyword "check"')
            return ('exceptval', (args[0] if args else None, check))

        directivetype = Options.directive_types.get(optname)
919
        if len(args) == 1 and isinstance(args[0], ExprNodes.NoneNode):
920
            return optname, Options.get_directive_defaults()[optname]
921
        elif directivetype is bool:
922
            if kwds is not None or len(args) != 1 or not isinstance(args[0], ExprNodes.BoolNode):
923 924 925
                raise PostParseError(pos,
                    'The %s directive takes one compile-time boolean argument' % optname)
            return (optname, args[0].value)
926 927 928 929 930
        elif directivetype is int:
            if kwds is not None or len(args) != 1 or not isinstance(args[0], ExprNodes.IntNode):
                raise PostParseError(pos,
                    'The %s directive takes one compile-time integer argument' % optname)
            return (optname, int(args[0].value))
931
        elif directivetype is str:
932 933
            if kwds is not None or len(args) != 1 or not isinstance(
                    args[0], (ExprNodes.StringNode, ExprNodes.UnicodeNode)):
934 935 936
                raise PostParseError(pos,
                    'The %s directive takes one compile-time string argument' % optname)
            return (optname, str(args[0].value))
937 938 939 940 941
        elif directivetype is type:
            if kwds is not None or len(args) != 1:
                raise PostParseError(pos,
                    'The %s directive takes one type argument' % optname)
            return (optname, args[0])
942 943 944 945 946 947
        elif directivetype is dict:
            if len(args) != 0:
                raise PostParseError(pos,
                    'The %s directive takes no prepositional arguments' % optname)
            return optname, dict([(key.value, value) for key, value in kwds.key_value_pairs])
        elif directivetype is list:
948
            if kwds and len(kwds.key_value_pairs) != 0:
949 950 951
                raise PostParseError(pos,
                    'The %s directive takes no keyword arguments' % optname)
            return optname, [ str(arg.value) for arg in args ]
952 953 954 955 956
        elif callable(directivetype):
            if kwds is not None or len(args) != 1 or not isinstance(
                    args[0], (ExprNodes.StringNode, ExprNodes.UnicodeNode)):
                raise PostParseError(pos,
                    'The %s directive takes one compile-time string argument' % optname)
Stefan Behnel's avatar
Stefan Behnel committed
957
            return (optname, directivetype(optname, str(args[0].value)))
958 959 960
        else:
            assert False

961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979
    def visit_with_directives(self, node, directives):
        if not directives:
            return self.visit_Node(node)

        old_directives = self.directives
        new_directives = dict(old_directives)
        new_directives.update(directives)

        if new_directives == old_directives:
            return self.visit_Node(node)

        self.directives = new_directives
        retbody = self.visit_Node(node)
        self.directives = old_directives

        if not isinstance(retbody, Nodes.StatListNode):
            retbody = Nodes.StatListNode(node.pos, stats=[retbody])
        return Nodes.CompilerDirectivesNode(
            pos=retbody.pos, body=retbody, directives=new_directives)
980

981
    # Handle decorators
982
    def visit_FuncDefNode(self, node):
983
        directives = self._extract_directives(node, 'function')
984
        return self.visit_with_directives(node, directives)
985 986

    def visit_CVarDefNode(self, node):
987
        directives = self._extract_directives(node, 'function')
988
        for name, value in directives.items():
989 990
            if name == 'locals':
                node.directive_locals = value
991
            elif name not in ('final', 'staticmethod'):
Stefan Behnel's avatar
Stefan Behnel committed
992 993
                self.context.nonfatal_error(PostParseError(
                    node.pos,
994 995
                    "Cdef functions can only take cython.locals(), "
                    "staticmethod, or final decorators, got %s." % name))
996
        return self.visit_with_directives(node, directives)
997 998 999

    def visit_CClassDefNode(self, node):
        directives = self._extract_directives(node, 'cclass')
1000
        return self.visit_with_directives(node, directives)
1001

1002 1003
    def visit_CppClassNode(self, node):
        directives = self._extract_directives(node, 'cppclass')
1004
        return self.visit_with_directives(node, directives)
1005

1006 1007
    def visit_PyClassDefNode(self, node):
        directives = self._extract_directives(node, 'class')
1008
        return self.visit_with_directives(node, directives)
1009

1010 1011 1012 1013 1014 1015
    def _extract_directives(self, node, scope_name):
        if not node.decorators:
            return {}
        # Split the decorators into two lists -- real decorators and directives
        directives = []
        realdecs = []
Robert Bradshaw's avatar
Robert Bradshaw committed
1016
        both = []
1017 1018
        # Decorators coming first take precedence.
        for dec in node.decorators[::-1]:
1019 1020 1021 1022
            new_directives = self.try_to_parse_directives(dec.decorator)
            if new_directives is not None:
                for directive in new_directives:
                    if self.check_directive_scope(node.pos, directive[0], scope_name):
Robert Bradshaw's avatar
Robert Bradshaw committed
1023 1024 1025 1026 1027
                        name, value = directive
                        if self.directives.get(name, object()) != value:
                            directives.append(directive)
                        if directive[0] == 'staticmethod':
                            both.append(dec)
1028 1029 1030
                    # Adapt scope type based on decorators that change it.
                    if directive[0] == 'cclass' and scope_name == 'class':
                        scope_name = 'cclass'
1031
            else:
1032
                realdecs.append(dec)
1033
        if realdecs and (scope_name == 'cclass' or
1034
                         isinstance(node, (Nodes.CClassDefNode, Nodes.CVarDefNode))):
1035
            raise PostParseError(realdecs[0].pos, "Cdef functions/classes cannot take arbitrary decorators.")
1036
        node.decorators = realdecs[::-1] + both[::-1]
1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048
        # merge or override repeated directives
        optdict = {}
        for directive in directives:
            name, value = directive
            if name in optdict:
                old_value = optdict[name]
                # keywords and arg lists can be merged, everything
                # else overrides completely
                if isinstance(old_value, dict):
                    old_value.update(value)
                elif isinstance(old_value, list):
                    old_value.extend(value)
1049 1050
                else:
                    optdict[name] = value
1051 1052 1053 1054
            else:
                optdict[name] = value
        return optdict

1055
    # Handle with-statements
1056
    def visit_WithStatNode(self, node):
1057 1058 1059 1060 1061 1062 1063 1064
        directive_dict = {}
        for directive in self.try_to_parse_directives(node.manager) or []:
            if directive is not None:
                if node.target is not None:
                    self.context.nonfatal_error(
                        PostParseError(node.pos, "Compiler directive with statements cannot contain 'as'"))
                else:
                    name, value = directive
1065
                    if name in ('nogil', 'gil'):
1066
                        # special case: in pure mode, "with nogil" spells "with cython.nogil"
1067
                        node = Nodes.GILStatNode(node.pos, state = name, body = node.body)
1068
                        return self.visit_Node(node)
1069 1070 1071 1072
                    if self.check_directive_scope(node.pos, name, 'with statement'):
                        directive_dict[name] = value
        if directive_dict:
            return self.visit_with_directives(node.body, directive_dict)
1073
        return self.visit_Node(node)
1074

1075

Mark Florisson's avatar
Mark Florisson committed
1076 1077 1078 1079 1080 1081
class ParallelRangeTransform(CythonTransform, SkipDeclarations):
    """
    Transform cython.parallel stuff. The parallel_directives come from the
    module node, set there by InterpretCompilerDirectives.

        x = cython.parallel.threadavailable()   -> ParallelThreadAvailableNode
1082
        with nogil, cython.parallel.parallel(): -> ParallelWithBlockNode
Mark Florisson's avatar
Mark Florisson committed
1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097
            print cython.parallel.threadid()    -> ParallelThreadIdNode
            for i in cython.parallel.prange(...):  -> ParallelRangeNode
                ...
    """

    # a list of names, maps 'cython.parallel.prange' in the code to
    # ['cython', 'parallel', 'prange']
    parallel_directive = None

    # Indicates whether a namenode in an expression is the cython module
    namenode_is_cython_module = False

    # Keep track of whether we are the context manager of a 'with' statement
    in_context_manager_section = False

1098 1099 1100 1101
    # One of 'prange' or 'with parallel'. This is used to disallow closely
    # nested 'with parallel:' blocks
    state = None

Mark Florisson's avatar
Mark Florisson committed
1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127
    directive_to_node = {
        u"cython.parallel.parallel": Nodes.ParallelWithBlockNode,
        # u"cython.parallel.threadsavailable": ExprNodes.ParallelThreadsAvailableNode,
        u"cython.parallel.threadid": ExprNodes.ParallelThreadIdNode,
        u"cython.parallel.prange": Nodes.ParallelRangeNode,
    }

    def node_is_parallel_directive(self, node):
        return node.name in self.parallel_directives or node.is_cython_module

    def get_directive_class_node(self, node):
        """
        Figure out which parallel directive was used and return the associated
        Node class.

        E.g. for a cython.parallel.prange() call we return ParallelRangeNode
        """
        if self.namenode_is_cython_module:
            directive = '.'.join(self.parallel_directive)
        else:
            directive = self.parallel_directives[self.parallel_directive[0]]
            directive = '%s.%s' % (directive,
                                   '.'.join(self.parallel_directive[1:]))
            directive = directive.rstrip('.')

        cls = self.directive_to_node.get(directive)
1128 1129
        if cls is None and not (self.namenode_is_cython_module and
                                self.parallel_directive[0] != 'parallel'):
Mark Florisson's avatar
Mark Florisson committed
1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161
            error(node.pos, "Invalid directive: %s" % directive)

        self.namenode_is_cython_module = False
        self.parallel_directive = None

        return cls

    def visit_ModuleNode(self, node):
        """
        If any parallel directives were imported, copy them over and visit
        the AST
        """
        if node.parallel_directives:
            self.parallel_directives = node.parallel_directives
            return self.visit_Node(node)

        # No parallel directives were imported, so they can't be used :)
        return node

    def visit_NameNode(self, node):
        if self.node_is_parallel_directive(node):
            self.parallel_directive = [node.name]
            self.namenode_is_cython_module = node.is_cython_module
        return node

    def visit_AttributeNode(self, node):
        self.visitchildren(node)
        if self.parallel_directive:
            self.parallel_directive.append(node.attribute)
        return node

    def visit_CallNode(self, node):
1162
        self.visit(node.function)
Mark Florisson's avatar
Mark Florisson committed
1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177
        if not self.parallel_directive:
            return node

        # We are a parallel directive, replace this node with the
        # corresponding ParallelSomethingSomething node

        if isinstance(node, ExprNodes.GeneralCallNode):
            args = node.positional_args.args
            kwargs = node.keyword_args
        else:
            args = node.args
            kwargs = {}

        parallel_directive_class = self.get_directive_class_node(node)
        if parallel_directive_class:
1178 1179
            # Note: in case of a parallel() the body is set by
            # visit_WithStatNode
Mark Florisson's avatar
Mark Florisson committed
1180 1181 1182 1183 1184
            node = parallel_directive_class(node.pos, args=args, kwargs=kwargs)

        return node

    def visit_WithStatNode(self, node):
1185 1186
        "Rewrite with cython.parallel.parallel() blocks"
        newnode = self.visit(node.manager)
Mark Florisson's avatar
Mark Florisson committed
1187

1188
        if isinstance(newnode, Nodes.ParallelWithBlockNode):
1189 1190
            if self.state == 'parallel with':
                error(node.manager.pos,
1191
                      "Nested parallel with blocks are disallowed")
1192 1193

            self.state = 'parallel with'
1194
            body = self.visit(node.body)
1195
            self.state = None
Mark Florisson's avatar
Mark Florisson committed
1196

1197 1198 1199 1200
            newnode.body = body
            return newnode
        elif self.parallel_directive:
            parallel_directive_class = self.get_directive_class_node(node)
1201

1202 1203 1204
            if not parallel_directive_class:
                # There was an error, stop here and now
                return None
Mark Florisson's avatar
Mark Florisson committed
1205

1206 1207 1208
            if parallel_directive_class is Nodes.ParallelWithBlockNode:
                error(node.pos, "The parallel directive must be called")
                return None
Mark Florisson's avatar
Mark Florisson committed
1209

1210 1211
        node.body = self.visit(node.body)
        return node
Mark Florisson's avatar
Mark Florisson committed
1212 1213 1214 1215 1216 1217

    def visit_ForInStatNode(self, node):
        "Rewrite 'for i in cython.parallel.prange(...):'"
        self.visit(node.iterator)
        self.visit(node.target)

1218 1219
        in_prange = isinstance(node.iterator.sequence,
                               Nodes.ParallelRangeNode)
1220
        previous_state = self.state
Mark Florisson's avatar
Mark Florisson committed
1221

1222
        if in_prange:
Mark Florisson's avatar
Mark Florisson committed
1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236
            # This will replace the entire ForInStatNode, so copy the
            # attributes
            parallel_range_node = node.iterator.sequence

            parallel_range_node.target = node.target
            parallel_range_node.body = node.body
            parallel_range_node.else_clause = node.else_clause

            node = parallel_range_node

            if not isinstance(node.target, ExprNodes.NameNode):
                error(node.target.pos,
                      "Can only iterate over an iteration variable")

1237
            self.state = 'prange'
Mark Florisson's avatar
Mark Florisson committed
1238

1239 1240 1241
        self.visit(node.body)
        self.state = previous_state
        self.visit(node.else_clause)
Mark Florisson's avatar
Mark Florisson committed
1242 1243 1244 1245 1246
        return node

    def visit(self, node):
        "Visit a node that may be None"
        if node is not None:
1247
            return super(ParallelRangeTransform, self).visit(node)
Mark Florisson's avatar
Mark Florisson committed
1248 1249


1250
class WithTransform(CythonTransform, SkipDeclarations):
1251
    def visit_WithStatNode(self, node):
1252 1253
        self.visitchildren(node, 'body')
        pos = node.pos
1254
        is_async = node.is_async
1255
        body, target, manager = node.body, node.target, node.manager
1256
        node.enter_call = ExprNodes.SimpleCallNode(
1257 1258
            pos, function=ExprNodes.AttributeNode(
                pos, obj=ExprNodes.CloneNode(manager),
1259
                attribute=EncodedString('__aenter__' if is_async else '__enter__'),
1260 1261
                is_special_lookup=True),
            args=[],
1262 1263
            is_temp=True)

1264 1265 1266
        if is_async:
            node.enter_call = ExprNodes.AwaitExprNode(pos, arg=node.enter_call)

1267 1268
        if target is not None:
            body = Nodes.StatListNode(
Stefan Behnel's avatar
Stefan Behnel committed
1269
                pos, stats=[
1270
                    Nodes.WithTargetAssignmentStatNode(
1271
                        pos, lhs=target, with_node=node),
1272
                    body])
1273

1274 1275
        excinfo_target = ExprNodes.TupleNode(pos, slow=True, args=[
            ExprNodes.ExcValueNode(pos) for _ in range(3)])
1276
        except_clause = Nodes.ExceptClauseNode(
1277 1278
            pos, body=Nodes.IfStatNode(
                pos, if_clauses=[
1279
                    Nodes.IfClauseNode(
1280 1281 1282 1283
                        pos, condition=ExprNodes.NotNode(
                            pos, operand=ExprNodes.WithExitCallNode(
                                pos, with_stat=node,
                                test_if_run=False,
1284
                                args=excinfo_target,
1285
                                await_expr=ExprNodes.AwaitExprNode(pos, arg=None) if is_async else None)),
1286
                        body=Nodes.ReraiseStatNode(pos),
Stefan Behnel's avatar
Stefan Behnel committed
1287 1288
                    ),
                ],
1289 1290 1291 1292
                else_clause=None),
            pattern=None,
            target=None,
            excinfo_target=excinfo_target,
Stefan Behnel's avatar
Stefan Behnel committed
1293
        )
1294 1295

        node.body = Nodes.TryFinallyStatNode(
1296 1297 1298 1299
            pos, body=Nodes.TryExceptStatNode(
                pos, body=body,
                except_clauses=[except_clause],
                else_clause=None,
Stefan Behnel's avatar
Stefan Behnel committed
1300
            ),
1301 1302 1303 1304 1305
            finally_clause=Nodes.ExprStatNode(
                pos, expr=ExprNodes.WithExitCallNode(
                    pos, with_stat=node,
                    test_if_run=True,
                    args=ExprNodes.TupleNode(
1306
                        pos, args=[ExprNodes.NoneNode(pos) for _ in range(3)]),
1307
                    await_expr=ExprNodes.AwaitExprNode(pos, arg=None) if is_async else None)),
1308
            handle_error_case=False,
Stefan Behnel's avatar
Stefan Behnel committed
1309
        )
1310
        return node
1311

1312 1313 1314
    def visit_ExprNode(self, node):
        # With statements are never inside expressions.
        return node
1315

1316

1317
class DecoratorTransform(ScopeTrackingTransform, SkipDeclarations):
1318
    """
1319
    Transforms method decorators in cdef classes into nested calls or properties.
1320

1321 1322
    Python-style decorator properties are transformed into a PropertyNode
    with up to the three getter, setter and deleter DefNodes.
1323 1324
    The functional style isn't supported yet.
    """
1325 1326 1327 1328 1329 1330 1331
    _properties = None

    _map_property_attribute = {
        'getter': '__get__',
        'setter': '__set__',
        'deleter': '__del__',
    }.get
1332 1333

    def visit_CClassDefNode(self, node):
1334 1335 1336
        if self._properties is None:
            self._properties = []
        self._properties.append({})
Stefan Behnel's avatar
Stefan Behnel committed
1337
        super(DecoratorTransform, self).visit_CClassDefNode(node)
1338
        self._properties.pop()
1339 1340
        return node

1341
    def visit_PropertyNode(self, node):
1342 1343 1344
        # Low-level warning for other code until we can convert all our uses over.
        level = 2 if isinstance(node.pos[0], str) else 0
        warning(node.pos, "'property %s:' syntax is deprecated, use '@property'" % node.name, level)
1345 1346
        return node

1347
    def visit_DefNode(self, node):
1348 1349 1350
        scope_type = self.scope_type
        node = self.visit_FuncDefNode(node)
        if scope_type != 'cclass' or not node.decorators:
1351
            return node
1352 1353

        # transform @property decorators
1354
        properties = self._properties[-1]
1355 1356 1357 1358 1359 1360
        for decorator_node in node.decorators[::-1]:
            decorator = decorator_node.decorator
            if decorator.is_name and decorator.name == 'property':
                if len(node.decorators) > 1:
                    return self._reject_decorated_property(node, decorator_node)
                name = node.name
1361
                node.name = EncodedString('__get__')
1362 1363 1364 1365 1366 1367 1368 1369 1370
                node.decorators.remove(decorator_node)
                stat_list = [node]
                if name in properties:
                    prop = properties[name]
                    prop.pos = node.pos
                    prop.doc = node.doc
                    prop.body.stats = stat_list
                    return []
                prop = Nodes.PropertyNode(node.pos, name=name)
1371
                prop.doc = node.doc
1372 1373 1374 1375 1376 1377
                prop.body = Nodes.StatListNode(node.pos, stats=stat_list)
                properties[name] = prop
                return [prop]
            elif decorator.is_attribute and decorator.obj.name in properties:
                handler_name = self._map_property_attribute(decorator.attribute)
                if handler_name:
1378 1379 1380 1381 1382 1383
                    if decorator.obj.name != node.name:
                        # CPython does not generate an error or warning, but not something useful either.
                        error(decorator_node.pos,
                              "Mismatching property names, expected '%s', got '%s'" % (
                                  decorator.obj.name, node.name))
                    elif len(node.decorators) > 1:
1384
                        return self._reject_decorated_property(node, decorator_node)
1385 1386
                    else:
                        return self._add_to_property(properties, node, handler_name, decorator_node)
1387

1388 1389 1390 1391 1392 1393 1394 1395
        # we clear node.decorators, so we need to set the
        # is_staticmethod/is_classmethod attributes now
        for decorator in node.decorators:
            func = decorator.decorator
            if func.is_name:
                node.is_classmethod |= func.name == 'classmethod'
                node.is_staticmethod |= func.name == 'staticmethod'

1396
        # transform normal decorators
1397 1398 1399
        decs = node.decorators
        node.decorators = None
        return self.chain_decorators(node, decs, node.name)
1400 1401 1402

    @staticmethod
    def _reject_decorated_property(node, decorator_node):
1403 1404 1405 1406
        # restrict transformation to outermost decorator as wrapped properties will probably not work
        for deco in node.decorators:
            if deco != decorator_node:
                error(deco.pos, "Property methods with additional decorators are not supported")
1407 1408
        return node

1409 1410
    @staticmethod
    def _add_to_property(properties, node, name, decorator):
1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422
        prop = properties[node.name]
        node.name = name
        node.decorators.remove(decorator)
        stats = prop.body.stats
        for i, stat in enumerate(stats):
            if stat.name == name:
                stats[i] = node
                break
        else:
            stats.append(node)
        return []

1423
    @staticmethod
Stefan Behnel's avatar
Stefan Behnel committed
1424
    def chain_decorators(node, decorators, name):
1425 1426 1427 1428 1429
        """
        Decorators are applied directly in DefNode and PyClassDefNode to avoid
        reassignments to the function/class name - except for cdef class methods.
        For those, the reassignment is required as methods are originally
        defined in the PyMethodDef struct.
1430

1431 1432 1433
        The IndirectionNode allows DefNode to override the decorator.
        """
        decorator_result = ExprNodes.NameNode(node.pos, name=name)
1434
        for decorator in decorators[::-1]:
1435
            decorator_result = ExprNodes.SimpleCallNode(
1436
                decorator.pos,
1437 1438
                function=decorator.decorator,
                args=[decorator_result])
1439

1440
        name_node = ExprNodes.NameNode(node.pos, name=name)
1441
        reassignment = Nodes.SingleAssignmentNode(
1442
            node.pos,
1443 1444
            lhs=name_node,
            rhs=decorator_result)
1445 1446 1447

        reassignment = Nodes.IndirectionNode([reassignment])
        node.decorator_indirection = reassignment
1448
        return [node, reassignment]
1449

1450

1451 1452 1453 1454 1455 1456 1457 1458 1459
class CnameDirectivesTransform(CythonTransform, SkipDeclarations):
    """
    Only part of the CythonUtilityCode pipeline. Must be run before
    DecoratorTransform in case this is a decorator for a cdef class.
    It filters out @cname('my_cname') decorators and rewrites them to
    CnameDecoratorNodes.
    """

    def handle_function(self, node):
1460
        if not getattr(node, 'decorators', None):
1461 1462
            return self.visit_Node(node)

1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483
        for i, decorator in enumerate(node.decorators):
            decorator = decorator.decorator

            if (isinstance(decorator, ExprNodes.CallNode) and
                    decorator.function.is_name and
                    decorator.function.name == 'cname'):
                args, kwargs = decorator.explicit_args_kwds()

                if kwargs:
                    raise AssertionError(
                            "cname decorator does not take keyword arguments")

                if len(args) != 1:
                    raise AssertionError(
                            "cname decorator takes exactly one argument")

                if not (args[0].is_literal and
                        args[0].type == Builtin.str_type):
                    raise AssertionError(
                            "argument to cname decorator must be a string literal")

1484
                cname = args[0].compile_time_value(None)
1485 1486 1487 1488 1489
                del node.decorators[i]
                node = Nodes.CnameDecoratorNode(pos=node.pos, node=node,
                                                cname=cname)
                break

1490
        return self.visit_Node(node)
1491

1492 1493
    visit_FuncDefNode = handle_function
    visit_CClassDefNode = handle_function
1494 1495
    visit_CEnumDefNode = handle_function
    visit_CStructOrUnionDefNode = handle_function
1496 1497


1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532
class ForwardDeclareTypes(CythonTransform):

    def visit_CompilerDirectivesNode(self, node):
        env = self.module_scope
        old = env.directives
        env.directives = node.directives
        self.visitchildren(node)
        env.directives = old
        return node

    def visit_ModuleNode(self, node):
        self.module_scope = node.scope
        self.module_scope.directives = node.directives
        self.visitchildren(node)
        return node

    def visit_CDefExternNode(self, node):
        old_cinclude_flag = self.module_scope.in_cinclude
        self.module_scope.in_cinclude = 1
        self.visitchildren(node)
        self.module_scope.in_cinclude = old_cinclude_flag
        return node

    def visit_CEnumDefNode(self, node):
        node.declare(self.module_scope)
        return node

    def visit_CStructOrUnionDefNode(self, node):
        if node.name not in self.module_scope.entries:
            node.declare(self.module_scope)
        return node

    def visit_CClassDefNode(self, node):
        if node.class_name not in self.module_scope.entries:
            node.declare(self.module_scope)
1533 1534 1535 1536 1537 1538 1539
        # Expand fused methods of .pxd declared types to construct the final vtable order.
        type = self.module_scope.entries[node.class_name].type
        if type is not None and type.is_extension_type and not type.is_builtin_type and type.scope:
            scope = type.scope
            for entry in scope.cfunc_entries:
                if entry.type and entry.type.is_fused:
                    entry.type.get_all_specialized_function_types()
1540 1541
        return node

1542

1543
class AnalyseDeclarationsTransform(EnvTransform):
1544

1545 1546 1547 1548 1549 1550
    basic_property = TreeFragment(u"""
property NAME:
    def __get__(self):
        return ATTR
    def __set__(self, value):
        ATTR = value
1551
    """, level='c_class', pipeline=[NormalizeTree(None)])
1552 1553 1554 1555 1556 1557 1558 1559
    basic_pyobject_property = TreeFragment(u"""
property NAME:
    def __get__(self):
        return ATTR
    def __set__(self, value):
        ATTR = value
    def __del__(self):
        ATTR = None
1560
    """, level='c_class', pipeline=[NormalizeTree(None)])
1561 1562 1563 1564
    basic_property_ro = TreeFragment(u"""
property NAME:
    def __get__(self):
        return ATTR
1565
    """, level='c_class', pipeline=[NormalizeTree(None)])
1566

1567 1568 1569 1570 1571 1572 1573 1574
    struct_or_union_wrapper = TreeFragment(u"""
cdef class NAME:
    cdef TYPE value
    def __init__(self, MEMBER=None):
        cdef int count
        count = 0
        INIT_ASSIGNMENTS
        if IS_UNION and count > 1:
1575
            raise ValueError("At most one union member should be specified.")
1576 1577 1578 1579
    def __str__(self):
        return STR_FORMAT % MEMBER_TUPLE
    def __repr__(self):
        return REPR_FORMAT % MEMBER_TUPLE
1580
    """, pipeline=[NormalizeTree(None)])
1581 1582 1583 1584 1585

    init_assignment = TreeFragment(u"""
if VALUE is not None:
    ATTR = VALUE
    count += 1
1586
    """, pipeline=[NormalizeTree(None)])
1587

1588
    fused_function = None
1589
    in_lambda = 0
1590

1591
    def __call__(self, root):
1592
        # needed to determine if a cdef var is declared after it's used.
1593
        self.seen_vars_stack = []
1594
        self.fused_error_funcs = set()
1595 1596 1597
        super_class = super(AnalyseDeclarationsTransform, self)
        self._super_visit_FuncDefNode = super_class.visit_FuncDefNode
        return super_class.__call__(root)
1598

1599
    def visit_NameNode(self, node):
1600
        self.seen_vars_stack[-1].add(node.name)
1601 1602
        return node

1603
    def visit_ModuleNode(self, node):
Robert Bradshaw's avatar
Robert Bradshaw committed
1604 1605
        # Pickling support requires injecting module-level nodes.
        self.extra_module_declarations = []
Robert Bradshaw's avatar
Robert Bradshaw committed
1606
        self.seen_vars_stack.append(set())
1607
        node.analyse_declarations(self.current_env())
1608
        self.visitchildren(node)
1609
        self.seen_vars_stack.pop()
Robert Bradshaw's avatar
Robert Bradshaw committed
1610
        node.body.stats.extend(self.extra_module_declarations)
1611
        return node
Stefan Behnel's avatar
Stefan Behnel committed
1612 1613

    def visit_LambdaNode(self, node):
1614
        self.in_lambda += 1
1615
        node.analyse_declarations(self.current_env())
Stefan Behnel's avatar
Stefan Behnel committed
1616
        self.visitchildren(node)
1617
        self.in_lambda -= 1
Stefan Behnel's avatar
Stefan Behnel committed
1618 1619
        return node

1620 1621
    def visit_CClassDefNode(self, node):
        node = self.visit_ClassDefNode(node)
1622
        if node.scope and node.scope.implemented and node.body:
1623 1624 1625 1626 1627 1628 1629 1630 1631
            stats = []
            for entry in node.scope.var_entries:
                if entry.needs_property:
                    property = self.create_Property(entry)
                    property.analyse_declarations(node.scope)
                    self.visit(property)
                    stats.append(property)
            if stats:
                node.body.stats += stats
Robert Bradshaw's avatar
Robert Bradshaw committed
1632 1633 1634
            if (node.visibility != 'extern'
                and not node.scope.lookup('__reduce__')
                and not node.scope.lookup('__reduce_ex__')):
Robert Bradshaw's avatar
Robert Bradshaw committed
1635
                self._inject_pickle_methods(node)
1636
        return node
1637

Robert Bradshaw's avatar
Robert Bradshaw committed
1638
    def _inject_pickle_methods(self, node):
1639
        env = self.current_env()
1640 1641 1642
        if node.scope.directives['auto_pickle'] is False:   # None means attempt it.
            # Old behavior of not doing anything.
            return
1643
        auto_pickle_forced = node.scope.directives['auto_pickle'] is True
1644

Robert Bradshaw's avatar
Robert Bradshaw committed
1645 1646
        all_members = []
        cls = node.entry.type
1647
        cinit = None
1648
        inherited_reduce = None
Robert Bradshaw's avatar
Robert Bradshaw committed
1649
        while cls is not None:
1650
            all_members.extend(e for e in cls.scope.var_entries if e.name not in ('__weakref__', '__dict__'))
1651
            cinit = cinit or cls.scope.lookup('__cinit__')
1652
            inherited_reduce = inherited_reduce or cls.scope.lookup('__reduce__') or cls.scope.lookup('__reduce_ex__')
Robert Bradshaw's avatar
Robert Bradshaw committed
1653
            cls = cls.base_type
1654
        all_members.sort(key=lambda e: e.name)
1655

1656 1657
        if inherited_reduce:
            # This is not failsafe, as we may not know whether a cimported class defines a __reduce__.
1658 1659
            # This is why we define __reduce_cython__ and only replace __reduce__
            # (via ExtensionTypes.SetupReduce utility code) at runtime on class creation.
1660 1661
            return

1662 1663
        non_py = [
            e for e in all_members
1664 1665 1666
            if not e.type.is_pyobject and (not e.type.can_coerce_to_pyobject(env)
                                           or not e.type.can_coerce_from_pyobject(env))
        ]
1667

1668 1669 1670
        structs = [e for e in all_members if e.type.is_struct_or_union]

        if cinit or non_py or (structs and not auto_pickle_forced):
1671 1672 1673
            if cinit:
                # TODO(robertwb): We could allow this if __cinit__ has no require arguments.
                msg = 'no default __reduce__ due to non-trivial __cinit__'
1674
            elif non_py:
1675
                msg = "%s cannot be converted to a Python object for pickling" % ','.join("self.%s" % e.name for e in non_py)
1676 1677 1678 1679 1680 1681
            else:
                # Extern structs may be only partially defined.
                # TODO(robertwb): Limit the restriction to extern
                # (and recursively extern-containing) structs.
                msg = ("Pickling of struct members such as %s must be explicitly requested "
                       "with @auto_pickle(True)" % ','.join("self.%s" % e.name for e in structs))
1682

1683
            if auto_pickle_forced:
1684 1685
                error(node.pos, msg)

1686
            pickle_func = TreeFragment(u"""
1687
                def __reduce_cython__(self):
1688 1689 1690 1691
                    raise TypeError("%(msg)s")
                def __setstate_cython__(self, __pyx_state):
                    raise TypeError("%(msg)s")
                """ % {'msg': msg},
1692 1693 1694 1695 1696 1697
                level='c_class', pipeline=[NormalizeTree(None)]).substitute({})
            pickle_func.analyse_declarations(node.scope)
            self.visit(pickle_func)
            node.body.stats.append(pickle_func)

        else:
1698 1699 1700 1701
            for e in all_members:
                if not e.type.is_pyobject:
                    e.type.create_to_py_utility_code(env)
                    e.type.create_from_py_utility_code(env)
1702
            all_members_names = sorted([e.name for e in all_members])
1703
            checksum = '0x%s' % hashlib.sha1(' '.join(all_members_names).encode('utf-8')).hexdigest()[:7]
1704 1705
            unpickle_func_name = '__pyx_unpickle_%s' % node.class_name

1706 1707
            # TODO(robertwb): Move the state into the third argument
            # so it can be pickled *after* self is memoized.
1708
            unpickle_func = TreeFragment(u"""
1709
                def %(unpickle_func_name)s(__pyx_type, long __pyx_checksum, __pyx_state):
1710 1711
                    cdef object __pyx_PickleError
                    cdef object __pyx_result
1712
                    if __pyx_checksum != %(checksum)s:
1713 1714
                        from pickle import PickleError as __pyx_PickleError
                        raise __pyx_PickleError("Incompatible checksums (%%s vs %(checksum)s = (%(members)s))" %% __pyx_checksum)
1715
                    __pyx_result = %(class_name)s.__new__(__pyx_type)
1716
                    if __pyx_state is not None:
1717 1718
                        %(unpickle_func_name)s__set_state(<%(class_name)s> __pyx_result, __pyx_state)
                    return __pyx_result
1719

1720
                cdef %(unpickle_func_name)s__set_state(%(class_name)s __pyx_result, tuple __pyx_state):
1721
                    %(assignments)s
1722 1723
                    if len(__pyx_state) > %(num_members)d and hasattr(__pyx_result, '__dict__'):
                        __pyx_result.__dict__.update(__pyx_state[%(num_members)d])
1724 1725
                """ % {
                    'unpickle_func_name': unpickle_func_name,
1726 1727
                    'checksum': checksum,
                    'members': ', '.join(all_members_names),
1728
                    'class_name': node.class_name,
1729
                    'assignments': '; '.join(
1730
                        '__pyx_result.%s = __pyx_state[%s]' % (v, ix)
1731 1732
                        for ix, v in enumerate(all_members_names)),
                    'num_members': len(all_members_names),
1733 1734 1735 1736 1737 1738
                }, level='module', pipeline=[NormalizeTree(None)]).substitute({})
            unpickle_func.analyse_declarations(node.entry.scope)
            self.visit(unpickle_func)
            self.extra_module_declarations.append(unpickle_func)

            pickle_func = TreeFragment(u"""
1739
                def __reduce_cython__(self):
1740 1741
                    cdef tuple state
                    cdef object _dict
1742 1743 1744 1745
                    cdef bint use_setstate
                    state = (%(members)s)
                    _dict = getattr(self, '__dict__', None)
                    if _dict is not None:
1746
                        state += (_dict,)
1747 1748 1749 1750 1751
                        use_setstate = True
                    else:
                        use_setstate = %(any_notnone_members)s
                    if use_setstate:
                        return %(unpickle_func_name)s, (type(self), %(checksum)s, None), state
1752
                    else:
1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763
                        return %(unpickle_func_name)s, (type(self), %(checksum)s, state)

                def __setstate_cython__(self, __pyx_state):
                    %(unpickle_func_name)s__set_state(self, __pyx_state)
                """ % {
                    'unpickle_func_name': unpickle_func_name,
                    'checksum': checksum,
                    'members': ', '.join('self.%s' % v for v in all_members_names) + (',' if len(all_members_names) == 1 else ''),
                    # Even better, we could check PyType_IS_GC.
                    'any_notnone_members' : ' or '.join(['self.%s is not None' % e.name for e in all_members if e.type.is_pyobject] or ['False']),
                },
1764 1765 1766 1767
                level='c_class', pipeline=[NormalizeTree(None)]).substitute({})
            pickle_func.analyse_declarations(node.scope)
            self.visit(pickle_func)
            node.body.stats.append(pickle_func)
Robert Bradshaw's avatar
Robert Bradshaw committed
1768

1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787
    def _handle_fused_def_decorators(self, old_decorators, env, node):
        """
        Create function calls to the decorators and reassignments to
        the function.
        """
        # Delete staticmethod and classmethod decorators, this is
        # handled directly by the fused function object.
        decorators = []
        for decorator in old_decorators:
            func = decorator.decorator
            if (not func.is_name or
                func.name not in ('staticmethod', 'classmethod') or
                env.lookup_here(func.name)):
                # not a static or classmethod
                decorators.append(decorator)

        if decorators:
            transform = DecoratorTransform(self.context)
            def_node = node.node
Stefan Behnel's avatar
Stefan Behnel committed
1788
            _, reassignments = transform.chain_decorators(
1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800
                def_node, decorators, def_node.name)
            reassignments.analyse_declarations(env)
            node = [node, reassignments]

        return node

    def _handle_def(self, decorators, env, node):
        "Handle def or cpdef fused functions"
        # Create PyCFunction nodes for each specialization
        node.stats.insert(0, node.py_func)
        node.py_func = self.visit(node.py_func)
        node.update_fused_defnode_entry(env)
Stefan Behnel's avatar
Stefan Behnel committed
1801
        pycfunc = ExprNodes.PyCFunctionNode.from_defnode(node.py_func, binding=True)
1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814
        pycfunc = ExprNodes.ProxyNode(pycfunc.coerce_to_temp(env))
        node.resulting_fused_function = pycfunc
        # Create assignment node for our def function
        node.fused_func_assignment = self._create_assignment(
            node.py_func, ExprNodes.CloneNode(pycfunc), env)

        if decorators:
            node = self._handle_fused_def_decorators(decorators, env, node)

        return node

    def _create_fused_function(self, env, node):
        "Create a fused function for a DefNode with fused arguments"
1815
        from . import FusedNode
1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852

        if self.fused_function or self.in_lambda:
            if self.fused_function not in self.fused_error_funcs:
                if self.in_lambda:
                    error(node.pos, "Fused lambdas not allowed")
                else:
                    error(node.pos, "Cannot nest fused functions")

            self.fused_error_funcs.add(self.fused_function)

            node.body = Nodes.PassStatNode(node.pos)
            for arg in node.args:
                if arg.type.is_fused:
                    arg.type = arg.type.get_fused_types()[0]

            return node

        decorators = getattr(node, 'decorators', None)
        node = FusedNode.FusedCFuncDefNode(node, env)
        self.fused_function = node
        self.visitchildren(node)
        self.fused_function = None
        if node.py_func:
            node = self._handle_def(decorators, env, node)

        return node

    def _handle_nogil_cleanup(self, lenv, node):
        "Handle cleanup for 'with gil' blocks in nogil functions."
        if lenv.nogil and lenv.has_with_gil_block:
            # Acquire the GIL for cleanup in 'nogil' functions, by wrapping
            # the entire function body in try/finally.
            # The corresponding release will be taken care of by
            # Nodes.FuncDefNode.generate_function_definitions()
            node.body = Nodes.NogilTryFinallyStatNode(
                node.body.pos,
                body=node.body,
1853 1854
                finally_clause=Nodes.EnsureGILNode(node.body.pos),
                finally_except_clause=Nodes.EnsureGILNode(node.body.pos))
1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865

    def _handle_fused(self, node):
        if node.is_generator and node.has_fused_arguments:
            node.has_fused_arguments = False
            error(node.pos, "Fused generators not supported")
            node.gbody = Nodes.StatListNode(node.pos,
                                            stats=[],
                                            body=Nodes.PassStatNode(node.pos))

        return node.has_fused_arguments

1866
    def visit_FuncDefNode(self, node):
1867
        """
Unknown's avatar
Unknown committed
1868
        Analyse a function and its body, as that hasn't happened yet.  Also
Stefan Behnel's avatar
Stefan Behnel committed
1869 1870 1871 1872 1873 1874 1875
        analyse the directive_locals set by @cython.locals().

        Then, if we are a function with fused arguments, replace the function
        (after it has declared itself in the symbol table!) with a
        FusedCFuncDefNode, and analyse its children (which are in turn normal
        functions). If we're a normal function, just analyse the body of the
        function.
1876
        """
1877
        env = self.current_env()
1878

Robert Bradshaw's avatar
Robert Bradshaw committed
1879
        self.seen_vars_stack.append(set())
1880
        lenv = node.local_scope
1881
        node.declare_arguments(lenv)
1882

Stefan Behnel's avatar
Stefan Behnel committed
1883
        # @cython.locals(...)
1884 1885 1886 1887 1888 1889 1890
        for var, type_node in node.directive_locals.items():
            if not lenv.lookup_here(var):   # don't redeclare args
                type = type_node.analyse_as_type(lenv)
                if type:
                    lenv.declare_var(var, type, type_node.pos)
                else:
                    error(type_node.pos, "Not a type")
1891

1892 1893
        if self._handle_fused(node):
            node = self._create_fused_function(env, node)
1894 1895
        else:
            node.body.analyse_declarations(lenv)
1896
            self._handle_nogil_cleanup(lenv, node)
1897
            self._super_visit_FuncDefNode(node)
1898

1899
        self.seen_vars_stack.pop()
1900
        return node
1901

1902 1903
    def visit_DefNode(self, node):
        node = self.visit_FuncDefNode(node)
1904
        env = self.current_env()
1905
        if isinstance(node, Nodes.DefNode) and node.is_wrapper:
1906
            env = env.parent_scope
1907
        if (not isinstance(node, Nodes.DefNode) or
Stefan Behnel's avatar
Stefan Behnel committed
1908 1909
                node.fused_py_func or node.is_generator_body or
                not node.needs_assignment_synthesis(env)):
1910 1911 1912
            return node
        return [node, self._synthesize_assignment(node, env)]

1913 1914 1915
    def visit_GeneratorBodyDefNode(self, node):
        return self.visit_FuncDefNode(node)

1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927
    def _synthesize_assignment(self, node, env):
        # Synthesize assignment node and put it right after defnode
        genv = env
        while genv.is_py_class_scope or genv.is_c_class_scope:
            genv = genv.outer_scope

        if genv.is_closure_scope:
            rhs = node.py_cfunc_node = ExprNodes.InnerFunctionNode(
                node.pos, def_node=node,
                pymethdef_cname=node.entry.pymethdef_cname,
                code_object=ExprNodes.CodeObjectNode(node))
        else:
1928 1929
            binding = self.current_directives.get('binding')
            rhs = ExprNodes.PyCFunctionNode.from_defnode(node, binding)
1930
            node.code_object = rhs.code_object
1931 1932
            if node.is_generator:
                node.gbody.code_object = node.code_object
1933 1934 1935 1936 1937

        if env.is_py_class_scope:
            rhs.binding = True

        node.is_cyfunction = rhs.binding
1938
        return self._create_assignment(node, rhs, env)
1939

1940 1941 1942
    def _create_assignment(self, def_node, rhs, env):
        if def_node.decorators:
            for decorator in def_node.decorators[::-1]:
1943 1944 1945 1946
                rhs = ExprNodes.SimpleCallNode(
                    decorator.pos,
                    function = decorator.decorator,
                    args = [rhs])
1947
            def_node.decorators = None
1948 1949

        assmt = Nodes.SingleAssignmentNode(
1950 1951
            def_node.pos,
            lhs=ExprNodes.NameNode(def_node.pos, name=def_node.name),
1952 1953 1954 1955
            rhs=rhs)
        assmt.analyse_declarations(env)
        return assmt

1956
    def visit_ScopedExprNode(self, node):
1957
        env = self.current_env()
1958
        node.analyse_declarations(env)
Stefan Behnel's avatar
Stefan Behnel committed
1959
        # the node may or may not have a local scope
1960
        if node.has_local_scope:
Robert Bradshaw's avatar
Robert Bradshaw committed
1961
            self.seen_vars_stack.append(set(self.seen_vars_stack[-1]))
1962
            self.enter_scope(node, node.expr_scope)
1963
            node.analyse_scoped_declarations(node.expr_scope)
Stefan Behnel's avatar
Stefan Behnel committed
1964
            self.visitchildren(node)
1965
            self.exit_scope()
Stefan Behnel's avatar
Stefan Behnel committed
1966
            self.seen_vars_stack.pop()
1967
        else:
1968
            node.analyse_scoped_declarations(env)
Stefan Behnel's avatar
Stefan Behnel committed
1969
            self.visitchildren(node)
1970 1971
        return node

1972 1973
    def visit_TempResultFromStatNode(self, node):
        self.visitchildren(node)
1974
        node.analyse_declarations(self.current_env())
1975 1976
        return node

1977 1978 1979 1980 1981
    def visit_CppClassNode(self, node):
        if node.visibility == 'extern':
            return None
        else:
            return self.visit_ClassDefNode(node)
1982

1983
    def visit_CStructOrUnionDefNode(self, node):
1984
        # Create a wrapper node if needed.
1985 1986 1987
        # We want to use the struct type information (so it can't happen
        # before this phase) but also create new objects to be declared
        # (so it can't happen later).
1988
        # Note that we don't return the original node, as it is
1989 1990 1991
        # never used after this phase.
        if True: # private (default)
            return None
1992

1993 1994 1995 1996 1997 1998 1999 2000 2001 2002
        self_value = ExprNodes.AttributeNode(
            pos = node.pos,
            obj = ExprNodes.NameNode(pos=node.pos, name=u"self"),
            attribute = EncodedString(u"value"))
        var_entries = node.entry.type.scope.var_entries
        attributes = []
        for entry in var_entries:
            attributes.append(ExprNodes.AttributeNode(pos = entry.pos,
                                                      obj = self_value,
                                                      attribute = entry.name))
2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039
        # __init__ assignments
        init_assignments = []
        for entry, attr in zip(var_entries, attributes):
            # TODO: branch on visibility
            init_assignments.append(self.init_assignment.substitute({
                    u"VALUE": ExprNodes.NameNode(entry.pos, name = entry.name),
                    u"ATTR": attr,
                }, pos = entry.pos))

        # create the class
        str_format = u"%s(%s)" % (node.entry.type.name, ("%s, " * len(attributes))[:-2])
        wrapper_class = self.struct_or_union_wrapper.substitute({
            u"INIT_ASSIGNMENTS": Nodes.StatListNode(node.pos, stats = init_assignments),
            u"IS_UNION": ExprNodes.BoolNode(node.pos, value = not node.entry.type.is_struct),
            u"MEMBER_TUPLE": ExprNodes.TupleNode(node.pos, args=attributes),
            u"STR_FORMAT": ExprNodes.StringNode(node.pos, value = EncodedString(str_format)),
            u"REPR_FORMAT": ExprNodes.StringNode(node.pos, value = EncodedString(str_format.replace("%s", "%r"))),
        }, pos = node.pos).stats[0]
        wrapper_class.class_name = node.name
        wrapper_class.shadow = True
        class_body = wrapper_class.body.stats

        # fix value type
        assert isinstance(class_body[0].base_type, Nodes.CSimpleBaseTypeNode)
        class_body[0].base_type.name = node.name

        # fix __init__ arguments
        init_method = class_body[1]
        assert isinstance(init_method, Nodes.DefNode) and init_method.name == '__init__'
        arg_template = init_method.args[1]
        if not node.entry.type.is_struct:
            arg_template.kw_only = True
        del init_method.args[1]
        for entry, attr in zip(var_entries, attributes):
            arg = copy.deepcopy(arg_template)
            arg.declarator.name = entry.name
            init_method.args.append(arg)
Robert Bradshaw's avatar
Robert Bradshaw committed
2040

2041
        # setters/getters
2042 2043 2044 2045 2046 2047 2048 2049 2050 2051
        for entry, attr in zip(var_entries, attributes):
            # TODO: branch on visibility
            if entry.type.is_pyobject:
                template = self.basic_pyobject_property
            else:
                template = self.basic_property
            property = template.substitute({
                    u"ATTR": attr,
                }, pos = entry.pos).stats[0]
            property.name = entry.name
2052
            wrapper_class.body.stats.append(property)
Robert Bradshaw's avatar
Robert Bradshaw committed
2053

2054
        wrapper_class.analyse_declarations(self.current_env())
2055
        return self.visit_CClassDefNode(wrapper_class)
2056

2057 2058
    # Some nodes are no longer needed after declaration
    # analysis and can be dropped. The analysis was performed
Unknown's avatar
Unknown committed
2059
    # on these nodes in a separate recursive process from the
2060
    # enclosing function or module, so we can simply drop them.
2061
    def visit_CDeclaratorNode(self, node):
2062 2063
        # necessary to ensure that all CNameDeclaratorNodes are visited.
        self.visitchildren(node)
2064
        return node
2065

2066 2067 2068 2069 2070
    def visit_CTypeDefNode(self, node):
        return node

    def visit_CBaseTypeNode(self, node):
        return None
2071

2072
    def visit_CEnumDefNode(self, node):
2073 2074 2075 2076
        if node.visibility == 'public':
            return node
        else:
            return None
2077

2078
    def visit_CNameDeclaratorNode(self, node):
2079
        if node.name in self.seen_vars_stack[-1]:
2080
            entry = self.current_env().lookup(node.name)
Dag Sverre Seljebotn's avatar
Dag Sverre Seljebotn committed
2081 2082
            if (entry is None or entry.visibility != 'extern'
                and not entry.scope.is_c_class_scope):
2083
                warning(node.pos, "cdef variable '%s' declared after it is used" % node.name, 2)
2084 2085 2086
        self.visitchildren(node)
        return node

2087
    def visit_CVarDefNode(self, node):
2088 2089
        # to ensure all CNameDeclaratorNodes are visited.
        self.visitchildren(node)
2090
        return None
2091

2092
    def visit_CnameDecoratorNode(self, node):
2093 2094
        child_node = self.visit(node.node)
        if not child_node:
2095
            return None
2096 2097 2098 2099
        if type(child_node) is list: # Assignment synthesized
            node.child_node = child_node[0]
            return [node] + child_node[1:]
        node.node = child_node
2100 2101
        return node

2102
    def create_Property(self, entry):
2103
        if entry.visibility == 'public':
2104 2105 2106 2107
            if entry.type.is_pyobject:
                template = self.basic_pyobject_property
            else:
                template = self.basic_property
2108 2109
        elif entry.visibility == 'readonly':
            template = self.basic_property_ro
2110
        property = template.substitute({
2111
                u"ATTR": ExprNodes.AttributeNode(pos=entry.pos,
2112
                                                 obj=ExprNodes.NameNode(pos=entry.pos, name="self"),
2113
                                                 attribute=entry.name),
2114 2115
            }, pos=entry.pos).stats[0]
        property.name = entry.name
2116
        property.doc = entry.doc
2117
        return property
2118

2119

2120 2121 2122 2123 2124 2125 2126 2127
class CalculateQualifiedNamesTransform(EnvTransform):
    """
    Calculate and store the '__qualname__' and the global
    module name on some nodes.
    """
    def visit_ModuleNode(self, node):
        self.module_name = self.global_scope().qualified_name
        self.qualified_name = []
2128 2129 2130
        _super = super(CalculateQualifiedNamesTransform, self)
        self._super_visit_FuncDefNode = _super.visit_FuncDefNode
        self._super_visit_ClassDefNode = _super.visit_ClassDefNode
2131 2132 2133 2134 2135 2136 2137 2138 2139 2140 2141 2142
        self.visitchildren(node)
        return node

    def _set_qualname(self, node, name=None):
        if name:
            qualname = self.qualified_name[:]
            qualname.append(name)
        else:
            qualname = self.qualified_name
        node.qualname = EncodedString('.'.join(qualname))
        node.module_name = self.module_name

2143 2144 2145
    def _append_entry(self, entry):
        if entry.is_pyglobal and not entry.is_pyclass_attr:
            self.qualified_name = [entry.name]
2146
        else:
2147
            self.qualified_name.append(entry.name)
2148 2149

    def visit_ClassNode(self, node):
2150 2151 2152
        self._set_qualname(node, node.name)
        self.visitchildren(node)
        return node
2153 2154

    def visit_PyClassNamespaceNode(self, node):
2155
        # class name was already added by parent node
2156 2157 2158
        self._set_qualname(node)
        self.visitchildren(node)
        return node
2159 2160

    def visit_PyCFunctionNode(self, node):
2161
        orig_qualified_name = self.qualified_name[:]
2162
        if node.def_node.is_wrapper and self.qualified_name and self.qualified_name[-1] == '<locals>':
2163 2164 2165 2166
            self.qualified_name.pop()
            self._set_qualname(node)
        else:
            self._set_qualname(node, node.def_node.name)
2167
        self.visitchildren(node)
2168
        self.qualified_name = orig_qualified_name
2169
        return node
2170

2171
    def visit_DefNode(self, node):
2172 2173 2174 2175 2176 2177 2178 2179 2180 2181 2182
        if node.is_wrapper and self.qualified_name:
            assert self.qualified_name[-1] == '<locals>', self.qualified_name
            orig_qualified_name = self.qualified_name[:]
            self.qualified_name.pop()
            self._set_qualname(node)
            self._super_visit_FuncDefNode(node)
            self.qualified_name = orig_qualified_name
        else:
            self._set_qualname(node, node.name)
            self.visit_FuncDefNode(node)
        return node
2183 2184

    def visit_FuncDefNode(self, node):
2185
        orig_qualified_name = self.qualified_name[:]
2186 2187 2188 2189
        if getattr(node, 'name', None) == '<lambda>':
            self.qualified_name.append('<lambda>')
        else:
            self._append_entry(node.entry)
2190
        self.qualified_name.append('<locals>')
2191
        self._super_visit_FuncDefNode(node)
2192 2193 2194 2195 2196
        self.qualified_name = orig_qualified_name
        return node

    def visit_ClassDefNode(self, node):
        orig_qualified_name = self.qualified_name[:]
2197 2198 2199 2200
        entry = (getattr(node, 'entry', None) or             # PyClass
                 self.current_env().lookup_here(node.name))  # CClass
        self._append_entry(entry)
        self._super_visit_ClassDefNode(node)
2201 2202 2203 2204
        self.qualified_name = orig_qualified_name
        return node


2205
class AnalyseExpressionsTransform(CythonTransform):
2206

2207
    def visit_ModuleNode(self, node):
2208
        node.scope.infer_types()
2209
        node.body = node.body.analyse_expressions(node.scope)
2210 2211
        self.visitchildren(node)
        return node
2212

2213
    def visit_FuncDefNode(self, node):
2214
        node.local_scope.infer_types()
2215
        node.body = node.body.analyse_expressions(node.local_scope)
2216 2217
        self.visitchildren(node)
        return node
2218 2219

    def visit_ScopedExprNode(self, node):
2220
        if node.has_local_scope:
2221
            node.expr_scope.infer_types()
2222
            node = node.analyse_scoped_expressions(node.expr_scope)
2223 2224
        self.visitchildren(node)
        return node
2225

2226 2227 2228
    def visit_IndexNode(self, node):
        """
        Replace index nodes used to specialize cdef functions with fused
Mark Florisson's avatar
Mark Florisson committed
2229 2230 2231
        argument types with the Attribute- or NameNode referring to the
        function. We then need to copy over the specialization properties to
        the attribute or name node.
2232 2233 2234 2235

        Because the indexing might be a Python indexing operation on a fused
        function, or (usually) a Cython indexing operation, we need to
        re-analyse the types.
2236 2237
        """
        self.visit_Node(node)
2238
        if node.is_fused_index and not node.type.is_error:
2239
            node = node.base
2240
        return node
2241

2242 2243 2244 2245
class ReplacePropertyNode(CythonTransform):
    def visit_CFuncDefNode(self, node):
        if not node.decorators:
            return node
2246 2247 2248 2249 2250 2251 2252 2253 2254 2255 2256 2257 2258
        decorator = self.find_first_decorator(node, 'property')
        if decorator:
            # transform class functions into c-getters
            if len(node.decorators) > 1:
                # raises
                self._reject_decorated_property(node, decorator_node)
            node.entry.is_cgetter = True
            # Add a func_cname to be output instead of the attribute
            node.entry.func_cname = node.body.stats[0].value.function.name
            node.decorators.remove(decorator)
        return node

    def find_first_decorator(self, node, name):
2259
        for decorator_node in node.decorators[::-1]:
2260 2261 2262 2263
            decorator = decorator_node.decorator
            if decorator.is_name and decorator.name == name:
                return decorator_node
        return None
2264

2265

2266 2267 2268 2269 2270 2271 2272 2273 2274 2275 2276 2277 2278 2279 2280 2281 2282 2283 2284 2285 2286 2287
class FindInvalidUseOfFusedTypes(CythonTransform):

    def visit_FuncDefNode(self, node):
        # Errors related to use in functions with fused args will already
        # have been detected
        if not node.has_fused_arguments:
            if not node.is_generator_body and node.return_type.is_fused:
                error(node.pos, "Return type is not specified as argument type")
            else:
                self.visitchildren(node)

        return node

    def visit_ExprNode(self, node):
        if node.type and node.type.is_fused:
            error(node.pos, "Invalid use of fused types, type cannot be specialized")
        else:
            self.visitchildren(node)

        return node


2288
class ExpandInplaceOperators(EnvTransform):
2289

2290 2291 2292 2293 2294 2295
    def visit_InPlaceAssignmentNode(self, node):
        lhs = node.lhs
        rhs = node.rhs
        if lhs.type.is_cpp_class:
            # No getting around this exact operator here.
            return node
2296 2297
        if isinstance(lhs, ExprNodes.BufferIndexNode):
            # There is code to handle this case in InPlaceAssignmentNode
2298 2299
            return node

Robert Bradshaw's avatar
Robert Bradshaw committed
2300
        env = self.current_env()
2301
        def side_effect_free_reference(node, setting=False):
2302
            if node.is_name:
Robert Bradshaw's avatar
Robert Bradshaw committed
2303 2304
                return node, []
            elif node.type.is_pyobject and not setting:
2305 2306
                node = LetRefNode(node)
                return node, [node]
2307
            elif node.is_subscript:
2308 2309
                base, temps = side_effect_free_reference(node.base)
                index = LetRefNode(node.index)
2310
                return ExprNodes.IndexNode(node.pos, base=base, index=index), temps + [index]
2311
            elif node.is_attribute:
2312
                obj, temps = side_effect_free_reference(node.obj)
2313
                return ExprNodes.AttributeNode(node.pos, obj=obj, attribute=node.attribute), temps
2314 2315
            elif isinstance(node, ExprNodes.BufferIndexNode):
                raise ValueError("Don't allow things like attributes of buffer indexing operations")
2316 2317 2318 2319 2320 2321 2322 2323
            else:
                node = LetRefNode(node)
                return node, [node]
        try:
            lhs, let_ref_nodes = side_effect_free_reference(lhs, setting=True)
        except ValueError:
            return node
        dup = lhs.__class__(**lhs.__dict__)
2324
        binop = ExprNodes.binop_node(node.pos,
2325 2326 2327 2328
                                     operator = node.operator,
                                     operand1 = dup,
                                     operand2 = rhs,
                                     inplace=True)
Robert Bradshaw's avatar
Robert Bradshaw committed
2329 2330 2331
        # Manually analyse types for new node.
        lhs.analyse_target_types(env)
        dup.analyse_types(env)
Robert Bradshaw's avatar
Robert Bradshaw committed
2332
        binop.analyse_operation(env)
2333
        node = Nodes.SingleAssignmentNode(
2334
            node.pos,
2335 2336
            lhs = lhs,
            rhs=binop.coerce_to(lhs.type, env))
2337 2338 2339 2340 2341 2342 2343 2344 2345 2346
        # Use LetRefNode to avoid side effects.
        let_ref_nodes.reverse()
        for t in let_ref_nodes:
            node = LetNode(t, node)
        return node

    def visit_ExprNode(self, node):
        # In-place assignments can't happen within an expression.
        return node

2347

Haoyu Bai's avatar
Haoyu Bai committed
2348 2349 2350 2351 2352 2353 2354
class AdjustDefByDirectives(CythonTransform, SkipDeclarations):
    """
    Adjust function and class definitions by the decorator directives:

    @cython.cfunc
    @cython.cclass
    @cython.ccall
2355
    @cython.inline
2356
    @cython.nogil
Haoyu Bai's avatar
Haoyu Bai committed
2357 2358 2359 2360 2361 2362 2363 2364 2365 2366 2367 2368 2369 2370 2371 2372
    """

    def visit_ModuleNode(self, node):
        self.directives = node.directives
        self.in_py_class = False
        self.visitchildren(node)
        return node

    def visit_CompilerDirectivesNode(self, node):
        old_directives = self.directives
        self.directives = node.directives
        self.visitchildren(node)
        self.directives = old_directives
        return node

    def visit_DefNode(self, node):
2373 2374 2375
        modifiers = []
        if 'inline' in self.directives:
            modifiers.append('inline')
2376
        nogil = self.directives.get('nogil')
2377
        except_val = self.directives.get('exceptval')
2378 2379 2380
        return_type_node = self.directives.get('returns')
        if return_type_node is None and self.directives['annotation_typing']:
            return_type_node = node.return_type_annotation
2381 2382 2383 2384 2385 2386
            # for Python anntations, prefer safe exception handling by default
            if return_type_node is not None and except_val is None:
                except_val = (None, True)  # except *
        elif except_val is None:
            # backward compatible default: no exception check
            except_val = (None, False)
Haoyu Bai's avatar
Haoyu Bai committed
2387
        if 'ccall' in self.directives:
2388
            node = node.as_cfunction(
2389
                overridable=True, modifiers=modifiers, nogil=nogil,
2390
                returns=return_type_node, except_val=except_val)
Haoyu Bai's avatar
Haoyu Bai committed
2391
            return self.visit(node)
Haoyu Bai's avatar
Haoyu Bai committed
2392 2393 2394 2395
        if 'cfunc' in self.directives:
            if self.in_py_class:
                error(node.pos, "cfunc directive is not allowed here")
            else:
2396
                node = node.as_cfunction(
2397
                    overridable=False, modifiers=modifiers, nogil=nogil,
2398
                    returns=return_type_node, except_val=except_val)
Haoyu Bai's avatar
Haoyu Bai committed
2399
                return self.visit(node)
2400 2401
        if 'inline' in modifiers:
            error(node.pos, "Python functions cannot be declared 'inline'")
2402 2403 2404
        if nogil:
            # TODO: turn this into a "with gil" declaration.
            error(node.pos, "Python functions cannot be declared 'nogil'")
Haoyu Bai's avatar
Haoyu Bai committed
2405 2406 2407
        self.visitchildren(node)
        return node

2408 2409 2410 2411
    def visit_LambdaNode(self, node):
        # No directives should modify lambdas or generator expressions (and also nothing in them).
        return node

Haoyu Bai's avatar
Haoyu Bai committed
2412
    def visit_PyClassDefNode(self, node):
2413 2414 2415 2416 2417 2418 2419 2420 2421
        if 'cclass' in self.directives:
            node = node.as_cclass()
            return self.visit(node)
        else:
            old_in_pyclass = self.in_py_class
            self.in_py_class = True
            self.visitchildren(node)
            self.in_py_class = old_in_pyclass
            return node
Haoyu Bai's avatar
Haoyu Bai committed
2422 2423 2424 2425 2426 2427 2428

    def visit_CClassDefNode(self, node):
        old_in_pyclass = self.in_py_class
        self.in_py_class = False
        self.visitchildren(node)
        self.in_py_class = old_in_pyclass
        return node
2429

2430

2431 2432
class AlignFunctionDefinitions(CythonTransform):
    """
2433 2434
    This class takes the signatures from a .pxd file and applies them to
    the def methods in a .py file.
2435
    """
2436

2437 2438
    def visit_ModuleNode(self, node):
        self.scope = node.scope
2439
        self.directives = node.directives
2440
        self.imported_names = set()  # hack, see visit_FromImportStatNode()
2441 2442
        self.visitchildren(node)
        return node
2443

2444 2445 2446 2447 2448
    def visit_PyClassDefNode(self, node):
        pxd_def = self.scope.lookup(node.name)
        if pxd_def:
            if pxd_def.is_cclass:
                return self.visit_CClassDefNode(node.as_cclass(), pxd_def)
2449
            elif not pxd_def.scope or not pxd_def.scope.is_builtin_scope:
2450
                error(node.pos, "'%s' redeclared" % node.name)
2451 2452
                if pxd_def.pos:
                    error(pxd_def.pos, "previous declaration here")
2453
                return None
2454
        return node
2455

2456 2457 2458 2459
    def visit_CClassDefNode(self, node, pxd_def=None):
        if pxd_def is None:
            pxd_def = self.scope.lookup(node.class_name)
        if pxd_def:
2460 2461
            if not pxd_def.defined_in_pxd:
                return node
2462 2463 2464 2465 2466 2467
            outer_scope = self.scope
            self.scope = pxd_def.type.scope
        self.visitchildren(node)
        if pxd_def:
            self.scope = outer_scope
        return node
2468

2469 2470
    def visit_DefNode(self, node):
        pxd_def = self.scope.lookup(node.name)
2471
        if pxd_def and (not pxd_def.scope or not pxd_def.scope.is_builtin_scope):
2472
            if not pxd_def.is_cfunction:
2473
                error(node.pos, "'%s' redeclared" % node.name)
2474 2475
                if pxd_def.pos:
                    error(pxd_def.pos, "previous declaration here")
2476
                return None
2477
            node = node.as_cfunction(pxd_def)
2478
        elif (self.scope.is_module_scope and self.directives['auto_cpdef']
2479
              and not node.name in self.imported_names
2480
              and node.is_cdef_func_compatible()):
2481
            # FIXME: cpdef-ing should be done in analyse_declarations()
2482
            node = node.as_cfunction(scope=self.scope)
2483
        # Enable this when nested cdef functions are allowed.
2484 2485
        # self.visitchildren(node)
        return node
2486

2487 2488 2489 2490 2491 2492 2493 2494 2495
    def visit_FromImportStatNode(self, node):
        # hack to prevent conditional import fallback functions from
        # being cdpef-ed (global Python variables currently conflict
        # with imports)
        if self.scope.is_module_scope:
            for name, _ in node.items:
                self.imported_names.add(name)
        return node

2496 2497 2498 2499
    def visit_ExprNode(self, node):
        # ignore lambdas and everything else that appears in expressions
        return node

2500

2501 2502 2503 2504 2505 2506 2507 2508 2509 2510 2511 2512 2513 2514 2515 2516 2517 2518 2519 2520 2521 2522 2523 2524 2525 2526 2527 2528 2529 2530 2531 2532
class RemoveUnreachableCode(CythonTransform):
    def visit_StatListNode(self, node):
        if not self.current_directives['remove_unreachable']:
            return node
        self.visitchildren(node)
        for idx, stat in enumerate(node.stats):
            idx += 1
            if stat.is_terminator:
                if idx < len(node.stats):
                    if self.current_directives['warn.unreachable']:
                        warning(node.stats[idx].pos, "Unreachable code", 2)
                    node.stats = node.stats[:idx]
                node.is_terminator = True
                break
        return node

    def visit_IfClauseNode(self, node):
        self.visitchildren(node)
        if node.body.is_terminator:
            node.is_terminator = True
        return node

    def visit_IfStatNode(self, node):
        self.visitchildren(node)
        if node.else_clause and node.else_clause.is_terminator:
            for clause in node.if_clauses:
                if not clause.is_terminator:
                    break
            else:
                node.is_terminator = True
        return node

2533 2534
    def visit_TryExceptStatNode(self, node):
        self.visitchildren(node)
2535 2536 2537
        if node.body.is_terminator and node.else_clause:
            if self.current_directives['warn.unreachable']:
                warning(node.else_clause.pos, "Unreachable code", 2)
2538 2539 2540
            node.else_clause = None
        return node

2541 2542 2543 2544 2545 2546
    def visit_TryFinallyStatNode(self, node):
        self.visitchildren(node)
        if node.finally_clause.is_terminator:
            node.is_terminator = True
        return node

2547

2548 2549 2550 2551
class YieldNodeCollector(TreeVisitor):

    def __init__(self):
        super(YieldNodeCollector, self).__init__()
2552
        self.yields = []
2553
        self.returns = []
2554
        self.finallys = []
2555
        self.excepts = []
2556
        self.has_return_value = False
2557 2558
        self.has_yield = False
        self.has_await = False
2559

2560
    def visit_Node(self, node):
2561
        self.visitchildren(node)
2562 2563

    def visit_YieldExprNode(self, node):
2564
        self.yields.append(node)
2565
        self.has_yield = True
Vitja Makarov's avatar
Vitja Makarov committed
2566
        self.visitchildren(node)
2567

2568
    def visit_AwaitExprNode(self, node):
2569 2570
        self.yields.append(node)
        self.has_await = True
2571 2572
        self.visitchildren(node)

2573
    def visit_ReturnStatNode(self, node):
2574
        self.visitchildren(node)
2575 2576 2577
        if node.value:
            self.has_return_value = True
        self.returns.append(node)
2578

2579 2580 2581 2582
    def visit_TryFinallyStatNode(self, node):
        self.visitchildren(node)
        self.finallys.append(node)

2583 2584 2585 2586
    def visit_TryExceptStatNode(self, node):
        self.visitchildren(node)
        self.excepts.append(node)

2587 2588 2589
    def visit_ClassDefNode(self, node):
        pass

2590
    def visit_FuncDefNode(self, node):
2591
        pass
2592

Vitja Makarov's avatar
Vitja Makarov committed
2593 2594 2595
    def visit_LambdaNode(self, node):
        pass

Vitja Makarov's avatar
Vitja Makarov committed
2596 2597 2598
    def visit_GeneratorExpressionNode(self, node):
        pass

2599 2600 2601 2602 2603
    def visit_CArgDeclNode(self, node):
        # do not look into annotations
        # FIXME: support (yield) in default arguments (currently crashes)
        pass

2604

2605
class MarkClosureVisitor(CythonTransform):
2606 2607 2608 2609 2610 2611

    def visit_ModuleNode(self, node):
        self.needs_closure = False
        self.visitchildren(node)
        return node

Robert Bradshaw's avatar
Robert Bradshaw committed
2612 2613 2614 2615 2616
    def visit_FuncDefNode(self, node):
        self.needs_closure = False
        self.visitchildren(node)
        node.needs_closure = self.needs_closure
        self.needs_closure = True
2617

2618 2619 2620
        collector = YieldNodeCollector()
        collector.visitchildren(node)

2621
        if node.is_async_def:
2622
            coroutine_type = Nodes.AsyncDefNode
2623
            if collector.has_yield:
2624
                coroutine_type = Nodes.AsyncGenNode
2625
                for yield_expr in collector.yields + collector.returns:
2626
                    yield_expr.in_async_gen = True
2627 2628
            elif self.current_directives['iterable_coroutine']:
                coroutine_type = Nodes.IterableAsyncDefNode
2629 2630 2631 2632 2633 2634
        elif collector.has_await:
            found = next(y for y in collector.yields if y.is_await)
            error(found.pos, "'await' not allowed in generators (use 'yield')")
            return node
        elif collector.has_yield:
            coroutine_type = Nodes.GeneratorDefNode
2635 2636
        else:
            return node
2637

2638
        for i, yield_expr in enumerate(collector.yields, 1):
2639
            yield_expr.label_num = i
2640
        for retnode in collector.returns + collector.finallys + collector.excepts:
2641 2642 2643
            retnode.in_generator = True

        gbody = Nodes.GeneratorBodyDefNode(
2644
            pos=node.pos, name=node.name, body=node.body,
2645
            is_async_gen_body=node.is_async_def and collector.has_yield)
2646
        coroutine = coroutine_type(
2647 2648 2649 2650 2651
            pos=node.pos, name=node.name, args=node.args,
            star_arg=node.star_arg, starstar_arg=node.starstar_arg,
            doc=node.doc, decorators=node.decorators,
            gbody=gbody, lambda_name=node.lambda_name)
        return coroutine
2652

2653
    def visit_CFuncDefNode(self, node):
2654 2655 2656 2657
        self.needs_closure = False
        self.visitchildren(node)
        node.needs_closure = self.needs_closure
        self.needs_closure = True
2658 2659
        if node.needs_closure and node.overridable:
            error(node.pos, "closures inside cpdef functions not yet supported")
2660
        return node
Stefan Behnel's avatar
Stefan Behnel committed
2661 2662 2663 2664 2665 2666 2667 2668

    def visit_LambdaNode(self, node):
        self.needs_closure = False
        self.visitchildren(node)
        node.needs_closure = self.needs_closure
        self.needs_closure = True
        return node

Robert Bradshaw's avatar
Robert Bradshaw committed
2669 2670 2671 2672
    def visit_ClassDefNode(self, node):
        self.visitchildren(node)
        self.needs_closure = True
        return node
Stefan Behnel's avatar
Stefan Behnel committed
2673

2674

2675
class CreateClosureClasses(CythonTransform):
2676
    # Output closure classes in module scope for all functions
Vitja Makarov's avatar
Vitja Makarov committed
2677 2678 2679 2680 2681 2682 2683
    # that really need it.

    def __init__(self, context):
        super(CreateClosureClasses, self).__init__(context)
        self.path = []
        self.in_lambda = False

2684 2685 2686 2687 2688
    def visit_ModuleNode(self, node):
        self.module_scope = node.scope
        self.visitchildren(node)
        return node

Stefan Behnel's avatar
Stefan Behnel committed
2689
    def find_entries_used_in_closures(self, node):
Vitja Makarov's avatar
Vitja Makarov committed
2690 2691
        from_closure = []
        in_closure = []
2692 2693 2694 2695 2696 2697 2698 2699
        for scope in node.local_scope.iter_local_scopes():
            for name, entry in scope.entries.items():
                if not name:
                    continue
                if entry.from_closure:
                    from_closure.append((name, entry))
                elif entry.in_closure:
                    in_closure.append((name, entry))
Vitja Makarov's avatar
Vitja Makarov committed
2700 2701 2702
        return from_closure, in_closure

    def create_class_from_scope(self, node, target_module_scope, inner_node=None):
2703 2704
        # move local variables into closure
        if node.is_generator:
2705 2706
            for scope in node.local_scope.iter_local_scopes():
                for entry in scope.entries.values():
2707
                    if not (entry.from_closure or entry.is_pyglobal or entry.is_cglobal):
2708
                        entry.in_closure = True
2709

Stefan Behnel's avatar
Stefan Behnel committed
2710
        from_closure, in_closure = self.find_entries_used_in_closures(node)
Vitja Makarov's avatar
Vitja Makarov committed
2711 2712
        in_closure.sort()

2713
        # Now from the beginning
Vitja Makarov's avatar
Vitja Makarov committed
2714 2715 2716
        node.needs_closure = False
        node.needs_outer_scope = False

2717
        func_scope = node.local_scope
Vitja Makarov's avatar
Vitja Makarov committed
2718 2719 2720 2721
        cscope = node.entry.scope
        while cscope.is_py_class_scope or cscope.is_c_class_scope:
            cscope = cscope.outer_scope

2722
        if not from_closure and (self.path or inner_node):
Vitja Makarov's avatar
Vitja Makarov committed
2723
            if not inner_node:
2724
                if not node.py_cfunc_node:
2725
                    raise InternalError("DefNode does not have assignment node")
2726
                inner_node = node.py_cfunc_node
Vitja Makarov's avatar
Vitja Makarov committed
2727 2728
            inner_node.needs_self_code = False
            node.needs_outer_scope = False
2729 2730

        if node.is_generator:
2731
            pass
2732
        elif not in_closure and not from_closure:
Vitja Makarov's avatar
Vitja Makarov committed
2733 2734 2735 2736 2737 2738 2739
            return
        elif not in_closure:
            func_scope.is_passthrough = True
            func_scope.scope_class = cscope.scope_class
            node.needs_outer_scope = True
            return

2740 2741 2742
        as_name = '%s_%s' % (
            target_module_scope.next_id(Naming.closure_class_prefix),
            node.entry.cname)
2743

Stefan Behnel's avatar
Stefan Behnel committed
2744 2745
        entry = target_module_scope.declare_c_class(
            name=as_name, pos=node.pos, defining=True,
2746
            implementing=True)
2747
        entry.type.is_final_type = True
Stefan Behnel's avatar
Stefan Behnel committed
2748

Robert Bradshaw's avatar
Robert Bradshaw committed
2749
        func_scope.scope_class = entry
2750
        class_scope = entry.type.scope
2751
        class_scope.is_internal = True
2752
        class_scope.is_closure_class_scope = True
2753 2754 2755
        if node.is_async_def or node.is_generator:
            # Generators need their closure intact during cleanup as they resume to handle GeneratorExit
            class_scope.directives['no_gc_clear'] = True
2756 2757
        if Options.closure_freelist_size:
            class_scope.directives['freelist'] = Options.closure_freelist_size
2758

Vitja Makarov's avatar
Vitja Makarov committed
2759 2760
        if from_closure:
            assert cscope.is_closure_scope
2761
            class_scope.declare_var(pos=node.pos,
Vitja Makarov's avatar
Vitja Makarov committed
2762
                                    name=Naming.outer_scope_cname,
2763
                                    cname=Naming.outer_scope_cname,
2764
                                    type=cscope.scope_class.type,
2765
                                    is_cdef=True)
Vitja Makarov's avatar
Vitja Makarov committed
2766 2767
            node.needs_outer_scope = True
        for name, entry in in_closure:
2768 2769 2770 2771 2772 2773
            closure_entry = class_scope.declare_var(
                pos=entry.pos,
                name=entry.name if not entry.in_subscope else None,
                cname=entry.cname,
                type=entry.type,
                is_cdef=True)
2774 2775
            if entry.is_declared_generic:
                closure_entry.is_declared_generic = 1
Vitja Makarov's avatar
Vitja Makarov committed
2776 2777 2778 2779 2780
        node.needs_closure = True
        # Do it here because other classes are already checked
        target_module_scope.check_c_class(func_scope.scope_class)

    def visit_LambdaNode(self, node):
2781 2782 2783 2784
        if not isinstance(node.def_node, Nodes.DefNode):
            # fused function, an error has been previously issued
            return node

Vitja Makarov's avatar
Vitja Makarov committed
2785 2786 2787 2788 2789 2790 2791
        was_in_lambda = self.in_lambda
        self.in_lambda = True
        self.create_class_from_scope(node.def_node, self.module_scope, node)
        self.visitchildren(node)
        self.in_lambda = was_in_lambda
        return node

2792
    def visit_FuncDefNode(self, node):
Vitja Makarov's avatar
Vitja Makarov committed
2793 2794 2795 2796
        if self.in_lambda:
            self.visitchildren(node)
            return node
        if node.needs_closure or self.path:
Robert Bradshaw's avatar
Robert Bradshaw committed
2797
            self.create_class_from_scope(node, self.module_scope)
Vitja Makarov's avatar
Vitja Makarov committed
2798
            self.path.append(node)
2799
            self.visitchildren(node)
Vitja Makarov's avatar
Vitja Makarov committed
2800
            self.path.pop()
2801
        return node
2802

2803 2804 2805 2806
    def visit_GeneratorBodyDefNode(self, node):
        self.visitchildren(node)
        return node

2807
    def visit_CFuncDefNode(self, node):
Robert Bradshaw's avatar
Robert Bradshaw committed
2808 2809 2810 2811 2812
        if not node.overridable:
            return self.visit_FuncDefNode(node)
        else:
            self.visitchildren(node)
            return node
2813

2814

2815 2816
class InjectGilHandling(VisitorTransform, SkipDeclarations):
    """
Stefan Behnel's avatar
Stefan Behnel committed
2817 2818 2819 2820
    Allow certain Python operations inside of nogil blocks by implicitly acquiring the GIL.

    Must run before the AnalyseDeclarationsTransform to make sure the GILStatNodes get
    set up, parallel sections know that the GIL is acquired inside of them, etc.
2821 2822 2823 2824 2825 2826 2827 2828 2829 2830 2831 2832 2833
    """
    def __call__(self, root):
        self.nogil = False
        return super(InjectGilHandling, self).__call__(root)

    # special node handling

    def visit_RaiseStatNode(self, node):
        """Allow raising exceptions in nogil sections by wrapping them in a 'with gil' block."""
        if self.nogil:
            node = Nodes.GILStatNode(node.pos, state='gil', body=node)
        return node

2834 2835 2836 2837
    # further candidates:
    # def visit_AssertStatNode(self, node):
    # def visit_ReraiseStatNode(self, node):

2838 2839 2840 2841 2842 2843 2844 2845 2846 2847 2848 2849 2850 2851 2852 2853 2854 2855 2856 2857 2858 2859 2860 2861 2862 2863 2864 2865 2866 2867 2868
    # nogil tracking

    def visit_GILStatNode(self, node):
        was_nogil = self.nogil
        self.nogil = (node.state == 'nogil')
        self.visitchildren(node)
        self.nogil = was_nogil
        return node

    def visit_CFuncDefNode(self, node):
        was_nogil = self.nogil
        if isinstance(node.declarator, Nodes.CFuncDeclaratorNode):
            self.nogil = node.declarator.nogil and not node.declarator.with_gil
        self.visitchildren(node)
        self.nogil = was_nogil
        return node

    def visit_ParallelRangeNode(self, node):
        was_nogil = self.nogil
        self.nogil = node.nogil
        self.visitchildren(node)
        self.nogil = was_nogil
        return node

    def visit_ExprNode(self, node):
        # No special GIL handling inside of expressions for now.
        return node

    visit_Node = VisitorTransform.recurse_to_children


2869 2870 2871 2872 2873
class GilCheck(VisitorTransform):
    """
    Call `node.gil_check(env)` on each node to make sure we hold the
    GIL when we need it.  Raise an error when on Python operations
    inside a `nogil` environment.
2874 2875 2876

    Additionally, raise exceptions for closely nested with gil or with nogil
    statements. The latter would abort Python.
2877
    """
2878

2879 2880
    def __call__(self, root):
        self.env_stack = [root.scope]
2881
        self.nogil = False
2882 2883 2884 2885

        # True for 'cdef func() nogil:' functions, as the GIL may be held while
        # calling this function (thus contained 'nogil' blocks may be valid).
        self.nogil_declarator_only = False
2886 2887
        return super(GilCheck, self).__call__(root)

2888 2889 2890 2891 2892 2893 2894 2895 2896 2897 2898
    def _visit_scoped_children(self, node, gil_state):
        was_nogil = self.nogil
        outer_attrs = node.outer_attrs
        if outer_attrs and len(self.env_stack) > 1:
            self.nogil = self.env_stack[-2].nogil
            self.visitchildren(node, outer_attrs)

        self.nogil = gil_state
        self.visitchildren(node, exclude=outer_attrs)
        self.nogil = was_nogil

2899 2900
    def visit_FuncDefNode(self, node):
        self.env_stack.append(node.local_scope)
2901
        inner_nogil = node.local_scope.nogil
Mark Florisson's avatar
Mark Florisson committed
2902

2903
        if inner_nogil:
2904 2905
            self.nogil_declarator_only = True

2906
        if inner_nogil and node.nogil_check:
2907
            node.nogil_check(node.local_scope)
Mark Florisson's avatar
Mark Florisson committed
2908

2909
        self._visit_scoped_children(node, inner_nogil)
2910 2911 2912 2913

        # This cannot be nested, so it doesn't need backup/restore
        self.nogil_declarator_only = False

2914 2915 2916 2917
        self.env_stack.pop()
        return node

    def visit_GILStatNode(self, node):
Mark Florisson's avatar
Mark Florisson committed
2918 2919 2920
        if self.nogil and node.nogil_check:
            node.nogil_check()

2921
        was_nogil = self.nogil
2922
        is_nogil = (node.state == 'nogil')
2923

2924
        if was_nogil == is_nogil and not self.nogil_declarator_only:
2925 2926 2927 2928 2929 2930 2931
            if not was_nogil:
                error(node.pos, "Trying to acquire the GIL while it is "
                                "already held.")
            else:
                error(node.pos, "Trying to release the GIL while it was "
                                "previously released.")

2932 2933 2934 2935 2936
        if isinstance(node.finally_clause, Nodes.StatListNode):
            # The finally clause of the GILStatNode is a GILExitNode,
            # which is wrapped in a StatListNode. Just unpack that.
            node.finally_clause, = node.finally_clause.stats

2937
        self._visit_scoped_children(node, is_nogil)
2938 2939
        return node

Mark Florisson's avatar
Mark Florisson committed
2940
    def visit_ParallelRangeNode(self, node):
2941 2942
        if node.nogil:
            node.nogil = False
Mark Florisson's avatar
Mark Florisson committed
2943 2944 2945 2946 2947 2948 2949 2950 2951 2952 2953 2954 2955 2956 2957 2958 2959 2960 2961 2962 2963 2964 2965 2966 2967
            node = Nodes.GILStatNode(node.pos, state='nogil', body=node)
            return self.visit_GILStatNode(node)

        if not self.nogil:
            error(node.pos, "prange() can only be used without the GIL")
            # Forget about any GIL-related errors that may occur in the body
            return None

        node.nogil_check(self.env_stack[-1])
        self.visitchildren(node)
        return node

    def visit_ParallelWithBlockNode(self, node):
        if not self.nogil:
            error(node.pos, "The parallel section may only be used without "
                            "the GIL")
            return None

        if node.nogil_check:
            # It does not currently implement this, but test for it anyway to
            # avoid potential future surprises
            node.nogil_check(self.env_stack[-1])

        self.visitchildren(node)
        return node
2968 2969 2970

    def visit_TryFinallyStatNode(self, node):
        """
2971
        Take care of try/finally statements in nogil code sections.
2972 2973 2974 2975 2976 2977
        """
        if not self.nogil or isinstance(node, Nodes.GILStatNode):
            return self.visit_Node(node)

        node.nogil_check = None
        node.is_try_finally_in_nogil = True
2978
        self.visitchildren(node)
Mark Florisson's avatar
Mark Florisson committed
2979
        return node
Mark Florisson's avatar
Mark Florisson committed
2980

2981
    def visit_Node(self, node):
2982 2983
        if self.env_stack and self.nogil and node.nogil_check:
            node.nogil_check(self.env_stack[-1])
2984 2985 2986 2987 2988 2989
        if node.outer_attrs:
            self._visit_scoped_children(node, self.nogil)
        else:
            self.visitchildren(node)
        if self.nogil:
            node.in_nogil_context = True
2990 2991
        return node

2992

Robert Bradshaw's avatar
Robert Bradshaw committed
2993
class TransformBuiltinMethods(EnvTransform):
2994 2995 2996
    """
    Replace Cython's own cython.* builtins by the corresponding tree nodes.
    """
Robert Bradshaw's avatar
Robert Bradshaw committed
2997

2998 2999 3000 3001 3002 3003
    def visit_SingleAssignmentNode(self, node):
        if node.declaration_only:
            return None
        else:
            self.visitchildren(node)
            return node
3004

3005
    def visit_AttributeNode(self, node):
3006
        self.visitchildren(node)
3007 3008 3009 3010
        return self.visit_cython_attribute(node)

    def visit_NameNode(self, node):
        return self.visit_cython_attribute(node)
3011

3012 3013
    def visit_cython_attribute(self, node):
        attribute = node.as_cython_attribute()
3014 3015
        if attribute:
            if attribute == u'compiled':
3016
                node = ExprNodes.BoolNode(node.pos, value=True)
Stefan Behnel's avatar
Stefan Behnel committed
3017
            elif attribute == u'__version__':
3018 3019
                from .. import __version__ as version
                node = ExprNodes.StringNode(node.pos, value=EncodedString(version))
3020
            elif attribute == u'NULL':
3021
                node = ExprNodes.NullNode(node.pos)
3022
            elif attribute in (u'set', u'frozenset', u'staticmethod'):
3023 3024
                node = ExprNodes.NameNode(node.pos, name=EncodedString(attribute),
                                          entry=self.current_env().builtin_scope().lookup_here(attribute))
3025 3026
            elif PyrexTypes.parse_basic_type(attribute):
                pass
3027
            elif self.context.cython_scope.lookup_qualified_name(attribute):
3028 3029
                pass
            else:
Robert Bradshaw's avatar
Robert Bradshaw committed
3030
                error(node.pos, u"'%s' not a valid cython attribute or is being used incorrectly" % attribute)
3031 3032
        return node

Vitja Makarov's avatar
Vitja Makarov committed
3033 3034 3035 3036 3037 3038 3039 3040 3041 3042 3043
    def visit_ExecStatNode(self, node):
        lenv = self.current_env()
        self.visitchildren(node)
        if len(node.args) == 1:
            node.args.append(ExprNodes.GlobalsExprNode(node.pos))
            if not lenv.is_module_scope:
                node.args.append(
                    ExprNodes.LocalsExprNode(
                        node.pos, self.current_scope_node(), lenv))
        return node

3044
    def _inject_locals(self, node, func_name):
3045
        # locals()/dir()/vars() builtins
3046 3047 3048 3049 3050 3051
        lenv = self.current_env()
        entry = lenv.lookup_here(func_name)
        if entry:
            # not the builtin
            return node
        pos = node.pos
3052 3053
        if func_name in ('locals', 'vars'):
            if func_name == 'locals' and len(node.args) > 0:
3054 3055 3056
                error(self.pos, "Builtin 'locals()' called with wrong number of args, expected 0, got %d"
                      % len(node.args))
                return node
3057 3058 3059 3060 3061 3062
            elif func_name == 'vars':
                if len(node.args) > 1:
                    error(self.pos, "Builtin 'vars()' called with wrong number of args, expected 0-1, got %d"
                          % len(node.args))
                if len(node.args) > 0:
                    return node # nothing to do
3063
            return ExprNodes.LocalsExprNode(pos, self.current_scope_node(), lenv)
3064
        else: # dir()
3065 3066 3067
            if len(node.args) > 1:
                error(self.pos, "Builtin 'dir()' called with wrong number of args, expected 0-1, got %d"
                      % len(node.args))
3068
            if len(node.args) > 0:
3069 3070
                # optimised in Builtin.py
                return node
3071 3072 3073 3074 3075 3076
            if lenv.is_py_class_scope or lenv.is_module_scope:
                if lenv.is_py_class_scope:
                    pyclass = self.current_scope_node()
                    locals_dict = ExprNodes.CloneNode(pyclass.dict)
                else:
                    locals_dict = ExprNodes.GlobalsExprNode(pos)
3077
                return ExprNodes.SortedDictKeysNode(locals_dict)
3078 3079 3080
            local_names = sorted(var.name for var in lenv.entries.values() if var.name)
            items = [ExprNodes.IdentifierStringNode(pos, value=var)
                     for var in local_names]
3081
            return ExprNodes.ListNode(pos, args=items)
3082

3083 3084 3085 3086 3087 3088 3089 3090 3091 3092 3093 3094 3095 3096
    def visit_PrimaryCmpNode(self, node):
        # special case: for in/not-in test, we do not need to sort locals()
        self.visitchildren(node)
        if node.operator in 'not_in':  # in/not_in
            if isinstance(node.operand2, ExprNodes.SortedDictKeysNode):
                arg = node.operand2.arg
                if isinstance(arg, ExprNodes.NoneCheckNode):
                    arg = arg.arg
                node.operand2 = arg
        return node

    def visit_CascadedCmpNode(self, node):
        return self.visit_PrimaryCmpNode(node)

3097 3098
    def _inject_eval(self, node, func_name):
        lenv = self.current_env()
3099 3100
        entry = lenv.lookup(func_name)
        if len(node.args) != 1 or (entry and not entry.is_builtin):
3101 3102 3103 3104 3105 3106 3107 3108
            return node
        # Inject globals and locals
        node.args.append(ExprNodes.GlobalsExprNode(node.pos))
        if not lenv.is_module_scope:
            node.args.append(
                ExprNodes.LocalsExprNode(
                    node.pos, self.current_scope_node(), lenv))
        return node
3109

3110 3111 3112 3113 3114 3115 3116
    def _inject_super(self, node, func_name):
        lenv = self.current_env()
        entry = lenv.lookup_here(func_name)
        if entry or node.args:
            return node
        # Inject no-args super
        def_node = self.current_scope_node()
3117
        if (not isinstance(def_node, Nodes.DefNode) or not def_node.args or
3118 3119 3120
            len(self.env_stack) < 2):
            return node
        class_node, class_scope = self.env_stack[-2]
3121
        if class_scope.is_py_class_scope:
3122 3123 3124 3125 3126 3127 3128
            def_node.requires_classobj = True
            class_node.class_cell.is_active = True
            node.args = [
                ExprNodes.ClassCellNode(
                    node.pos, is_generator=def_node.is_generator),
                ExprNodes.NameNode(node.pos, name=def_node.args[0].name)
                ]
3129 3130 3131 3132 3133 3134 3135
        elif class_scope.is_c_class_scope:
            node.args = [
                ExprNodes.NameNode(
                    node.pos, name=class_node.scope.name,
                    entry=class_node.entry),
                ExprNodes.NameNode(node.pos, name=def_node.args[0].name)
                ]
3136 3137
        return node

3138
    def visit_SimpleCallNode(self, node):
3139
        # cython.foo
3140
        function = node.function.as_cython_attribute()
3141
        if function:
3142 3143 3144 3145
            if function in InterpretCompilerDirectives.unop_method_nodes:
                if len(node.args) != 1:
                    error(node.function.pos, u"%s() takes exactly one argument" % function)
                else:
Stefan Behnel's avatar
Stefan Behnel committed
3146 3147
                    node = InterpretCompilerDirectives.unop_method_nodes[function](
                        node.function.pos, operand=node.args[0])
Robert Bradshaw's avatar
Robert Bradshaw committed
3148 3149 3150 3151
            elif function in InterpretCompilerDirectives.binop_method_nodes:
                if len(node.args) != 2:
                    error(node.function.pos, u"%s() takes exactly two arguments" % function)
                else:
Stefan Behnel's avatar
Stefan Behnel committed
3152 3153
                    node = InterpretCompilerDirectives.binop_method_nodes[function](
                        node.function.pos, operand1=node.args[0], operand2=node.args[1])
3154
            elif function == u'cast':
3155
                if len(node.args) != 2:
memeplex's avatar
memeplex committed
3156 3157
                    error(node.function.pos,
                          u"cast() takes exactly two arguments and an optional typecheck keyword")
3158
                else:
Stefan Behnel's avatar
Stefan Behnel committed
3159
                    type = node.args[0].analyse_as_type(self.current_env())
3160
                    if type:
3161 3162
                        node = ExprNodes.TypecastNode(
                            node.function.pos, type=type, operand=node.args[1], typecheck=False)
3163 3164 3165 3166
                    else:
                        error(node.args[0].pos, "Not a type")
            elif function == u'sizeof':
                if len(node.args) != 1:
Robert Bradshaw's avatar
Robert Bradshaw committed
3167
                    error(node.function.pos, u"sizeof() takes exactly one argument")
3168
                else:
Stefan Behnel's avatar
Stefan Behnel committed
3169
                    type = node.args[0].analyse_as_type(self.current_env())
3170
                    if type:
3171
                        node = ExprNodes.SizeofTypeNode(node.function.pos, arg_type=type)
3172
                    else:
3173
                        node = ExprNodes.SizeofVarNode(node.function.pos, operand=node.args[0])
3174 3175
            elif function == 'cmod':
                if len(node.args) != 2:
Robert Bradshaw's avatar
Robert Bradshaw committed
3176
                    error(node.function.pos, u"cmod() takes exactly two arguments")
3177
                else:
3178
                    node = ExprNodes.binop_node(node.function.pos, '%', node.args[0], node.args[1])
3179 3180 3181
                    node.cdivision = True
            elif function == 'cdiv':
                if len(node.args) != 2:
Robert Bradshaw's avatar
Robert Bradshaw committed
3182
                    error(node.function.pos, u"cdiv() takes exactly two arguments")
3183
                else:
3184
                    node = ExprNodes.binop_node(node.function.pos, '/', node.args[0], node.args[1])
3185
                    node.cdivision = True
3186
            elif function == u'set':
3187
                node.function = ExprNodes.NameNode(node.pos, name=EncodedString('set'))
Robert Bradshaw's avatar
Robert Bradshaw committed
3188 3189
            elif function == u'staticmethod':
                node.function = ExprNodes.NameNode(node.pos, name=EncodedString('staticmethod'))
3190 3191
            elif self.context.cython_scope.lookup_qualified_name(function):
                pass
3192
            else:
3193 3194
                error(node.function.pos,
                      u"'%s' not a valid cython language construct" % function)
3195

3196
        self.visitchildren(node)
3197 3198 3199 3200 3201

        if isinstance(node, ExprNodes.SimpleCallNode) and node.function.is_name:
            func_name = node.function.name
            if func_name in ('dir', 'locals', 'vars'):
                return self._inject_locals(node, func_name)
3202 3203
            if func_name == 'eval':
                return self._inject_eval(node, func_name)
3204 3205
            if func_name == 'super':
                return self._inject_super(node, func_name)
Robert Bradshaw's avatar
Robert Bradshaw committed
3206
        return node
3207

memeplex's avatar
memeplex committed
3208 3209
    def visit_GeneralCallNode(self, node):
        function = node.function.as_cython_attribute()
3210 3211
        if function == u'cast':
            # NOTE: assuming simple tuple/dict nodes for positional_args and keyword_args
memeplex's avatar
memeplex committed
3212 3213
            args = node.positional_args.args
            kwargs = node.keyword_args.compile_time_value(None)
3214 3215 3216 3217 3218 3219 3220 3221 3222 3223
            if (len(args) != 2 or len(kwargs) > 1 or
                    (len(kwargs) == 1 and 'typecheck' not in kwargs)):
                error(node.function.pos,
                      u"cast() takes exactly two arguments and an optional typecheck keyword")
            else:
                type = args[0].analyse_as_type(self.current_env())
                if type:
                    typecheck = kwargs.get('typecheck', False)
                    node = ExprNodes.TypecastNode(
                        node.function.pos, type=type, operand=args[1], typecheck=typecheck)
memeplex's avatar
memeplex committed
3224
                else:
3225
                    error(args[0].pos, "Not a type")
3226 3227

        self.visitchildren(node)
memeplex's avatar
memeplex committed
3228 3229
        return node

3230

3231 3232 3233 3234 3235 3236 3237 3238 3239 3240 3241 3242 3243 3244
class ReplaceFusedTypeChecks(VisitorTransform):
    """
    This is not a transform in the pipeline. It is invoked on the specific
    versions of a cdef function with fused argument types. It filters out any
    type branches that don't match. e.g.

        if fused_t is mytype:
            ...
        elif fused_t in other_fused_type:
            ...
    """
    def __init__(self, local_scope):
        super(ReplaceFusedTypeChecks, self).__init__()
        self.local_scope = local_scope
Stefan Behnel's avatar
Stefan Behnel committed
3245
        # defer the import until now to avoid circular import time dependencies
3246 3247
        from .Optimize import ConstantFolding
        self.transform = ConstantFolding(reevaluate=True)
3248 3249

    def visit_IfStatNode(self, node):
3250 3251 3252 3253
        """
        Filters out any if clauses with false compile time type check
        expression.
        """
3254
        self.visitchildren(node)
3255
        return self.transform(node)
3256

3257
    def visit_PrimaryCmpNode(self, node):
3258 3259 3260
        with Errors.local_errors(ignore=True):
          type1 = node.operand1.analyse_as_type(self.local_scope)
          type2 = node.operand2.analyse_as_type(self.local_scope)
3261 3262

        if type1 and type2:
Mark Florisson's avatar
Mark Florisson committed
3263 3264
            false_node = ExprNodes.BoolNode(node.pos, value=False)
            true_node = ExprNodes.BoolNode(node.pos, value=True)
3265 3266 3267 3268

            type1 = self.specialize_type(type1, node.operand1.pos)
            op = node.operator

3269
            if op in ('is', 'is_not', '==', '!='):
3270 3271 3272 3273 3274 3275
                type2 = self.specialize_type(type2, node.operand2.pos)

                is_same = type1.same_as(type2)
                eq = op in ('is', '==')

                if (is_same and eq) or (not is_same and not eq):
Mark Florisson's avatar
Mark Florisson committed
3276
                    return true_node
3277 3278 3279 3280 3281 3282 3283 3284 3285 3286 3287 3288 3289 3290

            elif op in ('in', 'not_in'):
                # We have to do an instance check directly, as operand2
                # needs to be a fused type and not a type with a subtype
                # that is fused. First unpack the typedef
                if isinstance(type2, PyrexTypes.CTypedefType):
                    type2 = type2.typedef_base_type

                if type1.is_fused:
                    error(node.operand1.pos, "Type is fused")
                elif not type2.is_fused:
                    error(node.operand2.pos,
                          "Can only use 'in' or 'not in' on a fused type")
                else:
3291
                    types = PyrexTypes.get_specialized_types(type2)
3292

3293 3294
                    for specialized_type in types:
                        if type1.same_as(specialized_type):
3295
                            if op == 'in':
Mark Florisson's avatar
Mark Florisson committed
3296
                                return true_node
3297
                            else:
Mark Florisson's avatar
Mark Florisson committed
3298
                                return false_node
3299 3300

                    if op == 'not_in':
Mark Florisson's avatar
Mark Florisson committed
3301
                        return true_node
3302

Mark Florisson's avatar
Mark Florisson committed
3303
            return false_node
3304 3305 3306 3307 3308 3309 3310 3311 3312 3313 3314 3315 3316 3317 3318

        return node

    def specialize_type(self, type, pos):
        try:
            return type.specialize(self.local_scope.fused_to_specific)
        except KeyError:
            error(pos, "Type is not specific")
            return type

    def visit_Node(self, node):
        self.visitchildren(node)
        return node


Mark Florisson's avatar
Mark Florisson committed
3319
class DebugTransform(CythonTransform):
3320
    """
Mark Florisson's avatar
Mark Florisson committed
3321
    Write debug information for this Cython module.
3322
    """
3323

3324
    def __init__(self, context, options, result):
Mark Florisson's avatar
Mark Florisson committed
3325
        super(DebugTransform, self).__init__(context)
Robert Bradshaw's avatar
Robert Bradshaw committed
3326
        self.visited = set()
3327
        # our treebuilder and debug output writer
Mark Florisson's avatar
Mark Florisson committed
3328
        # (see Cython.Debugger.debug_output.CythonDebugWriter)
3329
        self.tb = self.context.gdb_debug_outputwriter
3330
        #self.c_output_file = options.output_file
3331
        self.c_output_file = result.c_file
3332

3333 3334 3335
        # Closure support, basically treat nested functions as if the AST were
        # never nested
        self.nested_funcdefs = []
3336

Mark Florisson's avatar
Mark Florisson committed
3337 3338
        # tells visit_NameNode whether it should register step-into functions
        self.register_stepinto = False
3339

3340
    def visit_ModuleNode(self, node):
Mark Florisson's avatar
Mark Florisson committed
3341
        self.tb.module_name = node.full_module_name
3342
        attrs = dict(
Mark Florisson's avatar
Mark Florisson committed
3343
            module_name=node.full_module_name,
Mark Florisson's avatar
Mark Florisson committed
3344 3345
            filename=node.pos[0].filename,
            c_filename=self.c_output_file)
3346

3347
        self.tb.start('Module', attrs)
3348

3349
        # serialize functions
Mark Florisson's avatar
Mark Florisson committed
3350
        self.tb.start('Functions')
3351
        # First, serialize functions normally...
3352
        self.visitchildren(node)
3353

3354 3355 3356
        # ... then, serialize nested functions
        for nested_funcdef in self.nested_funcdefs:
            self.visit_FuncDefNode(nested_funcdef)
3357

3358 3359 3360
        self.register_stepinto = True
        self.serialize_modulenode_as_function(node)
        self.register_stepinto = False
3361
        self.tb.end('Functions')
3362

3363
        # 2.3 compatibility. Serialize global variables
Mark Florisson's avatar
Mark Florisson committed
3364
        self.tb.start('Globals')
3365
        entries = {}
Mark Florisson's avatar
Mark Florisson committed
3366

3367
        for k, v in node.scope.entries.items():
Mark Florisson's avatar
Mark Florisson committed
3368
            if (v.qualified_name not in self.visited and not
3369 3370 3371
                    v.name.startswith('__pyx_') and not
                    v.type.is_cfunction and not
                    v.type.is_extension_type):
3372
                entries[k]= v
3373

3374 3375
        self.serialize_local_variables(entries)
        self.tb.end('Globals')
Mark Florisson's avatar
Mark Florisson committed
3376 3377
        # self.tb.end('Module') # end Module after the line number mapping in
        # Cython.Compiler.ModuleNode.ModuleNode._serialize_lineno_map
3378
        return node
3379 3380

    def visit_FuncDefNode(self, node):
3381
        self.visited.add(node.local_scope.qualified_name)
3382 3383 3384 3385 3386 3387 3388 3389

        if getattr(node, 'is_wrapper', False):
            return node

        if self.register_stepinto:
            self.nested_funcdefs.append(node)
            return node

3390
        # node.entry.visibility = 'extern'
3391 3392 3393 3394
        if node.py_func is None:
            pf_cname = ''
        else:
            pf_cname = node.py_func.entry.func_cname
3395

3396
        attrs = dict(
3397
            name=node.entry.name or getattr(node, 'name', '<unknown>'),
3398 3399 3400 3401
            cname=node.entry.func_cname,
            pf_cname=pf_cname,
            qualified_name=node.local_scope.qualified_name,
            lineno=str(node.pos[1]))
3402

3403
        self.tb.start('Function', attrs=attrs)
3404

Mark Florisson's avatar
Mark Florisson committed
3405
        self.tb.start('Locals')
3406 3407
        self.serialize_local_variables(node.local_scope.entries)
        self.tb.end('Locals')
Mark Florisson's avatar
Mark Florisson committed
3408 3409

        self.tb.start('Arguments')
3410
        for arg in node.local_scope.arg_entries:
Mark Florisson's avatar
Mark Florisson committed
3411 3412
            self.tb.start(arg.name)
            self.tb.end(arg.name)
3413
        self.tb.end('Arguments')
Mark Florisson's avatar
Mark Florisson committed
3414 3415

        self.tb.start('StepIntoFunctions')
Mark Florisson's avatar
Mark Florisson committed
3416
        self.register_stepinto = True
Mark Florisson's avatar
Mark Florisson committed
3417
        self.visitchildren(node)
Mark Florisson's avatar
Mark Florisson committed
3418
        self.register_stepinto = False
Mark Florisson's avatar
Mark Florisson committed
3419
        self.tb.end('StepIntoFunctions')
3420
        self.tb.end('Function')
Mark Florisson's avatar
Mark Florisson committed
3421 3422 3423 3424

        return node

    def visit_NameNode(self, node):
3425
        if (self.register_stepinto and
3426
            node.type is not None and
3427
            node.type.is_cfunction and
3428 3429
            getattr(node, 'is_called', False) and
            node.entry.func_cname is not None):
3430 3431 3432 3433
            # don't check node.entry.in_cinclude, as 'cdef extern: ...'
            # declared functions are not 'in_cinclude'.
            # This means we will list called 'cdef' functions as
            # "step into functions", but this is not an issue as they will be
Mark Florisson's avatar
Mark Florisson committed
3434
            # recognized as Cython functions anyway.
Mark Florisson's avatar
Mark Florisson committed
3435 3436 3437
            attrs = dict(name=node.entry.func_cname)
            self.tb.start('StepIntoFunction', attrs=attrs)
            self.tb.end('StepIntoFunction')
3438

Mark Florisson's avatar
Mark Florisson committed
3439
        self.visitchildren(node)
3440
        return node
3441

3442 3443 3444 3445 3446 3447 3448
    def serialize_modulenode_as_function(self, node):
        """
        Serialize the module-level code as a function so the debugger will know
        it's a "relevant frame" and it will know where to set the breakpoint
        for 'break modulename'.
        """
        name = node.full_module_name.rpartition('.')[-1]
3449

3450 3451
        cname_py2 = 'init' + name
        cname_py3 = 'PyInit_' + name
3452

3453 3454 3455 3456
        py2_attrs = dict(
            name=name,
            cname=cname_py2,
            pf_cname='',
3457
            # Ignore the qualified_name, breakpoints should be set using
3458 3459 3460 3461 3462
            # `cy break modulename:lineno` for module-level breakpoints.
            qualified_name='',
            lineno='1',
            is_initmodule_function="True",
        )
3463

3464
        py3_attrs = dict(py2_attrs, cname=cname_py3)
3465

3466 3467
        self._serialize_modulenode_as_function(node, py2_attrs)
        self._serialize_modulenode_as_function(node, py3_attrs)
3468

3469 3470
    def _serialize_modulenode_as_function(self, node, attrs):
        self.tb.start('Function', attrs=attrs)
3471

3472 3473 3474 3475 3476 3477 3478 3479 3480 3481 3482 3483
        self.tb.start('Locals')
        self.serialize_local_variables(node.scope.entries)
        self.tb.end('Locals')

        self.tb.start('Arguments')
        self.tb.end('Arguments')

        self.tb.start('StepIntoFunctions')
        self.register_stepinto = True
        self.visitchildren(node)
        self.register_stepinto = False
        self.tb.end('StepIntoFunctions')
3484

3485
        self.tb.end('Function')
3486

3487 3488
    def serialize_local_variables(self, entries):
        for entry in entries.values():
3489 3490 3491
            if not entry.cname:
                # not a local variable
                continue
3492
            if entry.type.is_pyobject:
Mark Florisson's avatar
Mark Florisson committed
3493
                vartype = 'PythonObject'
3494 3495
            else:
                vartype = 'CObject'
3496

3497 3498 3499
            if entry.from_closure:
                # We're dealing with a closure where a variable from an outer
                # scope is accessed, get it from the scope object.
3500
                cname = '%s->%s' % (Naming.cur_scope_cname,
3501
                                    entry.outer_entry.cname)
3502

3503
                qname = '%s.%s.%s' % (entry.scope.outer_scope.qualified_name,
3504
                                      entry.scope.name,
3505
                                      entry.name)
3506
            elif entry.in_closure:
3507
                cname = '%s->%s' % (Naming.cur_scope_cname,
3508 3509
                                    entry.cname)
                qname = entry.qualified_name
3510 3511 3512
            else:
                cname = entry.cname
                qname = entry.qualified_name
3513

3514 3515 3516 3517 3518 3519 3520
            if not entry.pos:
                # this happens for variables that are not in the user's code,
                # e.g. for the global __builtins__, __doc__, etc. We can just
                # set the lineno to 0 for those.
                lineno = '0'
            else:
                lineno = str(entry.pos[1])
3521

3522 3523 3524
            attrs = dict(
                name=entry.name,
                cname=cname,
3525
                qualified_name=qname,
3526 3527
                type=vartype,
                lineno=lineno)
3528

Mark Florisson's avatar
Mark Florisson committed
3529 3530
            self.tb.start('LocalVar', attrs)
            self.tb.end('LocalVar')