Optimize.py 141 KB
Newer Older
1 2 3 4 5 6

import cython
cython.declare(UtilityCode=object, EncodedString=object, BytesLiteral=object,
               Nodes=object, ExprNodes=object, PyrexTypes=object, Builtin=object,
               UtilNodes=object, Naming=object)

7 8
import Nodes
import ExprNodes
9
import PyrexTypes
10
import Visitor
11 12 13 14
import Builtin
import UtilNodes
import TypeSlots
import Symtab
15
import Options
16
import Naming
17

18
from Code import UtilityCode
19
from StringEncoding import EncodedString, BytesLiteral
20
from Errors import error
21 22
from ParseTreeTransforms import SkipDeclarations

23 24
import codecs

25
try:
26 27
    from __builtin__ import reduce
except ImportError:
28 29
    from functools import reduce

30 31 32 33 34
try:
    from __builtin__ import basestring
except ImportError:
    basestring = str # Python 3

35 36 37 38
class FakePythonEnv(object):
    "A fake environment for creating type test nodes etc."
    nogil = False

39 40 41 42 43
def unwrap_coerced_node(node, coercion_nodes=(ExprNodes.CoerceToPyTypeNode, ExprNodes.CoerceFromPyTypeNode)):
    if isinstance(node, coercion_nodes):
        return node.arg
    return node

44
def unwrap_node(node):
45 46
    while isinstance(node, UtilNodes.ResultRefNode):
        node = node.expression
47
    return node
48 49

def is_common_value(a, b):
50 51
    a = unwrap_node(a)
    b = unwrap_node(b)
52 53 54
    if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
        return a.name == b.name
    if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
55
        return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
56 57
    return False

58 59 60 61
class IterationTransform(Visitor.VisitorTransform):
    """Transform some common for-in loop patterns into efficient C loops:

    - for-in-dict loop becomes a while loop calling PyDict_Next()
Stefan Behnel's avatar
Stefan Behnel committed
62
    - for-in-enumerate is replaced by an external counter variable
63
    - for-in-range loop becomes a plain C for loop
64
    """
65 66
    PyDict_Size_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_py_ssize_t_type, [
67 68 69
            PyrexTypes.CFuncTypeArg("dict",  PyrexTypes.py_object_type, None),
            ])

70
    PyDict_Size_name = EncodedString("PyDict_Size")
71

72 73
    PyDict_Size_entry = Symtab.Entry(
        PyDict_Size_name, PyDict_Size_name, PyDict_Size_func_type)
74

75
    visit_Node = Visitor.VisitorTransform.recurse_to_children
Stefan Behnel's avatar
Stefan Behnel committed
76

77 78
    def visit_ModuleNode(self, node):
        self.current_scope = node.scope
79
        self.module_scope = node.scope
80 81 82 83 84 85 86 87 88
        self.visitchildren(node)
        return node

    def visit_DefNode(self, node):
        oldscope = self.current_scope
        self.current_scope = node.entry.scope
        self.visitchildren(node)
        self.current_scope = oldscope
        return node
89

90 91
    def visit_PrimaryCmpNode(self, node):
        if node.is_ptr_contains():
92

93 94 95 96 97 98
            # for t in operand2:
            #     if operand1 == t:
            #         res = True
            #         break
            # else:
            #     res = False
99

100 101 102 103 104 105 106 107 108 109 110 111 112
            pos = node.pos
            res_handle = UtilNodes.TempHandle(PyrexTypes.c_bint_type)
            res = res_handle.ref(pos)
            result_ref = UtilNodes.ResultRefNode(node)
            if isinstance(node.operand2, ExprNodes.IndexNode):
                base_type = node.operand2.base.type.base_type
            else:
                base_type = node.operand2.type.base_type
            target_handle = UtilNodes.TempHandle(base_type)
            target = target_handle.ref(pos)
            cmp_node = ExprNodes.PrimaryCmpNode(
                pos, operator=u'==', operand1=node.operand1, operand2=target)
            if_body = Nodes.StatListNode(
113
                pos,
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
                stats = [Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=1)),
                         Nodes.BreakStatNode(pos)])
            if_node = Nodes.IfStatNode(
                pos,
                if_clauses=[Nodes.IfClauseNode(pos, condition=cmp_node, body=if_body)],
                else_clause=None)
            for_loop = UtilNodes.TempsBlockNode(
                pos,
                temps = [target_handle],
                body = Nodes.ForInStatNode(
                    pos,
                    target=target,
                    iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2),
                    body=if_node,
                    else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0))))
            for_loop.analyse_expressions(self.current_scope)
            for_loop = self(for_loop)
            new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop)
132

133 134 135 136 137 138 139
            if node.operator == 'not_in':
                new_node = ExprNodes.NotNode(pos, operand=new_node)
            return new_node

        else:
            self.visitchildren(node)
            return node
140

141 142
    def visit_ForInStatNode(self, node):
        self.visitchildren(node)
143
        return self._optimise_for_loop(node, node.iterator.sequence)
144

145
    def _optimise_for_loop(self, node, iterator, reversed=False):
146 147
        if iterator.type is Builtin.dict_type:
            # like iterating over dict.keys()
148
            if reversed:
Stefan Behnel's avatar
Stefan Behnel committed
149
                # CPython raises an error here: not a sequence
150
                return node
Stefan Behnel's avatar
Stefan Behnel committed
151 152
            return self._transform_dict_iteration(
                node, dict_obj=iterator, keys=True, values=False)
153

154
        # C array (slice) iteration?
155
        if iterator.type.is_ptr or iterator.type.is_array:
156
            return self._transform_carray_iteration(node, iterator, reversed=reversed)
157
        if iterator.type in (Builtin.bytes_type, Builtin.unicode_type):
158
            return self._transform_string_iteration(node, iterator, reversed=reversed)
159 160 161

        # the rest is based on function calls
        if not isinstance(iterator, ExprNodes.SimpleCallNode):
Stefan Behnel's avatar
Stefan Behnel committed
162 163 164
            return node

        function = iterator.function
165
        # dict iteration?
Stefan Behnel's avatar
Stefan Behnel committed
166 167
        if isinstance(function, ExprNodes.AttributeNode) and \
                function.obj.type == Builtin.dict_type:
168
            if reversed:
Stefan Behnel's avatar
Stefan Behnel committed
169
                # CPython raises an error here: not a sequence
170
                return node
171
            dict_obj = iterator.self or function.obj
172 173
            method = function.attribute

174
            is_py3 = self.module_scope.context.language_level >= 3
175
            keys = values = False
176
            if method == 'iterkeys' or (is_py3 and method == 'keys'):
177
                keys = True
178
            elif method == 'itervalues' or (is_py3 and method == 'values'):
179
                values = True
180
            elif method == 'iteritems' or (is_py3 and method == 'items'):
181 182 183
                keys = values = True
            else:
                return node
Stefan Behnel's avatar
Stefan Behnel committed
184 185
            return self._transform_dict_iteration(
                node, dict_obj, keys, values)
186

187
        # enumerate/reversed ?
Stefan Behnel's avatar
Stefan Behnel committed
188
        if iterator.self is None and function.is_name and \
189 190 191
               function.entry and function.entry.is_builtin:
            if function.name == 'enumerate':
                if reversed:
Stefan Behnel's avatar
Stefan Behnel committed
192
                    # CPython raises an error here: not a sequence
193 194 195 196
                    return node
                return self._transform_enumerate_iteration(node, iterator)
            elif function.name == 'reversed':
                if reversed:
Stefan Behnel's avatar
Stefan Behnel committed
197
                    # CPython raises an error here: not a sequence
198 199
                    return node
                return self._transform_reversed_iteration(node, iterator)
200

201 202
        # range() iteration?
        if Options.convert_range and node.target.type.is_int:
Stefan Behnel's avatar
Stefan Behnel committed
203 204 205
            if iterator.self is None and function.is_name and \
                   function.entry and function.entry.is_builtin and \
                   function.name in ('range', 'xrange'):
206
                return self._transform_range_iteration(node, iterator, reversed=reversed)
207

Stefan Behnel's avatar
Stefan Behnel committed
208
        return node
209

210 211 212 213 214 215 216 217 218 219
    def _transform_reversed_iteration(self, node, reversed_function):
        args = reversed_function.arg_tuple.args
        if len(args) == 0:
            error(reversed_function.pos,
                  "reversed() requires an iterable argument")
            return node
        elif len(args) > 1:
            error(reversed_function.pos,
                  "reversed() takes exactly 1 argument")
            return node
220 221 222 223 224 225 226 227 228
        arg = args[0]

        # reversed(list/tuple) ?
        if arg.type in (Builtin.tuple_type, Builtin.list_type):
            node.iterator.sequence = arg.as_none_safe_node("'NoneType' object is not iterable")
            node.iterator.reversed = True
            return node

        return self._optimise_for_loop(node, arg, reversed=True)
229

230
    PyUnicode_AS_UNICODE_func_type = PyrexTypes.CFuncType(
Stefan Behnel's avatar
Stefan Behnel committed
231
        PyrexTypes.c_py_unicode_ptr_type, [
232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249
            PyrexTypes.CFuncTypeArg("s", Builtin.unicode_type, None)
            ])

    PyUnicode_GET_SIZE_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_py_ssize_t_type, [
            PyrexTypes.CFuncTypeArg("s", Builtin.unicode_type, None)
            ])

    PyBytes_AS_STRING_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_char_ptr_type, [
            PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
            ])

    PyBytes_GET_SIZE_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_py_ssize_t_type, [
            PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
            ])

250
    def _transform_string_iteration(self, node, slice_node, reversed=False):
251 252 253 254 255 256
        if slice_node.type is Builtin.unicode_type:
            unpack_func = "PyUnicode_AS_UNICODE"
            len_func = "PyUnicode_GET_SIZE"
            unpack_func_type = self.PyUnicode_AS_UNICODE_func_type
            len_func_type = self.PyUnicode_GET_SIZE_func_type
        elif slice_node.type is Builtin.bytes_type:
257 258 259 260 261
            target_type = node.target.type
            if not target_type.is_int:
                # bytes iteration returns bytes objects in Py2, but
                # integers in Py3
                return node
262 263 264 265 266 267 268 269
            unpack_func = "PyBytes_AS_STRING"
            unpack_func_type = self.PyBytes_AS_STRING_func_type
            len_func = "PyBytes_GET_SIZE"
            len_func_type = self.PyBytes_GET_SIZE_func_type
        else:
            return node

        unpack_temp_node = UtilNodes.LetRefNode(
270
            slice_node.as_none_safe_node("'NoneType' is not iterable"))
271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294

        slice_base_node = ExprNodes.PythonCapiCallNode(
            slice_node.pos, unpack_func, unpack_func_type,
            args = [unpack_temp_node],
            is_temp = 0,
            )
        len_node = ExprNodes.PythonCapiCallNode(
            slice_node.pos, len_func, len_func_type,
            args = [unpack_temp_node],
            is_temp = 0,
            )

        return UtilNodes.LetNode(
            unpack_temp_node,
            self._transform_carray_iteration(
                node,
                ExprNodes.SliceIndexNode(
                    slice_node.pos,
                    base = slice_base_node,
                    start = None,
                    step = None,
                    stop = len_node,
                    type = slice_base_node.type,
                    is_temp = 1,
295 296
                    ),
                reversed = reversed))
297

298
    def _transform_carray_iteration(self, node, slice_node, reversed=False):
299
        neg_step = False
300 301 302 303 304 305
        if isinstance(slice_node, ExprNodes.SliceIndexNode):
            slice_base = slice_node.base
            start = slice_node.start
            stop = slice_node.stop
            step = None
            if not stop:
306 307
                if not slice_base.type.is_pyobject:
                    error(slice_node.pos, "C array iteration requires known end index")
308
                return node
309

310
        elif isinstance(slice_node, ExprNodes.IndexNode):
311
            assert isinstance(slice_node.index, ExprNodes.SliceNode)
312 313
            slice_base = slice_node.base
            index = slice_node.index
314 315 316 317 318 319 320 321 322 323
            start = index.start
            stop = index.stop
            step = index.step
            if step:
                if step.constant_result is None:
                    step = None
                elif not isinstance(step.constant_result, (int,long)) \
                       or step.constant_result == 0 \
                       or step.constant_result > 0 and not stop \
                       or step.constant_result < 0 and not start:
324 325
                    if not slice_base.type.is_pyobject:
                        error(step.pos, "C array iteration requires known step size and end index")
326 327 328
                    return node
                else:
                    # step sign is handled internally by ForFromStatNode
329 330 331 332
                    step_value = step.constant_result
                    if reversed:
                        step_value = -step_value
                    neg_step = step_value < 0
333
                    step = ExprNodes.IntNode(step.pos, type=PyrexTypes.c_py_ssize_t_type,
334 335
                                             value=str(abs(step_value)),
                                             constant_result=abs(step_value))
336

337 338 339 340
        elif slice_node.type.is_array:
            if slice_node.type.size is None:
                error(step.pos, "C array iteration requires known end index")
                return node
341 342 343
            slice_base = slice_node
            start = None
            stop = ExprNodes.IntNode(
344 345
                slice_node.pos, value=str(slice_node.type.size),
                type=PyrexTypes.c_py_ssize_t_type, constant_result=slice_node.type.size)
346
            step = None
347

348
        else:
349
            if not slice_node.type.is_pyobject:
350
                error(slice_node.pos, "C array iteration requires known end index")
351 352
            return node

353 354 355 356 357 358 359 360 361 362
        if start:
            if start.constant_result is None:
                start = None
            else:
                start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_scope)
        if stop:
            if stop.constant_result is None:
                stop = None
            else:
                stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_scope)
363 364 365 366 367 368 369
        if stop is None:
            if neg_step:
                stop = ExprNodes.IntNode(
                    slice_node.pos, value='-1', type=PyrexTypes.c_py_ssize_t_type, constant_result=-1)
            else:
                error(slice_node.pos, "C array iteration requires known step size and end index")
                return node
370

371 372 373 374 375 376 377
        if reversed:
            if not start:
                start = ExprNodes.IntNode(slice_node.pos, value="0",  constant_result=0,
                                          type=PyrexTypes.c_py_ssize_t_type)
            # if step was provided, it was already negated above
            start, stop = stop, start

378 379 380 381
        ptr_type = slice_base.type
        if ptr_type.is_array:
            ptr_type = ptr_type.element_ptr_type()
        carray_ptr = slice_base.coerce_to_simple(self.current_scope)
382

383
        if start and start.constant_result != 0:
384 385 386 387 388
            start_ptr_node = ExprNodes.AddNode(
                start.pos,
                operand1=carray_ptr,
                operator='+',
                operand2=start,
389
                type=ptr_type)
390
        else:
391
            start_ptr_node = carray_ptr
392

393 394 395 396 397 398 399 400 401 402
        if stop and stop.constant_result != 0:
            stop_ptr_node = ExprNodes.AddNode(
                stop.pos,
                operand1=ExprNodes.CloneNode(carray_ptr),
                operator='+',
                operand2=stop,
                type=ptr_type
                ).coerce_to_simple(self.current_scope)
        else:
            stop_ptr_node = ExprNodes.CloneNode(carray_ptr)
403

404
        counter = UtilNodes.TempHandle(ptr_type)
405 406
        counter_temp = counter.ref(node.target.pos)

407
        if slice_base.type.is_string and node.target.type.is_pyobject:
408
            # special case: char* -> bytes
409 410
            target_value = ExprNodes.SliceIndexNode(
                node.target.pos,
411 412 413 414 415 416 417
                start=ExprNodes.IntNode(node.target.pos, value='0',
                                        constant_result=0,
                                        type=PyrexTypes.c_int_type),
                stop=ExprNodes.IntNode(node.target.pos, value='1',
                                       constant_result=1,
                                       type=PyrexTypes.c_int_type),
                base=counter_temp,
418 419
                type=Builtin.bytes_type,
                is_temp=1)
420 421 422
        elif node.target.type.is_ptr and not node.target.type.assignable_from(ptr_type.base_type):
            # Allow iteration with pointer target to avoid copy.
            target_value = counter_temp
423 424 425
        else:
            target_value = ExprNodes.IndexNode(
                node.target.pos,
426 427 428 429
                index=ExprNodes.IntNode(node.target.pos, value='0',
                                        constant_result=0,
                                        type=PyrexTypes.c_int_type),
                base=counter_temp,
430
                is_buffer_access=False,
431
                type=ptr_type.base_type)
432 433 434 435 436 437 438 439 440 441 442 443 444 445

        if target_value.type != node.target.type:
            target_value = target_value.coerce_to(node.target.type,
                                                  self.current_scope)

        target_assign = Nodes.SingleAssignmentNode(
            pos = node.target.pos,
            lhs = node.target,
            rhs = target_value)

        body = Nodes.StatListNode(
            node.pos,
            stats = [target_assign, node.body])

446 447
        relation1, relation2 = self._find_for_from_node_relations(neg_step, reversed)

448 449
        for_node = Nodes.ForFromStatNode(
            node.pos,
450
            bound1=start_ptr_node, relation1=relation1,
451
            target=counter_temp,
452
            relation2=relation2, bound2=stop_ptr_node,
453 454 455 456 457 458 459 460
            step=step, body=body,
            else_clause=node.else_clause,
            from_range=True)

        return UtilNodes.TempsBlockNode(
            node.pos, temps=[counter],
            body=for_node)

461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489
    def _transform_enumerate_iteration(self, node, enumerate_function):
        args = enumerate_function.arg_tuple.args
        if len(args) == 0:
            error(enumerate_function.pos,
                  "enumerate() requires an iterable argument")
            return node
        elif len(args) > 1:
            error(enumerate_function.pos,
                  "enumerate() takes at most 1 argument")
            return node

        if not node.target.is_sequence_constructor:
            # leave this untouched for now
            return node
        targets = node.target.args
        if len(targets) != 2:
            # leave this untouched for now
            return node
        if not isinstance(targets[0], ExprNodes.NameNode):
            # leave this untouched for now
            return node

        enumerate_target, iterable_target = targets
        counter_type = enumerate_target.type

        if not counter_type.is_pyobject and not counter_type.is_int:
            # nothing we can do here, I guess
            return node

490 491 492 493
        temp = UtilNodes.LetRefNode(ExprNodes.IntNode(enumerate_function.pos,
                                                      value='0',
                                                      type=counter_type,
                                                      constant_result=0))
494 495
        inc_expression = ExprNodes.AddNode(
            enumerate_function.pos,
496
            operand1 = temp,
497
            operand2 = ExprNodes.IntNode(node.pos, value='1',
498 499
                                         type=counter_type,
                                         constant_result=1),
500 501 502 503 504
            operator = '+',
            type = counter_type,
            is_temp = counter_type.is_pyobject
            )

505 506 507 508
        loop_body = [
            Nodes.SingleAssignmentNode(
                pos = enumerate_target.pos,
                lhs = enumerate_target,
509
                rhs = temp),
510 511
            Nodes.SingleAssignmentNode(
                pos = enumerate_target.pos,
512
                lhs = temp,
513 514
                rhs = inc_expression)
            ]
515

516 517 518 519 520 521 522
        if isinstance(node.body, Nodes.StatListNode):
            node.body.stats = loop_body + node.body.stats
        else:
            loop_body.append(node.body)
            node.body = Nodes.StatListNode(
                node.body.pos,
                stats = loop_body)
523 524

        node.target = iterable_target
525
        node.item = node.item.coerce_to(iterable_target.type, self.current_scope)
526 527 528
        node.iterator.sequence = enumerate_function.arg_tuple.args[0]

        # recurse into loop to check for further optimisations
529 530 531 532 533 534 535 536 537 538 539 540 541
        return UtilNodes.LetNode(temp, self._optimise_for_loop(node, node.iterator.sequence))

    def _find_for_from_node_relations(self, neg_step_value, reversed):
        if reversed:
            if neg_step_value:
                return '<', '<='
            else:
                return '>', '>='
        else:
            if neg_step_value:
                return '>=', '>'
            else:
                return '<=', '<'
542

543
    def _transform_range_iteration(self, node, range_function, reversed=False):
544 545 546 547
        args = range_function.arg_tuple.args
        if len(args) < 3:
            step_pos = range_function.pos
            step_value = 1
548 549
            step = ExprNodes.IntNode(step_pos, value='1',
                                     constant_result=1)
550 551 552
        else:
            step = args[2]
            step_pos = step.pos
553
            if not isinstance(step.constant_result, (int, long)):
554 555
                # cannot determine step direction
                return node
556 557 558
            step_value = step.constant_result
            if step_value == 0:
                # will lead to an error elsewhere
559 560
                return node
            if not isinstance(step, ExprNodes.IntNode):
561 562
                step = ExprNodes.IntNode(step_pos, value=str(step_value),
                                         constant_result=step_value)
563 564

        if len(args) == 1:
565 566
            bound1 = ExprNodes.IntNode(range_function.pos, value='0',
                                       constant_result=0)
567
            bound2 = args[0].coerce_to_integer(self.current_scope)
568
        else:
569 570
            bound1 = args[0].coerce_to_integer(self.current_scope)
            bound2 = args[1].coerce_to_integer(self.current_scope)
571

572 573
        relation1, relation2 = self._find_for_from_node_relations(step_value < 0, reversed)

574 575 576 577 578 579 580 581 582 583
        if reversed:
            bound1, bound2 = bound2, bound1
            if step_value < 0:
                step_value = -step_value
        else:
            if step_value < 0:
                step_value = -step_value

        step.value = str(step_value)
        step.constant_result = step_value
584
        step = step.coerce_to_integer(self.current_scope)
585

586
        if not bound2.is_literal:
587 588 589 590 591 592
            # stop bound must be immutable => keep it in a temp var
            bound2_is_temp = True
            bound2 = UtilNodes.LetRefNode(bound2)
        else:
            bound2_is_temp = False

593 594 595 596 597 598 599
        for_node = Nodes.ForFromStatNode(
            node.pos,
            target=node.target,
            bound1=bound1, relation1=relation1,
            relation2=relation2, bound2=bound2,
            step=step, body=node.body,
            else_clause=node.else_clause,
Magnus Lie Hetland's avatar
Magnus Lie Hetland committed
600
            from_range=True)
601 602 603 604

        if bound2_is_temp:
            for_node = UtilNodes.LetNode(bound2, for_node)

605 606
        return for_node

Stefan Behnel's avatar
Stefan Behnel committed
607
    def _transform_dict_iteration(self, node, dict_obj, keys, values):
608
        py_object_ptr = PyrexTypes.py_object_type
609 610

        temps = []
611 612 613
        temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
        temps.append(temp)
        dict_temp = temp.ref(dict_obj.pos)
614 615
        temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
        temps.append(temp)
616
        pos_temp = temp.ref(node.pos)
617 618 619
        pos_temp_addr = ExprNodes.AmpersandNode(
            node.pos, operand=pos_temp,
            type=PyrexTypes.c_ptr_type(PyrexTypes.c_py_ssize_t_type))
620 621

        target_temps = []
622
        if keys:
623 624 625
            temp = UtilNodes.TempHandle(
                py_object_ptr, needs_cleanup=False) # ref will be stolen
            target_temps.append(temp)
626 627 628 629 630 631 632 633
            key_temp = temp.ref(node.target.pos)
            key_temp_addr = ExprNodes.AmpersandNode(
                node.target.pos, operand=key_temp,
                type=PyrexTypes.c_ptr_type(py_object_ptr))
        else:
            key_temp_addr = key_temp = ExprNodes.NullNode(
                pos=node.target.pos)
        if values:
634 635 636
            temp = UtilNodes.TempHandle(
                py_object_ptr, needs_cleanup=False) # ref will be stolen
            target_temps.append(temp)
637 638 639 640 641 642 643 644 645 646 647 648 649 650 651
            value_temp = temp.ref(node.target.pos)
            value_temp_addr = ExprNodes.AmpersandNode(
                node.target.pos, operand=value_temp,
                type=PyrexTypes.c_ptr_type(py_object_ptr))
        else:
            value_temp_addr = value_temp = ExprNodes.NullNode(
                pos=node.target.pos)

        key_target = value_target = node.target
        tuple_target = None
        if keys and values:
            if node.target.is_sequence_constructor:
                if len(node.target.args) == 2:
                    key_target, value_target = node.target.args
                else:
Stefan Behnel's avatar
Stefan Behnel committed
652
                    # unusual case that may or may not lead to an error
653 654 655 656
                    return node
            else:
                tuple_target = node.target

657 658
        def coerce_object_to(obj_node, dest_type):
            if dest_type.is_pyobject:
659 660 661
                if dest_type != obj_node.type:
                    if dest_type.is_extension_type or dest_type.is_builtin_type:
                        obj_node = ExprNodes.PyTypeTestNode(
662
                            obj_node, dest_type, self.current_scope, notnone=True)
663 664 665 666
                result = ExprNodes.TypecastNode(
                    obj_node.pos,
                    operand = obj_node,
                    type = dest_type)
667
                return (result, None)
668 669
            else:
                temp = UtilNodes.TempHandle(dest_type)
670
                target_temps.append(temp)
671 672 673 674 675 676
                temp_result = temp.ref(obj_node.pos)
                class CoercedTempNode(ExprNodes.CoerceFromPyTypeNode):
                    def result(self):
                        return temp_result.result()
                    def generate_execution_code(self, code):
                        self.generate_result_code(code)
677
                return (temp_result, CoercedTempNode(dest_type, obj_node, self.current_scope))
678 679

        if tuple_target:
680
            tuple_result = ExprNodes.TupleNode(
681
                pos = tuple_target.pos,
682
                args = [key_temp, value_temp],
683 684
                is_temp = 1,
                type = Builtin.tuple_type,
685
                )
686 687
            body_init_stats = [
                Nodes.SingleAssignmentNode(
688 689
                    pos = tuple_target.pos,
                    lhs = tuple_target,
690 691
                    rhs = tuple_result)
                ]
692
        else:
693 694 695
            # execute all coercions before the assignments
            coercion_stats = []
            assign_stats = []
696
            if keys:
697 698 699 700 701 702 703
                temp_result, coercion = coerce_object_to(
                    key_temp, key_target.type)
                if coercion:
                    coercion_stats.append(coercion)
                assign_stats.append(
                    Nodes.SingleAssignmentNode(
                        pos = key_temp.pos,
704 705
                        lhs = key_target,
                        rhs = temp_result))
706 707 708 709 710 711 712 713
            if values:
                temp_result, coercion = coerce_object_to(
                    value_temp, value_target.type)
                if coercion:
                    coercion_stats.append(coercion)
                assign_stats.append(
                    Nodes.SingleAssignmentNode(
                        pos = value_temp.pos,
714 715
                        lhs = value_target,
                        rhs = temp_result))
716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738
            body_init_stats = coercion_stats + assign_stats

        if isinstance(node.body, Nodes.StatListNode):
            body = node.body
        else:
            body = Nodes.StatListNode(pos = node.body.pos,
                                      stats = [node.body])

        # keep original length to guard against dict modification
        dict_len_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
        temps.append(dict_len_temp)

        body_init_stats.insert(0, Nodes.DictIterationNextNode(
            dict_temp,
            dict_len_temp.ref(dict_obj.pos),
            pos_temp_addr, key_temp_addr, value_temp_addr
            ))
        body.stats[0:0] = [UtilNodes.TempsBlockNode(
            node.pos,
            temps = target_temps,
            body = Nodes.StatListNode(pos = node.pos,
                                      stats = body_init_stats)
            )]
739 740

        result_code = [
741 742 743 744
            Nodes.SingleAssignmentNode(
                pos = dict_obj.pos,
                lhs = dict_temp,
                rhs = dict_obj),
745 746 747
            Nodes.SingleAssignmentNode(
                pos = node.pos,
                lhs = pos_temp,
748 749
                rhs = ExprNodes.IntNode(node.pos, value='0',
                                        constant_result=0)),
750 751 752 753
            Nodes.SingleAssignmentNode(
                pos = dict_obj.pos,
                lhs = dict_len_temp.ref(dict_obj.pos),
                rhs = ExprNodes.SimpleCallNode(
754
                    pos = dict_obj.pos,
755
                    type = PyrexTypes.c_py_ssize_t_type,
756
                    function = ExprNodes.NameNode(
Stefan Behnel's avatar
Stefan Behnel committed
757
                        pos = dict_obj.pos,
758 759 760 761 762 763 764 765
                        name = self.PyDict_Size_name,
                        type = self.PyDict_Size_func_type,
                        entry = self.PyDict_Size_entry),
                    args = [dict_temp],
                )),
            Nodes.WhileStatNode(
                pos = node.pos,
                condition = None,
766 767 768 769 770 771 772 773
                body = body,
                else_clause = node.else_clause
                )
            ]

        return UtilNodes.TempsBlockNode(
            node.pos, temps=temps,
            body=Nodes.StatListNode(
774
                node.pos,
775 776 777 778
                stats = result_code
                ))


779 780
class SwitchTransform(Visitor.VisitorTransform):
    """
781
    This transformation tries to turn long if statements into C switch statements.
782
    The requirement is that every clause be an (or of) var == value, where the var
783
    is common among all clauses and both var and value are ints.
784
    """
785 786 787
    NO_MATCH = (None, None, None)

    def extract_conditions(self, cond, allow_not_in):
788 789 790 791 792 793 794 795 796 797
        while True:
            if isinstance(cond, ExprNodes.CoerceToTempNode):
                cond = cond.arg
            elif isinstance(cond, UtilNodes.EvalWithTempExprNode):
                # this is what we get from the FlattenInListTransform
                cond = cond.subexpression
            elif isinstance(cond, ExprNodes.TypecastNode):
                cond = cond.operand
            else:
                break
798

799
        if isinstance(cond, ExprNodes.PrimaryCmpNode):
800 801 802 803 804 805 806 807 808 809 810 811 812 813 814
            if cond.cascade is not None:
                return self.NO_MATCH
            elif cond.is_c_string_contains() and \
                   isinstance(cond.operand2, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)):
                not_in = cond.operator == 'not_in'
                if not_in and not allow_not_in:
                    return self.NO_MATCH
                if isinstance(cond.operand2, ExprNodes.UnicodeNode) and \
                       cond.operand2.contains_surrogates():
                    # dealing with surrogates leads to different
                    # behaviour on wide and narrow Unicode
                    # platforms => refuse to optimise this case
                    return self.NO_MATCH
                return not_in, cond.operand1, self.extract_in_string_conditions(cond.operand2)
            elif not cond.is_python_comparison():
815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844
                if cond.operator == '==':
                    not_in = False
                elif allow_not_in and cond.operator == '!=':
                    not_in = True
                else:
                    return self.NO_MATCH
                # this looks somewhat silly, but it does the right
                # checks for NameNode and AttributeNode
                if is_common_value(cond.operand1, cond.operand1):
                    if cond.operand2.is_literal:
                        return not_in, cond.operand1, [cond.operand2]
                    elif getattr(cond.operand2, 'entry', None) \
                             and cond.operand2.entry.is_const:
                        return not_in, cond.operand1, [cond.operand2]
                if is_common_value(cond.operand2, cond.operand2):
                    if cond.operand1.is_literal:
                        return not_in, cond.operand2, [cond.operand1]
                    elif getattr(cond.operand1, 'entry', None) \
                             and cond.operand1.entry.is_const:
                        return not_in, cond.operand2, [cond.operand1]
        elif isinstance(cond, ExprNodes.BoolBinopNode):
            if cond.operator == 'or' or (allow_not_in and cond.operator == 'and'):
                allow_not_in = (cond.operator == 'and')
                not_in_1, t1, c1 = self.extract_conditions(cond.operand1, allow_not_in)
                not_in_2, t2, c2 = self.extract_conditions(cond.operand2, allow_not_in)
                if t1 is not None and not_in_1 == not_in_2 and is_common_value(t1, t2):
                    if (not not_in_1) or allow_not_in:
                        return not_in_1, t1, c1+c2
        return self.NO_MATCH

845 846
    def extract_in_string_conditions(self, string_literal):
        if isinstance(string_literal, ExprNodes.UnicodeNode):
847
            charvals = list(map(ord, set(string_literal.value)))
848 849 850 851 852 853 854 855 856
            charvals.sort()
            return [ ExprNodes.IntNode(string_literal.pos, value=str(charval),
                                       constant_result=charval)
                     for charval in charvals ]
        else:
            # this is a bit tricky as Py3's bytes type returns
            # integers on iteration, whereas Py2 returns 1-char byte
            # strings
            characters = string_literal.value
857 858
            characters = list(set([ characters[i:i+1] for i in range(len(characters)) ]))
            characters.sort()
859 860 861 862
            return [ ExprNodes.CharNode(string_literal.pos, value=charval,
                                        constant_result=charval)
                     for charval in characters ]

863 864
    def extract_common_conditions(self, common_var, condition, allow_not_in):
        not_in, var, conditions = self.extract_conditions(condition, allow_not_in)
865
        if var is None:
866
            return self.NO_MATCH
867
        elif common_var is not None and not is_common_value(var, common_var):
868
            return self.NO_MATCH
869
        elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885
            return self.NO_MATCH
        return not_in, var, conditions

    def has_duplicate_values(self, condition_values):
        # duplicated values don't work in a switch statement
        seen = set()
        for value in condition_values:
            if value.constant_result is not ExprNodes.not_a_constant:
                if value.constant_result in seen:
                    return True
                seen.add(value.constant_result)
            else:
                # this isn't completely safe as we don't know the
                # final C value, but this is about the best we can do
                seen.add(getattr(getattr(value, 'entry', None), 'cname'))
        return False
886

887 888 889 890
    def visit_IfStatNode(self, node):
        common_var = None
        cases = []
        for if_clause in node.if_clauses:
891 892
            _, common_var, conditions = self.extract_common_conditions(
                common_var, if_clause.condition, False)
893
            if common_var is None:
894
                self.visitchildren(node)
895
                return node
896 897 898 899 900
            cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
                                              conditions = conditions,
                                              body = if_clause.body))

        if sum([ len(case.conditions) for case in cases ]) < 2:
901 902 903 904
            self.visitchildren(node)
            return node
        if self.has_duplicate_values(sum([case.conditions for case in cases], [])):
            self.visitchildren(node)
905
            return node
906

Robert Bradshaw's avatar
Robert Bradshaw committed
907
        common_var = unwrap_node(common_var)
908 909 910 911 912 913 914
        switch_node = Nodes.SwitchStatNode(pos = node.pos,
                                           test = common_var,
                                           cases = cases,
                                           else_clause = node.else_clause)
        return switch_node

    def visit_CondExprNode(self, node):
915 916 917 918 919 920
        not_in, common_var, conditions = self.extract_common_conditions(
            None, node.test, True)
        if common_var is None \
               or len(conditions) < 2 \
               or self.has_duplicate_values(conditions):
            self.visitchildren(node)
921
            return node
922 923 924
        return self.build_simple_switch_statement(
            node, common_var, conditions, not_in,
            node.true_val, node.false_val)
925 926

    def visit_BoolBinopNode(self, node):
927 928 929 930 931 932
        not_in, common_var, conditions = self.extract_common_conditions(
            None, node, True)
        if common_var is None \
               or len(conditions) < 2 \
               or self.has_duplicate_values(conditions):
            self.visitchildren(node)
933 934
            return node

935 936
        return self.build_simple_switch_statement(
            node, common_var, conditions, not_in,
937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952
            ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
            ExprNodes.BoolNode(node.pos, value=False, constant_result=False))

    def visit_PrimaryCmpNode(self, node):
        not_in, common_var, conditions = self.extract_common_conditions(
            None, node, True)
        if common_var is None \
               or len(conditions) < 2 \
               or self.has_duplicate_values(conditions):
            self.visitchildren(node)
            return node

        return self.build_simple_switch_statement(
            node, common_var, conditions, not_in,
            ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
            ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
953 954 955

    def build_simple_switch_statement(self, node, common_var, conditions,
                                      not_in, true_val, false_val):
956 957 958 959
        result_ref = UtilNodes.ResultRefNode(node)
        true_body = Nodes.SingleAssignmentNode(
            node.pos,
            lhs = result_ref,
960
            rhs = true_val,
961 962 963 964
            first = True)
        false_body = Nodes.SingleAssignmentNode(
            node.pos,
            lhs = result_ref,
965
            rhs = false_val,
966 967
            first = True)

968 969 970
        if not_in:
            true_body, false_body = false_body, true_body

971 972 973 974 975 976 977 978 979 980
        cases = [Nodes.SwitchCaseNode(pos = node.pos,
                                      conditions = conditions,
                                      body = true_body)]

        common_var = unwrap_node(common_var)
        switch_node = Nodes.SwitchStatNode(pos = node.pos,
                                           test = common_var,
                                           cases = cases,
                                           else_clause = false_body)
        return UtilNodes.TempResultFromStatNode(result_ref, switch_node)
981

982
    visit_Node = Visitor.VisitorTransform.recurse_to_children
983

984

985
class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
986 987
    """
    This transformation flattens "x in [val1, ..., valn]" into a sequential list
988
    of comparisons.
989
    """
990

991 992 993 994 995 996 997 998 999 1000 1001 1002
    def visit_PrimaryCmpNode(self, node):
        self.visitchildren(node)
        if node.cascade is not None:
            return node
        elif node.operator == 'in':
            conjunction = 'or'
            eq_or_neq = '=='
        elif node.operator == 'not_in':
            conjunction = 'and'
            eq_or_neq = '!='
        else:
            return node
1003

1004 1005 1006
        if not isinstance(node.operand2, (ExprNodes.TupleNode,
                                          ExprNodes.ListNode,
                                          ExprNodes.SetNode)):
Stefan Behnel's avatar
Stefan Behnel committed
1007
            return node
1008

Stefan Behnel's avatar
Stefan Behnel committed
1009 1010 1011
        args = node.operand2.args
        if len(args) == 0:
            return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in')
1012

1013
        lhs = UtilNodes.ResultRefNode(node.operand1)
Stefan Behnel's avatar
Stefan Behnel committed
1014 1015

        conds = []
1016
        temps = []
Stefan Behnel's avatar
Stefan Behnel committed
1017
        for arg in args:
1018 1019 1020 1021 1022 1023 1024 1025 1026
            try:
                # Trial optimisation to avoid redundant temp
                # assignments.  However, since is_simple() is meant to
                # be called after type analysis, we ignore any errors
                # and just play safe in that case.
                is_simple_arg = arg.is_simple()
            except Exception:
                is_simple_arg = False
            if not is_simple_arg:
1027 1028 1029
                # must evaluate all non-simple RHS before doing the comparisons
                arg = UtilNodes.LetRefNode(arg)
                temps.append(arg)
Stefan Behnel's avatar
Stefan Behnel committed
1030 1031 1032 1033 1034 1035 1036
            cond = ExprNodes.PrimaryCmpNode(
                                pos = node.pos,
                                operand1 = lhs,
                                operator = eq_or_neq,
                                operand2 = arg,
                                cascade = None)
            conds.append(ExprNodes.TypecastNode(
1037
                                pos = node.pos,
Stefan Behnel's avatar
Stefan Behnel committed
1038 1039 1040 1041
                                operand = cond,
                                type = PyrexTypes.c_bint_type))
        def concat(left, right):
            return ExprNodes.BoolBinopNode(
1042
                                pos = node.pos,
Stefan Behnel's avatar
Stefan Behnel committed
1043 1044 1045 1046
                                operator = conjunction,
                                operand1 = left,
                                operand2 = right)

1047
        condition = reduce(concat, conds)
1048 1049 1050 1051
        new_node = UtilNodes.EvalWithTempExprNode(lhs, condition)
        for temp in temps[::-1]:
            new_node = UtilNodes.EvalWithTempExprNode(temp, new_node)
        return new_node
1052

1053
    visit_Node = Visitor.VisitorTransform.recurse_to_children
1054 1055


1056 1057 1058 1059 1060 1061
class DropRefcountingTransform(Visitor.VisitorTransform):
    """Drop ref-counting in safe places.
    """
    visit_Node = Visitor.VisitorTransform.recurse_to_children

    def visit_ParallelAssignmentNode(self, node):
Stefan Behnel's avatar
Stefan Behnel committed
1062 1063 1064
        """
        Parallel swap assignments like 'a,b = b,a' are safe.
        """
1065 1066 1067 1068
        left_names, right_names = [], []
        left_indices, right_indices = [], []
        temps = []

1069 1070
        for stat in node.stats:
            if isinstance(stat, Nodes.SingleAssignmentNode):
1071 1072
                if not self._extract_operand(stat.lhs, left_names,
                                             left_indices, temps):
1073
                    return node
1074 1075
                if not self._extract_operand(stat.rhs, right_names,
                                             right_indices, temps):
1076
                    return node
1077 1078 1079
            elif isinstance(stat, Nodes.CascadedAssignmentNode):
                # FIXME
                return node
1080 1081 1082
            else:
                return node

1083 1084
        if left_names or right_names:
            # lhs/rhs names must be a non-redundant permutation
1085 1086
            lnames = [ path for path, n in left_names ]
            rnames = [ path for path, n in right_names ]
1087 1088 1089
            if set(lnames) != set(rnames):
                return node
            if len(set(lnames)) != len(right_names):
1090 1091
                return node

1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106
        if left_indices or right_indices:
            # base name and index of index nodes must be a
            # non-redundant permutation
            lindices = []
            for lhs_node in left_indices:
                index_id = self._extract_index_id(lhs_node)
                if not index_id:
                    return node
                lindices.append(index_id)
            rindices = []
            for rhs_node in right_indices:
                index_id = self._extract_index_id(rhs_node)
                if not index_id:
                    return node
                rindices.append(index_id)
1107

1108 1109 1110 1111 1112 1113 1114
            if set(lindices) != set(rindices):
                return node
            if len(set(lindices)) != len(right_indices):
                return node

            # really supporting IndexNode requires support in
            # __Pyx_GetItemInt(), so let's stop short for now
1115 1116
            return node

1117 1118 1119 1120
        temp_args = [t.arg for t in temps]
        for temp in temps:
            temp.use_managed_ref = False

1121
        for _, name_node in left_names + right_names:
1122 1123 1124 1125 1126
            if name_node not in temp_args:
                name_node.use_managed_ref = False

        for index_node in left_indices + right_indices:
            index_node.use_managed_ref = False
1127 1128 1129

        return node

1130 1131 1132 1133 1134 1135 1136
    def _extract_operand(self, node, names, indices, temps):
        node = unwrap_node(node)
        if not node.type.is_pyobject:
            return False
        if isinstance(node, ExprNodes.CoerceToTempNode):
            temps.append(node)
            node = node.arg
1137 1138 1139 1140
        name_path = []
        obj_node = node
        while isinstance(obj_node, ExprNodes.AttributeNode):
            if obj_node.is_py_attr:
1141
                return False
1142 1143 1144 1145 1146
            name_path.append(obj_node.member)
            obj_node = obj_node.obj
        if isinstance(obj_node, ExprNodes.NameNode):
            name_path.append(obj_node.name)
            names.append( ('.'.join(name_path[::-1]), node) )
1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170
        elif isinstance(node, ExprNodes.IndexNode):
            if node.base.type != Builtin.list_type:
                return False
            if not node.index.type.is_int:
                return False
            if not isinstance(node.base, ExprNodes.NameNode):
                return False
            indices.append(node)
        else:
            return False
        return True

    def _extract_index_id(self, index_node):
        base = index_node.base
        index = index_node.index
        if isinstance(index, ExprNodes.NameNode):
            index_val = index.name
        elif isinstance(index, ExprNodes.ConstNode):
            # FIXME:
            return None
        else:
            return None
        return (base.name, index_val)

1171

1172 1173 1174 1175 1176 1177 1178
class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
    """Optimize some common calls to builtin types *before* the type
    analysis phase and *after* the declarations analysis phase.

    This transform cannot make use of any argument types, but it can
    restructure the tree in a way that the type analysis phase can
    respond to.
Stefan Behnel's avatar
Stefan Behnel committed
1179 1180 1181 1182

    Introducing C function calls here may not be a good idea.  Move
    them to the OptimizeBuiltinCalls transform instead, which runs
    after type analyis.
1183
    """
1184 1185
    # only intercept on call nodes
    visit_Node = Visitor.VisitorTransform.recurse_to_children
1186

1187 1188 1189 1190 1191 1192 1193
    def visit_SimpleCallNode(self, node):
        self.visitchildren(node)
        function = node.function
        if not self._function_is_builtin_name(function):
            return node
        return self._dispatch_to_handler(node, function, node.args)

1194
    def visit_GeneralCallNode(self, node):
1195
        self.visitchildren(node)
1196
        function = node.function
1197
        if not self._function_is_builtin_name(function):
1198 1199 1200 1201
            return node
        arg_tuple = node.positional_args
        if not isinstance(arg_tuple, ExprNodes.TupleNode):
            return node
1202
        args = arg_tuple.args
1203
        return self._dispatch_to_handler(
1204
            node, function, args, node.keyword_args)
1205

1206 1207 1208
    def _function_is_builtin_name(self, function):
        if not function.is_name:
            return False
1209 1210 1211
        env = self.current_env()
        entry = env.lookup(function.name)
        if entry is not env.builtin_scope().lookup_here(function.name):
1212
            return False
1213
        # if entry is None, it's at least an undeclared name, so likely builtin
1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251
        return True

    def _dispatch_to_handler(self, node, function, args, kwargs=None):
        if kwargs is None:
            handler_name = '_handle_simple_function_%s' % function.name
        else:
            handler_name = '_handle_general_function_%s' % function.name
        handle_call = getattr(self, handler_name, None)
        if handle_call is not None:
            if kwargs is None:
                return handle_call(node, args)
            else:
                return handle_call(node, args, kwargs)
        return node

    def _inject_capi_function(self, node, cname, func_type, utility_code=None):
        node.function = ExprNodes.PythonCapiFunctionNode(
            node.function.pos, node.function.name, cname, func_type,
            utility_code = utility_code)

    def _error_wrong_arg_count(self, function_name, node, args, expected=None):
        if not expected: # None or 0
            arg_str = ''
        elif isinstance(expected, basestring) or expected > 1:
            arg_str = '...'
        elif expected == 1:
            arg_str = 'x'
        else:
            arg_str = ''
        if expected is not None:
            expected_str = 'expected %s, ' % expected
        else:
            expected_str = ''
        error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
            function_name, arg_str, expected_str, len(args)))

    # specific handlers for simple call nodes

1252 1253 1254 1255 1256 1257 1258
    def _handle_simple_function_float(self, node, pos_args):
        if len(pos_args) == 0:
            return ExprNodes.FloatNode(node.pos, value='0.0')
        if len(pos_args) > 1:
            self._error_wrong_arg_count('float', node, pos_args, 1)
        return node

1259 1260 1261
    class YieldNodeCollector(Visitor.TreeVisitor):
        def __init__(self):
            Visitor.TreeVisitor.__init__(self)
1262
            self.yield_stat_nodes = {}
1263 1264 1265
            self.yield_nodes = []

        visit_Node = Visitor.TreeVisitor.visitchildren
1266 1267
        # XXX: disable inlining while it's not back supported
        def __visit_YieldExprNode(self, node):
1268 1269 1270
            self.yield_nodes.append(node)
            self.visitchildren(node)

1271
        def __visit_ExprStatNode(self, node):
1272 1273 1274 1275
            self.visitchildren(node)
            if node.expr in self.yield_nodes:
                self.yield_stat_nodes[node.expr] = node

1276 1277 1278 1279 1280 1281
        def __visit_GeneratorExpressionNode(self, node):
            # enable when we support generic generator expressions
            #
            # everything below this node is out of scope
            pass

1282
    def _find_single_yield_expression(self, node):
1283 1284 1285
        collector = self.YieldNodeCollector()
        collector.visitchildren(node)
        if len(collector.yield_nodes) != 1:
1286 1287
            return None, None
        yield_node = collector.yield_nodes[0]
1288 1289 1290 1291
        try:
            return (yield_node.arg, collector.yield_stat_nodes[yield_node])
        except KeyError:
            return None, None
1292

1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337
    def _handle_simple_function_all(self, node, pos_args):
        """Transform

        _result = all(x for L in LL for x in L)

        into

        for L in LL:
            for x in L:
                if not x:
                    _result = False
                    break
            else:
                continue
            break
        else:
            _result = True
        """
        return self._transform_any_all(node, pos_args, False)

    def _handle_simple_function_any(self, node, pos_args):
        """Transform

        _result = any(x for L in LL for x in L)

        into

        for L in LL:
            for x in L:
                if x:
                    _result = True
                    break
            else:
                continue
            break
        else:
            _result = False
        """
        return self._transform_any_all(node, pos_args, True)

    def _transform_any_all(self, node, pos_args, is_any):
        if len(pos_args) != 1:
            return node
        if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
            return node
1338 1339
        gen_expr_node = pos_args[0]
        loop_node = gen_expr_node.loop
1340 1341
        yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
        if yield_expression is None:
1342 1343 1344 1345 1346 1347 1348
            return node

        if is_any:
            condition = yield_expression
        else:
            condition = ExprNodes.NotNode(yield_expression.pos, operand = yield_expression)

1349
        result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.c_bint_type)
1350
        test_node = Nodes.IfStatNode(
1351
            yield_expression.pos,
1352 1353
            else_clause = None,
            if_clauses = [ Nodes.IfClauseNode(
1354
                yield_expression.pos,
1355 1356 1357 1358 1359 1360 1361
                condition = condition,
                body = Nodes.StatListNode(
                    node.pos,
                    stats = [
                        Nodes.SingleAssignmentNode(
                            node.pos,
                            lhs = result_ref,
1362
                            rhs = ExprNodes.BoolNode(yield_expression.pos, value = is_any,
1363 1364 1365 1366 1367 1368 1369 1370 1371
                                                     constant_result = is_any)),
                        Nodes.BreakStatNode(node.pos)
                        ])) ]
            )
        loop = loop_node
        while isinstance(loop.body, Nodes.LoopNode):
            next_loop = loop.body
            loop.body = Nodes.StatListNode(loop.body.pos, stats = [
                loop.body,
1372
                Nodes.BreakStatNode(yield_expression.pos)
1373
                ])
1374
            next_loop.else_clause = Nodes.ContinueStatNode(yield_expression.pos)
1375 1376 1377 1378
            loop = next_loop
        loop_node.else_clause = Nodes.SingleAssignmentNode(
            node.pos,
            lhs = result_ref,
1379
            rhs = ExprNodes.BoolNode(yield_expression.pos, value = not is_any,
1380 1381
                                     constant_result = not is_any))

1382
        Visitor.recursively_replace_node(loop_node, yield_stat_node, test_node)
1383

1384 1385
        return ExprNodes.InlinedGeneratorExpressionNode(
            gen_expr_node.pos, loop = loop_node, result_node = result_ref,
1386
            expr_scope = gen_expr_node.expr_scope, orig_func = is_any and 'any' or 'all')
1387

1388
    def _handle_simple_function_sorted(self, node, pos_args):
1389 1390 1391 1392 1393
        """Transform sorted(genexpr) and sorted([listcomp]) into
        [listcomp].sort().  CPython just reads the iterable into a
        list and calls .sort() on it.  Expanding the iterable in a
        listcomp is still faster and the result can be sorted in
        place.
1394 1395 1396
        """
        if len(pos_args) != 1:
            return node
1397 1398 1399 1400 1401 1402 1403 1404 1405 1406
        if isinstance(pos_args[0], ExprNodes.ComprehensionNode) \
               and pos_args[0].target.type is Builtin.list_type:
            listcomp_node = pos_args[0]
            loop_node = listcomp_node.loop
        elif isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
            gen_expr_node = pos_args[0]
            loop_node = gen_expr_node.loop
            yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
            if yield_expression is None:
                return node
1407

1408 1409 1410 1411
            target = ExprNodes.ListNode(node.pos, args = [])
            append_node = ExprNodes.ComprehensionAppendNode(
                yield_expression.pos, expr = yield_expression,
                target = ExprNodes.CloneNode(target))
1412

1413
            Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1414

1415 1416 1417 1418 1419 1420 1421
            listcomp_node = ExprNodes.ComprehensionNode(
                gen_expr_node.pos, loop = loop_node, target = target,
                append = append_node, type = Builtin.list_type,
                expr_scope = gen_expr_node.expr_scope,
                has_local_scope = True)
        else:
            return node
1422

1423 1424
        result_node = UtilNodes.ResultRefNode(
            pos = loop_node.pos, type = Builtin.list_type, may_hold_none=False)
1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441
        listcomp_assign_node = Nodes.SingleAssignmentNode(
            node.pos, lhs = result_node, rhs = listcomp_node, first = True)

        sort_method = ExprNodes.AttributeNode(
            node.pos, obj = result_node, attribute = EncodedString('sort'),
            # entry ? type ?
            needs_none_check = False)
        sort_node = Nodes.ExprStatNode(
            node.pos, expr = ExprNodes.SimpleCallNode(
                node.pos, function = sort_method, args = []))

        sort_node.analyse_declarations(self.current_env())

        return UtilNodes.TempResultFromStatNode(
            result_node,
            Nodes.StatListNode(node.pos, stats = [ listcomp_assign_node, sort_node ]))

1442
    def _handle_simple_function_sum(self, node, pos_args):
Stefan Behnel's avatar
Stefan Behnel committed
1443 1444
        """Transform sum(genexpr) into an equivalent inlined aggregation loop.
        """
1445 1446
        if len(pos_args) not in (1,2):
            return node
1447 1448
        if not isinstance(pos_args[0], (ExprNodes.GeneratorExpressionNode,
                                        ExprNodes.ComprehensionNode)):
1449 1450 1451 1452
            return node
        gen_expr_node = pos_args[0]
        loop_node = gen_expr_node.loop

1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466
        if isinstance(gen_expr_node, ExprNodes.GeneratorExpressionNode):
            yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
            if yield_expression is None:
                return node
        else: # ComprehensionNode
            yield_stat_node = gen_expr_node.append
            yield_expression = yield_stat_node.expr
            try:
                if not yield_expression.is_literal or not yield_expression.type.is_int:
                    return node
            except AttributeError:
                return node # in case we don't have a type yet
            # special case: old Py2 backwards compatible "sum([int_const for ...])"
            # can safely be unpacked into a genexpr
1467 1468 1469 1470 1471 1472 1473 1474

        if len(pos_args) == 1:
            start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
        else:
            start = pos_args[1]

        result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.py_object_type)
        add_node = Nodes.SingleAssignmentNode(
1475
            yield_expression.pos,
1476 1477 1478 1479
            lhs = result_ref,
            rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression)
            )

1480
        Visitor.recursively_replace_node(loop_node, yield_stat_node, add_node)
1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494

        exec_code = Nodes.StatListNode(
            node.pos,
            stats = [
                Nodes.SingleAssignmentNode(
                    start.pos,
                    lhs = UtilNodes.ResultRefNode(pos=node.pos, expression=result_ref),
                    rhs = start,
                    first = True),
                loop_node
                ])

        return ExprNodes.InlinedGeneratorExpressionNode(
            gen_expr_node.pos, loop = exec_code, result_node = result_ref,
1495 1496
            expr_scope = gen_expr_node.expr_scope, orig_func = 'sum',
            has_local_scope = gen_expr_node.has_local_scope)
1497

1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510
    def _handle_simple_function_min(self, node, pos_args):
        return self._optimise_min_max(node, pos_args, '<')

    def _handle_simple_function_max(self, node, pos_args):
        return self._optimise_min_max(node, pos_args, '>')

    def _optimise_min_max(self, node, args, operator):
        """Replace min(a,b,...) and max(a,b,...) by explicit comparison code.
        """
        if len(args) <= 1:
            # leave this to Python
            return node

1511
        cascaded_nodes = list(map(UtilNodes.ResultRefNode, args[1:]))
1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533

        last_result = args[0]
        for arg_node in cascaded_nodes:
            result_ref = UtilNodes.ResultRefNode(last_result)
            last_result = ExprNodes.CondExprNode(
                arg_node.pos,
                true_val = arg_node,
                false_val = result_ref,
                test = ExprNodes.PrimaryCmpNode(
                    arg_node.pos,
                    operand1 = arg_node,
                    operator = operator,
                    operand2 = result_ref,
                    )
                )
            last_result = UtilNodes.EvalWithTempExprNode(result_ref, last_result)

        for ref_node in cascaded_nodes[::-1]:
            last_result = UtilNodes.EvalWithTempExprNode(ref_node, last_result)

        return last_result

1534
    def _DISABLED_handle_simple_function_tuple(self, node, pos_args):
1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548
        if len(pos_args) == 0:
            return ExprNodes.TupleNode(node.pos, args=[], constant_result=())
        # This is a bit special - for iterables (including genexps),
        # Python actually overallocates and resizes a newly created
        # tuple incrementally while reading items, which we can't
        # easily do without explicit node support. Instead, we read
        # the items into a list and then copy them into a tuple of the
        # final size.  This takes up to twice as much memory, but will
        # have to do until we have real support for genexps.
        result = self._transform_list_set_genexpr(node, pos_args, ExprNodes.ListNode)
        if result is not node:
            return ExprNodes.AsTupleNode(node.pos, arg=result)
        return node

1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568
    def _handle_simple_function_list(self, node, pos_args):
        if len(pos_args) == 0:
            return ExprNodes.ListNode(node.pos, args=[], constant_result=[])
        return self._transform_list_set_genexpr(node, pos_args, ExprNodes.ListNode)

    def _handle_simple_function_set(self, node, pos_args):
        if len(pos_args) == 0:
            return ExprNodes.SetNode(node.pos, args=[], constant_result=set())
        return self._transform_list_set_genexpr(node, pos_args, ExprNodes.SetNode)

    def _transform_list_set_genexpr(self, node, pos_args, container_node_class):
        """Replace set(genexpr) and list(genexpr) by a literal comprehension.
        """
        if len(pos_args) > 1:
            return node
        if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
            return node
        gen_expr_node = pos_args[0]
        loop_node = gen_expr_node.loop

1569 1570
        yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
        if yield_expression is None:
1571 1572 1573 1574
            return node

        target_node = container_node_class(node.pos, args=[])
        append_node = ExprNodes.ComprehensionAppendNode(
1575
            yield_expression.pos,
1576
            expr = yield_expression,
Stefan Behnel's avatar
Stefan Behnel committed
1577
            target = ExprNodes.CloneNode(target_node))
1578

1579
        Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602

        setcomp = ExprNodes.ComprehensionNode(
            node.pos,
            has_local_scope = True,
            expr_scope = gen_expr_node.expr_scope,
            loop = loop_node,
            append = append_node,
            target = target_node)
        append_node.target = setcomp
        return setcomp

    def _handle_simple_function_dict(self, node, pos_args):
        """Replace dict( (a,b) for ... ) by a literal { a:b for ... }.
        """
        if len(pos_args) == 0:
            return ExprNodes.DictNode(node.pos, key_value_pairs=[], constant_result={})
        if len(pos_args) > 1:
            return node
        if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
            return node
        gen_expr_node = pos_args[0]
        loop_node = gen_expr_node.loop

1603 1604
        yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
        if yield_expression is None:
1605 1606 1607 1608 1609 1610 1611 1612 1613
            return node

        if not isinstance(yield_expression, ExprNodes.TupleNode):
            return node
        if len(yield_expression.args) != 2:
            return node

        target_node = ExprNodes.DictNode(node.pos, key_value_pairs=[])
        append_node = ExprNodes.DictComprehensionAppendNode(
1614
            yield_expression.pos,
1615 1616
            key_expr = yield_expression.args[0],
            value_expr = yield_expression.args[1],
Stefan Behnel's avatar
Stefan Behnel committed
1617
            target = ExprNodes.CloneNode(target_node))
1618

1619
        Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630

        dictcomp = ExprNodes.ComprehensionNode(
            node.pos,
            has_local_scope = True,
            expr_scope = gen_expr_node.expr_scope,
            loop = loop_node,
            append = append_node,
            target = target_node)
        append_node.target = dictcomp
        return dictcomp

1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643
    # specific handlers for general call nodes

    def _handle_general_function_dict(self, node, pos_args, kwargs):
        """Replace dict(a=b,c=d,...) by the underlying keyword dict
        construction which is done anyway.
        """
        if len(pos_args) > 0:
            return node
        if not isinstance(kwargs, ExprNodes.DictNode):
            return node
        return kwargs


1644
class OptimizeBuiltinCalls(Visitor.EnvTransform):
Stefan Behnel's avatar
Stefan Behnel committed
1645
    """Optimize some common methods calls and instantiation patterns
1646 1647 1648 1649 1650
    for builtin types *after* the type analysis phase.

    Running after type analysis, this transform can only perform
    function replacements that do not alter the function return type
    in a way that was not anticipated by the type analysis.
1651
    """
1652 1653
    # only intercept on call nodes
    visit_Node = Visitor.VisitorTransform.recurse_to_children
1654

1655
    def visit_GeneralCallNode(self, node):
1656
        self.visitchildren(node)
1657 1658
        function = node.function
        if not function.type.is_pyobject:
Stefan Behnel's avatar
Stefan Behnel committed
1659
            return node
1660 1661 1662
        arg_tuple = node.positional_args
        if not isinstance(arg_tuple, ExprNodes.TupleNode):
            return node
1663 1664 1665
        keyword_args = node.keyword_args
        if keyword_args and not isinstance(keyword_args, ExprNodes.DictNode):
            # can't handle **kwargs
1666
            return node
1667
        args = arg_tuple.args
1668
        return self._dispatch_to_handler(
1669
            node, function, args, keyword_args)
1670 1671 1672

    def visit_SimpleCallNode(self, node):
        self.visitchildren(node)
1673
        function = node.function
1674 1675 1676 1677 1678 1679 1680
        if function.type.is_pyobject:
            arg_tuple = node.arg_tuple
            if not isinstance(arg_tuple, ExprNodes.TupleNode):
                return node
            args = arg_tuple.args
        else:
            args = node.args
1681
        return self._dispatch_to_handler(
1682
            node, function, args)
1683

1684 1685
    ### cleanup to avoid redundant coercions to/from Python types

1686 1687 1688
    def _visit_PyTypeTestNode(self, node):
        # disabled - appears to break assignments in some cases, and
        # also drops a None check, which might still be required
1689 1690 1691 1692 1693 1694 1695 1696
        """Flatten redundant type checks after tree changes.
        """
        old_arg = node.arg
        self.visitchildren(node)
        if old_arg is node.arg or node.arg.type != node.type:
            return node
        return node.arg

1697 1698 1699
    def _visit_TypecastNode(self, node):
        # disabled - the user may have had a reason to put a type
        # cast, even if it looks redundant to Cython
1700 1701 1702 1703 1704 1705 1706 1707
        """
        Drop redundant type casts.
        """
        self.visitchildren(node)
        if node.type == node.operand.type:
            return node.operand
        return node

1708 1709 1710 1711 1712 1713 1714 1715 1716
    def visit_ExprStatNode(self, node):
        """
        Drop useless coercions.
        """
        self.visitchildren(node)
        if isinstance(node.expr, ExprNodes.CoerceToPyTypeNode):
            node.expr = node.expr.arg
        return node

1717 1718 1719 1720 1721
    def visit_CoerceToBooleanNode(self, node):
        """Drop redundant conversion nodes after tree changes.
        """
        self.visitchildren(node)
        arg = node.arg
Stefan Behnel's avatar
Stefan Behnel committed
1722 1723
        if isinstance(arg, ExprNodes.PyTypeTestNode):
            arg = arg.arg
1724 1725
        if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
            if arg.type in (PyrexTypes.py_object_type, Builtin.bool_type):
1726
                return arg.arg.coerce_to_boolean(self.current_env())
1727 1728
        return node

1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742
    def visit_CoerceFromPyTypeNode(self, node):
        """Drop redundant conversion nodes after tree changes.

        Also, optimise away calls to Python's builtin int() and
        float() if the result is going to be coerced back into a C
        type anyway.
        """
        self.visitchildren(node)
        arg = node.arg
        if not arg.type.is_pyobject:
            # no Python conversion left at all, just do a C coercion instead
            if node.type == arg.type:
                return arg
            else:
1743
                return arg.coerce_to(node.type, self.current_env())
1744 1745
        if isinstance(arg, ExprNodes.PyTypeTestNode):
            arg = arg.arg
1746 1747 1748 1749
        if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
            if arg.type is PyrexTypes.py_object_type:
                if node.type.assignable_from(arg.arg.type):
                    # completely redundant C->Py->C coercion
1750
                    return arg.arg.coerce_to(node.type, self.current_env())
Stefan Behnel's avatar
Stefan Behnel committed
1751 1752 1753
        if isinstance(arg, ExprNodes.SimpleCallNode):
            if node.type.is_int or node.type.is_float:
                return self._optimise_numeric_cast_call(node, arg)
1754 1755 1756 1757 1758 1759
        elif isinstance(arg, ExprNodes.IndexNode) and not arg.is_buffer_access:
            index_node = arg.index
            if isinstance(index_node, ExprNodes.CoerceToPyTypeNode):
                index_node = index_node.arg
            if index_node.type.is_int:
                return self._optimise_int_indexing(node, arg, index_node)
Stefan Behnel's avatar
Stefan Behnel committed
1760 1761
        return node

1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773
    PyBytes_GetItemInt_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_char_type, [
            PyrexTypes.CFuncTypeArg("bytes", Builtin.bytes_type, None),
            PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("check_bounds", PyrexTypes.c_int_type, None),
            ],
        exception_value = "((char)-1)",
        exception_check = True)

    def _optimise_int_indexing(self, coerce_node, arg, index_node):
        env = self.current_env()
        bound_check_bool = env.directives['boundscheck'] and 1 or 0
1774
        if arg.base.type is Builtin.bytes_type:
1775 1776 1777 1778 1779 1780 1781 1782 1783
            if coerce_node.type in (PyrexTypes.c_char_type, PyrexTypes.c_uchar_type):
                # bytes[index] -> char
                bound_check_node = ExprNodes.IntNode(
                    coerce_node.pos, value=str(bound_check_bool),
                    constant_result=bound_check_bool)
                node = ExprNodes.PythonCapiCallNode(
                    coerce_node.pos, "__Pyx_PyBytes_GetItemInt",
                    self.PyBytes_GetItemInt_func_type,
                    args = [
Stefan Behnel's avatar
Stefan Behnel committed
1784
                        arg.base.as_none_safe_node("'NoneType' object is not subscriptable"),
1785 1786 1787 1788 1789 1790 1791 1792 1793 1794
                        index_node.coerce_to(PyrexTypes.c_py_ssize_t_type, env),
                        bound_check_node,
                        ],
                    is_temp = True,
                    utility_code=bytes_index_utility_code)
                if coerce_node.type is not PyrexTypes.c_char_type:
                    node = node.coerce_to(coerce_node.type, env)
                return node
        return coerce_node

Stefan Behnel's avatar
Stefan Behnel committed
1795
    def _optimise_numeric_cast_call(self, node, arg):
1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814
        function = arg.function
        if not isinstance(function, ExprNodes.NameNode) \
               or not function.type.is_builtin_type \
               or not isinstance(arg.arg_tuple, ExprNodes.TupleNode):
            return node
        args = arg.arg_tuple.args
        if len(args) != 1:
            return node
        func_arg = args[0]
        if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
            func_arg = func_arg.arg
        elif func_arg.type.is_pyobject:
            # play safe: Python conversion might work on all sorts of things
            return node
        if function.name == 'int':
            if func_arg.type.is_int or node.type.is_int:
                if func_arg.type == node.type:
                    return func_arg
                elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
1815 1816
                    return ExprNodes.TypecastNode(
                        node.pos, operand=func_arg, type=node.type)
1817 1818 1819 1820 1821
        elif function.name == 'float':
            if func_arg.type.is_float or node.type.is_float:
                if func_arg.type == node.type:
                    return func_arg
                elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
1822 1823
                    return ExprNodes.TypecastNode(
                        node.pos, operand=func_arg, type=node.type)
1824 1825 1826 1827
        return node

    ### dispatch to specific optimisers

1828 1829 1830 1831 1832 1833 1834
    def _find_handler(self, match_name, has_kwargs):
        call_type = has_kwargs and 'general' or 'simple'
        handler = getattr(self, '_handle_%s_%s' % (call_type, match_name), None)
        if handler is None:
            handler = getattr(self, '_handle_any_%s' % match_name, None)
        return handler

1835
    def _dispatch_to_handler(self, node, function, arg_list, kwargs=None):
1836
        if function.is_name:
1837 1838 1839
            # we only consider functions that are either builtin
            # Python functions or builtins that were already replaced
            # into a C function call (defined in the builtin scope)
1840 1841
            if not function.entry:
                return node
1842 1843
            is_builtin = function.entry.is_builtin or \
                         function.entry is self.current_env().builtin_scope().lookup_here(function.name)
1844 1845
            if not is_builtin:
                return node
1846 1847 1848 1849 1850
            function_handler = self._find_handler(
                "function_%s" % function.name, kwargs)
            if function_handler is None:
                return node
            if kwargs:
1851
                return function_handler(node, arg_list, kwargs)
1852
            else:
1853 1854
                return function_handler(node, arg_list)
        elif function.is_attribute and function.type.is_pyobject:
Stefan Behnel's avatar
Stefan Behnel committed
1855
            attr_name = function.attribute
1856 1857
            self_arg = function.obj
            obj_type = self_arg.type
1858
            is_unbound_method = False
1859 1860 1861 1862 1863 1864 1865
            if obj_type.is_builtin_type:
                if obj_type is Builtin.type_type and arg_list and \
                         arg_list[0].type.is_pyobject:
                    # calling an unbound method like 'list.append(L,x)'
                    # (ignoring 'type.mro()' here ...)
                    type_name = function.obj.name
                    self_arg = None
1866
                    is_unbound_method = True
1867 1868 1869
                else:
                    type_name = obj_type.name
            else:
1870
                type_name = "object" # safety measure
1871
            method_handler = self._find_handler(
Stefan Behnel's avatar
Stefan Behnel committed
1872
                "method_%s_%s" % (type_name, attr_name), kwargs)
1873
            if method_handler is None:
Stefan Behnel's avatar
Stefan Behnel committed
1874 1875 1876 1877
                if attr_name in TypeSlots.method_name_to_slot \
                       or attr_name == '__new__':
                    method_handler = self._find_handler(
                        "slot%s" % attr_name, kwargs)
1878 1879
                if method_handler is None:
                    return node
1880 1881 1882
            if self_arg is not None:
                arg_list = [self_arg] + list(arg_list)
            if kwargs:
1883
                return method_handler(node, arg_list, kwargs, is_unbound_method)
1884
            else:
1885
                return method_handler(node, arg_list, is_unbound_method)
1886
        else:
1887
            return node
1888

1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904
    def _error_wrong_arg_count(self, function_name, node, args, expected=None):
        if not expected: # None or 0
            arg_str = ''
        elif isinstance(expected, basestring) or expected > 1:
            arg_str = '...'
        elif expected == 1:
            arg_str = 'x'
        else:
            arg_str = ''
        if expected is not None:
            expected_str = 'expected %s, ' % expected
        else:
            expected_str = ''
        error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
            function_name, arg_str, expected_str, len(args)))

1905 1906
    ### builtin types

1907 1908 1909 1910 1911 1912
    PyDict_Copy_func_type = PyrexTypes.CFuncType(
        Builtin.dict_type, [
            PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None)
            ])

    def _handle_simple_function_dict(self, node, pos_args):
1913
        """Replace dict(some_dict) by PyDict_Copy(some_dict).
1914
        """
1915
        if len(pos_args) != 1:
1916
            return node
1917
        arg = pos_args[0]
1918
        if arg.type is Builtin.dict_type:
1919
            arg = arg.as_none_safe_node("'NoneType' is not iterable")
1920 1921
            return ExprNodes.PythonCapiCallNode(
                node.pos, "PyDict_Copy", self.PyDict_Copy_func_type,
1922
                args = [arg],
1923 1924 1925
                is_temp = node.is_temp
                )
        return node
1926

1927 1928 1929 1930 1931
    PyList_AsTuple_func_type = PyrexTypes.CFuncType(
        Builtin.tuple_type, [
            PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None)
            ])

1932
    def _handle_simple_function_tuple(self, node, pos_args):
1933 1934
        """Replace tuple([...]) by a call to PyList_AsTuple.
        """
1935
        if len(pos_args) != 1:
1936
            return node
1937
        list_arg = pos_args[0]
1938 1939 1940 1941
        if list_arg.type is not Builtin.list_type:
            return node
        if not isinstance(list_arg, (ExprNodes.ComprehensionNode,
                                     ExprNodes.ListNode)):
1942
            pos_args[0] = list_arg.as_none_safe_node(
1943
                "'NoneType' object is not iterable")
1944

1945 1946
        return ExprNodes.PythonCapiCallNode(
            node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type,
1947
            args = pos_args,
1948 1949 1950
            is_temp = node.is_temp
            )

1951 1952 1953 1954 1955 1956 1957
    PyObject_AsDouble_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_double_type, [
            PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
            ],
        exception_value = "((double)-1)",
        exception_check = True)

1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978
    def _handle_simple_function_set(self, node, pos_args):
        if len(pos_args) == 1 and isinstance(pos_args[0], (ExprNodes.ListNode,
                                                           ExprNodes.TupleNode)):
            # We can optimise set([x,y,z]) safely into a set literal,
            # but only if we create all items before adding them -
            # adding an item may raise an exception if it is not
            # hashable, but creating the later items may have
            # side-effects.
            args = []
            temps = []
            for arg in pos_args[0].args:
                if not arg.is_simple():
                    arg = UtilNodes.LetRefNode(arg)
                    temps.append(arg)
                args.append(arg)
            result = ExprNodes.SetNode(node.pos, is_temp=1, args=args)
            for temp in temps[::-1]:
                result = UtilNodes.EvalWithTempExprNode(temp, result)
            return result
        return node

1979
    def _handle_simple_function_float(self, node, pos_args):
Stefan Behnel's avatar
Stefan Behnel committed
1980 1981 1982
        """Transform float() into either a C type cast or a faster C
        function call.
        """
1983 1984
        # Note: this requires the float() function to be typed as
        # returning a C 'double'
1985
        if len(pos_args) == 0:
Stefan Behnel's avatar
typo  
Stefan Behnel committed
1986
            return ExprNodes.FloatNode(
1987 1988 1989 1990
                node, value="0.0", constant_result=0.0
                ).coerce_to(Builtin.float_type, self.current_env())
        elif len(pos_args) != 1:
            self._error_wrong_arg_count('float', node, pos_args, '0 or 1')
1991 1992 1993 1994 1995 1996 1997
            return node
        func_arg = pos_args[0]
        if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
            func_arg = func_arg.arg
        if func_arg.type is PyrexTypes.c_double_type:
            return func_arg
        elif node.type.assignable_from(func_arg.type) or func_arg.type.is_numeric:
1998 1999
            return ExprNodes.TypecastNode(
                node.pos, operand=func_arg, type=node.type)
2000 2001 2002 2003 2004 2005 2006 2007
        return ExprNodes.PythonCapiCallNode(
            node.pos, "__Pyx_PyObject_AsDouble",
            self.PyObject_AsDouble_func_type,
            args = pos_args,
            is_temp = node.is_temp,
            utility_code = pyobject_as_double_utility_code,
            py_name = "float")

2008 2009 2010
    def _handle_simple_function_bool(self, node, pos_args):
        """Transform bool(x) into a type coercion to a boolean.
        """
2011 2012 2013 2014 2015 2016
        if len(pos_args) == 0:
            return ExprNodes.BoolNode(
                node.pos, value=False, constant_result=False
                ).coerce_to(Builtin.bool_type, self.current_env())
        elif len(pos_args) != 1:
            self._error_wrong_arg_count('bool', node, pos_args, '0 or 1')
2017
            return node
Craig Citro's avatar
Craig Citro committed
2018
        else:
2019 2020 2021 2022 2023 2024
            # => !!<bint>(x)  to make sure it's exactly 0 or 1
            operand = pos_args[0].coerce_to_boolean(self.current_env())
            operand = ExprNodes.NotNode(node.pos, operand = operand)
            operand = ExprNodes.NotNode(node.pos, operand = operand)
            # coerce back to Python object as that's the result we are expecting
            return operand.coerce_to_pyobject(self.current_env())
2025

2026 2027
    ### builtin functions

2028 2029 2030 2031 2032
    Pyx_strlen_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_size_t_type, [
            PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_char_ptr_type, None)
            ])

2033 2034 2035 2036 2037 2038
    PyObject_Size_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_py_ssize_t_type, [
            PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None)
            ])

    _map_to_capi_len_function = {
2039
        Builtin.unicode_type   : "__Pyx_PyUnicode_GET_LENGTH",
2040
        Builtin.bytes_type     : "PyBytes_GET_SIZE",
2041 2042 2043 2044 2045 2046 2047
        Builtin.list_type      : "PyList_GET_SIZE",
        Builtin.tuple_type     : "PyTuple_GET_SIZE",
        Builtin.dict_type      : "PyDict_Size",
        Builtin.set_type       : "PySet_Size",
        Builtin.frozenset_type : "PySet_Size",
        }.get

2048
    def _handle_simple_function_len(self, node, pos_args):
Stefan Behnel's avatar
Stefan Behnel committed
2049 2050
        """Replace len(char*) by the equivalent call to strlen() and
        len(known_builtin_type) by an equivalent C-API call.
Stefan Behnel's avatar
Stefan Behnel committed
2051
        """
2052 2053 2054 2055 2056 2057
        if len(pos_args) != 1:
            self._error_wrong_arg_count('len', node, pos_args, 1)
            return node
        arg = pos_args[0]
        if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
            arg = arg.arg
2058 2059 2060 2061 2062
        if arg.type.is_string:
            new_node = ExprNodes.PythonCapiCallNode(
                node.pos, "strlen", self.Pyx_strlen_func_type,
                args = [arg],
                is_temp = node.is_temp,
2063
                utility_code = Builtin.include_string_h_utility_code)
2064 2065 2066 2067
        elif arg.type.is_pyobject:
            cfunc_name = self._map_to_capi_len_function(arg.type)
            if cfunc_name is None:
                return node
2068 2069
            arg = arg.as_none_safe_node(
                "object of type 'NoneType' has no len()")
2070 2071 2072 2073
            new_node = ExprNodes.PythonCapiCallNode(
                node.pos, cfunc_name, self.PyObject_Size_func_type,
                args = [arg],
                is_temp = node.is_temp)
Stefan Behnel's avatar
Stefan Behnel committed
2074
        elif arg.type.is_unicode_char:
2075 2076
            return ExprNodes.IntNode(node.pos, value='1', constant_result=1,
                                     type=node.type)
2077
        else:
Stefan Behnel's avatar
Stefan Behnel committed
2078
            return node
2079
        if node.type not in (PyrexTypes.c_size_t_type, PyrexTypes.c_py_ssize_t_type):
2080
            new_node = new_node.coerce_to(node.type, self.current_env())
2081
        return new_node
2082

2083 2084 2085 2086 2087 2088
    Pyx_Type_func_type = PyrexTypes.CFuncType(
        Builtin.type_type, [
            PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
            ])

    def _handle_simple_function_type(self, node, pos_args):
Stefan Behnel's avatar
Stefan Behnel committed
2089 2090
        """Replace type(o) by a macro call to Py_TYPE(o).
        """
2091
        if len(pos_args) != 1:
2092 2093
            return node
        node = ExprNodes.PythonCapiCallNode(
2094 2095 2096 2097
            node.pos, "Py_TYPE", self.Pyx_Type_func_type,
            args = pos_args,
            is_temp = False)
        return ExprNodes.CastNode(node, PyrexTypes.py_object_type)
2098

2099 2100 2101 2102 2103 2104 2105 2106 2107 2108 2109 2110 2111 2112 2113 2114 2115 2116 2117 2118 2119 2120 2121 2122 2123
    Py_type_check_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_bint_type, [
            PyrexTypes.CFuncTypeArg("arg", PyrexTypes.py_object_type, None)
            ])

    def _handle_simple_function_isinstance(self, node, pos_args):
        """Replace isinstance() checks against builtin types by the
        corresponding C-API call.
        """
        if len(pos_args) != 2:
            return node
        arg, types = pos_args
        temp = None
        if isinstance(types, ExprNodes.TupleNode):
            types = types.args
            arg = temp = UtilNodes.ResultRefNode(arg)
        elif types.type is Builtin.type_type:
            types = [types]
        else:
            return node

        tests = []
        test_nodes = []
        env = self.current_env()
        for test_type_node in types:
Robert Bradshaw's avatar
Robert Bradshaw committed
2124 2125 2126 2127 2128 2129 2130 2131
            builtin_type = None
            if isinstance(test_type_node, ExprNodes.NameNode):
                if test_type_node.entry:
                    entry = env.lookup(test_type_node.entry.name)
                    if entry and entry.type and entry.type.is_builtin_type:
                        builtin_type = entry.type
            if builtin_type and builtin_type is not Builtin.type_type:
                type_check_function = entry.type.type_check_function(exact=False)
2132 2133 2134
                if type_check_function in tests:
                    continue
                tests.append(type_check_function)
Robert Bradshaw's avatar
Robert Bradshaw committed
2135 2136 2137 2138 2139
                type_check_args = [arg]
            elif test_type_node.type is Builtin.type_type:
                type_check_function = '__Pyx_TypeCheck'
                type_check_args = [arg, test_type_node]
            else:
2140
                return node
2141 2142 2143 2144 2145 2146
            test_nodes.append(
                ExprNodes.PythonCapiCallNode(
                    test_type_node.pos, type_check_function, self.Py_type_check_func_type,
                    args = type_check_args,
                    is_temp = True,
                    ))
2147 2148 2149 2150 2151 2152 2153 2154 2155 2156 2157 2158

        def join_with_or(a,b, make_binop_node=ExprNodes.binop_node):
            or_node = make_binop_node(node.pos, 'or', a, b)
            or_node.type = PyrexTypes.c_bint_type
            or_node.is_temp = True
            return or_node

        test_node = reduce(join_with_or, test_nodes).coerce_to(node.type, env)
        if temp is not None:
            test_node = UtilNodes.EvalWithTempExprNode(temp, test_node)
        return test_node

2159
    def _handle_simple_function_ord(self, node, pos_args):
2160
        """Unpack ord(Py_UNICODE) and ord('X').
2161 2162 2163 2164 2165
        """
        if len(pos_args) != 1:
            return node
        arg = pos_args[0]
        if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
Stefan Behnel's avatar
Stefan Behnel committed
2166
            if arg.arg.type.is_unicode_char:
2167 2168 2169 2170 2171 2172 2173 2174 2175 2176 2177 2178 2179 2180 2181 2182 2183 2184
                return ExprNodes.TypecastNode(
                    arg.pos, operand=arg.arg, type=PyrexTypes.c_int_type
                    ).coerce_to(node.type, self.current_env())
        elif isinstance(arg, ExprNodes.UnicodeNode):
            if len(arg.value) == 1:
                return ExprNodes.IntNode(
                    ord(arg.value), type=PyrexTypes.c_int_type,
                    value=str(ord(arg.value)),
                    constant_result=ord(arg.value)
                    ).coerce_to(node.type, self.current_env())
        elif isinstance(arg, ExprNodes.StringNode):
            if arg.unicode_value and len(arg.unicode_value) == 1 \
                   and ord(arg.unicode_value) <= 255: # Py2/3 portability
                return ExprNodes.IntNode(
                    ord(arg.unicode_value), type=PyrexTypes.c_int_type,
                    value=str(ord(arg.unicode_value)),
                    constant_result=ord(arg.unicode_value)
                    ).coerce_to(node.type, self.current_env())
2185 2186
        return node

2187 2188
    ### special methods

2189 2190 2191 2192 2193
    Pyx_tp_new_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("type", Builtin.type_type, None)
            ])

Stefan Behnel's avatar
Stefan Behnel committed
2194
    def _handle_simple_slot__new__(self, node, args, is_unbound_method):
2195 2196 2197 2198 2199 2200 2201 2202 2203 2204 2205 2206 2207 2208 2209 2210 2211 2212
        """Replace 'exttype.__new__(exttype)' by a call to exttype->tp_new()
        """
        obj = node.function.obj
        if not is_unbound_method or len(args) != 1:
            return node
        type_arg = args[0]
        if not obj.is_name or not type_arg.is_name:
            # play safe
            return node
        if obj.type != Builtin.type_type or type_arg.type != Builtin.type_type:
            # not a known type, play safe
            return node
        if not type_arg.type_entry or not obj.type_entry:
            if obj.name != type_arg.name:
                return node
            # otherwise, we know it's a type and we know it's the same
            # type for both - that should do
        elif type_arg.type_entry != obj.type_entry:
2213
            # different types - may or may not lead to an error at runtime
2214 2215
            return node

Stefan Behnel's avatar
Stefan Behnel committed
2216 2217 2218 2219
        # FIXME: we could potentially look up the actual tp_new C
        # method of the extension type and call that instead of the
        # generic slot. That would also allow us to pass parameters
        # efficiently.
2220

2221 2222
        if not type_arg.type_entry:
            # arbitrary variable, needs a None check for safety
2223
            type_arg = type_arg.as_none_safe_node(
2224 2225
                "object.__new__(X): X is not a type object (NoneType)")

2226
        return ExprNodes.PythonCapiCallNode(
2227
            node.pos, "__Pyx_tp_new", self.Pyx_tp_new_func_type,
2228
            args = [type_arg],
2229 2230 2231 2232
            utility_code = tpnew_utility_code,
            is_temp = node.is_temp
            )

2233 2234
    ### methods of builtin types

2235 2236 2237 2238 2239 2240 2241 2242 2243 2244 2245 2246 2247 2248 2249 2250 2251 2252 2253 2254 2255 2256 2257 2258 2259 2260 2261 2262
    PyDict_Clear_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_void_type, [
            PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None)
            ])

    PyDict_Clear_Retval_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None)
            ])

    def _handle_simple_method_dict_clear(self, node, args, is_unbound_method):
        """Optimise dict.clear() differently, depending on the use (or
        non-use) of the return value.
        """
        if len(args) != 1:
            return node
        if node.result_is_used:
            return self._substitute_method_call(
                node, "__Pyx_PyDict_Clear", self.PyDict_Clear_Retval_func_type,
                'clear', is_unbound_method, args,
                may_return_none=True, is_temp=True,
                utility_code=py_dict_clear_utility_code
                ).coerce_to(node.type, self.current_env)
        else:
            return self._substitute_method_call(
                node, "PyDict_Clear", self.PyDict_Clear_func_type,
                'clear', is_unbound_method, args, is_temp=False)

2263 2264 2265 2266 2267 2268
    PyObject_Append_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
            ])

2269
    def _handle_simple_method_object_append(self, node, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
2270 2271 2272
        """Optimistic optimisation as X.append() is almost always
        referring to a list.
        """
2273
        if len(args) != 2:
2274 2275
            return node

2276 2277 2278
        return ExprNodes.PythonCapiCallNode(
            node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type,
            args = args,
Stefan Behnel's avatar
Stefan Behnel committed
2279
            may_return_none = True,
2280
            is_temp = node.is_temp,
2281
            utility_code = append_utility_code
2282 2283
            )

Robert Bradshaw's avatar
Robert Bradshaw committed
2284 2285 2286 2287 2288 2289 2290 2291 2292 2293 2294 2295
    PyObject_Pop_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
            ])

    PyObject_PopIndex_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_long_type, None),
            ])

    def _handle_simple_method_object_pop(self, node, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
2296 2297 2298
        """Optimistic optimisation as X.pop([n]) is almost always
        referring to a list.
        """
Robert Bradshaw's avatar
Robert Bradshaw committed
2299 2300 2301 2302
        if len(args) == 1:
            return ExprNodes.PythonCapiCallNode(
                node.pos, "__Pyx_PyObject_Pop", self.PyObject_Pop_func_type,
                args = args,
Stefan Behnel's avatar
Stefan Behnel committed
2303
                may_return_none = True,
Robert Bradshaw's avatar
Robert Bradshaw committed
2304 2305 2306 2307 2308 2309 2310 2311 2312 2313 2314
                is_temp = node.is_temp,
                utility_code = pop_utility_code
                )
        elif len(args) == 2:
            if isinstance(args[1], ExprNodes.CoerceToPyTypeNode) and args[1].arg.type.is_int:
                original_type = args[1].arg.type
                if PyrexTypes.widest_numeric_type(original_type, PyrexTypes.c_py_ssize_t_type) == PyrexTypes.c_py_ssize_t_type:
                    args[1] = args[1].arg
                    return ExprNodes.PythonCapiCallNode(
                        node.pos, "__Pyx_PyObject_PopIndex", self.PyObject_PopIndex_func_type,
                        args = args,
Stefan Behnel's avatar
Stefan Behnel committed
2315
                        may_return_none = True,
Robert Bradshaw's avatar
Robert Bradshaw committed
2316 2317 2318
                        is_temp = node.is_temp,
                        utility_code = pop_index_utility_code
                        )
2319

Robert Bradshaw's avatar
Robert Bradshaw committed
2320 2321
        return node

2322 2323
    _handle_simple_method_list_pop = _handle_simple_method_object_pop

2324
    single_param_func_type = PyrexTypes.CFuncType(
2325
        PyrexTypes.c_returncode_type, [
2326 2327 2328
            PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
            ],
        exception_value = "-1")
2329

2330
    def _handle_simple_method_list_sort(self, node, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
2331 2332
        """Call PyList_Sort() instead of the 0-argument l.sort().
        """
2333
        if len(args) != 1:
2334
            return node
2335
        return self._substitute_method_call(
2336
            node, "PyList_Sort", self.single_param_func_type,
2337
            'sort', is_unbound_method, args).coerce_to(node.type, self.current_env)
2338

2339 2340 2341 2342 2343
    Pyx_PyDict_GetItem_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
2344
            ])
2345 2346

    def _handle_simple_method_dict_get(self, node, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
2347 2348
        """Replace dict.get() by a call to PyDict_GetItem().
        """
2349 2350 2351 2352 2353 2354 2355 2356 2357
        if len(args) == 2:
            args.append(ExprNodes.NoneNode(node.pos))
        elif len(args) != 3:
            self._error_wrong_arg_count('dict.get', node, args, "2 or 3")
            return node

        return self._substitute_method_call(
            node, "__Pyx_PyDict_GetItemDefault", self.Pyx_PyDict_GetItem_func_type,
            'get', is_unbound_method, args,
Stefan Behnel's avatar
Stefan Behnel committed
2358
            may_return_none = True,
2359 2360
            utility_code = dict_getitem_default_utility_code)

2361 2362 2363 2364 2365 2366 2367 2368 2369 2370 2371 2372 2373 2374 2375 2376 2377 2378 2379 2380 2381 2382
    Pyx_PyDict_SetDefault_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
            ])

    def _handle_simple_method_dict_setdefault(self, node, args, is_unbound_method):
        """Replace dict.setdefault() by calls to PyDict_GetItem() and PyDict_SetItem().
        """
        if len(args) == 2:
            args.append(ExprNodes.NoneNode(node.pos))
        elif len(args) != 3:
            self._error_wrong_arg_count('dict.setdefault', node, args, "2 or 3")
            return node

        return self._substitute_method_call(
            node, "__Pyx_PyDict_SetDefault", self.Pyx_PyDict_SetDefault_func_type,
            'setdefault', is_unbound_method, args,
            may_return_none = True,
            utility_code = dict_setdefault_utility_code)

2383 2384 2385

    ### unicode type methods

2386 2387
    PyUnicode_uchar_predicate_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_bint_type, [
2388
            PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None),
2389 2390 2391 2392 2393 2394 2395
            ])

    def _inject_unicode_predicate(self, node, args, is_unbound_method):
        if is_unbound_method or len(args) != 1:
            return node
        ustring = args[0]
        if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
Stefan Behnel's avatar
Stefan Behnel committed
2396
               not ustring.arg.type.is_unicode_char:
2397 2398 2399 2400 2401 2402 2403 2404 2405 2406 2407 2408 2409 2410 2411 2412 2413 2414 2415 2416 2417 2418 2419 2420 2421 2422 2423 2424 2425
            return node
        uchar = ustring.arg
        method_name = node.function.attribute
        if method_name == 'istitle':
            # istitle() doesn't directly map to Py_UNICODE_ISTITLE()
            utility_code = py_unicode_istitle_utility_code
            function_name = '__Pyx_Py_UNICODE_ISTITLE'
        else:
            utility_code = None
            function_name = 'Py_UNICODE_%s' % method_name.upper()
        func_call = self._substitute_method_call(
            node, function_name, self.PyUnicode_uchar_predicate_func_type,
            method_name, is_unbound_method, [uchar],
            utility_code = utility_code)
        if node.type.is_pyobject:
            func_call = func_call.coerce_to_pyobject(self.current_env)
        return func_call

    _handle_simple_method_unicode_isalnum   = _inject_unicode_predicate
    _handle_simple_method_unicode_isalpha   = _inject_unicode_predicate
    _handle_simple_method_unicode_isdecimal = _inject_unicode_predicate
    _handle_simple_method_unicode_isdigit   = _inject_unicode_predicate
    _handle_simple_method_unicode_islower   = _inject_unicode_predicate
    _handle_simple_method_unicode_isnumeric = _inject_unicode_predicate
    _handle_simple_method_unicode_isspace   = _inject_unicode_predicate
    _handle_simple_method_unicode_istitle   = _inject_unicode_predicate
    _handle_simple_method_unicode_isupper   = _inject_unicode_predicate

    PyUnicode_uchar_conversion_func_type = PyrexTypes.CFuncType(
2426 2427
        PyrexTypes.c_py_ucs4_type, [
            PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None),
2428 2429 2430 2431 2432 2433 2434
            ])

    def _inject_unicode_character_conversion(self, node, args, is_unbound_method):
        if is_unbound_method or len(args) != 1:
            return node
        ustring = args[0]
        if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
Stefan Behnel's avatar
Stefan Behnel committed
2435
               not ustring.arg.type.is_unicode_char:
2436 2437 2438 2439 2440 2441 2442 2443 2444 2445 2446 2447 2448 2449 2450
            return node
        uchar = ustring.arg
        method_name = node.function.attribute
        function_name = 'Py_UNICODE_TO%s' % method_name.upper()
        func_call = self._substitute_method_call(
            node, function_name, self.PyUnicode_uchar_conversion_func_type,
            method_name, is_unbound_method, [uchar])
        if node.type.is_pyobject:
            func_call = func_call.coerce_to_pyobject(self.current_env)
        return func_call

    _handle_simple_method_unicode_lower = _inject_unicode_character_conversion
    _handle_simple_method_unicode_upper = _inject_unicode_character_conversion
    _handle_simple_method_unicode_title = _inject_unicode_character_conversion

2451 2452 2453 2454 2455 2456 2457 2458 2459 2460 2461 2462 2463
    PyUnicode_Splitlines_func_type = PyrexTypes.CFuncType(
        Builtin.list_type, [
            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
            PyrexTypes.CFuncTypeArg("keepends", PyrexTypes.c_bint_type, None),
            ])

    def _handle_simple_method_unicode_splitlines(self, node, args, is_unbound_method):
        """Replace unicode.splitlines(...) by a direct call to the
        corresponding C-API function.
        """
        if len(args) not in (1,2):
            self._error_wrong_arg_count('unicode.splitlines', node, args, "1 or 2")
            return node
2464
        self._inject_bint_default_argument(node, args, 1, False)
2465 2466 2467 2468 2469 2470 2471 2472 2473 2474 2475 2476 2477 2478 2479 2480 2481 2482 2483 2484 2485 2486

        return self._substitute_method_call(
            node, "PyUnicode_Splitlines", self.PyUnicode_Splitlines_func_type,
            'splitlines', is_unbound_method, args)

    PyUnicode_Split_func_type = PyrexTypes.CFuncType(
        Builtin.list_type, [
            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
            PyrexTypes.CFuncTypeArg("sep", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("maxsplit", PyrexTypes.c_py_ssize_t_type, None),
            ]
        )

    def _handle_simple_method_unicode_split(self, node, args, is_unbound_method):
        """Replace unicode.split(...) by a direct call to the
        corresponding C-API function.
        """
        if len(args) not in (1,2,3):
            self._error_wrong_arg_count('unicode.split', node, args, "1-3")
            return node
        if len(args) < 2:
            args.append(ExprNodes.NullNode(node.pos))
2487 2488
        self._inject_int_default_argument(
            node, args, 2, PyrexTypes.c_py_ssize_t_type, "-1")
2489 2490 2491 2492 2493

        return self._substitute_method_call(
            node, "PyUnicode_Split", self.PyUnicode_Split_func_type,
            'split', is_unbound_method, args)

2494
    PyString_Tailmatch_func_type = PyrexTypes.CFuncType(
2495
        PyrexTypes.c_bint_type, [
2496
            PyrexTypes.CFuncTypeArg("str", PyrexTypes.py_object_type, None),  # bytes/str/unicode
2497 2498 2499 2500 2501 2502 2503 2504
            PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
            ],
        exception_value = '-1')

    def _handle_simple_method_unicode_endswith(self, node, args, is_unbound_method):
2505
        return self._inject_tailmatch(
Vitja Makarov's avatar
Vitja Makarov committed
2506
            node, args, is_unbound_method, 'unicode', 'endswith',
2507
            unicode_tailmatch_utility_code, +1)
2508 2509

    def _handle_simple_method_unicode_startswith(self, node, args, is_unbound_method):
2510
        return self._inject_tailmatch(
Vitja Makarov's avatar
Vitja Makarov committed
2511
            node, args, is_unbound_method, 'unicode', 'startswith',
2512
            unicode_tailmatch_utility_code, -1)
2513

Vitja Makarov's avatar
Vitja Makarov committed
2514
    def _inject_tailmatch(self, node, args, is_unbound_method, type_name,
2515
                          method_name, utility_code, direction):
2516 2517 2518 2519
        """Replace unicode.startswith(...) and unicode.endswith(...)
        by a direct call to the corresponding C-API function.
        """
        if len(args) not in (2,3,4):
2520
            self._error_wrong_arg_count('%s.%s' % (type_name, method_name), node, args, "2-4")
2521
            return node
2522 2523 2524 2525
        self._inject_int_default_argument(
            node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
        self._inject_int_default_argument(
            node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2526 2527 2528 2529
        args.append(ExprNodes.IntNode(
            node.pos, value=str(direction), type=PyrexTypes.c_int_type))

        method_call = self._substitute_method_call(
2530 2531
            node, "__Pyx_Py%s_Tailmatch" % type_name.capitalize(),
            self.PyString_Tailmatch_func_type,
2532
            method_name, is_unbound_method, args,
2533
            utility_code = utility_code)
Stefan Behnel's avatar
Stefan Behnel committed
2534
        return method_call.coerce_to(Builtin.bool_type, self.current_env())
2535

2536 2537 2538 2539 2540 2541 2542 2543 2544 2545 2546 2547 2548 2549 2550 2551 2552 2553 2554 2555 2556 2557 2558 2559 2560 2561
    PyUnicode_Find_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_py_ssize_t_type, [
            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
            PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
            ],
        exception_value = '-2')

    def _handle_simple_method_unicode_find(self, node, args, is_unbound_method):
        return self._inject_unicode_find(
            node, args, is_unbound_method, 'find', +1)

    def _handle_simple_method_unicode_rfind(self, node, args, is_unbound_method):
        return self._inject_unicode_find(
            node, args, is_unbound_method, 'rfind', -1)

    def _inject_unicode_find(self, node, args, is_unbound_method,
                             method_name, direction):
        """Replace unicode.find(...) and unicode.rfind(...) by a
        direct call to the corresponding C-API function.
        """
        if len(args) not in (2,3,4):
            self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4")
            return node
2562 2563 2564 2565
        self._inject_int_default_argument(
            node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
        self._inject_int_default_argument(
            node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2566 2567 2568 2569 2570 2571
        args.append(ExprNodes.IntNode(
            node.pos, value=str(direction), type=PyrexTypes.c_int_type))

        method_call = self._substitute_method_call(
            node, "PyUnicode_Find", self.PyUnicode_Find_func_type,
            method_name, is_unbound_method, args)
Stefan Behnel's avatar
Stefan Behnel committed
2572
        return method_call.coerce_to_pyobject(self.current_env())
2573

Stefan Behnel's avatar
Stefan Behnel committed
2574 2575 2576 2577 2578 2579 2580 2581 2582 2583 2584 2585 2586 2587 2588 2589
    PyUnicode_Count_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_py_ssize_t_type, [
            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
            PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
            ],
        exception_value = '-1')

    def _handle_simple_method_unicode_count(self, node, args, is_unbound_method):
        """Replace unicode.count(...) by a direct call to the
        corresponding C-API function.
        """
        if len(args) not in (2,3,4):
            self._error_wrong_arg_count('unicode.count', node, args, "2-4")
            return node
2590 2591 2592 2593
        self._inject_int_default_argument(
            node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
        self._inject_int_default_argument(
            node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
Stefan Behnel's avatar
Stefan Behnel committed
2594 2595 2596 2597

        method_call = self._substitute_method_call(
            node, "PyUnicode_Count", self.PyUnicode_Count_func_type,
            'count', is_unbound_method, args)
Stefan Behnel's avatar
Stefan Behnel committed
2598
        return method_call.coerce_to_pyobject(self.current_env())
Stefan Behnel's avatar
Stefan Behnel committed
2599

Stefan Behnel's avatar
Stefan Behnel committed
2600 2601 2602 2603 2604 2605 2606 2607 2608 2609 2610 2611 2612 2613 2614
    PyUnicode_Replace_func_type = PyrexTypes.CFuncType(
        Builtin.unicode_type, [
            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
            PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("replstr", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("maxcount", PyrexTypes.c_py_ssize_t_type, None),
            ])

    def _handle_simple_method_unicode_replace(self, node, args, is_unbound_method):
        """Replace unicode.replace(...) by a direct call to the
        corresponding C-API function.
        """
        if len(args) not in (3,4):
            self._error_wrong_arg_count('unicode.replace', node, args, "3-4")
            return node
2615 2616
        self._inject_int_default_argument(
            node, args, 3, PyrexTypes.c_py_ssize_t_type, "-1")
Stefan Behnel's avatar
Stefan Behnel committed
2617 2618 2619 2620 2621

        return self._substitute_method_call(
            node, "PyUnicode_Replace", self.PyUnicode_Replace_func_type,
            'replace', is_unbound_method, args)

2622 2623
    PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType(
        Builtin.bytes_type, [
2624
            PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
2625 2626
            PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
            PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2627
            ])
2628 2629 2630

    PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType(
        Builtin.bytes_type, [
2631
            PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
2632
            ])
2633 2634 2635 2636

    _special_encodings = ['UTF8', 'UTF16', 'Latin1', 'ASCII',
                          'unicode_escape', 'raw_unicode_escape']

2637 2638
    _special_codecs = [ (name, codecs.getencoder(name))
                        for name in _special_encodings ]
2639 2640

    def _handle_simple_method_unicode_encode(self, node, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
2641 2642 2643
        """Replace unicode.encode(...) by a direct C-API call to the
        corresponding codec.
        """
2644
        if len(args) < 1 or len(args) > 3:
2645
            self._error_wrong_arg_count('unicode.encode', node, args, '1-3')
2646 2647 2648 2649 2650
            return node

        string_node = args[0]

        if len(args) == 1:
2651
            null_node = ExprNodes.NullNode(node.pos)
2652 2653 2654 2655 2656
            return self._substitute_method_call(
                node, "PyUnicode_AsEncodedString",
                self.PyUnicode_AsEncodedString_func_type,
                'encode', is_unbound_method, [string_node, null_node, null_node])

2657 2658 2659 2660 2661 2662 2663 2664 2665 2666 2667 2668 2669 2670 2671 2672 2673 2674 2675 2676 2677 2678 2679 2680 2681 2682 2683 2684 2685 2686 2687 2688 2689 2690 2691 2692 2693 2694 2695
        parameters = self._unpack_encoding_and_error_mode(node.pos, args)
        if parameters is None:
            return node
        encoding, encoding_node, error_handling, error_handling_node = parameters

        if isinstance(string_node, ExprNodes.UnicodeNode):
            # constant, so try to do the encoding at compile time
            try:
                value = string_node.value.encode(encoding, error_handling)
            except:
                # well, looks like we can't
                pass
            else:
                value = BytesLiteral(value)
                value.encoding = encoding
                return ExprNodes.BytesNode(
                    string_node.pos, value=value, type=Builtin.bytes_type)

        if error_handling == 'strict':
            # try to find a specific encoder function
            codec_name = self._find_special_codec_name(encoding)
            if codec_name is not None:
                encode_function = "PyUnicode_As%sString" % codec_name
                return self._substitute_method_call(
                    node, encode_function,
                    self.PyUnicode_AsXyzString_func_type,
                    'encode', is_unbound_method, [string_node])

        return self._substitute_method_call(
            node, "PyUnicode_AsEncodedString",
            self.PyUnicode_AsEncodedString_func_type,
            'encode', is_unbound_method,
            [string_node, encoding_node, error_handling_node])

    PyUnicode_DecodeXyz_func_type = PyrexTypes.CFuncType(
        Builtin.unicode_type, [
            PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
            PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2696
            ])
2697 2698 2699 2700 2701 2702 2703

    PyUnicode_Decode_func_type = PyrexTypes.CFuncType(
        Builtin.unicode_type, [
            PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
            PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
            PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2704
            ])
2705 2706

    def _handle_simple_method_bytes_decode(self, node, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
2707 2708 2709
        """Replace char*.decode() by a direct C-API call to the
        corresponding codec, possibly resoving a slice on the char*.
        """
2710 2711 2712
        if len(args) < 1 or len(args) > 3:
            self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
            return node
2713 2714 2715 2716 2717 2718 2719 2720 2721 2722 2723 2724
        temps = []
        if isinstance(args[0], ExprNodes.SliceIndexNode):
            index_node = args[0]
            string_node = index_node.base
            if not string_node.type.is_string:
                # nothing to optimise here
                return node
            start, stop = index_node.start, index_node.stop
            if not start or start.constant_result == 0:
                start = None
            else:
                if start.type.is_pyobject:
2725
                    start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2726
                if stop:
2727 2728 2729 2730 2731 2732 2733 2734 2735 2736
                    start = UtilNodes.LetRefNode(start)
                    temps.append(start)
                string_node = ExprNodes.AddNode(pos=start.pos,
                                                operand1=string_node,
                                                operator='+',
                                                operand2=start,
                                                is_temp=False,
                                                type=string_node.type
                                                )
            if stop and stop.type.is_pyobject:
2737
                stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2738 2739 2740 2741 2742 2743 2744
        elif isinstance(args[0], ExprNodes.CoerceToPyTypeNode) \
                 and args[0].arg.type.is_string:
            # use strlen() to find the string length, just as CPython would
            start = stop = None
            string_node = args[0].arg
        else:
            # let Python do its job
2745
            return node
2746

2747
        if not stop:
2748
            if start or not string_node.is_name:
2749 2750 2751 2752 2753 2754
                string_node = UtilNodes.LetRefNode(string_node)
                temps.append(string_node)
            stop = ExprNodes.PythonCapiCallNode(
                string_node.pos, "strlen", self.Pyx_strlen_func_type,
                    args = [string_node],
                    is_temp = False,
2755
                    utility_code = Builtin.include_string_h_utility_code,
2756
                    ).coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2757 2758 2759 2760 2761 2762 2763 2764 2765
        elif start:
            stop = ExprNodes.SubNode(
                pos = stop.pos,
                operand1 = stop,
                operator = '-',
                operand2 = start,
                is_temp = False,
                type = PyrexTypes.c_py_ssize_t_type
                )
2766 2767 2768 2769 2770 2771 2772

        parameters = self._unpack_encoding_and_error_mode(node.pos, args)
        if parameters is None:
            return node
        encoding, encoding_node, error_handling, error_handling_node = parameters

        # try to find a specific encoder function
2773 2774 2775
        codec_name = None
        if encoding is not None:
            codec_name = self._find_special_codec_name(encoding)
2776 2777
        if codec_name is not None:
            decode_function = "PyUnicode_Decode%s" % codec_name
2778
            node = ExprNodes.PythonCapiCallNode(
2779 2780 2781 2782 2783
                node.pos, decode_function,
                self.PyUnicode_DecodeXyz_func_type,
                args = [string_node, stop, error_handling_node],
                is_temp = node.is_temp,
                )
2784 2785 2786 2787 2788 2789 2790
        else:
            node = ExprNodes.PythonCapiCallNode(
                node.pos, "PyUnicode_Decode",
                self.PyUnicode_Decode_func_type,
                args = [string_node, stop, encoding_node, error_handling_node],
                is_temp = node.is_temp,
                )
2791

2792 2793 2794
        for temp in temps[::-1]:
            node = UtilNodes.EvalWithTempExprNode(temp, node)
        return node
2795 2796 2797 2798 2799 2800 2801 2802 2803 2804 2805 2806 2807 2808 2809

    def _find_special_codec_name(self, encoding):
        try:
            requested_codec = codecs.getencoder(encoding)
        except:
            return None
        for name, codec in self._special_codecs:
            if codec == requested_codec:
                if '_' in name:
                    name = ''.join([ s.capitalize()
                                     for s in name.split('_')])
                return name
        return None

    def _unpack_encoding_and_error_mode(self, pos, args):
2810 2811 2812 2813 2814 2815 2816 2817 2818 2819 2820
        null_node = ExprNodes.NullNode(pos)

        if len(args) >= 2:
            encoding_node = args[1]
            if isinstance(encoding_node, ExprNodes.CoerceToPyTypeNode):
                encoding_node = encoding_node.arg
            if isinstance(encoding_node, (ExprNodes.UnicodeNode, ExprNodes.StringNode,
                                          ExprNodes.BytesNode)):
                encoding = encoding_node.value
                encoding_node = ExprNodes.BytesNode(encoding_node.pos, value=encoding,
                                                     type=PyrexTypes.c_char_ptr_type)
2821 2822 2823 2824
            elif encoding_node.type is Builtin.bytes_type:
                encoding = None
                encoding_node = encoding_node.coerce_to(
                    PyrexTypes.c_char_ptr_type, self.current_env())
2825 2826 2827 2828
            elif encoding_node.type.is_string:
                encoding = None
            else:
                return None
2829
        else:
2830 2831
            encoding = None
            encoding_node = null_node
2832 2833 2834 2835 2836

        if len(args) == 3:
            error_handling_node = args[2]
            if isinstance(error_handling_node, ExprNodes.CoerceToPyTypeNode):
                error_handling_node = error_handling_node.arg
2837 2838 2839 2840 2841 2842 2843 2844 2845 2846
            if isinstance(error_handling_node,
                          (ExprNodes.UnicodeNode, ExprNodes.StringNode,
                           ExprNodes.BytesNode)):
                error_handling = error_handling_node.value
                if error_handling == 'strict':
                    error_handling_node = null_node
                else:
                    error_handling_node = ExprNodes.BytesNode(
                        error_handling_node.pos, value=error_handling,
                        type=PyrexTypes.c_char_ptr_type)
2847 2848 2849 2850
            elif error_handling_node.type is Builtin.bytes_type:
                error_handling = None
                error_handling_node = error_handling_node.coerce_to(
                    PyrexTypes.c_char_ptr_type, self.current_env())
2851 2852
            elif error_handling_node.type.is_string:
                error_handling = None
2853
            else:
2854
                return None
2855 2856 2857 2858
        else:
            error_handling = 'strict'
            error_handling_node = null_node

2859
        return (encoding, encoding_node, error_handling, error_handling_node)
2860

2861
    def _handle_simple_method_str_endswith(self, node, args, is_unbound_method):
2862
        return self._inject_tailmatch(
Vitja Makarov's avatar
Vitja Makarov committed
2863
            node, args, is_unbound_method, 'str', 'endswith',
2864
            str_tailmatch_utility_code, +1)
2865 2866

    def _handle_simple_method_str_startswith(self, node, args, is_unbound_method):
2867 2868
        return self._inject_tailmatch(
            node, args, is_unbound_method, 'str', 'startswith',
2869
            str_tailmatch_utility_code, -1)
2870 2871 2872

    def _handle_simple_method_bytes_endswith(self, node, args, is_unbound_method):
        return self._inject_tailmatch(
Vitja Makarov's avatar
Vitja Makarov committed
2873
            node, args, is_unbound_method, 'bytes', 'endswith',
2874
            bytes_tailmatch_utility_code, +1)
2875 2876 2877 2878

    def _handle_simple_method_bytes_startswith(self, node, args, is_unbound_method):
        return self._inject_tailmatch(
            node, args, is_unbound_method, 'bytes', 'startswith',
2879 2880
            bytes_tailmatch_utility_code, -1)

2881 2882
    ### helpers

2883
    def _substitute_method_call(self, node, name, func_type,
2884
                                attr_name, is_unbound_method, args=(),
2885
                                utility_code=None, is_temp=None,
Stefan Behnel's avatar
Stefan Behnel committed
2886
                                may_return_none=ExprNodes.PythonCapiCallNode.may_return_none):
2887
        args = list(args)
2888
        if args and not args[0].is_literal:
2889 2890
            self_arg = args[0]
            if is_unbound_method:
2891
                self_arg = self_arg.as_none_safe_node(
2892 2893
                    "descriptor '%s' requires a '%s' object but received a 'NoneType'",
                    format_args = [attr_name, node.function.obj.name])
2894
            else:
2895
                self_arg = self_arg.as_none_safe_node(
2896 2897 2898
                    "'NoneType' object has no attribute '%s'",
                    error = "PyExc_AttributeError",
                    format_args = [attr_name])
2899
            args[0] = self_arg
2900 2901
        if is_temp is None:
            is_temp = node.is_temp
2902
        return ExprNodes.PythonCapiCallNode(
2903
            node.pos, name, func_type,
2904
            args = args,
2905
            is_temp = is_temp,
Stefan Behnel's avatar
Stefan Behnel committed
2906 2907
            utility_code = utility_code,
            may_return_none = may_return_none,
2908
            result_is_used = node.result_is_used,
2909 2910
            )

2911 2912 2913
    def _inject_int_default_argument(self, node, args, arg_index, type, default_value):
        assert len(args) >= arg_index
        if len(args) == arg_index:
2914 2915
            args.append(ExprNodes.IntNode(node.pos, value=str(default_value),
                                          type=type, constant_result=default_value))
2916
        else:
2917
            args[arg_index] = args[arg_index].coerce_to(type, self.current_env())
2918 2919 2920 2921

    def _inject_bint_default_argument(self, node, args, arg_index, default_value):
        assert len(args) >= arg_index
        if len(args) == arg_index:
2922 2923 2924
            default_value = bool(default_value)
            args.append(ExprNodes.BoolNode(node.pos, value=default_value,
                                           constant_result=default_value))
2925
        else:
2926
            args[arg_index] = args[arg_index].coerce_to_boolean(self.current_env())
2927

2928

2929 2930 2931 2932
py_unicode_istitle_utility_code = UtilityCode(
# Py_UNICODE_ISTITLE() doesn't match unicode.istitle() as the latter
# additionally allows character that comply with Py_UNICODE_ISUPPER()
proto = '''
2933
#if PY_VERSION_HEX < 0x030200A2
2934
static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UNICODE uchar); /* proto */
2935 2936 2937
#else
static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UCS4 uchar); /* proto */
#endif
2938 2939
''',
impl = '''
2940
#if PY_VERSION_HEX < 0x030200A2
2941
static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UNICODE uchar) {
2942 2943 2944
#else
static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UCS4 uchar) {
#endif
2945 2946 2947 2948
    return Py_UNICODE_ISTITLE(uchar) || Py_UNICODE_ISUPPER(uchar);
}
''')

2949 2950 2951 2952 2953 2954 2955 2956 2957 2958 2959 2960 2961 2962 2963 2964 2965 2966 2967 2968 2969 2970 2971 2972 2973 2974 2975
unicode_tailmatch_utility_code = UtilityCode(
    # Python's unicode.startswith() and unicode.endswith() support a
    # tuple of prefixes/suffixes, whereas it's much more common to
    # test for a single unicode string.
proto = '''
static int __Pyx_PyUnicode_Tailmatch(PyObject* s, PyObject* substr, \
Py_ssize_t start, Py_ssize_t end, int direction);
''',
impl = '''
static int __Pyx_PyUnicode_Tailmatch(PyObject* s, PyObject* substr,
                                     Py_ssize_t start, Py_ssize_t end, int direction) {
    if (unlikely(PyTuple_Check(substr))) {
        int result;
        Py_ssize_t i;
        for (i = 0; i < PyTuple_GET_SIZE(substr); i++) {
            result = PyUnicode_Tailmatch(s, PyTuple_GET_ITEM(substr, i),
                                         start, end, direction);
            if (result) {
                return result;
            }
        }
        return 0;
    }
    return PyUnicode_Tailmatch(s, substr, start, end, direction);
}
''',
)
2976

2977 2978 2979 2980 2981 2982 2983 2984 2985 2986 2987 2988 2989
bytes_tailmatch_utility_code = UtilityCode(
proto="""
static int __Pyx_PyBytes_Tailmatch(PyObject* self, PyObject* arg, Py_ssize_t start,
                                   Py_ssize_t end, int direction);
""",
impl = """
static int __Pyx_PyBytes_SingleTailmatch(PyObject* self, PyObject* arg, Py_ssize_t start,
                                         Py_ssize_t end, int direction)
{
    const char* self_ptr = PyBytes_AS_STRING(self);
    Py_ssize_t self_len = PyBytes_GET_SIZE(self);
    const char* sub_ptr;
    Py_ssize_t sub_len;
2990
    int retval;
Vitja Makarov's avatar
Vitja Makarov committed
2991

2992 2993 2994 2995
#if PY_VERSION_HEX >= 0x02060000
    Py_buffer view;
    view.obj = NULL;
#endif
Vitja Makarov's avatar
Vitja Makarov committed
2996

2997 2998 2999 3000 3001 3002 3003 3004 3005 3006 3007
    if ( PyBytes_Check(arg) ) {
        sub_ptr = PyBytes_AS_STRING(arg);
        sub_len = PyBytes_GET_SIZE(arg);
    }
#if PY_MAJOR_VERSION < 3
    // Python 2.x allows mixing unicode and str
    else if ( PyUnicode_Check(arg) ) {
        return PyUnicode_Tailmatch(self, arg, start, end, direction);
    }
#endif
    else {
3008
#if PY_VERSION_HEX < 0x02060000
3009
        if (unlikely(PyObject_AsCharBuffer(arg, &sub_ptr, &sub_len)))
3010
            return -1;
3011
#else
3012
        if (unlikely(PyObject_GetBuffer(self, &view, PyBUF_SIMPLE) == -1))
3013 3014 3015 3016
            return -1;
        sub_ptr = (const char*) view.buf;
        sub_len = view.len;
#endif
3017
    }
Vitja Makarov's avatar
Vitja Makarov committed
3018

3019 3020 3021 3022 3023 3024 3025 3026 3027 3028 3029
    if (end > self_len)
        end = self_len;
    else if (end < 0)
        end += self_len;
    if (end < 0)
        end = 0;
    if (start < 0)
        start += self_len;
    if (start < 0)
        start = 0;

3030
    if (direction > 0) {
3031 3032 3033 3034
        /* endswith */
        if (end-sub_len > start)
            start = end - sub_len;
    }
Vitja Makarov's avatar
Vitja Makarov committed
3035

3036 3037 3038 3039 3040 3041
    if (start + sub_len <= end)
        retval = !memcmp(self_ptr+start, sub_ptr, sub_len);
    else
        retval = 0;

#if PY_VERSION_HEX >= 0x02060000
3042 3043
    if (view.obj)
        PyBuffer_Release(&view);
3044
#endif
3045

3046
    return retval;
3047
}
Vitja Makarov's avatar
Vitja Makarov committed
3048

3049 3050 3051 3052 3053 3054 3055 3056 3057 3058 3059 3060 3061 3062 3063
static int __Pyx_PyBytes_Tailmatch(PyObject* self, PyObject* substr, Py_ssize_t start,
                                   Py_ssize_t end, int direction)
{
    if (unlikely(PyTuple_Check(substr))) {
        int result;
        Py_ssize_t i;
        for (i = 0; i < PyTuple_GET_SIZE(substr); i++) {
            result = __Pyx_PyBytes_SingleTailmatch(self, PyTuple_GET_ITEM(substr, i),
                                                   start, end, direction);
            if (result) {
                return result;
            }
        }
        return 0;
    }
Vitja Makarov's avatar
Vitja Makarov committed
3064

3065 3066 3067 3068 3069
    return __Pyx_PyBytes_SingleTailmatch(self, substr, start, end, direction);
}

""")

3070 3071
str_tailmatch_utility_code = UtilityCode(
proto = '''
3072 3073
static CYTHON_INLINE int __Pyx_PyStr_Tailmatch(PyObject* self, PyObject* arg, Py_ssize_t start,
                                               Py_ssize_t end, int direction);
3074
''',
3075 3076 3077 3078
# We do not use a C compiler macro here to avoid "unused function"
# warnings for the *_Tailmatch() function that is not being used in
# the specific CPython version.  The C compiler will generate the same
# code anyway, and will usually just remove the unused function.
3079
impl = '''
3080 3081
static CYTHON_INLINE int __Pyx_PyStr_Tailmatch(PyObject* self, PyObject* arg, Py_ssize_t start,
                                               Py_ssize_t end, int direction)
3082
{
3083 3084 3085 3086
    if (PY_MAJOR_VERSION < 3)
        return __Pyx_PyBytes_Tailmatch(self, arg, start, end, direction);
    else
        return __Pyx_PyUnicode_Tailmatch(self, arg, start, end, direction);
3087 3088 3089 3090 3091
}
''',
requires=[unicode_tailmatch_utility_code, bytes_tailmatch_utility_code]
)

3092 3093
dict_getitem_default_utility_code = UtilityCode(
proto = '''
3094
static PyObject* __Pyx_PyDict_GetItemDefault(PyObject* d, PyObject* key, PyObject* default_value) {
3095 3096 3097 3098 3099 3100 3101 3102 3103 3104 3105 3106 3107 3108 3109 3110 3111 3112 3113 3114 3115
    PyObject* value;
#if PY_MAJOR_VERSION >= 3
    value = PyDict_GetItemWithError(d, key);
    if (unlikely(!value)) {
        if (unlikely(PyErr_Occurred()))
            return NULL;
        value = default_value;
    }
    Py_INCREF(value);
#else
    if (PyString_CheckExact(key) || PyUnicode_CheckExact(key) || PyInt_CheckExact(key)) {
        /* these presumably have safe hash functions */
        value = PyDict_GetItem(d, key);
        if (unlikely(!value)) {
            value = default_value;
        }
        Py_INCREF(value);
    } else {
        PyObject *m;
        m = __Pyx_GetAttrString(d, "get");
        if (!m) return NULL;
3116 3117
        value = PyObject_CallFunctionObjArgs(m, key,
            (default_value == Py_None) ? NULL : default_value, NULL);
3118 3119 3120 3121 3122 3123 3124 3125 3126
        Py_DECREF(m);
    }
#endif
    return value;
}
''',
impl = ""
)

3127 3128 3129 3130 3131 3132 3133 3134 3135 3136 3137 3138 3139 3140 3141 3142 3143 3144 3145 3146 3147 3148 3149 3150 3151 3152 3153 3154 3155 3156 3157 3158 3159 3160 3161 3162 3163 3164 3165
dict_setdefault_utility_code = UtilityCode(
proto = """
static PyObject *__Pyx_PyDict_SetDefault(PyObject *, PyObject *, PyObject *); /*proto*/
""",
impl = '''
static PyObject *__Pyx_PyDict_SetDefault(PyObject *d, PyObject *key, PyObject *default_value) {
    PyObject* value;
#if PY_MAJOR_VERSION >= 3
    value = PyDict_GetItemWithError(d, key);
    if (unlikely(!value)) {
        if (unlikely(PyErr_Occurred()))
            return NULL;
        if (unlikely(PyDict_SetItem(d, key, default_value) == -1))
            return NULL;
        value = default_value;
    }
    Py_INCREF(value);
#else
    if (PyString_CheckExact(key) || PyUnicode_CheckExact(key) || PyInt_CheckExact(key)) {
        /* these presumably have safe hash functions */
        value = PyDict_GetItem(d, key);
        if (unlikely(!value)) {
            if (unlikely(PyDict_SetItem(d, key, default_value) == -1))
                return NULL;
            value = default_value;
        }
        Py_INCREF(value);
    } else {
        PyObject *m;
        m = __Pyx_GetAttrString(d, "setdefault");
        if (!m) return NULL;
        value = PyObject_CallFunctionObjArgs(m, key, default_value, NULL);
        Py_DECREF(m);
    }
#endif
    return value;
}
''')

3166 3167
append_utility_code = UtilityCode(
proto = """
3168
static CYTHON_INLINE PyObject* __Pyx_PyObject_Append(PyObject* L, PyObject* x) {
3169 3170 3171 3172 3173 3174 3175 3176 3177 3178 3179 3180 3181 3182 3183 3184 3185
    if (likely(PyList_CheckExact(L))) {
        if (PyList_Append(L, x) < 0) return NULL;
        Py_INCREF(Py_None);
        return Py_None; /* this is just to have an accurate signature */
    }
    else {
        PyObject *r, *m;
        m = __Pyx_GetAttrString(L, "append");
        if (!m) return NULL;
        r = PyObject_CallFunctionObjArgs(m, x, NULL);
        Py_DECREF(m);
        return r;
    }
}
""",
impl = ""
)
3186 3187


Robert Bradshaw's avatar
Robert Bradshaw committed
3188 3189
pop_utility_code = UtilityCode(
proto = """
3190
static CYTHON_INLINE PyObject* __Pyx_PyObject_Pop(PyObject* L) {
3191
#if PY_VERSION_HEX >= 0x02040000
Robert Bradshaw's avatar
Robert Bradshaw committed
3192 3193 3194 3195 3196 3197
    if (likely(PyList_CheckExact(L))
            /* Check that both the size is positive and no reallocation shrinking needs to be done. */
            && likely(PyList_GET_SIZE(L) > (((PyListObject*)L)->allocated >> 1))) {
        Py_SIZE(L) -= 1;
        return PyList_GET_ITEM(L, PyList_GET_SIZE(L));
    }
3198
#if PY_VERSION_HEX >= 0x02050000
3199 3200 3201
    else if (Py_TYPE(L) == (&PySet_Type)) {
        return PySet_Pop(L);
    }
3202
#endif
3203
#endif
3204
    return PyObject_CallMethod(L, (char*)"pop", NULL);
Robert Bradshaw's avatar
Robert Bradshaw committed
3205 3206 3207 3208 3209 3210 3211 3212 3213 3214 3215 3216
}
""",
impl = ""
)

pop_index_utility_code = UtilityCode(
proto = """
static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix);
""",
impl = """
static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix) {
    PyObject *r, *m, *t, *py_ix;
3217
#if PY_VERSION_HEX >= 0x02040000
Robert Bradshaw's avatar
Robert Bradshaw committed
3218 3219 3220 3221 3222 3223 3224 3225 3226 3227 3228 3229 3230 3231 3232 3233 3234 3235
    if (likely(PyList_CheckExact(L))) {
        Py_ssize_t size = PyList_GET_SIZE(L);
        if (likely(size > (((PyListObject*)L)->allocated >> 1))) {
            if (ix < 0) {
                ix += size;
            }
            if (likely(0 <= ix && ix < size)) {
                Py_ssize_t i;
                PyObject* v = PyList_GET_ITEM(L, ix);
                Py_SIZE(L) -= 1;
                size -= 1;
                for(i=ix; i<size; i++) {
                    PyList_SET_ITEM(L, i, PyList_GET_ITEM(L, i+1));
                }
                return v;
            }
        }
    }
3236
#endif
Robert Bradshaw's avatar
Robert Bradshaw committed
3237 3238 3239 3240 3241 3242 3243 3244 3245 3246 3247 3248 3249 3250 3251 3252 3253 3254 3255 3256 3257 3258 3259
    py_ix = t = NULL;
    m = __Pyx_GetAttrString(L, "pop");
    if (!m) goto bad;
    py_ix = PyInt_FromSsize_t(ix);
    if (!py_ix) goto bad;
    t = PyTuple_New(1);
    if (!t) goto bad;
    PyTuple_SET_ITEM(t, 0, py_ix);
    py_ix = NULL;
    r = PyObject_CallObject(m, t);
    Py_DECREF(m);
    Py_DECREF(t);
    return r;
bad:
    Py_XDECREF(m);
    Py_XDECREF(t);
    Py_XDECREF(py_ix);
    return NULL;
}
"""
)


3260 3261 3262 3263 3264 3265 3266 3267 3268 3269
py_dict_clear_utility_code = UtilityCode(
proto = '''
static CYTHON_INLINE PyObject* __Pyx_PyDict_Clear(PyObject* d) {
    PyDict_Clear(d);
    Py_INCREF(Py_None);
    return Py_None;
}
''')


3270 3271 3272 3273 3274 3275 3276 3277 3278 3279 3280 3281 3282
pyobject_as_double_utility_code = UtilityCode(
proto = '''
static double __Pyx__PyObject_AsDouble(PyObject* obj); /* proto */

#define __Pyx_PyObject_AsDouble(obj) \\
    ((likely(PyFloat_CheckExact(obj))) ? \\
     PyFloat_AS_DOUBLE(obj) : __Pyx__PyObject_AsDouble(obj))
''',
impl='''
static double __Pyx__PyObject_AsDouble(PyObject* obj) {
    PyObject* float_value;
    if (Py_TYPE(obj)->tp_as_number && Py_TYPE(obj)->tp_as_number->nb_float) {
        return PyFloat_AsDouble(obj);
3283
    } else if (PyUnicode_CheckExact(obj) || PyBytes_CheckExact(obj)) {
3284 3285 3286 3287 3288 3289 3290 3291 3292 3293 3294 3295 3296 3297 3298 3299 3300 3301 3302 3303
#if PY_MAJOR_VERSION >= 3
        float_value = PyFloat_FromString(obj);
#else
        float_value = PyFloat_FromString(obj, 0);
#endif
    } else {
        PyObject* args = PyTuple_New(1);
        if (unlikely(!args)) goto bad;
        PyTuple_SET_ITEM(args, 0, obj);
        float_value = PyObject_Call((PyObject*)&PyFloat_Type, args, 0);
        PyTuple_SET_ITEM(args, 0, 0);
        Py_DECREF(args);
    }
    if (likely(float_value)) {
        double value = PyFloat_AS_DOUBLE(float_value);
        Py_DECREF(float_value);
        return value;
    }
bad:
    return (double)-1;
3304
}
3305 3306 3307 3308
'''
)


3309 3310 3311 3312 3313 3314 3315 3316
bytes_index_utility_code = UtilityCode(
proto = """
static CYTHON_INLINE char __Pyx_PyBytes_GetItemInt(PyObject* unicode, Py_ssize_t index, int check_bounds); /* proto */
""",
impl = """
static CYTHON_INLINE char __Pyx_PyBytes_GetItemInt(PyObject* bytes, Py_ssize_t index, int check_bounds) {
    if (check_bounds) {
        if (unlikely(index >= PyBytes_GET_SIZE(bytes)) |
3317
            ((index < 0) & unlikely(index < -PyBytes_GET_SIZE(bytes)))) {
3318 3319 3320 3321 3322 3323 3324 3325 3326 3327 3328 3329
            PyErr_Format(PyExc_IndexError, "string index out of range");
            return -1;
        }
    }
    if (index < 0)
        index += PyBytes_GET_SIZE(bytes);
    return PyBytes_AS_STRING(bytes)[index];
}
"""
)


3330 3331
tpnew_utility_code = UtilityCode(
proto = """
3332
static CYTHON_INLINE PyObject* __Pyx_tp_new(PyObject* type_obj) {
3333 3334 3335 3336 3337 3338 3339
    return (PyObject*) (((PyTypeObject*)(type_obj))->tp_new(
        (PyTypeObject*)(type_obj), %(TUPLE)s, NULL));
}
""" % {'TUPLE' : Naming.empty_tuple}
)


3340 3341 3342 3343
class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
    """Calculate the result of constant expressions to store it in
    ``expr_node.constant_result``, and replace trivial cases by their
    constant result.
3344 3345 3346 3347 3348 3349 3350 3351 3352 3353 3354

    General rules:

    - We calculate float constants to make them available to the
      compiler, but we do not aggregate them into a single literal
      node to prevent any loss of precision.

    - We recursively calculate constants from non-literal nodes to
      make them available to the compiler, but we only aggregate
      literal nodes at each step.  Non-literal nodes are never merged
      into a single node.
3355 3356 3357 3358 3359 3360 3361 3362 3363 3364 3365
    """
    def _calculate_const(self, node):
        if node.constant_result is not ExprNodes.constant_value_not_set:
            return

        # make sure we always set the value
        not_a_constant = ExprNodes.not_a_constant
        node.constant_result = not_a_constant

        # check if all children are constant
        children = self.visitchildren(node)
3366
        for child_result in children.values():
3367 3368
            if type(child_result) is list:
                for child in child_result:
Stefan Behnel's avatar
Stefan Behnel committed
3369
                    if getattr(child, 'constant_result', not_a_constant) is not_a_constant:
3370
                        return
Stefan Behnel's avatar
Stefan Behnel committed
3371
            elif getattr(child_result, 'constant_result', not_a_constant) is not_a_constant:
3372 3373 3374 3375 3376 3377 3378
                return

        # now try to calculate the real constant value
        try:
            node.calculate_constant_result()
#            if node.constant_result is not ExprNodes.not_a_constant:
#                print node.__class__.__name__, node.constant_result
3379
        except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
3380 3381 3382 3383 3384 3385 3386
            # ignore all 'normal' errors here => no constant result
            pass
        except Exception:
            # this looks like a real error
            import traceback, sys
            traceback.print_exc(file=sys.stdout)

Stefan Behnel's avatar
Stefan Behnel committed
3387 3388
    NODE_TYPE_ORDER = [ExprNodes.CharNode, ExprNodes.IntNode,
                       ExprNodes.LongNode, ExprNodes.FloatNode]
3389 3390 3391 3392 3393 3394 3395 3396

    def _widest_node_class(self, *nodes):
        try:
            return self.NODE_TYPE_ORDER[
                max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))]
        except ValueError:
            return None

3397 3398 3399 3400
    def visit_ExprNode(self, node):
        self._calculate_const(node)
        return node

3401
    def visit_UnopNode(self, node):
3402 3403 3404 3405 3406
        self._calculate_const(node)
        if node.constant_result is ExprNodes.not_a_constant:
            return node
        if not node.operand.is_literal:
            return node
3407 3408 3409 3410 3411 3412 3413 3414 3415 3416 3417
        if isinstance(node.operand, ExprNodes.BoolNode):
            return ExprNodes.IntNode(node.pos, value = str(node.constant_result),
                                     type = PyrexTypes.c_int_type,
                                     constant_result = node.constant_result)
        if node.operator == '+':
            return self._handle_UnaryPlusNode(node)
        elif node.operator == '-':
            return self._handle_UnaryMinusNode(node)
        return node

    def _handle_UnaryMinusNode(self, node):
3418 3419 3420 3421 3422 3423 3424 3425 3426 3427 3428 3429
        if isinstance(node.operand, ExprNodes.LongNode):
            return ExprNodes.LongNode(node.pos, value = '-' + node.operand.value,
                                      constant_result = node.constant_result)
        if isinstance(node.operand, ExprNodes.FloatNode):
            # this is a safe operation
            return ExprNodes.FloatNode(node.pos, value = '-' + node.operand.value,
                                       constant_result = node.constant_result)
        node_type = node.operand.type
        if node_type.is_int and node_type.signed or \
               isinstance(node.operand, ExprNodes.IntNode) and node_type.is_pyobject:
            return ExprNodes.IntNode(node.pos, value = '-' + node.operand.value,
                                     type = node_type,
3430
                                     longness = node.operand.longness,
3431 3432 3433
                                     constant_result = node.constant_result)
        return node

3434
    def _handle_UnaryPlusNode(self, node):
3435 3436 3437 3438
        if node.constant_result == node.operand.constant_result:
            return node.operand
        return node

3439 3440 3441 3442 3443 3444 3445 3446 3447 3448 3449 3450 3451 3452 3453
    def visit_BoolBinopNode(self, node):
        self._calculate_const(node)
        if node.constant_result is ExprNodes.not_a_constant:
            return node
        if not node.operand1.is_literal or not node.operand2.is_literal:
            return node

        if node.constant_result == node.operand1.constant_result and node.operand1.is_literal:
            return node.operand1
        elif node.constant_result == node.operand2.constant_result and node.operand2.is_literal:
            return node.operand2
        else:
            # FIXME: we could do more ...
            return node

3454 3455 3456 3457
    def visit_BinopNode(self, node):
        self._calculate_const(node)
        if node.constant_result is ExprNodes.not_a_constant:
            return node
3458 3459
        if isinstance(node.constant_result, float):
            return node
3460 3461
        operand1, operand2 = node.operand1, node.operand2
        if not operand1.is_literal or not operand2.is_literal:
3462 3463 3464
            return node

        # now inject a new constant node with the calculated value
3465
        try:
3466
            type1, type2 = operand1.type, operand2.type
3467
            if type1 is None or type2 is None:
3468 3469 3470 3471
                return node
        except AttributeError:
            return node

3472
        if type1.is_numeric and type2.is_numeric:
3473
            widest_type = PyrexTypes.widest_numeric_type(type1, type2)
3474 3475
        else:
            widest_type = PyrexTypes.py_object_type
3476
        target_class = self._widest_node_class(operand1, operand2)
3477 3478 3479
        if target_class is None:
            return node
        elif target_class is ExprNodes.IntNode:
3480 3481 3482 3483
            unsigned = getattr(operand1, 'unsigned', '') and \
                       getattr(operand2, 'unsigned', '')
            longness = "LL"[:max(len(getattr(operand1, 'longness', '')),
                                 len(getattr(operand2, 'longness', '')))]
3484 3485 3486 3487 3488 3489 3490 3491
            new_node = ExprNodes.IntNode(pos=node.pos,
                                         unsigned = unsigned, longness = longness,
                                         value = str(node.constant_result),
                                         constant_result = node.constant_result)
            # IntNode is smart about the type it chooses, so we just
            # make sure we were not smarter this time
            if widest_type.is_pyobject or new_node.type.is_pyobject:
                new_node.type = PyrexTypes.py_object_type
3492
            else:
3493
                new_node.type = PyrexTypes.widest_numeric_type(widest_type, new_node.type)
3494
        else:
3495 3496 3497 3498 3499 3500 3501
            if isinstance(node, ExprNodes.BoolNode):
                node_value = node.constant_result
            else:
                node_value = str(node.constant_result)
            new_node = target_class(pos=node.pos, type = widest_type,
                                    value = node_value,
                                    constant_result = node.constant_result)
3502 3503
        return new_node

3504 3505 3506 3507 3508 3509 3510 3511
    def visit_PrimaryCmpNode(self, node):
        self._calculate_const(node)
        if node.constant_result is ExprNodes.not_a_constant:
            return node
        bool_result = bool(node.constant_result)
        return ExprNodes.BoolNode(node.pos, value=bool_result,
                                  constant_result=bool_result)

3512 3513 3514 3515 3516 3517 3518 3519 3520
    def visit_CondExprNode(self, node):
        self._calculate_const(node)
        if node.test.constant_result is ExprNodes.not_a_constant:
            return node
        if node.test.constant_result:
            return node.true_val
        else:
            return node.false_val

3521 3522 3523 3524 3525 3526
    def visit_IfStatNode(self, node):
        self.visitchildren(node)
        # eliminate dead code based on constant condition results
        if_clauses = []
        for if_clause in node.if_clauses:
            condition_result = if_clause.get_constant_condition_result()
3527 3528
            if condition_result is None:
                # unknown result => normal runtime evaluation
3529
                if_clauses.append(if_clause)
3530 3531 3532 3533 3534 3535
            elif condition_result == True:
                # subsequent clauses can safely be dropped
                node.else_clause = if_clause.body
                break
            else:
                assert condition_result == False
3536
        if not if_clauses:
3537 3538 3539
            return node.else_clause
        node.if_clauses = if_clauses
        return node
3540

3541 3542
    # in the future, other nodes can have their own handler method here
    # that can replace them with a constant result node
Stefan Behnel's avatar
Stefan Behnel committed
3543

3544
    visit_Node = Visitor.VisitorTransform.recurse_to_children
3545 3546


3547 3548 3549
class FinalOptimizePhase(Visitor.CythonTransform):
    """
    This visitor handles several commuting optimizations, and is run
3550 3551 3552 3553
    just before the C code generation phase.

    The optimizations currently implemented in this class are:
        - eliminate None assignment and refcounting for first assignment.
3554
        - isinstance -> typecheck for cdef types
Stefan Behnel's avatar
Stefan Behnel committed
3555
        - eliminate checks for None and/or types that became redundant after tree changes
3556
    """
3557
    def visit_SingleAssignmentNode(self, node):
3558 3559 3560 3561
        """Avoid redundant initialisation of local variables before their
        first assignment.
        """
        self.visitchildren(node)
3562 3563
        if node.first:
            lhs = node.lhs
3564
            lhs.lhs_of_first_assignment = True
3565
        return node
3566

3567
    def visit_SimpleCallNode(self, node):
3568 3569 3570
        """Replace generic calls to isinstance(x, type) by a more efficient
        type check.
        """
3571
        self.visitchildren(node)
Robert Bradshaw's avatar
Robert Bradshaw committed
3572
        if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode):
3573 3574 3575
            if node.function.name == 'isinstance':
                type_arg = node.args[1]
                if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
3576 3577
                    cython_scope = self.context.cython_scope
                    node.function.entry = cython_scope.lookup('PyObject_TypeCheck')
3578
                    node.function.type = node.function.entry.type
3579
                    PyTypeObjectPtr = PyrexTypes.CPtrType(cython_scope.lookup('PyTypeObject').type)
3580 3581
                    node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
        return node
Stefan Behnel's avatar
Stefan Behnel committed
3582 3583 3584 3585 3586 3587 3588 3589 3590 3591 3592

    def visit_PyTypeTestNode(self, node):
        """Remove tests for alternatively allowed None values from
        type tests when we know that the argument cannot be None
        anyway.
        """
        self.visitchildren(node)
        if not node.notnone:
            if not node.arg.may_be_none():
                node.notnone = True
        return node
3593 3594 3595 3596 3597 3598 3599 3600 3601

    def visit_NoneCheckNode(self, node):
        """Remove None checks from expressions that definitely do not
        carry a None value.
        """
        self.visitchildren(node)
        if not node.arg.may_be_none():
            return node.arg
        return node