Optimize.py 195 KB
Newer Older
1 2
from __future__ import absolute_import

3 4 5
import sys
import copy
import codecs
6
import itertools
7

8 9
from . import TypeSlots
from .ExprNodes import not_a_constant
10
import cython
11
cython.declare(UtilityCode=object, EncodedString=object, bytes_literal=object,
12
               Nodes=object, ExprNodes=object, PyrexTypes=object, Builtin=object,
13 14 15 16 17 18
               UtilNodes=object, _py_int_types=object)

if sys.version_info[0] >= 3:
    _py_int_types = int
else:
    _py_int_types = (int, long)
19 20 21 22 23 24 25 26 27

from . import Nodes
from . import ExprNodes
from . import PyrexTypes
from . import Visitor
from . import Builtin
from . import UtilNodes
from . import Options

28
from .Code import UtilityCode, TempitaUtilityCode
29
from .StringEncoding import EncodedString, bytes_literal
30 31
from .Errors import error
from .ParseTreeTransforms import SkipDeclarations
32

33
try:
34 35
    from __builtin__ import reduce
except ImportError:
36 37
    from functools import reduce

38 39 40 41 42
try:
    from __builtin__ import basestring
except ImportError:
    basestring = str # Python 3

43

44
def load_c_utility(name):
45
    return UtilityCode.load_cached(name, "Optimize.c")
46

47

48 49 50 51 52
def unwrap_coerced_node(node, coercion_nodes=(ExprNodes.CoerceToPyTypeNode, ExprNodes.CoerceFromPyTypeNode)):
    if isinstance(node, coercion_nodes):
        return node.arg
    return node

53

54
def unwrap_node(node):
55 56
    while isinstance(node, UtilNodes.ResultRefNode):
        node = node.expression
57
    return node
58

59

60
def is_common_value(a, b):
61 62
    a = unwrap_node(a)
    b = unwrap_node(b)
63 64 65
    if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
        return a.name == b.name
    if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
66
        return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
67 68
    return False

69

70 71 72 73 74
def filter_none_node(node):
    if node is not None and node.constant_result is None:
        return None
    return node

75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103

class _YieldNodeCollector(Visitor.TreeVisitor):
    """
    YieldExprNode finder for generator expressions.
    """
    def __init__(self):
        Visitor.TreeVisitor.__init__(self)
        self.yield_stat_nodes = {}
        self.yield_nodes = []

    visit_Node = Visitor.TreeVisitor.visitchildren

    def visit_YieldExprNode(self, node):
        self.yield_nodes.append(node)
        self.visitchildren(node)

    def visit_ExprStatNode(self, node):
        self.visitchildren(node)
        if node.expr in self.yield_nodes:
            self.yield_stat_nodes[node.expr] = node

    # everything below these nodes is out of scope:

    def visit_GeneratorExpressionNode(self, node):
        pass

    def visit_LambdaNode(self, node):
        pass

104 105 106
    def visit_FuncDefNode(self, node):
        pass

107 108

def _find_single_yield_expression(node):
109 110
    yield_statements = _find_yield_statements(node)
    if len(yield_statements) != 1:
111
        return None, None
112
    return yield_statements[0]
113 114


115
def _find_yield_statements(node):
116 117
    collector = _YieldNodeCollector()
    collector.visitchildren(node)
118 119 120 121 122 123 124 125 126
    try:
        yield_statements = [
            (yield_node.arg, collector.yield_stat_nodes[yield_node])
            for yield_node in collector.yield_nodes
        ]
    except KeyError:
        # found YieldExprNode without ExprStatNode (i.e. a non-statement usage of 'yield')
        yield_statements = []
    return yield_statements
127 128


129
class IterationTransform(Visitor.EnvTransform):
130 131 132
    """Transform some common for-in loop patterns into efficient C loops:

    - for-in-dict loop becomes a while loop calling PyDict_Next()
Stefan Behnel's avatar
Stefan Behnel committed
133
    - for-in-enumerate is replaced by an external counter variable
134
    - for-in-range loop becomes a plain C for loop
135
    """
136 137
    def visit_PrimaryCmpNode(self, node):
        if node.is_ptr_contains():
138

139 140 141 142 143 144
            # for t in operand2:
            #     if operand1 == t:
            #         res = True
            #         break
            # else:
            #     res = False
145

146 147
            pos = node.pos
            result_ref = UtilNodes.ResultRefNode(node)
148
            if node.operand2.is_subscript:
149 150 151 152 153 154 155 156
                base_type = node.operand2.base.type.base_type
            else:
                base_type = node.operand2.type.base_type
            target_handle = UtilNodes.TempHandle(base_type)
            target = target_handle.ref(pos)
            cmp_node = ExprNodes.PrimaryCmpNode(
                pos, operator=u'==', operand1=node.operand1, operand2=target)
            if_body = Nodes.StatListNode(
157
                pos,
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
                stats = [Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=1)),
                         Nodes.BreakStatNode(pos)])
            if_node = Nodes.IfStatNode(
                pos,
                if_clauses=[Nodes.IfClauseNode(pos, condition=cmp_node, body=if_body)],
                else_clause=None)
            for_loop = UtilNodes.TempsBlockNode(
                pos,
                temps = [target_handle],
                body = Nodes.ForInStatNode(
                    pos,
                    target=target,
                    iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2),
                    body=if_node,
                    else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0))))
173
            for_loop = for_loop.analyse_expressions(self.current_env())
174
            for_loop = self.visit(for_loop)
175
            new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop)
176

177 178 179 180 181 182 183
            if node.operator == 'not_in':
                new_node = ExprNodes.NotNode(pos, operand=new_node)
            return new_node

        else:
            self.visitchildren(node)
            return node
184

185 186
    def visit_ForInStatNode(self, node):
        self.visitchildren(node)
187
        return self._optimise_for_loop(node, node.iterator.sequence)
188

189
    def _optimise_for_loop(self, node, iterator, reversed=False):
190 191 192 193 194 195
        annotation_type = None
        if (iterator.is_name or iterator.is_attribute) and iterator.entry and iterator.entry.annotation:
            annotation = iterator.entry.annotation
            if annotation.is_subscript:
                annotation = annotation.base  # container base type
            # FIXME: generalise annotation evaluation => maybe provide a "qualified name" also for imported names?
196 197 198 199 200 201 202 203 204
            if annotation.is_name:
                if annotation.entry and annotation.entry.qualified_name == 'typing.Dict':
                    annotation_type = Builtin.dict_type
                elif annotation.name == 'Dict':
                    annotation_type = Builtin.dict_type
                if annotation.entry and annotation.entry.qualified_name in ('typing.Set', 'typing.FrozenSet'):
                    annotation_type = Builtin.set_type
                elif annotation.name in ('Set', 'FrozenSet'):
                    annotation_type = Builtin.set_type
205 206

        if Builtin.dict_type in (iterator.type, annotation_type):
207
            # like iterating over dict.keys()
208
            if reversed:
Stefan Behnel's avatar
Stefan Behnel committed
209
                # CPython raises an error here: not a sequence
210
                return node
Stefan Behnel's avatar
Stefan Behnel committed
211
            return self._transform_dict_iteration(
212
                node, dict_obj=iterator, method=None, keys=True, values=False)
213

214 215 216 217 218 219 220
        if (Builtin.set_type in (iterator.type, annotation_type) or
                Builtin.frozenset_type in (iterator.type, annotation_type)):
            if reversed:
                # CPython raises an error here: not a sequence
                return node
            return self._transform_set_iteration(node, iterator)

221
        # C array (slice) iteration?
222
        if iterator.type.is_ptr or iterator.type.is_array:
223
            return self._transform_carray_iteration(node, iterator, reversed=reversed)
224 225 226 227
        if iterator.type is Builtin.bytes_type:
            return self._transform_bytes_iteration(node, iterator, reversed=reversed)
        if iterator.type is Builtin.unicode_type:
            return self._transform_unicode_iteration(node, iterator, reversed=reversed)
228 229 230

        # the rest is based on function calls
        if not isinstance(iterator, ExprNodes.SimpleCallNode):
Stefan Behnel's avatar
Stefan Behnel committed
231 232
            return node

233 234 235 236 237 238 239
        if iterator.args is None:
            arg_count = iterator.arg_tuple and len(iterator.arg_tuple.args) or 0
        else:
            arg_count = len(iterator.args)
            if arg_count and iterator.self is not None:
                arg_count -= 1

Stefan Behnel's avatar
Stefan Behnel committed
240
        function = iterator.function
241
        # dict iteration?
242
        if function.is_attribute and not reversed and not arg_count:
243
            base_obj = iterator.self or function.obj
244
            method = function.attribute
245
            # in Py3, items() is equivalent to Py2's iteritems()
246
            is_safe_iter = self.global_scope().context.language_level >= 3
247 248 249

            if not is_safe_iter and method in ('keys', 'values', 'items'):
                # try to reduce this to the corresponding .iter*() methods
250
                if isinstance(base_obj, ExprNodes.CallNode):
251 252 253 254 255 256
                    inner_function = base_obj.function
                    if (inner_function.is_name and inner_function.name == 'dict'
                            and inner_function.entry
                            and inner_function.entry.is_builtin):
                        # e.g. dict(something).items() => safe to use .iter*()
                        is_safe_iter = True
257 258

            keys = values = False
259
            if method == 'iterkeys' or (is_safe_iter and method == 'keys'):
260
                keys = True
261
            elif method == 'itervalues' or (is_safe_iter and method == 'values'):
262
                values = True
263
            elif method == 'iteritems' or (is_safe_iter and method == 'items'):
264
                keys = values = True
265 266 267 268

            if keys or values:
                return self._transform_dict_iteration(
                    node, base_obj, method, keys, values)
269

270
        # enumerate/reversed ?
Stefan Behnel's avatar
Stefan Behnel committed
271
        if iterator.self is None and function.is_name and \
272 273 274
               function.entry and function.entry.is_builtin:
            if function.name == 'enumerate':
                if reversed:
Stefan Behnel's avatar
Stefan Behnel committed
275
                    # CPython raises an error here: not a sequence
276 277 278 279
                    return node
                return self._transform_enumerate_iteration(node, iterator)
            elif function.name == 'reversed':
                if reversed:
Stefan Behnel's avatar
Stefan Behnel committed
280
                    # CPython raises an error here: not a sequence
281 282
                    return node
                return self._transform_reversed_iteration(node, iterator)
283

284
        # range() iteration?
285 286 287 288 289
        if Options.convert_range and arg_count >= 1 and (
                iterator.self is None and
                function.is_name and function.name in ('range', 'xrange') and
                function.entry and function.entry.is_builtin):
            if node.target.type.is_int or node.target.type.is_enum:
290
                return self._transform_range_iteration(node, iterator, reversed=reversed)
291 292 293 294 295 296 297 298 299
            if node.target.type.is_pyobject:
                # Assume that small integer ranges (C long >= 32bit) are best handled in C as well.
                for arg in (iterator.arg_tuple.args if iterator.args is None else iterator.args):
                    if isinstance(arg, ExprNodes.IntNode):
                        if arg.has_constant_result() and -2**30 <= arg.constant_result < 2**30:
                            continue
                    break
                else:
                    return self._transform_range_iteration(node, iterator, reversed=reversed)
300

Stefan Behnel's avatar
Stefan Behnel committed
301
        return node
302

303 304 305 306 307 308 309 310 311 312
    def _transform_reversed_iteration(self, node, reversed_function):
        args = reversed_function.arg_tuple.args
        if len(args) == 0:
            error(reversed_function.pos,
                  "reversed() requires an iterable argument")
            return node
        elif len(args) > 1:
            error(reversed_function.pos,
                  "reversed() takes exactly 1 argument")
            return node
313 314 315 316 317 318 319 320 321
        arg = args[0]

        # reversed(list/tuple) ?
        if arg.type in (Builtin.tuple_type, Builtin.list_type):
            node.iterator.sequence = arg.as_none_safe_node("'NoneType' object is not iterable")
            node.iterator.reversed = True
            return node

        return self._optimise_for_loop(node, arg, reversed=True)
322

323 324 325 326 327 328 329 330 331 332
    PyBytes_AS_STRING_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_char_ptr_type, [
            PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
            ])

    PyBytes_GET_SIZE_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_py_ssize_t_type, [
            PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
            ])

333 334
    def _transform_bytes_iteration(self, node, slice_node, reversed=False):
        target_type = node.target.type
335
        if not target_type.is_int and target_type is not Builtin.bytes_type:
336 337
            # bytes iteration returns bytes objects in Py2, but
            # integers in Py3
338 339 340
            return node

        unpack_temp_node = UtilNodes.LetRefNode(
341
            slice_node.as_none_safe_node("'NoneType' is not iterable"))
342 343

        slice_base_node = ExprNodes.PythonCapiCallNode(
344 345
            slice_node.pos, "PyBytes_AS_STRING",
            self.PyBytes_AS_STRING_func_type,
346 347 348 349
            args = [unpack_temp_node],
            is_temp = 0,
            )
        len_node = ExprNodes.PythonCapiCallNode(
350 351
            slice_node.pos, "PyBytes_GET_SIZE",
            self.PyBytes_GET_SIZE_func_type,
352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367
            args = [unpack_temp_node],
            is_temp = 0,
            )

        return UtilNodes.LetNode(
            unpack_temp_node,
            self._transform_carray_iteration(
                node,
                ExprNodes.SliceIndexNode(
                    slice_node.pos,
                    base = slice_base_node,
                    start = None,
                    step = None,
                    stop = len_node,
                    type = slice_base_node.type,
                    is_temp = 1,
368 369
                    ),
                reversed = reversed))
370

371 372 373 374 375 376 377
    PyUnicode_READ_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_py_ucs4_type, [
            PyrexTypes.CFuncTypeArg("kind", PyrexTypes.c_int_type, None),
            PyrexTypes.CFuncTypeArg("data", PyrexTypes.c_void_ptr_type, None),
            PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None)
        ])

378 379 380 381 382 383 384 385
    init_unicode_iteration_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_int_type, [
            PyrexTypes.CFuncTypeArg("s", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("length", PyrexTypes.c_py_ssize_t_ptr_type, None),
            PyrexTypes.CFuncTypeArg("data", PyrexTypes.c_void_ptr_ptr_type, None),
            PyrexTypes.CFuncTypeArg("kind", PyrexTypes.c_int_ptr_type, None)
        ],
        exception_value = '-1')
386 387

    def _transform_unicode_iteration(self, node, slice_node, reversed=False):
388 389 390
        if slice_node.is_literal:
            # try to reduce to byte iteration for plain Latin-1 strings
            try:
391
                bytes_value = bytes_literal(slice_node.value.encode('latin1'), 'iso8859-1')
392 393 394 395 396 397 398 399
            except UnicodeEncodeError:
                pass
            else:
                bytes_slice = ExprNodes.SliceIndexNode(
                    slice_node.pos,
                    base=ExprNodes.BytesNode(
                        slice_node.pos, value=bytes_value,
                        constant_result=bytes_value,
400 401
                        type=PyrexTypes.c_const_char_ptr_type).coerce_to(
                            PyrexTypes.c_const_uchar_ptr_type, self.current_env()),
402 403
                    start=None,
                    stop=ExprNodes.IntNode(
404
                        slice_node.pos, value=str(len(bytes_value)),
405 406 407 408 409 410
                        constant_result=len(bytes_value),
                        type=PyrexTypes.c_py_ssize_t_type),
                    type=Builtin.unicode_type,  # hint for Python conversion
                )
                return self._transform_carray_iteration(node, bytes_slice, reversed)

411 412 413 414 415
        unpack_temp_node = UtilNodes.LetRefNode(
            slice_node.as_none_safe_node("'NoneType' is not iterable"))

        start_node = ExprNodes.IntNode(
            node.pos, value='0', constant_result=0, type=PyrexTypes.c_py_ssize_t_type)
416 417
        length_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
        end_node = length_temp.ref(node.pos)
418 419 420 421 422 423
        if reversed:
            relation1, relation2 = '>', '>='
            start_node, end_node = end_node, start_node
        else:
            relation1, relation2 = '<=', '<'

424 425 426
        kind_temp = UtilNodes.TempHandle(PyrexTypes.c_int_type)
        data_temp = UtilNodes.TempHandle(PyrexTypes.c_void_ptr_type)
        counter_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
427

428 429 430
        target_value = ExprNodes.PythonCapiCallNode(
            slice_node.pos, "__Pyx_PyUnicode_READ",
            self.PyUnicode_READ_func_type,
431 432 433
            args = [kind_temp.ref(slice_node.pos),
                    data_temp.ref(slice_node.pos),
                    counter_temp.ref(node.target.pos)],
434 435 436 437
            is_temp = False,
            )
        if target_value.type != node.target.type:
            target_value = target_value.coerce_to(node.target.type,
438
                                                  self.current_env())
439 440 441
        target_assign = Nodes.SingleAssignmentNode(
            pos = node.target.pos,
            lhs = node.target,
442
            rhs = target_value)
443 444 445 446 447 448 449
        body = Nodes.StatListNode(
            node.pos,
            stats = [target_assign, node.body])

        loop_node = Nodes.ForFromStatNode(
            node.pos,
            bound1=start_node, relation1=relation1,
450
            target=counter_temp.ref(node.target.pos),
451 452 453 454 455
            relation2=relation2, bound2=end_node,
            step=None, body=body,
            else_clause=node.else_clause,
            from_range=True)

456 457 458
        setup_node = Nodes.ExprStatNode(
            node.pos,
            expr = ExprNodes.PythonCapiCallNode(
459 460 461 462 463 464 465 466 467 468
                slice_node.pos, "__Pyx_init_unicode_iteration",
                self.init_unicode_iteration_func_type,
                args = [unpack_temp_node,
                        ExprNodes.AmpersandNode(slice_node.pos, operand=length_temp.ref(slice_node.pos),
                                                type=PyrexTypes.c_py_ssize_t_ptr_type),
                        ExprNodes.AmpersandNode(slice_node.pos, operand=data_temp.ref(slice_node.pos),
                                                type=PyrexTypes.c_void_ptr_ptr_type),
                        ExprNodes.AmpersandNode(slice_node.pos, operand=kind_temp.ref(slice_node.pos),
                                                type=PyrexTypes.c_int_ptr_type),
                        ],
469 470
                is_temp = True,
                result_is_used = False,
471
                utility_code=UtilityCode.load_cached("unicode_iter", "Optimize.c"),
472 473 474
                ))
        return UtilNodes.LetNode(
            unpack_temp_node,
475 476 477
            UtilNodes.TempsBlockNode(
                node.pos, temps=[counter_temp, length_temp, data_temp, kind_temp],
                body=Nodes.StatListNode(node.pos, stats=[setup_node, loop_node])))
478

479
    def _transform_carray_iteration(self, node, slice_node, reversed=False):
480
        neg_step = False
481 482
        if isinstance(slice_node, ExprNodes.SliceIndexNode):
            slice_base = slice_node.base
483 484
            start = filter_none_node(slice_node.start)
            stop = filter_none_node(slice_node.stop)
485 486
            step = None
            if not stop:
487 488
                if not slice_base.type.is_pyobject:
                    error(slice_node.pos, "C array iteration requires known end index")
489
                return node
490

491
        elif slice_node.is_subscript:
492
            assert isinstance(slice_node.index, ExprNodes.SliceNode)
493 494
            slice_base = slice_node.base
            index = slice_node.index
495 496 497
            start = filter_none_node(index.start)
            stop = filter_none_node(index.stop)
            step = filter_none_node(index.step)
498
            if step:
499
                if not isinstance(step.constant_result, _py_int_types) \
500 501 502
                       or step.constant_result == 0 \
                       or step.constant_result > 0 and not stop \
                       or step.constant_result < 0 and not start:
503 504
                    if not slice_base.type.is_pyobject:
                        error(step.pos, "C array iteration requires known step size and end index")
505 506 507
                    return node
                else:
                    # step sign is handled internally by ForFromStatNode
508 509 510 511
                    step_value = step.constant_result
                    if reversed:
                        step_value = -step_value
                    neg_step = step_value < 0
512
                    step = ExprNodes.IntNode(step.pos, type=PyrexTypes.c_py_ssize_t_type,
513 514
                                             value=str(abs(step_value)),
                                             constant_result=abs(step_value))
515

516 517
        elif slice_node.type.is_array:
            if slice_node.type.size is None:
Stefan Behnel's avatar
Stefan Behnel committed
518
                error(slice_node.pos, "C array iteration requires known end index")
519
                return node
520 521 522
            slice_base = slice_node
            start = None
            stop = ExprNodes.IntNode(
523 524
                slice_node.pos, value=str(slice_node.type.size),
                type=PyrexTypes.c_py_ssize_t_type, constant_result=slice_node.type.size)
525
            step = None
526

527
        else:
528
            if not slice_node.type.is_pyobject:
529
                error(slice_node.pos, "C array iteration requires known end index")
530 531
            return node

532
        if start:
533
            start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
534
        if stop:
535
            stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
536 537 538 539 540 541 542
        if stop is None:
            if neg_step:
                stop = ExprNodes.IntNode(
                    slice_node.pos, value='-1', type=PyrexTypes.c_py_ssize_t_type, constant_result=-1)
            else:
                error(slice_node.pos, "C array iteration requires known step size and end index")
                return node
543

544 545 546 547 548 549 550
        if reversed:
            if not start:
                start = ExprNodes.IntNode(slice_node.pos, value="0",  constant_result=0,
                                          type=PyrexTypes.c_py_ssize_t_type)
            # if step was provided, it was already negated above
            start, stop = stop, start

551 552 553
        ptr_type = slice_base.type
        if ptr_type.is_array:
            ptr_type = ptr_type.element_ptr_type()
554
        carray_ptr = slice_base.coerce_to_simple(self.current_env())
555

556
        if start and start.constant_result != 0:
557 558 559 560 561
            start_ptr_node = ExprNodes.AddNode(
                start.pos,
                operand1=carray_ptr,
                operator='+',
                operand2=start,
562
                type=ptr_type)
563
        else:
564
            start_ptr_node = carray_ptr
565

566 567 568 569 570 571 572
        if stop and stop.constant_result != 0:
            stop_ptr_node = ExprNodes.AddNode(
                stop.pos,
                operand1=ExprNodes.CloneNode(carray_ptr),
                operator='+',
                operand2=stop,
                type=ptr_type
573
                ).coerce_to_simple(self.current_env())
574 575
        else:
            stop_ptr_node = ExprNodes.CloneNode(carray_ptr)
576

577
        counter = UtilNodes.TempHandle(ptr_type)
578 579
        counter_temp = counter.ref(node.target.pos)

580
        if slice_base.type.is_string and node.target.type.is_pyobject:
581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601
            # special case: char* -> bytes/unicode
            if slice_node.type is Builtin.unicode_type:
                target_value = ExprNodes.CastNode(
                    ExprNodes.DereferenceNode(
                        node.target.pos, operand=counter_temp,
                        type=ptr_type.base_type),
                    PyrexTypes.c_py_ucs4_type).coerce_to(
                        node.target.type, self.current_env())
            else:
                # char* -> bytes coercion requires slicing, not indexing
                target_value = ExprNodes.SliceIndexNode(
                    node.target.pos,
                    start=ExprNodes.IntNode(node.target.pos, value='0',
                                            constant_result=0,
                                            type=PyrexTypes.c_int_type),
                    stop=ExprNodes.IntNode(node.target.pos, value='1',
                                           constant_result=1,
                                           type=PyrexTypes.c_int_type),
                    base=counter_temp,
                    type=Builtin.bytes_type,
                    is_temp=1)
602 603 604
        elif node.target.type.is_ptr and not node.target.type.assignable_from(ptr_type.base_type):
            # Allow iteration with pointer target to avoid copy.
            target_value = counter_temp
605
        else:
606
            # TODO: can this safely be replaced with DereferenceNode() as above?
607 608
            target_value = ExprNodes.IndexNode(
                node.target.pos,
609 610 611 612
                index=ExprNodes.IntNode(node.target.pos, value='0',
                                        constant_result=0,
                                        type=PyrexTypes.c_int_type),
                base=counter_temp,
613
                type=ptr_type.base_type)
614 615 616

        if target_value.type != node.target.type:
            target_value = target_value.coerce_to(node.target.type,
617
                                                  self.current_env())
618 619 620 621 622 623 624 625 626 627

        target_assign = Nodes.SingleAssignmentNode(
            pos = node.target.pos,
            lhs = node.target,
            rhs = target_value)

        body = Nodes.StatListNode(
            node.pos,
            stats = [target_assign, node.body])

628 629
        relation1, relation2 = self._find_for_from_node_relations(neg_step, reversed)

630 631
        for_node = Nodes.ForFromStatNode(
            node.pos,
632
            bound1=start_ptr_node, relation1=relation1,
633
            target=counter_temp,
634
            relation2=relation2, bound2=stop_ptr_node,
635 636 637 638 639 640 641 642
            step=step, body=body,
            else_clause=node.else_clause,
            from_range=True)

        return UtilNodes.TempsBlockNode(
            node.pos, temps=[counter],
            body=for_node)

643 644 645 646 647 648
    def _transform_enumerate_iteration(self, node, enumerate_function):
        args = enumerate_function.arg_tuple.args
        if len(args) == 0:
            error(enumerate_function.pos,
                  "enumerate() requires an iterable argument")
            return node
649
        elif len(args) > 2:
650
            error(enumerate_function.pos,
651
                  "enumerate() takes at most 2 arguments")
652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668
            return node

        if not node.target.is_sequence_constructor:
            # leave this untouched for now
            return node
        targets = node.target.args
        if len(targets) != 2:
            # leave this untouched for now
            return node

        enumerate_target, iterable_target = targets
        counter_type = enumerate_target.type

        if not counter_type.is_pyobject and not counter_type.is_int:
            # nothing we can do here, I guess
            return node

669
        if len(args) == 2:
670
            start = unwrap_coerced_node(args[1]).coerce_to(counter_type, self.current_env())
671 672 673 674 675 676 677
        else:
            start = ExprNodes.IntNode(enumerate_function.pos,
                                      value='0',
                                      type=counter_type,
                                      constant_result=0)
        temp = UtilNodes.LetRefNode(start)

678 679
        inc_expression = ExprNodes.AddNode(
            enumerate_function.pos,
680
            operand1 = temp,
681
            operand2 = ExprNodes.IntNode(node.pos, value='1',
682 683
                                         type=counter_type,
                                         constant_result=1),
684 685
            operator = '+',
            type = counter_type,
Stefan Behnel's avatar
Stefan Behnel committed
686
            #inplace = True,   # not worth using in-place operation for Py ints
687 688 689
            is_temp = counter_type.is_pyobject
            )

690 691 692 693
        loop_body = [
            Nodes.SingleAssignmentNode(
                pos = enumerate_target.pos,
                lhs = enumerate_target,
694
                rhs = temp),
695 696
            Nodes.SingleAssignmentNode(
                pos = enumerate_target.pos,
697
                lhs = temp,
698 699
                rhs = inc_expression)
            ]
700

701 702 703 704 705 706 707
        if isinstance(node.body, Nodes.StatListNode):
            node.body.stats = loop_body + node.body.stats
        else:
            loop_body.append(node.body)
            node.body = Nodes.StatListNode(
                node.body.pos,
                stats = loop_body)
708 709

        node.target = iterable_target
710
        node.item = node.item.coerce_to(iterable_target.type, self.current_env())
711
        node.iterator.sequence = args[0]
712 713

        # recurse into loop to check for further optimisations
714
        return UtilNodes.LetNode(temp, self._optimise_for_loop(node, node.iterator.sequence))
715

716 717 718 719 720 721 722 723 724 725 726
    def _find_for_from_node_relations(self, neg_step_value, reversed):
        if reversed:
            if neg_step_value:
                return '<', '<='
            else:
                return '>', '>='
        else:
            if neg_step_value:
                return '>=', '>'
            else:
                return '<=', '<'
727

728
    def _transform_range_iteration(self, node, range_function, reversed=False):
729 730 731 732
        args = range_function.arg_tuple.args
        if len(args) < 3:
            step_pos = range_function.pos
            step_value = 1
733
            step = ExprNodes.IntNode(step_pos, value='1', constant_result=1)
734 735 736
        else:
            step = args[2]
            step_pos = step.pos
737
            if not isinstance(step.constant_result, _py_int_types):
738 739
                # cannot determine step direction
                return node
740 741 742
            step_value = step.constant_result
            if step_value == 0:
                # will lead to an error elsewhere
743
                return node
744 745
            step = ExprNodes.IntNode(step_pos, value=str(step_value),
                                     constant_result=step_value)
746 747

        if len(args) == 1:
748 749
            bound1 = ExprNodes.IntNode(range_function.pos, value='0',
                                       constant_result=0)
750
            bound2 = args[0].coerce_to_integer(self.current_env())
751
        else:
752 753
            bound1 = args[0].coerce_to_integer(self.current_env())
            bound2 = args[1].coerce_to_integer(self.current_env())
754

755 756
        relation1, relation2 = self._find_for_from_node_relations(step_value < 0, reversed)

757
        bound2_ref_node = None
758 759
        if reversed:
            bound1, bound2 = bound2, bound1
Stefan Behnel's avatar
Stefan Behnel committed
760 761
            abs_step = abs(step_value)
            if abs_step != 1:
762 763
                if (isinstance(bound1.constant_result, _py_int_types) and
                        isinstance(bound2.constant_result, _py_int_types)):
764
                    # calculate final bounds now
765 766 767 768 769 770 771 772 773 774 775 776
                    if step_value < 0:
                        begin_value = bound2.constant_result
                        end_value = bound1.constant_result
                        bound1_value = begin_value - abs_step * ((begin_value - end_value - 1) // abs_step) - 1
                    else:
                        begin_value = bound1.constant_result
                        end_value = bound2.constant_result
                        bound1_value = end_value + abs_step * ((begin_value - end_value - 1) // abs_step) + 1

                    bound1 = ExprNodes.IntNode(
                        bound1.pos, value=str(bound1_value), constant_result=bound1_value,
                        type=PyrexTypes.spanning_type(bound1.type, bound2.type))
777
                else:
778
                    # evaluate the same expression as above at runtime
779
                    bound2_ref_node = UtilNodes.LetRefNode(bound2)
780 781
                    bound1 = self._build_range_step_calculation(
                        bound1, bound2_ref_node, step, step_value)
Stefan Behnel's avatar
Stefan Behnel committed
782 783 784

        if step_value < 0:
            step_value = -step_value
785 786
        step.value = str(step_value)
        step.constant_result = step_value
787
        step = step.coerce_to_integer(self.current_env())
788

789
        if not bound2.is_literal:
790 791
            # stop bound must be immutable => keep it in a temp var
            bound2_is_temp = True
792
            bound2 = bound2_ref_node or UtilNodes.LetRefNode(bound2)
793 794 795
        else:
            bound2_is_temp = False

796 797 798 799 800 801 802
        for_node = Nodes.ForFromStatNode(
            node.pos,
            target=node.target,
            bound1=bound1, relation1=relation1,
            relation2=relation2, bound2=bound2,
            step=step, body=node.body,
            else_clause=node.else_clause,
Magnus Lie Hetland's avatar
Magnus Lie Hetland committed
803
            from_range=True)
804
        for_node.set_up_loop(self.current_env())
805 806 807 808

        if bound2_is_temp:
            for_node = UtilNodes.LetNode(bound2, for_node)

809 810
        return for_node

811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874
    def _build_range_step_calculation(self, bound1, bound2_ref_node, step, step_value):
        abs_step = abs(step_value)
        spanning_type = PyrexTypes.spanning_type(bound1.type, bound2_ref_node.type)
        if step.type.is_int and abs_step < 0x7FFF:
            # Avoid loss of integer precision warnings.
            spanning_step_type = PyrexTypes.spanning_type(spanning_type, PyrexTypes.c_int_type)
        else:
            spanning_step_type = PyrexTypes.spanning_type(spanning_type, step.type)
        if step_value < 0:
            begin_value = bound2_ref_node
            end_value = bound1
            final_op = '-'
        else:
            begin_value = bound1
            end_value = bound2_ref_node
            final_op = '+'

        step_calculation_node = ExprNodes.binop_node(
            bound1.pos,
            operand1=ExprNodes.binop_node(
                bound1.pos,
                operand1=bound2_ref_node,
                operator=final_op,  # +/-
                operand2=ExprNodes.MulNode(
                    bound1.pos,
                    operand1=ExprNodes.IntNode(
                        bound1.pos,
                        value=str(abs_step),
                        constant_result=abs_step,
                        type=spanning_step_type),
                    operator='*',
                    operand2=ExprNodes.DivNode(
                        bound1.pos,
                        operand1=ExprNodes.SubNode(
                            bound1.pos,
                            operand1=ExprNodes.SubNode(
                                bound1.pos,
                                operand1=begin_value,
                                operator='-',
                                operand2=end_value,
                                type=spanning_type),
                            operator='-',
                            operand2=ExprNodes.IntNode(
                                bound1.pos,
                                value='1',
                                constant_result=1),
                            type=spanning_step_type),
                        operator='//',
                        operand2=ExprNodes.IntNode(
                            bound1.pos,
                            value=str(abs_step),
                            constant_result=abs_step,
                            type=spanning_step_type),
                        type=spanning_step_type),
                    type=spanning_step_type),
                type=spanning_step_type),
            operator=final_op,  # +/-
            operand2=ExprNodes.IntNode(
                bound1.pos,
                value='1',
                constant_result=1),
            type=spanning_type)
        return step_calculation_node

875
    def _transform_dict_iteration(self, node, dict_obj, method, keys, values):
876
        temps = []
877 878 879
        temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
        temps.append(temp)
        dict_temp = temp.ref(dict_obj.pos)
880 881
        temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
        temps.append(temp)
882
        pos_temp = temp.ref(node.pos)
883

884
        key_target = value_target = tuple_target = None
885 886 887 888 889
        if keys and values:
            if node.target.is_sequence_constructor:
                if len(node.target.args) == 2:
                    key_target, value_target = node.target.args
                else:
Stefan Behnel's avatar
Stefan Behnel committed
890
                    # unusual case that may or may not lead to an error
891 892 893
                    return node
            else:
                tuple_target = node.target
894 895
        elif keys:
            key_target = node.target
896
        else:
897
            value_target = node.target
898 899 900 901 902 903 904 905 906 907

        if isinstance(node.body, Nodes.StatListNode):
            body = node.body
        else:
            body = Nodes.StatListNode(pos = node.body.pos,
                                      stats = [node.body])

        # keep original length to guard against dict modification
        dict_len_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
        temps.append(dict_len_temp)
908 909 910 911 912 913 914 915 916 917 918 919 920 921
        dict_len_temp_addr = ExprNodes.AmpersandNode(
            node.pos, operand=dict_len_temp.ref(dict_obj.pos),
            type=PyrexTypes.c_ptr_type(dict_len_temp.type))
        temp = UtilNodes.TempHandle(PyrexTypes.c_int_type)
        temps.append(temp)
        is_dict_temp = temp.ref(node.pos)
        is_dict_temp_addr = ExprNodes.AmpersandNode(
            node.pos, operand=is_dict_temp,
            type=PyrexTypes.c_ptr_type(temp.type))

        iter_next_node = Nodes.DictIterationNextNode(
            dict_temp, dict_len_temp.ref(dict_obj.pos), pos_temp,
            key_target, value_target, tuple_target,
            is_dict_temp)
922
        iter_next_node = iter_next_node.analyse_expressions(self.current_env())
923 924 925 926 927 928
        body.stats[0:0] = [iter_next_node]

        if method:
            method_node = ExprNodes.StringNode(
                dict_obj.pos, is_identifier=True, value=method)
            dict_obj = dict_obj.as_none_safe_node(
929
                "'NoneType' object has no attribute '%{0}s'".format('.30' if len(method) <= 30 else ''),
930 931 932 933 934
                error = "PyExc_AttributeError",
                format_args = [method])
        else:
            method_node = ExprNodes.NullNode(dict_obj.pos)
            dict_obj = dict_obj.as_none_safe_node("'NoneType' object is not iterable")
935

936 937 938
        def flag_node(value):
            value = value and 1 or 0
            return ExprNodes.IntNode(node.pos, value=str(value), constant_result=value)
939 940

        result_code = [
941
            Nodes.SingleAssignmentNode(
Stefan Behnel's avatar
Stefan Behnel committed
942
                node.pos,
943
                lhs = pos_temp,
944 945
                rhs = ExprNodes.IntNode(node.pos, value='0',
                                        constant_result=0)),
946
            Nodes.SingleAssignmentNode(
Stefan Behnel's avatar
Stefan Behnel committed
947
                dict_obj.pos,
948 949 950 951 952 953
                lhs = dict_temp,
                rhs = ExprNodes.PythonCapiCallNode(
                    dict_obj.pos,
                    "__Pyx_dict_iterator",
                    self.PyDict_Iterator_func_type,
                    utility_code = UtilityCode.load_cached("dict_iter", "Optimize.c"),
954
                    args = [dict_obj, flag_node(dict_obj.type is Builtin.dict_type),
955 956 957
                            method_node, dict_len_temp_addr, is_dict_temp_addr,
                            ],
                    is_temp=True,
958 959
                )),
            Nodes.WhileStatNode(
Stefan Behnel's avatar
Stefan Behnel committed
960
                node.pos,
961
                condition = None,
962 963 964 965 966 967 968 969
                body = body,
                else_clause = node.else_clause
                )
            ]

        return UtilNodes.TempsBlockNode(
            node.pos, temps=temps,
            body=Nodes.StatListNode(
970
                node.pos,
971 972 973
                stats = result_code
                ))

974 975 976 977 978 979 980 981 982
    PyDict_Iterator_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("dict",  PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("is_dict",  PyrexTypes.c_int_type, None),
            PyrexTypes.CFuncTypeArg("method_name",  PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("p_orig_length",  PyrexTypes.c_py_ssize_t_ptr_type, None),
            PyrexTypes.CFuncTypeArg("p_is_dict",  PyrexTypes.c_int_ptr_type, None),
            ])

983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061
    PySet_Iterator_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("set",  PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("is_set",  PyrexTypes.c_int_type, None),
            PyrexTypes.CFuncTypeArg("p_orig_length",  PyrexTypes.c_py_ssize_t_ptr_type, None),
            PyrexTypes.CFuncTypeArg("p_is_set",  PyrexTypes.c_int_ptr_type, None),
            ])

    def _transform_set_iteration(self, node, set_obj):
        temps = []
        temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
        temps.append(temp)
        set_temp = temp.ref(set_obj.pos)
        temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
        temps.append(temp)
        pos_temp = temp.ref(node.pos)

        if isinstance(node.body, Nodes.StatListNode):
            body = node.body
        else:
            body = Nodes.StatListNode(pos = node.body.pos,
                                      stats = [node.body])

        # keep original length to guard against set modification
        set_len_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
        temps.append(set_len_temp)
        set_len_temp_addr = ExprNodes.AmpersandNode(
            node.pos, operand=set_len_temp.ref(set_obj.pos),
            type=PyrexTypes.c_ptr_type(set_len_temp.type))
        temp = UtilNodes.TempHandle(PyrexTypes.c_int_type)
        temps.append(temp)
        is_set_temp = temp.ref(node.pos)
        is_set_temp_addr = ExprNodes.AmpersandNode(
            node.pos, operand=is_set_temp,
            type=PyrexTypes.c_ptr_type(temp.type))

        value_target = node.target
        iter_next_node = Nodes.SetIterationNextNode(
            set_temp, set_len_temp.ref(set_obj.pos), pos_temp, value_target, is_set_temp)
        iter_next_node = iter_next_node.analyse_expressions(self.current_env())
        body.stats[0:0] = [iter_next_node]

        def flag_node(value):
            value = value and 1 or 0
            return ExprNodes.IntNode(node.pos, value=str(value), constant_result=value)

        result_code = [
            Nodes.SingleAssignmentNode(
                node.pos,
                lhs=pos_temp,
                rhs=ExprNodes.IntNode(node.pos, value='0', constant_result=0)),
            Nodes.SingleAssignmentNode(
                set_obj.pos,
                lhs=set_temp,
                rhs=ExprNodes.PythonCapiCallNode(
                    set_obj.pos,
                    "__Pyx_set_iterator",
                    self.PySet_Iterator_func_type,
                    utility_code=UtilityCode.load_cached("set_iter", "Optimize.c"),
                    args=[set_obj, flag_node(set_obj.type is Builtin.set_type),
                          set_len_temp_addr, is_set_temp_addr,
                          ],
                    is_temp=True,
                )),
            Nodes.WhileStatNode(
                node.pos,
                condition=None,
                body=body,
                else_clause=node.else_clause,
                )
            ]

        return UtilNodes.TempsBlockNode(
            node.pos, temps=temps,
            body=Nodes.StatListNode(
                node.pos,
                stats = result_code
                ))

1062

1063
class SwitchTransform(Visitor.EnvTransform):
1064
    """
1065
    This transformation tries to turn long if statements into C switch statements.
1066
    The requirement is that every clause be an (or of) var == value, where the var
1067
    is common among all clauses and both var and value are ints.
1068
    """
1069 1070 1071
    NO_MATCH = (None, None, None)

    def extract_conditions(self, cond, allow_not_in):
1072
        while True:
1073 1074
            if isinstance(cond, (ExprNodes.CoerceToTempNode,
                                 ExprNodes.CoerceToBooleanNode)):
1075
                cond = cond.arg
1076 1077
            elif isinstance(cond, ExprNodes.BoolBinopResultNode):
                cond = cond.arg.arg
1078 1079 1080 1081 1082 1083 1084
            elif isinstance(cond, UtilNodes.EvalWithTempExprNode):
                # this is what we get from the FlattenInListTransform
                cond = cond.subexpression
            elif isinstance(cond, ExprNodes.TypecastNode):
                cond = cond.operand
            else:
                break
1085

1086
        if isinstance(cond, ExprNodes.PrimaryCmpNode):
1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101
            if cond.cascade is not None:
                return self.NO_MATCH
            elif cond.is_c_string_contains() and \
                   isinstance(cond.operand2, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)):
                not_in = cond.operator == 'not_in'
                if not_in and not allow_not_in:
                    return self.NO_MATCH
                if isinstance(cond.operand2, ExprNodes.UnicodeNode) and \
                       cond.operand2.contains_surrogates():
                    # dealing with surrogates leads to different
                    # behaviour on wide and narrow Unicode
                    # platforms => refuse to optimise this case
                    return self.NO_MATCH
                return not_in, cond.operand1, self.extract_in_string_conditions(cond.operand2)
            elif not cond.is_python_comparison():
1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121
                if cond.operator == '==':
                    not_in = False
                elif allow_not_in and cond.operator == '!=':
                    not_in = True
                else:
                    return self.NO_MATCH
                # this looks somewhat silly, but it does the right
                # checks for NameNode and AttributeNode
                if is_common_value(cond.operand1, cond.operand1):
                    if cond.operand2.is_literal:
                        return not_in, cond.operand1, [cond.operand2]
                    elif getattr(cond.operand2, 'entry', None) \
                             and cond.operand2.entry.is_const:
                        return not_in, cond.operand1, [cond.operand2]
                if is_common_value(cond.operand2, cond.operand2):
                    if cond.operand1.is_literal:
                        return not_in, cond.operand2, [cond.operand1]
                    elif getattr(cond.operand1, 'entry', None) \
                             and cond.operand1.entry.is_const:
                        return not_in, cond.operand2, [cond.operand1]
Stefan Behnel's avatar
Stefan Behnel committed
1122
        elif isinstance(cond, ExprNodes.BoolBinopNode):
1123 1124 1125 1126 1127 1128 1129 1130 1131
            if cond.operator == 'or' or (allow_not_in and cond.operator == 'and'):
                allow_not_in = (cond.operator == 'and')
                not_in_1, t1, c1 = self.extract_conditions(cond.operand1, allow_not_in)
                not_in_2, t2, c2 = self.extract_conditions(cond.operand2, allow_not_in)
                if t1 is not None and not_in_1 == not_in_2 and is_common_value(t1, t2):
                    if (not not_in_1) or allow_not_in:
                        return not_in_1, t1, c1+c2
        return self.NO_MATCH

1132 1133
    def extract_in_string_conditions(self, string_literal):
        if isinstance(string_literal, ExprNodes.UnicodeNode):
1134
            charvals = list(map(ord, set(string_literal.value)))
1135 1136 1137 1138 1139 1140 1141 1142 1143
            charvals.sort()
            return [ ExprNodes.IntNode(string_literal.pos, value=str(charval),
                                       constant_result=charval)
                     for charval in charvals ]
        else:
            # this is a bit tricky as Py3's bytes type returns
            # integers on iteration, whereas Py2 returns 1-char byte
            # strings
            characters = string_literal.value
1144 1145
            characters = list(set([ characters[i:i+1] for i in range(len(characters)) ]))
            characters.sort()
1146 1147 1148 1149
            return [ ExprNodes.CharNode(string_literal.pos, value=charval,
                                        constant_result=charval)
                     for charval in characters ]

1150 1151
    def extract_common_conditions(self, common_var, condition, allow_not_in):
        not_in, var, conditions = self.extract_conditions(condition, allow_not_in)
1152
        if var is None:
1153
            return self.NO_MATCH
1154
        elif common_var is not None and not is_common_value(var, common_var):
1155
            return self.NO_MATCH
1156
        elif not (var.type.is_int or var.type.is_enum) or sum([not (cond.type.is_int or cond.type.is_enum) for cond in conditions]):
1157 1158 1159 1160 1161 1162 1163
            return self.NO_MATCH
        return not_in, var, conditions

    def has_duplicate_values(self, condition_values):
        # duplicated values don't work in a switch statement
        seen = set()
        for value in condition_values:
1164
            if value.has_constant_result():
1165 1166 1167 1168 1169 1170
                if value.constant_result in seen:
                    return True
                seen.add(value.constant_result)
            else:
                # this isn't completely safe as we don't know the
                # final C value, but this is about the best we can do
1171 1172 1173 1174 1175 1176
                try:
                    if value.entry.cname in seen:
                        return True
                except AttributeError:
                    return True  # play safe
                seen.add(value.entry.cname)
1177
        return False
1178

1179
    def visit_IfStatNode(self, node):
1180
        if not self.current_directives.get('optimize.use_switch'):
1181 1182 1183
            self.visitchildren(node)
            return node

1184 1185 1186
        common_var = None
        cases = []
        for if_clause in node.if_clauses:
1187 1188
            _, common_var, conditions = self.extract_common_conditions(
                common_var, if_clause.condition, False)
1189
            if common_var is None:
1190
                self.visitchildren(node)
1191
                return node
1192 1193 1194 1195
            cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
                                              conditions = conditions,
                                              body = if_clause.body))

1196 1197 1198
        condition_values = [
            cond for case in cases for cond in case.conditions]
        if len(condition_values) < 2:
1199 1200
            self.visitchildren(node)
            return node
1201
        if self.has_duplicate_values(condition_values):
1202
            self.visitchildren(node)
1203
            return node
1204

Robert Bradshaw's avatar
Robert Bradshaw committed
1205
        common_var = unwrap_node(common_var)
1206 1207 1208 1209 1210 1211 1212
        switch_node = Nodes.SwitchStatNode(pos = node.pos,
                                           test = common_var,
                                           cases = cases,
                                           else_clause = node.else_clause)
        return switch_node

    def visit_CondExprNode(self, node):
1213
        if not self.current_directives.get('optimize.use_switch'):
1214 1215 1216
            self.visitchildren(node)
            return node

1217 1218 1219 1220 1221 1222
        not_in, common_var, conditions = self.extract_common_conditions(
            None, node.test, True)
        if common_var is None \
               or len(conditions) < 2 \
               or self.has_duplicate_values(conditions):
            self.visitchildren(node)
1223
            return node
1224 1225 1226
        return self.build_simple_switch_statement(
            node, common_var, conditions, not_in,
            node.true_val, node.false_val)
1227 1228

    def visit_BoolBinopNode(self, node):
1229
        if not self.current_directives.get('optimize.use_switch'):
1230 1231 1232
            self.visitchildren(node)
            return node

1233 1234 1235 1236 1237 1238
        not_in, common_var, conditions = self.extract_common_conditions(
            None, node, True)
        if common_var is None \
               or len(conditions) < 2 \
               or self.has_duplicate_values(conditions):
            self.visitchildren(node)
1239
            node.wrap_operands(self.current_env())  # in case we changed the operands
1240 1241
            return node

1242 1243
        return self.build_simple_switch_statement(
            node, common_var, conditions, not_in,
1244 1245 1246 1247
            ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
            ExprNodes.BoolNode(node.pos, value=False, constant_result=False))

    def visit_PrimaryCmpNode(self, node):
1248
        if not self.current_directives.get('optimize.use_switch'):
1249 1250 1251
            self.visitchildren(node)
            return node

1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263
        not_in, common_var, conditions = self.extract_common_conditions(
            None, node, True)
        if common_var is None \
               or len(conditions) < 2 \
               or self.has_duplicate_values(conditions):
            self.visitchildren(node)
            return node

        return self.build_simple_switch_statement(
            node, common_var, conditions, not_in,
            ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
            ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
1264 1265 1266

    def build_simple_switch_statement(self, node, common_var, conditions,
                                      not_in, true_val, false_val):
1267 1268 1269
        result_ref = UtilNodes.ResultRefNode(node)
        true_body = Nodes.SingleAssignmentNode(
            node.pos,
1270 1271 1272
            lhs=result_ref,
            rhs=true_val.coerce_to(node.type, self.current_env()),
            first=True)
1273 1274
        false_body = Nodes.SingleAssignmentNode(
            node.pos,
1275 1276 1277
            lhs=result_ref,
            rhs=false_val.coerce_to(node.type, self.current_env()),
            first=True)
1278

1279 1280 1281
        if not_in:
            true_body, false_body = false_body, true_body

1282 1283 1284 1285 1286 1287 1288 1289 1290
        cases = [Nodes.SwitchCaseNode(pos = node.pos,
                                      conditions = conditions,
                                      body = true_body)]

        common_var = unwrap_node(common_var)
        switch_node = Nodes.SwitchStatNode(pos = node.pos,
                                           test = common_var,
                                           cases = cases,
                                           else_clause = false_body)
1291 1292 1293 1294
        replacement = UtilNodes.TempResultFromStatNode(result_ref, switch_node)
        return replacement

    def visit_EvalWithTempExprNode(self, node):
1295
        if not self.current_directives.get('optimize.use_switch'):
1296 1297 1298
            self.visitchildren(node)
            return node

1299 1300 1301 1302 1303 1304 1305 1306 1307
        # drop unused expression temp from FlattenInListTransform
        orig_expr = node.subexpression
        temp_ref = node.lazy_temp
        self.visitchildren(node)
        if node.subexpression is not orig_expr:
            # node was restructured => check if temp is still used
            if not Visitor.tree_contains(node.subexpression, temp_ref):
                return node.subexpression
        return node
1308

1309
    visit_Node = Visitor.VisitorTransform.recurse_to_children
1310

1311

1312
class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
1313 1314
    """
    This transformation flattens "x in [val1, ..., valn]" into a sequential list
1315
    of comparisons.
1316
    """
1317

1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329
    def visit_PrimaryCmpNode(self, node):
        self.visitchildren(node)
        if node.cascade is not None:
            return node
        elif node.operator == 'in':
            conjunction = 'or'
            eq_or_neq = '=='
        elif node.operator == 'not_in':
            conjunction = 'and'
            eq_or_neq = '!='
        else:
            return node
1330

1331 1332 1333
        if not isinstance(node.operand2, (ExprNodes.TupleNode,
                                          ExprNodes.ListNode,
                                          ExprNodes.SetNode)):
Stefan Behnel's avatar
Stefan Behnel committed
1334
            return node
1335

Stefan Behnel's avatar
Stefan Behnel committed
1336 1337
        args = node.operand2.args
        if len(args) == 0:
1338 1339
            # note: lhs may have side effects
            return node
1340

1341
        lhs = UtilNodes.ResultRefNode(node.operand1)
Stefan Behnel's avatar
Stefan Behnel committed
1342 1343

        conds = []
1344
        temps = []
Stefan Behnel's avatar
Stefan Behnel committed
1345
        for arg in args:
1346 1347 1348 1349 1350 1351 1352 1353 1354
            try:
                # Trial optimisation to avoid redundant temp
                # assignments.  However, since is_simple() is meant to
                # be called after type analysis, we ignore any errors
                # and just play safe in that case.
                is_simple_arg = arg.is_simple()
            except Exception:
                is_simple_arg = False
            if not is_simple_arg:
1355 1356 1357
                # must evaluate all non-simple RHS before doing the comparisons
                arg = UtilNodes.LetRefNode(arg)
                temps.append(arg)
Stefan Behnel's avatar
Stefan Behnel committed
1358 1359 1360 1361 1362 1363 1364
            cond = ExprNodes.PrimaryCmpNode(
                                pos = node.pos,
                                operand1 = lhs,
                                operator = eq_or_neq,
                                operand2 = arg,
                                cascade = None)
            conds.append(ExprNodes.TypecastNode(
1365
                                pos = node.pos,
Stefan Behnel's avatar
Stefan Behnel committed
1366 1367 1368 1369
                                operand = cond,
                                type = PyrexTypes.c_bint_type))
        def concat(left, right):
            return ExprNodes.BoolBinopNode(
1370
                                pos = node.pos,
Stefan Behnel's avatar
Stefan Behnel committed
1371 1372 1373 1374
                                operator = conjunction,
                                operand1 = left,
                                operand2 = right)

1375
        condition = reduce(concat, conds)
1376 1377 1378 1379
        new_node = UtilNodes.EvalWithTempExprNode(lhs, condition)
        for temp in temps[::-1]:
            new_node = UtilNodes.EvalWithTempExprNode(temp, new_node)
        return new_node
1380

1381
    visit_Node = Visitor.VisitorTransform.recurse_to_children
1382 1383


1384 1385 1386 1387 1388 1389
class DropRefcountingTransform(Visitor.VisitorTransform):
    """Drop ref-counting in safe places.
    """
    visit_Node = Visitor.VisitorTransform.recurse_to_children

    def visit_ParallelAssignmentNode(self, node):
Stefan Behnel's avatar
Stefan Behnel committed
1390 1391 1392
        """
        Parallel swap assignments like 'a,b = b,a' are safe.
        """
1393 1394 1395 1396
        left_names, right_names = [], []
        left_indices, right_indices = [], []
        temps = []

1397 1398
        for stat in node.stats:
            if isinstance(stat, Nodes.SingleAssignmentNode):
1399 1400
                if not self._extract_operand(stat.lhs, left_names,
                                             left_indices, temps):
1401
                    return node
1402 1403
                if not self._extract_operand(stat.rhs, right_names,
                                             right_indices, temps):
1404
                    return node
1405 1406 1407
            elif isinstance(stat, Nodes.CascadedAssignmentNode):
                # FIXME
                return node
1408 1409 1410
            else:
                return node

1411 1412
        if left_names or right_names:
            # lhs/rhs names must be a non-redundant permutation
1413 1414
            lnames = [ path for path, n in left_names ]
            rnames = [ path for path, n in right_names ]
1415 1416 1417
            if set(lnames) != set(rnames):
                return node
            if len(set(lnames)) != len(right_names):
1418 1419
                return node

1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434
        if left_indices or right_indices:
            # base name and index of index nodes must be a
            # non-redundant permutation
            lindices = []
            for lhs_node in left_indices:
                index_id = self._extract_index_id(lhs_node)
                if not index_id:
                    return node
                lindices.append(index_id)
            rindices = []
            for rhs_node in right_indices:
                index_id = self._extract_index_id(rhs_node)
                if not index_id:
                    return node
                rindices.append(index_id)
1435

1436 1437 1438 1439 1440 1441 1442
            if set(lindices) != set(rindices):
                return node
            if len(set(lindices)) != len(right_indices):
                return node

            # really supporting IndexNode requires support in
            # __Pyx_GetItemInt(), so let's stop short for now
1443 1444
            return node

1445 1446 1447 1448
        temp_args = [t.arg for t in temps]
        for temp in temps:
            temp.use_managed_ref = False

1449
        for _, name_node in left_names + right_names:
1450 1451 1452 1453 1454
            if name_node not in temp_args:
                name_node.use_managed_ref = False

        for index_node in left_indices + right_indices:
            index_node.use_managed_ref = False
1455 1456 1457

        return node

1458 1459 1460 1461 1462 1463 1464
    def _extract_operand(self, node, names, indices, temps):
        node = unwrap_node(node)
        if not node.type.is_pyobject:
            return False
        if isinstance(node, ExprNodes.CoerceToTempNode):
            temps.append(node)
            node = node.arg
1465 1466
        name_path = []
        obj_node = node
1467
        while obj_node.is_attribute:
1468
            if obj_node.is_py_attr:
1469
                return False
1470 1471
            name_path.append(obj_node.member)
            obj_node = obj_node.obj
1472
        if obj_node.is_name:
1473 1474
            name_path.append(obj_node.name)
            names.append( ('.'.join(name_path[::-1]), node) )
1475
        elif node.is_subscript:
1476 1477 1478 1479
            if node.base.type != Builtin.list_type:
                return False
            if not node.index.type.is_int:
                return False
1480
            if not node.base.is_name:
1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498
                return False
            indices.append(node)
        else:
            return False
        return True

    def _extract_index_id(self, index_node):
        base = index_node.base
        index = index_node.index
        if isinstance(index, ExprNodes.NameNode):
            index_val = index.name
        elif isinstance(index, ExprNodes.ConstNode):
            # FIXME:
            return None
        else:
            return None
        return (base.name, index_val)

1499

1500 1501 1502 1503 1504 1505 1506
class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
    """Optimize some common calls to builtin types *before* the type
    analysis phase and *after* the declarations analysis phase.

    This transform cannot make use of any argument types, but it can
    restructure the tree in a way that the type analysis phase can
    respond to.
Stefan Behnel's avatar
Stefan Behnel committed
1507 1508 1509

    Introducing C function calls here may not be a good idea.  Move
    them to the OptimizeBuiltinCalls transform instead, which runs
Stefan Behnel's avatar
Stefan Behnel committed
1510
    after type analysis.
1511
    """
1512 1513
    # only intercept on call nodes
    visit_Node = Visitor.VisitorTransform.recurse_to_children
1514

1515 1516 1517 1518 1519 1520 1521
    def visit_SimpleCallNode(self, node):
        self.visitchildren(node)
        function = node.function
        if not self._function_is_builtin_name(function):
            return node
        return self._dispatch_to_handler(node, function, node.args)

1522
    def visit_GeneralCallNode(self, node):
1523
        self.visitchildren(node)
1524
        function = node.function
1525
        if not self._function_is_builtin_name(function):
1526 1527 1528 1529
            return node
        arg_tuple = node.positional_args
        if not isinstance(arg_tuple, ExprNodes.TupleNode):
            return node
1530
        args = arg_tuple.args
1531
        return self._dispatch_to_handler(
1532
            node, function, args, node.keyword_args)
1533

1534 1535 1536
    def _function_is_builtin_name(self, function):
        if not function.is_name:
            return False
1537
        env = self.current_env()
1538
        entry = env.lookup(function.name)
1539
        if entry is not env.builtin_scope().lookup_here(function.name):
1540
            return False
1541
        # if entry is None, it's at least an undeclared name, so likely builtin
1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579
        return True

    def _dispatch_to_handler(self, node, function, args, kwargs=None):
        if kwargs is None:
            handler_name = '_handle_simple_function_%s' % function.name
        else:
            handler_name = '_handle_general_function_%s' % function.name
        handle_call = getattr(self, handler_name, None)
        if handle_call is not None:
            if kwargs is None:
                return handle_call(node, args)
            else:
                return handle_call(node, args, kwargs)
        return node

    def _inject_capi_function(self, node, cname, func_type, utility_code=None):
        node.function = ExprNodes.PythonCapiFunctionNode(
            node.function.pos, node.function.name, cname, func_type,
            utility_code = utility_code)

    def _error_wrong_arg_count(self, function_name, node, args, expected=None):
        if not expected: # None or 0
            arg_str = ''
        elif isinstance(expected, basestring) or expected > 1:
            arg_str = '...'
        elif expected == 1:
            arg_str = 'x'
        else:
            arg_str = ''
        if expected is not None:
            expected_str = 'expected %s, ' % expected
        else:
            expected_str = ''
        error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
            function_name, arg_str, expected_str, len(args)))

    # specific handlers for simple call nodes

1580
    def _handle_simple_function_float(self, node, pos_args):
Stefan Behnel's avatar
Stefan Behnel committed
1581
        if not pos_args:
1582 1583 1584
            return ExprNodes.FloatNode(node.pos, value='0.0')
        if len(pos_args) > 1:
            self._error_wrong_arg_count('float', node, pos_args, 1)
1585 1586 1587
        arg_type = getattr(pos_args[0], 'type', None)
        if arg_type in (PyrexTypes.c_double_type, Builtin.float_type):
            return pos_args[0]
1588 1589
        return node

Stefan Behnel's avatar
Stefan Behnel committed
1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607
    def _handle_simple_function_slice(self, node, pos_args):
        arg_count = len(pos_args)
        start = step = None
        if arg_count == 1:
            stop, = pos_args
        elif arg_count == 2:
            start, stop = pos_args
        elif arg_count == 3:
            start, stop, step = pos_args
        else:
            self._error_wrong_arg_count('slice', node, pos_args)
            return node
        return ExprNodes.SliceNode(
            node.pos,
            start=start or ExprNodes.NoneNode(node.pos),
            stop=stop,
            step=step or ExprNodes.NoneNode(node.pos))

1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632
    def _handle_simple_function_ord(self, node, pos_args):
        """Unpack ord('X').
        """
        if len(pos_args) != 1:
            return node
        arg = pos_args[0]
        if isinstance(arg, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)):
            if len(arg.value) == 1:
                return ExprNodes.IntNode(
                    arg.pos, type=PyrexTypes.c_long_type,
                    value=str(ord(arg.value)),
                    constant_result=ord(arg.value)
                )
        elif isinstance(arg, ExprNodes.StringNode):
            if arg.unicode_value and len(arg.unicode_value) == 1 \
                    and ord(arg.unicode_value) <= 255:  # Py2/3 portability
                return ExprNodes.IntNode(
                    arg.pos, type=PyrexTypes.c_int_type,
                    value=str(ord(arg.unicode_value)),
                    constant_result=ord(arg.unicode_value)
                )
        return node

    # sequence processing

1633 1634 1635
    def _handle_simple_function_all(self, node, pos_args):
        """Transform

Stefan Behnel's avatar
Stefan Behnel committed
1636
        _result = all(p(x) for L in LL for x in L)
1637 1638 1639 1640 1641

        into

        for L in LL:
            for x in L:
Stefan Behnel's avatar
Stefan Behnel committed
1642
                if not p(x):
1643
                    return False
1644
        else:
1645
            return True
1646 1647 1648 1649 1650 1651
        """
        return self._transform_any_all(node, pos_args, False)

    def _handle_simple_function_any(self, node, pos_args):
        """Transform

Stefan Behnel's avatar
Stefan Behnel committed
1652
        _result = any(p(x) for L in LL for x in L)
1653 1654 1655 1656 1657

        into

        for L in LL:
            for x in L:
Stefan Behnel's avatar
Stefan Behnel committed
1658
                if p(x):
1659
                    return True
1660
        else:
1661
            return False
1662 1663 1664 1665 1666 1667 1668 1669
        """
        return self._transform_any_all(node, pos_args, True)

    def _transform_any_all(self, node, pos_args, is_any):
        if len(pos_args) != 1:
            return node
        if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
            return node
1670
        gen_expr_node = pos_args[0]
1671 1672
        generator_body = gen_expr_node.def_node.gbody
        loop_node = generator_body.body
1673
        yield_expression, yield_stat_node = _find_single_yield_expression(loop_node)
1674
        if yield_expression is None:
1675 1676 1677 1678 1679
            return node

        if is_any:
            condition = yield_expression
        else:
1680
            condition = ExprNodes.NotNode(yield_expression.pos, operand=yield_expression)
1681 1682

        test_node = Nodes.IfStatNode(
1683 1684 1685 1686 1687 1688
            yield_expression.pos, else_clause=None, if_clauses=[
                Nodes.IfClauseNode(
                    yield_expression.pos,
                    condition=condition,
                    body=Nodes.ReturnStatNode(
                        node.pos,
1689
                        value=ExprNodes.BoolNode(yield_expression.pos, value=is_any, constant_result=is_any))
1690 1691 1692
                )]
        )
        loop_node.else_clause = Nodes.ReturnStatNode(
1693
            node.pos,
1694
            value=ExprNodes.BoolNode(yield_expression.pos, value=not is_any, constant_result=not is_any))
1695

1696
        Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, test_node)
1697

1698
        return ExprNodes.InlinedGeneratorExpressionNode(
1699
            gen_expr_node.pos, gen=gen_expr_node, orig_func='any' if is_any else 'all')
1700

1701 1702 1703 1704
    PySequence_List_func_type = PyrexTypes.CFuncType(
        Builtin.list_type,
        [PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None)])

1705
    def _handle_simple_function_sorted(self, node, pos_args):
1706 1707 1708 1709 1710
        """Transform sorted(genexpr) and sorted([listcomp]) into
        [listcomp].sort().  CPython just reads the iterable into a
        list and calls .sort() on it.  Expanding the iterable in a
        listcomp is still faster and the result can be sorted in
        place.
1711 1712 1713
        """
        if len(pos_args) != 1:
            return node
1714 1715 1716 1717 1718 1719 1720 1721

        arg = pos_args[0]
        if isinstance(arg, ExprNodes.ComprehensionNode) and arg.type is Builtin.list_type:
            list_node = pos_args[0]
            loop_node = list_node.loop

        elif isinstance(arg, ExprNodes.GeneratorExpressionNode):
            gen_expr_node = arg
1722
            loop_node = gen_expr_node.loop
1723 1724
            yield_statements = _find_yield_statements(loop_node)
            if not yield_statements:
1725
                return node
1726

1727 1728 1729 1730
            list_node = ExprNodes.InlinedGeneratorExpressionNode(
                node.pos, gen_expr_node, orig_func='sorted',
                comprehension_type=Builtin.list_type)

1731
            for yield_expression, yield_stat_node in yield_statements:
1732 1733 1734 1735 1736
                append_node = ExprNodes.ComprehensionAppendNode(
                    yield_expression.pos,
                    expr=yield_expression,
                    target=list_node.target)
                Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node)
1737

1738 1739 1740 1741 1742
        elif arg.is_sequence_constructor:
            # sorted([a, b, c]) or sorted((a, b, c)).  The result is always a list,
            # so starting off with a fresh one is more efficient.
            list_node = loop_node = arg.as_list()

1743
        else:
1744 1745
            # Interestingly, PySequence_List works on a lot of non-sequence
            # things as well.
1746
            list_node = loop_node = ExprNodes.PythonCapiCallNode(
1747 1748
                node.pos, "PySequence_List", self.PySequence_List_func_type,
                args=pos_args, is_temp=True)
1749

1750
        result_node = UtilNodes.ResultRefNode(
1751 1752 1753
            pos=loop_node.pos, type=Builtin.list_type, may_hold_none=False)
        list_assign_node = Nodes.SingleAssignmentNode(
            node.pos, lhs=result_node, rhs=list_node, first=True)
1754 1755

        sort_method = ExprNodes.AttributeNode(
1756
            node.pos, obj=result_node, attribute=EncodedString('sort'),
1757
            # entry ? type ?
1758
            needs_none_check=False)
1759
        sort_node = Nodes.ExprStatNode(
1760 1761
            node.pos, expr=ExprNodes.SimpleCallNode(
                node.pos, function=sort_method, args=[]))
1762 1763 1764 1765 1766

        sort_node.analyse_declarations(self.current_env())

        return UtilNodes.TempResultFromStatNode(
            result_node,
1767
            Nodes.StatListNode(node.pos, stats=[list_assign_node, sort_node]))
1768

1769
    def __handle_simple_function_sum(self, node, pos_args):
Stefan Behnel's avatar
Stefan Behnel committed
1770 1771
        """Transform sum(genexpr) into an equivalent inlined aggregation loop.
        """
1772 1773
        if len(pos_args) not in (1,2):
            return node
1774 1775
        if not isinstance(pos_args[0], (ExprNodes.GeneratorExpressionNode,
                                        ExprNodes.ComprehensionNode)):
1776 1777 1778 1779
            return node
        gen_expr_node = pos_args[0]
        loop_node = gen_expr_node.loop

1780
        if isinstance(gen_expr_node, ExprNodes.GeneratorExpressionNode):
1781
            yield_expression, yield_stat_node = _find_single_yield_expression(loop_node)
1782 1783
            # FIXME: currently nonfunctional
            yield_expression = None
1784 1785
            if yield_expression is None:
                return node
1786
        else:  # ComprehensionNode
1787 1788 1789 1790 1791 1792 1793 1794 1795
            yield_stat_node = gen_expr_node.append
            yield_expression = yield_stat_node.expr
            try:
                if not yield_expression.is_literal or not yield_expression.type.is_int:
                    return node
            except AttributeError:
                return node # in case we don't have a type yet
            # special case: old Py2 backwards compatible "sum([int_const for ...])"
            # can safely be unpacked into a genexpr
1796 1797 1798 1799 1800 1801 1802 1803

        if len(pos_args) == 1:
            start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
        else:
            start = pos_args[1]

        result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.py_object_type)
        add_node = Nodes.SingleAssignmentNode(
1804
            yield_expression.pos,
1805 1806 1807 1808
            lhs = result_ref,
            rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression)
            )

1809
        Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, add_node)
1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823

        exec_code = Nodes.StatListNode(
            node.pos,
            stats = [
                Nodes.SingleAssignmentNode(
                    start.pos,
                    lhs = UtilNodes.ResultRefNode(pos=node.pos, expression=result_ref),
                    rhs = start,
                    first = True),
                loop_node
                ])

        return ExprNodes.InlinedGeneratorExpressionNode(
            gen_expr_node.pos, loop = exec_code, result_node = result_ref,
1824 1825
            expr_scope = gen_expr_node.expr_scope, orig_func = 'sum',
            has_local_scope = gen_expr_node.has_local_scope)
1826

1827 1828 1829 1830 1831 1832 1833 1834 1835 1836
    def _handle_simple_function_min(self, node, pos_args):
        return self._optimise_min_max(node, pos_args, '<')

    def _handle_simple_function_max(self, node, pos_args):
        return self._optimise_min_max(node, pos_args, '>')

    def _optimise_min_max(self, node, args, operator):
        """Replace min(a,b,...) and max(a,b,...) by explicit comparison code.
        """
        if len(args) <= 1:
1837 1838
            if len(args) == 1 and args[0].is_sequence_constructor:
                args = args[0].args
Stefan Behnel's avatar
Stefan Behnel committed
1839
            if len(args) <= 1:
1840 1841
                # leave this to Python
                return node
1842

1843
        cascaded_nodes = list(map(UtilNodes.ResultRefNode, args[1:]))
1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865

        last_result = args[0]
        for arg_node in cascaded_nodes:
            result_ref = UtilNodes.ResultRefNode(last_result)
            last_result = ExprNodes.CondExprNode(
                arg_node.pos,
                true_val = arg_node,
                false_val = result_ref,
                test = ExprNodes.PrimaryCmpNode(
                    arg_node.pos,
                    operand1 = arg_node,
                    operator = operator,
                    operand2 = result_ref,
                    )
                )
            last_result = UtilNodes.EvalWithTempExprNode(result_ref, last_result)

        for ref_node in cascaded_nodes[::-1]:
            last_result = UtilNodes.EvalWithTempExprNode(ref_node, last_result)

        return last_result

1866 1867
    # builtin type creation

1868
    def _DISABLED_handle_simple_function_tuple(self, node, pos_args):
Stefan Behnel's avatar
Stefan Behnel committed
1869
        if not pos_args:
1870 1871 1872 1873 1874 1875 1876 1877
            return ExprNodes.TupleNode(node.pos, args=[], constant_result=())
        # This is a bit special - for iterables (including genexps),
        # Python actually overallocates and resizes a newly created
        # tuple incrementally while reading items, which we can't
        # easily do without explicit node support. Instead, we read
        # the items into a list and then copy them into a tuple of the
        # final size.  This takes up to twice as much memory, but will
        # have to do until we have real support for genexps.
1878
        result = self._transform_list_set_genexpr(node, pos_args, Builtin.list_type)
1879 1880 1881 1882
        if result is not node:
            return ExprNodes.AsTupleNode(node.pos, arg=result)
        return node

1883 1884 1885 1886 1887
    def _handle_simple_function_frozenset(self, node, pos_args):
        """Replace frozenset([...]) by frozenset((...)) as tuples are more efficient.
        """
        if len(pos_args) != 1:
            return node
1888 1889 1890
        if pos_args[0].is_sequence_constructor and not pos_args[0].args:
            del pos_args[0]
        elif isinstance(pos_args[0], ExprNodes.ListNode):
1891 1892 1893
            pos_args[0] = pos_args[0].as_tuple()
        return node

1894
    def _handle_simple_function_list(self, node, pos_args):
Stefan Behnel's avatar
Stefan Behnel committed
1895
        if not pos_args:
1896
            return ExprNodes.ListNode(node.pos, args=[], constant_result=[])
1897
        return self._transform_list_set_genexpr(node, pos_args, Builtin.list_type)
1898 1899

    def _handle_simple_function_set(self, node, pos_args):
Stefan Behnel's avatar
Stefan Behnel committed
1900
        if not pos_args:
1901
            return ExprNodes.SetNode(node.pos, args=[], constant_result=set())
1902
        return self._transform_list_set_genexpr(node, pos_args, Builtin.set_type)
1903

1904
    def _transform_list_set_genexpr(self, node, pos_args, target_type):
1905
        """Replace set(genexpr) and list(genexpr) by an inlined comprehension.
1906 1907 1908 1909 1910 1911 1912 1913
        """
        if len(pos_args) > 1:
            return node
        if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
            return node
        gen_expr_node = pos_args[0]
        loop_node = gen_expr_node.loop

1914 1915
        yield_statements = _find_yield_statements(loop_node)
        if not yield_statements:
1916 1917
            return node

1918 1919 1920 1921 1922
        result_node = ExprNodes.InlinedGeneratorExpressionNode(
            node.pos, gen_expr_node,
            orig_func='set' if target_type is Builtin.set_type else 'list',
            comprehension_type=target_type)

1923
        for yield_expression, yield_stat_node in yield_statements:
1924 1925 1926 1927 1928
            append_node = ExprNodes.ComprehensionAppendNode(
                yield_expression.pos,
                expr=yield_expression,
                target=result_node.target)
            Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node)
1929

1930
        return result_node
1931 1932

    def _handle_simple_function_dict(self, node, pos_args):
1933
        """Replace dict( (a,b) for ... ) by an inlined { a:b for ... }
1934 1935 1936 1937 1938 1939 1940 1941 1942 1943
        """
        if len(pos_args) == 0:
            return ExprNodes.DictNode(node.pos, key_value_pairs=[], constant_result={})
        if len(pos_args) > 1:
            return node
        if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
            return node
        gen_expr_node = pos_args[0]
        loop_node = gen_expr_node.loop

1944 1945
        yield_statements = _find_yield_statements(loop_node)
        if not yield_statements:
1946 1947
            return node

1948
        for yield_expression, _ in yield_statements:
1949 1950 1951 1952
            if not isinstance(yield_expression, ExprNodes.TupleNode):
                return node
            if len(yield_expression.args) != 2:
                return node
1953

1954 1955 1956 1957
        result_node = ExprNodes.InlinedGeneratorExpressionNode(
            node.pos, gen_expr_node, orig_func='dict',
            comprehension_type=Builtin.dict_type)

1958
        for yield_expression, yield_stat_node in yield_statements:
1959 1960 1961 1962 1963 1964
            append_node = ExprNodes.DictComprehensionAppendNode(
                yield_expression.pos,
                key_expr=yield_expression.args[0],
                value_expr=yield_expression.args[1],
                target=result_node.target)
            Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node)
1965

1966
        return result_node
1967

1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979
    # specific handlers for general call nodes

    def _handle_general_function_dict(self, node, pos_args, kwargs):
        """Replace dict(a=b,c=d,...) by the underlying keyword dict
        construction which is done anyway.
        """
        if len(pos_args) > 0:
            return node
        if not isinstance(kwargs, ExprNodes.DictNode):
            return node
        return kwargs

1980 1981

class InlineDefNodeCalls(Visitor.NodeRefCleanupMixin, Visitor.EnvTransform):
1982 1983
    visit_Node = Visitor.VisitorTransform.recurse_to_children

1984
    def get_constant_value_node(self, name_node):
1985 1986 1987
        if name_node.cf_state is None:
            return None
        if name_node.cf_state.cf_is_null:
1988 1989 1990 1991 1992 1993
            return None
        entry = self.current_env().lookup(name_node.name)
        if not entry or (not entry.cf_assignments
                         or len(entry.cf_assignments) != 1):
            # not just a single assignment in all closures
            return None
1994
        return entry.cf_assignments[0].rhs
1995

1996 1997 1998 1999 2000 2001 2002
    def visit_SimpleCallNode(self, node):
        self.visitchildren(node)
        if not self.current_directives.get('optimize.inline_defnode_calls'):
            return node
        function_name = node.function
        if not function_name.is_name:
            return node
2003
        function = self.get_constant_value_node(function_name)
2004 2005 2006 2007 2008 2009
        if not isinstance(function, ExprNodes.PyCFunctionNode):
            return node
        inlined = ExprNodes.InlinedDefNodeCallNode(
            node.pos, function_name=function_name,
            function=function, args=node.args)
        if inlined.can_be_inlined():
2010
            return self.replace(node, inlined)
2011 2012
        return node

2013

2014 2015
class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
                           Visitor.MethodDispatcherTransform):
Stefan Behnel's avatar
Stefan Behnel committed
2016
    """Optimize some common methods calls and instantiation patterns
2017 2018 2019 2020 2021
    for builtin types *after* the type analysis phase.

    Running after type analysis, this transform can only perform
    function replacements that do not alter the function return type
    in a way that was not anticipated by the type analysis.
2022
    """
2023 2024
    ### cleanup to avoid redundant coercions to/from Python types

2025 2026 2027
    def _visit_PyTypeTestNode(self, node):
        # disabled - appears to break assignments in some cases, and
        # also drops a None check, which might still be required
2028 2029 2030 2031 2032 2033 2034 2035
        """Flatten redundant type checks after tree changes.
        """
        old_arg = node.arg
        self.visitchildren(node)
        if old_arg is node.arg or node.arg.type != node.type:
            return node
        return node.arg

2036 2037 2038
    def _visit_TypecastNode(self, node):
        # disabled - the user may have had a reason to put a type
        # cast, even if it looks redundant to Cython
2039 2040 2041 2042 2043 2044 2045 2046
        """
        Drop redundant type casts.
        """
        self.visitchildren(node)
        if node.type == node.operand.type:
            return node.operand
        return node

2047 2048
    def visit_ExprStatNode(self, node):
        """
2049
        Drop dead code and useless coercions.
2050 2051 2052 2053
        """
        self.visitchildren(node)
        if isinstance(node.expr, ExprNodes.CoerceToPyTypeNode):
            node.expr = node.expr.arg
2054 2055 2056 2057 2058 2059 2060
        expr = node.expr
        if expr is None or expr.is_none or expr.is_literal:
            # Expression was removed or is dead code => remove ExprStatNode as well.
            return None
        if expr.is_name and expr.entry and (expr.entry.is_local or expr.entry.is_arg):
            # Ignore dead references to local variables etc.
            return None
2061 2062
        return node

2063 2064 2065 2066 2067
    def visit_CoerceToBooleanNode(self, node):
        """Drop redundant conversion nodes after tree changes.
        """
        self.visitchildren(node)
        arg = node.arg
Stefan Behnel's avatar
Stefan Behnel committed
2068 2069
        if isinstance(arg, ExprNodes.PyTypeTestNode):
            arg = arg.arg
2070 2071
        if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
            if arg.type in (PyrexTypes.py_object_type, Builtin.bool_type):
2072
                return arg.arg.coerce_to_boolean(self.current_env())
2073 2074
        return node

2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100 2101
    PyNumber_Float_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("o", PyrexTypes.py_object_type, None)
            ])

    def visit_CoerceToPyTypeNode(self, node):
        """Drop redundant conversion nodes after tree changes."""
        self.visitchildren(node)
        arg = node.arg
        if isinstance(arg, ExprNodes.CoerceFromPyTypeNode):
            arg = arg.arg
        if isinstance(arg, ExprNodes.PythonCapiCallNode):
            if arg.function.name == 'float' and len(arg.args) == 1:
                # undo redundant Py->C->Py coercion
                func_arg = arg.args[0]
                if func_arg.type is Builtin.float_type:
                    return func_arg.as_none_safe_node("float() argument must be a string or a number, not 'NoneType'")
                elif func_arg.type.is_pyobject:
                    return ExprNodes.PythonCapiCallNode(
                        node.pos, '__Pyx_PyNumber_Float', self.PyNumber_Float_func_type,
                        args=[func_arg],
                        py_name='float',
                        is_temp=node.is_temp,
                        result_is_used=node.result_is_used,
                    ).coerce_to(node.type, self.current_env())
        return node

2102 2103 2104 2105 2106 2107 2108 2109 2110 2111 2112
    def visit_CoerceFromPyTypeNode(self, node):
        """Drop redundant conversion nodes after tree changes.

        Also, optimise away calls to Python's builtin int() and
        float() if the result is going to be coerced back into a C
        type anyway.
        """
        self.visitchildren(node)
        arg = node.arg
        if not arg.type.is_pyobject:
            # no Python conversion left at all, just do a C coercion instead
Stefan Behnel's avatar
Stefan Behnel committed
2113 2114 2115
            if node.type != arg.type:
                arg = arg.coerce_to(node.type, self.current_env())
            return arg
2116 2117
        if isinstance(arg, ExprNodes.PyTypeTestNode):
            arg = arg.arg
2118 2119
        if arg.is_literal:
            if (node.type.is_int and isinstance(arg, ExprNodes.IntNode) or
2120 2121
                    node.type.is_float and isinstance(arg, ExprNodes.FloatNode) or
                    node.type.is_int and isinstance(arg, ExprNodes.BoolNode)):
2122 2123
                return arg.coerce_to(node.type, self.current_env())
        elif isinstance(arg, ExprNodes.CoerceToPyTypeNode):
2124 2125 2126
            if arg.type is PyrexTypes.py_object_type:
                if node.type.assignable_from(arg.arg.type):
                    # completely redundant C->Py->C coercion
2127
                    return arg.arg.coerce_to(node.type, self.current_env())
2128 2129 2130
            elif arg.type is Builtin.unicode_type:
                if arg.arg.type.is_unicode_char and node.type.is_unicode_char:
                    return arg.arg.coerce_to(node.type, self.current_env())
Stefan Behnel's avatar
Stefan Behnel committed
2131
        elif isinstance(arg, ExprNodes.SimpleCallNode):
Stefan Behnel's avatar
Stefan Behnel committed
2132 2133
            if node.type.is_int or node.type.is_float:
                return self._optimise_numeric_cast_call(node, arg)
2134
        elif arg.is_subscript:
2135 2136 2137 2138 2139
            index_node = arg.index
            if isinstance(index_node, ExprNodes.CoerceToPyTypeNode):
                index_node = index_node.arg
            if index_node.type.is_int:
                return self._optimise_int_indexing(node, arg, index_node)
Stefan Behnel's avatar
Stefan Behnel committed
2140 2141
        return node

2142 2143 2144 2145 2146 2147 2148 2149 2150 2151 2152 2153
    PyBytes_GetItemInt_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_char_type, [
            PyrexTypes.CFuncTypeArg("bytes", Builtin.bytes_type, None),
            PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("check_bounds", PyrexTypes.c_int_type, None),
            ],
        exception_value = "((char)-1)",
        exception_check = True)

    def _optimise_int_indexing(self, coerce_node, arg, index_node):
        env = self.current_env()
        bound_check_bool = env.directives['boundscheck'] and 1 or 0
2154
        if arg.base.type is Builtin.bytes_type:
2155 2156 2157 2158 2159 2160 2161 2162
            if coerce_node.type in (PyrexTypes.c_char_type, PyrexTypes.c_uchar_type):
                # bytes[index] -> char
                bound_check_node = ExprNodes.IntNode(
                    coerce_node.pos, value=str(bound_check_bool),
                    constant_result=bound_check_bool)
                node = ExprNodes.PythonCapiCallNode(
                    coerce_node.pos, "__Pyx_PyBytes_GetItemInt",
                    self.PyBytes_GetItemInt_func_type,
2163
                    args=[
Stefan Behnel's avatar
Stefan Behnel committed
2164
                        arg.base.as_none_safe_node("'NoneType' object is not subscriptable"),
2165 2166 2167
                        index_node.coerce_to(PyrexTypes.c_py_ssize_t_type, env),
                        bound_check_node,
                        ],
2168 2169 2170
                    is_temp=True,
                    utility_code=UtilityCode.load_cached(
                        'bytes_index', 'StringTools.c'))
2171 2172 2173 2174 2175
                if coerce_node.type is not PyrexTypes.c_char_type:
                    node = node.coerce_to(coerce_node.type, env)
                return node
        return coerce_node

2176 2177 2178 2179 2180 2181 2182
    float_float_func_types = dict(
        (float_type, PyrexTypes.CFuncType(
            float_type, [
                PyrexTypes.CFuncTypeArg("arg", float_type, None)
            ]))
        for float_type in (PyrexTypes.c_float_type, PyrexTypes.c_double_type, PyrexTypes.c_longdouble_type))

Stefan Behnel's avatar
Stefan Behnel committed
2183
    def _optimise_numeric_cast_call(self, node, arg):
2184
        function = arg.function
2185 2186 2187 2188 2189 2190 2191 2192
        args = None
        if isinstance(arg, ExprNodes.PythonCapiCallNode):
            args = arg.args
        elif isinstance(function, ExprNodes.NameNode):
            if function.type.is_builtin_type and isinstance(arg.arg_tuple, ExprNodes.TupleNode):
                args = arg.arg_tuple.args

        if args is None or len(args) != 1:
2193 2194 2195 2196 2197
            return node
        func_arg = args[0]
        if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
            func_arg = func_arg.arg
        elif func_arg.type.is_pyobject:
2198
            # play it safe: Python conversion might work on all sorts of things
2199
            return node
2200

2201 2202 2203 2204 2205
        if function.name == 'int':
            if func_arg.type.is_int or node.type.is_int:
                if func_arg.type == node.type:
                    return func_arg
                elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
2206 2207
                    return ExprNodes.TypecastNode(node.pos, operand=func_arg, type=node.type)
            elif func_arg.type.is_float and node.type.is_numeric:
2208 2209 2210 2211 2212
                if func_arg.type.math_h_modifier == 'l':
                    # Work around missing Cygwin definition.
                    truncl = '__Pyx_truncl'
                else:
                    truncl = 'trunc' + func_arg.type.math_h_modifier
2213
                return ExprNodes.PythonCapiCallNode(
2214
                    node.pos, truncl,
2215 2216 2217 2218 2219 2220
                    func_type=self.float_float_func_types[func_arg.type],
                    args=[func_arg],
                    py_name='int',
                    is_temp=node.is_temp,
                    result_is_used=node.result_is_used,
                ).coerce_to(node.type, self.current_env())
2221 2222 2223 2224 2225
        elif function.name == 'float':
            if func_arg.type.is_float or node.type.is_float:
                if func_arg.type == node.type:
                    return func_arg
                elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
2226 2227
                    return ExprNodes.TypecastNode(
                        node.pos, operand=func_arg, type=node.type)
2228 2229
        return node

2230 2231 2232 2233 2234 2235 2236 2237 2238 2239 2240 2241 2242 2243 2244 2245
    def _error_wrong_arg_count(self, function_name, node, args, expected=None):
        if not expected: # None or 0
            arg_str = ''
        elif isinstance(expected, basestring) or expected > 1:
            arg_str = '...'
        elif expected == 1:
            arg_str = 'x'
        else:
            arg_str = ''
        if expected is not None:
            expected_str = 'expected %s, ' % expected
        else:
            expected_str = ''
        error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
            function_name, arg_str, expected_str, len(args)))

2246 2247 2248 2249 2250 2251 2252 2253 2254 2255 2256 2257 2258 2259 2260 2261 2262 2263 2264 2265 2266 2267 2268 2269 2270 2271 2272 2273 2274 2275
    ### generic fallbacks

    def _handle_function(self, node, function_name, function, arg_list, kwargs):
        return node

    def _handle_method(self, node, type_name, attr_name, function,
                       arg_list, is_unbound_method, kwargs):
        """
        Try to inject C-API calls for unbound method calls to builtin types.
        While the method declarations in Builtin.py already handle this, we
        can additionally resolve bound and unbound methods here that were
        assigned to variables ahead of time.
        """
        if kwargs:
            return node
        if not function or not function.is_attribute or not function.obj.is_name:
            # cannot track unbound method calls over more than one indirection as
            # the names might have been reassigned in the meantime
            return node
        type_entry = self.current_env().lookup(type_name)
        if not type_entry:
            return node
        method = ExprNodes.AttributeNode(
            node.function.pos,
            obj=ExprNodes.NameNode(
                function.pos,
                name=type_name,
                entry=type_entry,
                type=type_entry.type),
            attribute=attr_name,
2276
            is_called=True).analyse_as_type_attribute(self.current_env())
2277
        if method is None:
2278 2279
            return self._optimise_generic_builtin_method_call(
                node, attr_name, function, arg_list, is_unbound_method)
2280 2281 2282 2283 2284 2285 2286 2287 2288 2289 2290 2291 2292
        args = node.args
        if args is None and node.arg_tuple:
            args = node.arg_tuple.args
        call_node = ExprNodes.SimpleCallNode(
            node.pos,
            function=method,
            args=args)
        if not is_unbound_method:
            call_node.self = function.obj
        call_node.analyse_c_function_call(self.current_env())
        call_node.analysed = True
        return call_node.coerce_to(node.type, self.current_env())

2293 2294
    ### builtin types

2295 2296 2297 2298 2299 2300
    def _optimise_generic_builtin_method_call(self, node, attr_name, function, arg_list, is_unbound_method):
        """
        Try to inject an unbound method call for a call to a method of a known builtin type.
        This enables caching the underlying C function of the method at runtime.
        """
        arg_count = len(arg_list)
2301
        if is_unbound_method or arg_count >= 3 or not (function.is_attribute and function.is_py_attr):
2302
            return node
2303 2304
        if function.obj.type.name in ('basestring', 'type'):
            # these allow different actual types => unsafe
2305
            return node
2306 2307 2308 2309
        assert function.obj.type.is_builtin_type
        return ExprNodes.CachedBuiltinMethodCallNode(
            node, function.obj, attr_name, arg_list)

2310 2311 2312 2313 2314
    PyDict_Copy_func_type = PyrexTypes.CFuncType(
        Builtin.dict_type, [
            PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None)
            ])

2315
    def _handle_simple_function_dict(self, node, function, pos_args):
2316
        """Replace dict(some_dict) by PyDict_Copy(some_dict).
2317
        """
2318
        if len(pos_args) != 1:
2319
            return node
2320
        arg = pos_args[0]
2321
        if arg.type is Builtin.dict_type:
2322
            arg = arg.as_none_safe_node("'NoneType' is not iterable")
2323 2324
            return ExprNodes.PythonCapiCallNode(
                node.pos, "PyDict_Copy", self.PyDict_Copy_func_type,
2325
                args = [arg],
2326 2327 2328
                is_temp = node.is_temp
                )
        return node
2329

2330 2331 2332 2333 2334 2335 2336 2337 2338 2339 2340 2341 2342 2343
    PySequence_List_func_type = PyrexTypes.CFuncType(
        Builtin.list_type,
        [PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None)])

    def _handle_simple_function_list(self, node, function, pos_args):
        """Turn list(ob) into PySequence_List(ob).
        """
        if len(pos_args) != 1:
            return node
        arg = pos_args[0]
        return ExprNodes.PythonCapiCallNode(
            node.pos, "PySequence_List", self.PySequence_List_func_type,
            args=pos_args, is_temp=node.is_temp)

2344 2345 2346 2347 2348
    PyList_AsTuple_func_type = PyrexTypes.CFuncType(
        Builtin.tuple_type, [
            PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None)
            ])

2349
    def _handle_simple_function_tuple(self, node, function, pos_args):
2350
        """Replace tuple([...]) by PyList_AsTuple or PySequence_Tuple.
2351
        """
2352
        if len(pos_args) != 1 or not node.is_temp:
2353
            return node
2354 2355 2356
        arg = pos_args[0]
        if arg.type is Builtin.tuple_type and not arg.may_be_none():
            return arg
2357 2358 2359
        if arg.type is Builtin.list_type:
            pos_args[0] = arg.as_none_safe_node(
                "'NoneType' object is not iterable")
2360

2361 2362 2363 2364
            return ExprNodes.PythonCapiCallNode(
                node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type,
                args=pos_args, is_temp=node.is_temp)
        else:
2365
            return ExprNodes.AsTupleNode(node.pos, arg=arg, type=Builtin.tuple_type)
2366

2367 2368 2369 2370 2371
    PySet_New_func_type = PyrexTypes.CFuncType(
        Builtin.set_type, [
            PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None)
        ])

2372
    def _handle_simple_function_set(self, node, function, pos_args):
2373 2374 2375 2376 2377 2378 2379 2380 2381 2382 2383 2384 2385 2386 2387 2388
        if len(pos_args) != 1:
            return node
        if pos_args[0].is_sequence_constructor:
            # We can optimise set([x,y,z]) safely into a set literal,
            # but only if we create all items before adding them -
            # adding an item may raise an exception if it is not
            # hashable, but creating the later items may have
            # side-effects.
            args = []
            temps = []
            for arg in pos_args[0].args:
                if not arg.is_simple():
                    arg = UtilNodes.LetRefNode(arg)
                    temps.append(arg)
                args.append(arg)
            result = ExprNodes.SetNode(node.pos, is_temp=1, args=args)
2389
            self.replace(node, result)
2390 2391 2392 2393 2394
            for temp in temps[::-1]:
                result = UtilNodes.EvalWithTempExprNode(temp, result)
            return result
        else:
            # PySet_New(it) is better than a generic Python call to set(it)
2395
            return self.replace(node, ExprNodes.PythonCapiCallNode(
2396 2397 2398 2399
                node.pos, "PySet_New",
                self.PySet_New_func_type,
                args=pos_args,
                is_temp=node.is_temp,
2400
                py_name="set"))
2401 2402 2403 2404 2405 2406 2407

    PyFrozenSet_New_func_type = PyrexTypes.CFuncType(
        Builtin.frozenset_type, [
            PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None)
        ])

    def _handle_simple_function_frozenset(self, node, function, pos_args):
2408 2409 2410
        if not pos_args:
            pos_args = [ExprNodes.NullNode(node.pos)]
        elif len(pos_args) > 1:
2411
            return node
2412 2413
        elif pos_args[0].type is Builtin.frozenset_type and not pos_args[0].may_be_none():
            return pos_args[0]
2414 2415
        # PyFrozenSet_New(it) is better than a generic Python call to frozenset(it)
        return ExprNodes.PythonCapiCallNode(
2416
            node.pos, "__Pyx_PyFrozenSet_New",
2417 2418 2419
            self.PyFrozenSet_New_func_type,
            args=pos_args,
            is_temp=node.is_temp,
2420
            utility_code=UtilityCode.load_cached('pyfrozenset_new', 'Builtins.c'),
2421 2422 2423 2424 2425 2426 2427 2428
            py_name="frozenset")

    PyObject_AsDouble_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_double_type, [
            PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
            ],
        exception_value = "((double)-1)",
        exception_check = True)
2429

2430
    def _handle_simple_function_float(self, node, function, pos_args):
Stefan Behnel's avatar
Stefan Behnel committed
2431 2432 2433
        """Transform float() into either a C type cast or a faster C
        function call.
        """
2434 2435
        # Note: this requires the float() function to be typed as
        # returning a C 'double'
2436
        if len(pos_args) == 0:
Stefan Behnel's avatar
typo  
Stefan Behnel committed
2437
            return ExprNodes.FloatNode(
2438 2439 2440 2441
                node, value="0.0", constant_result=0.0
                ).coerce_to(Builtin.float_type, self.current_env())
        elif len(pos_args) != 1:
            self._error_wrong_arg_count('float', node, pos_args, '0 or 1')
2442 2443 2444 2445 2446 2447 2448
            return node
        func_arg = pos_args[0]
        if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
            func_arg = func_arg.arg
        if func_arg.type is PyrexTypes.c_double_type:
            return func_arg
        elif node.type.assignable_from(func_arg.type) or func_arg.type.is_numeric:
2449 2450
            return ExprNodes.TypecastNode(
                node.pos, operand=func_arg, type=node.type)
2451 2452 2453 2454 2455
        return ExprNodes.PythonCapiCallNode(
            node.pos, "__Pyx_PyObject_AsDouble",
            self.PyObject_AsDouble_func_type,
            args = pos_args,
            is_temp = node.is_temp,
2456
            utility_code = load_c_utility('pyobject_as_double'),
2457 2458
            py_name = "float")

2459 2460 2461 2462 2463
    PyNumber_Int_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("o", PyrexTypes.py_object_type, None)
            ])

2464 2465 2466 2467 2468
    PyInt_FromDouble_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("value", PyrexTypes.c_double_type, None)
            ])

2469 2470 2471 2472
    def _handle_simple_function_int(self, node, function, pos_args):
        """Transform int() into a faster C function call.
        """
        if len(pos_args) == 0:
2473
            return ExprNodes.IntNode(node.pos, value="0", constant_result=0,
Stefan Behnel's avatar
Stefan Behnel committed
2474
                                     type=PyrexTypes.py_object_type)
2475 2476 2477 2478
        elif len(pos_args) != 1:
            return node  # int(x, base)
        func_arg = pos_args[0]
        if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
2479 2480 2481 2482 2483 2484 2485
            if func_arg.arg.type.is_float:
                return ExprNodes.PythonCapiCallNode(
                    node.pos, "__Pyx_PyInt_FromDouble", self.PyInt_FromDouble_func_type,
                    args=[func_arg.arg], is_temp=True, py_name='int',
                    utility_code=UtilityCode.load_cached("PyIntFromDouble", "TypeConversion.c"))
            else:
                return node  # handled in visit_CoerceFromPyTypeNode()
2486 2487
        if func_arg.type.is_pyobject and node.type.is_pyobject:
            return ExprNodes.PythonCapiCallNode(
2488
                node.pos, "__Pyx_PyNumber_Int", self.PyNumber_Int_func_type,
2489
                args=pos_args, is_temp=True, py_name='int')
2490 2491
        return node

2492
    def _handle_simple_function_bool(self, node, function, pos_args):
2493 2494
        """Transform bool(x) into a type coercion to a boolean.
        """
2495 2496 2497 2498 2499 2500
        if len(pos_args) == 0:
            return ExprNodes.BoolNode(
                node.pos, value=False, constant_result=False
                ).coerce_to(Builtin.bool_type, self.current_env())
        elif len(pos_args) != 1:
            self._error_wrong_arg_count('bool', node, pos_args, '0 or 1')
2501
            return node
Craig Citro's avatar
Craig Citro committed
2502
        else:
2503 2504 2505 2506 2507 2508
            # => !!<bint>(x)  to make sure it's exactly 0 or 1
            operand = pos_args[0].coerce_to_boolean(self.current_env())
            operand = ExprNodes.NotNode(node.pos, operand = operand)
            operand = ExprNodes.NotNode(node.pos, operand = operand)
            # coerce back to Python object as that's the result we are expecting
            return operand.coerce_to_pyobject(self.current_env())
2509

2510 2511
    ### builtin functions

2512 2513
    Pyx_strlen_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_size_t_type, [
2514
            PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_const_char_ptr_type, None)
Stefan Behnel's avatar
Stefan Behnel committed
2515
        ])
2516

2517 2518
    Pyx_Py_UNICODE_strlen_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_size_t_type, [
2519
            PyrexTypes.CFuncTypeArg("unicode", PyrexTypes.c_const_py_unicode_ptr_type, None)
Stefan Behnel's avatar
Stefan Behnel committed
2520
        ])
2521

2522 2523 2524
    PyObject_Size_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_py_ssize_t_type, [
            PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None)
Stefan Behnel's avatar
Stefan Behnel committed
2525
        ],
2526
        exception_value="-1")
2527 2528

    _map_to_capi_len_function = {
2529 2530
        Builtin.unicode_type:    "__Pyx_PyUnicode_GET_LENGTH",
        Builtin.bytes_type:      "PyBytes_GET_SIZE",
2531
        Builtin.bytearray_type:  'PyByteArray_GET_SIZE',
2532 2533 2534 2535 2536 2537
        Builtin.list_type:       "PyList_GET_SIZE",
        Builtin.tuple_type:      "PyTuple_GET_SIZE",
        Builtin.set_type:        "PySet_GET_SIZE",
        Builtin.frozenset_type:  "PySet_GET_SIZE",
        Builtin.dict_type:       "PyDict_Size",
    }.get
2538

2539 2540
    _ext_types_with_pysize = set(["cpython.array.array"])

2541
    def _handle_simple_function_len(self, node, function, pos_args):
2542 2543
        """Replace len(char*) by the equivalent call to strlen(),
        len(Py_UNICODE) by the equivalent Py_UNICODE_strlen() and
Stefan Behnel's avatar
Stefan Behnel committed
2544
        len(known_builtin_type) by an equivalent C-API call.
Stefan Behnel's avatar
Stefan Behnel committed
2545
        """
2546 2547 2548 2549 2550 2551
        if len(pos_args) != 1:
            self._error_wrong_arg_count('len', node, pos_args, 1)
            return node
        arg = pos_args[0]
        if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
            arg = arg.arg
2552 2553 2554 2555 2556
        if arg.type.is_string:
            new_node = ExprNodes.PythonCapiCallNode(
                node.pos, "strlen", self.Pyx_strlen_func_type,
                args = [arg],
                is_temp = node.is_temp,
2557
                utility_code = UtilityCode.load_cached("IncludeStringH", "StringTools.c"))
2558
        elif arg.type.is_pyunicode_ptr:
2559 2560 2561
            new_node = ExprNodes.PythonCapiCallNode(
                node.pos, "__Pyx_Py_UNICODE_strlen", self.Pyx_Py_UNICODE_strlen_func_type,
                args = [arg],
2562
                is_temp = node.is_temp)
2563 2564 2565 2566 2567 2568 2569 2570
        elif arg.type.is_memoryviewslice:
            func_type = PyrexTypes.CFuncType(
                PyrexTypes.c_size_t_type, [
                    PyrexTypes.CFuncTypeArg("memoryviewslice", arg.type, None)
                ], nogil=True)
            new_node = ExprNodes.PythonCapiCallNode(
                node.pos, "__Pyx_MemoryView_Len", func_type,
                args=[arg], is_temp=node.is_temp)
2571 2572 2573
        elif arg.type.is_pyobject:
            cfunc_name = self._map_to_capi_len_function(arg.type)
            if cfunc_name is None:
2574 2575 2576 2577 2578 2579
                arg_type = arg.type
                if ((arg_type.is_extension_type or arg_type.is_builtin_type)
                    and arg_type.entry.qualified_name in self._ext_types_with_pysize):
                    cfunc_name = 'Py_SIZE'
                else:
                    return node
2580 2581
            arg = arg.as_none_safe_node(
                "object of type 'NoneType' has no len()")
2582 2583
            new_node = ExprNodes.PythonCapiCallNode(
                node.pos, cfunc_name, self.PyObject_Size_func_type,
2584
                args=[arg], is_temp=node.is_temp)
Stefan Behnel's avatar
Stefan Behnel committed
2585
        elif arg.type.is_unicode_char:
2586 2587
            return ExprNodes.IntNode(node.pos, value='1', constant_result=1,
                                     type=node.type)
2588
        else:
Stefan Behnel's avatar
Stefan Behnel committed
2589
            return node
2590
        if node.type not in (PyrexTypes.c_size_t_type, PyrexTypes.c_py_ssize_t_type):
2591
            new_node = new_node.coerce_to(node.type, self.current_env())
2592
        return new_node
2593

2594 2595 2596 2597 2598
    Pyx_Type_func_type = PyrexTypes.CFuncType(
        Builtin.type_type, [
            PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
            ])

2599
    def _handle_simple_function_type(self, node, function, pos_args):
Stefan Behnel's avatar
Stefan Behnel committed
2600 2601
        """Replace type(o) by a macro call to Py_TYPE(o).
        """
2602
        if len(pos_args) != 1:
2603 2604
            return node
        node = ExprNodes.PythonCapiCallNode(
2605 2606 2607 2608
            node.pos, "Py_TYPE", self.Pyx_Type_func_type,
            args = pos_args,
            is_temp = False)
        return ExprNodes.CastNode(node, PyrexTypes.py_object_type)
2609

2610 2611 2612 2613 2614
    Py_type_check_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_bint_type, [
            PyrexTypes.CFuncTypeArg("arg", PyrexTypes.py_object_type, None)
            ])

2615
    def _handle_simple_function_isinstance(self, node, function, pos_args):
2616 2617 2618 2619 2620 2621
        """Replace isinstance() checks against builtin types by the
        corresponding C-API call.
        """
        if len(pos_args) != 2:
            return node
        arg, types = pos_args
2622
        temps = []
2623 2624
        if isinstance(types, ExprNodes.TupleNode):
            types = types.args
2625 2626
            if len(types) == 1 and not types[0].type is Builtin.type_type:
                return node  # nothing to improve here
2627
            if arg.is_attribute or not arg.is_simple():
2628 2629
                arg = UtilNodes.ResultRefNode(arg)
                temps.append(arg)
2630 2631 2632 2633 2634 2635 2636 2637 2638
        elif types.type is Builtin.type_type:
            types = [types]
        else:
            return node

        tests = []
        test_nodes = []
        env = self.current_env()
        for test_type_node in types:
Robert Bradshaw's avatar
Robert Bradshaw committed
2639
            builtin_type = None
Stefan Behnel's avatar
Stefan Behnel committed
2640
            if test_type_node.is_name:
Robert Bradshaw's avatar
Robert Bradshaw committed
2641 2642 2643 2644
                if test_type_node.entry:
                    entry = env.lookup(test_type_node.entry.name)
                    if entry and entry.type and entry.type.is_builtin_type:
                        builtin_type = entry.type
2645 2646 2647 2648 2649 2650
            if builtin_type is Builtin.type_type:
                # all types have type "type", but there's only one 'type'
                if entry.name != 'type' or not (
                        entry.scope and entry.scope.is_builtin_scope):
                    builtin_type = None
            if builtin_type is not None:
Robert Bradshaw's avatar
Robert Bradshaw committed
2651
                type_check_function = entry.type.type_check_function(exact=False)
2652 2653 2654
                if type_check_function in tests:
                    continue
                tests.append(type_check_function)
Robert Bradshaw's avatar
Robert Bradshaw committed
2655 2656 2657 2658 2659
                type_check_args = [arg]
            elif test_type_node.type is Builtin.type_type:
                type_check_function = '__Pyx_TypeCheck'
                type_check_args = [arg, test_type_node]
            else:
2660 2661 2662 2663 2664
                if not test_type_node.is_literal:
                    test_type_node = UtilNodes.ResultRefNode(test_type_node)
                    temps.append(test_type_node)
                type_check_function = 'PyObject_IsInstance'
                type_check_args = [arg, test_type_node]
2665 2666 2667
            test_nodes.append(
                ExprNodes.PythonCapiCallNode(
                    test_type_node.pos, type_check_function, self.Py_type_check_func_type,
2668 2669 2670
                    args=type_check_args,
                    is_temp=True,
                ))
2671

2672
        def join_with_or(a, b, make_binop_node=ExprNodes.binop_node):
2673 2674
            or_node = make_binop_node(node.pos, 'or', a, b)
            or_node.type = PyrexTypes.c_bint_type
2675
            or_node.wrap_operands(env)
2676 2677 2678
            return or_node

        test_node = reduce(join_with_or, test_nodes).coerce_to(node.type, env)
2679
        for temp in temps[::-1]:
2680 2681 2682
            test_node = UtilNodes.EvalWithTempExprNode(temp, test_node)
        return test_node

2683
    def _handle_simple_function_ord(self, node, function, pos_args):
2684
        """Unpack ord(Py_UNICODE) and ord('X').
2685 2686 2687 2688 2689
        """
        if len(pos_args) != 1:
            return node
        arg = pos_args[0]
        if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
Stefan Behnel's avatar
Stefan Behnel committed
2690
            if arg.arg.type.is_unicode_char:
2691
                return ExprNodes.TypecastNode(
2692
                    arg.pos, operand=arg.arg, type=PyrexTypes.c_long_type
2693 2694 2695 2696
                    ).coerce_to(node.type, self.current_env())
        elif isinstance(arg, ExprNodes.UnicodeNode):
            if len(arg.value) == 1:
                return ExprNodes.IntNode(
2697
                    arg.pos, type=PyrexTypes.c_int_type,
2698 2699 2700 2701 2702
                    value=str(ord(arg.value)),
                    constant_result=ord(arg.value)
                    ).coerce_to(node.type, self.current_env())
        elif isinstance(arg, ExprNodes.StringNode):
            if arg.unicode_value and len(arg.unicode_value) == 1 \
2703
                    and ord(arg.unicode_value) <= 255:  # Py2/3 portability
2704
                return ExprNodes.IntNode(
2705
                    arg.pos, type=PyrexTypes.c_int_type,
2706 2707 2708
                    value=str(ord(arg.unicode_value)),
                    constant_result=ord(arg.unicode_value)
                    ).coerce_to(node.type, self.current_env())
2709 2710
        return node

2711 2712
    ### special methods

2713 2714
    Pyx_tp_new_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
2715
            PyrexTypes.CFuncTypeArg("type",   PyrexTypes.py_object_type, None),
2716
            PyrexTypes.CFuncTypeArg("args",   Builtin.tuple_type, None),
2717 2718
            ])

2719 2720
    Pyx_tp_new_kwargs_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
2721
            PyrexTypes.CFuncTypeArg("type",   PyrexTypes.py_object_type, None),
2722 2723 2724 2725
            PyrexTypes.CFuncTypeArg("args",   Builtin.tuple_type, None),
            PyrexTypes.CFuncTypeArg("kwargs", Builtin.dict_type, None),
        ])

2726 2727
    def _handle_any_slot__new__(self, node, function, args,
                                is_unbound_method, kwargs=None):
2728
        """Replace 'exttype.__new__(exttype, ...)' by a call to exttype->tp_new()
2729
        """
2730
        obj = function.obj
2731
        if not is_unbound_method or len(args) < 1:
2732 2733 2734 2735 2736 2737 2738 2739 2740 2741 2742 2743 2744 2745
            return node
        type_arg = args[0]
        if not obj.is_name or not type_arg.is_name:
            # play safe
            return node
        if obj.type != Builtin.type_type or type_arg.type != Builtin.type_type:
            # not a known type, play safe
            return node
        if not type_arg.type_entry or not obj.type_entry:
            if obj.name != type_arg.name:
                return node
            # otherwise, we know it's a type and we know it's the same
            # type for both - that should do
        elif type_arg.type_entry != obj.type_entry:
2746
            # different types - may or may not lead to an error at runtime
2747 2748
            return node

2749 2750 2751
        args_tuple = ExprNodes.TupleNode(node.pos, args=args[1:])
        args_tuple = args_tuple.analyse_types(
            self.current_env(), skip_children=True)
2752

2753 2754
        if type_arg.type_entry:
            ext_type = type_arg.type_entry.type
2755 2756 2757
            if (ext_type.is_extension_type and ext_type.typeobj_cname and
                    ext_type.scope.global_scope() == self.current_env().global_scope()):
                # known type in current module
2758 2759 2760 2761 2762 2763 2764 2765 2766 2767 2768 2769 2770 2771 2772 2773 2774 2775 2776 2777 2778
                tp_slot = TypeSlots.ConstructorSlot("tp_new", '__new__')
                slot_func_cname = TypeSlots.get_slot_function(ext_type.scope, tp_slot)
                if slot_func_cname:
                    cython_scope = self.context.cython_scope
                    PyTypeObjectPtr = PyrexTypes.CPtrType(
                        cython_scope.lookup('PyTypeObject').type)
                    pyx_tp_new_kwargs_func_type = PyrexTypes.CFuncType(
                        PyrexTypes.py_object_type, [
                            PyrexTypes.CFuncTypeArg("type",   PyTypeObjectPtr, None),
                            PyrexTypes.CFuncTypeArg("args",   PyrexTypes.py_object_type, None),
                            PyrexTypes.CFuncTypeArg("kwargs", PyrexTypes.py_object_type, None),
                            ])

                    type_arg = ExprNodes.CastNode(type_arg, PyTypeObjectPtr)
                    if not kwargs:
                        kwargs = ExprNodes.NullNode(node.pos, type=PyrexTypes.py_object_type)  # hack?
                    return ExprNodes.PythonCapiCallNode(
                        node.pos, slot_func_cname,
                        pyx_tp_new_kwargs_func_type,
                        args=[type_arg, args_tuple, kwargs],
                        is_temp=True)
2779
        else:
2780
            # arbitrary variable, needs a None check for safety
2781
            type_arg = type_arg.as_none_safe_node(
2782 2783
                "object.__new__(X): X is not a type object (NoneType)")

2784 2785 2786 2787 2788 2789 2790 2791 2792 2793 2794 2795 2796 2797
        utility_code = UtilityCode.load_cached('tp_new', 'ObjectHandling.c')
        if kwargs:
            return ExprNodes.PythonCapiCallNode(
                node.pos, "__Pyx_tp_new_kwargs", self.Pyx_tp_new_kwargs_func_type,
                args=[type_arg, args_tuple, kwargs],
                utility_code=utility_code,
                is_temp=node.is_temp
                )
        else:
            return ExprNodes.PythonCapiCallNode(
                node.pos, "__Pyx_tp_new", self.Pyx_tp_new_func_type,
                args=[type_arg, args_tuple],
                utility_code=utility_code,
                is_temp=node.is_temp
2798 2799
            )

2800 2801 2802
    ### methods of builtin types

    PyObject_Append_func_type = PyrexTypes.CFuncType(
2803
        PyrexTypes.c_returncode_type, [
2804 2805
            PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
2806 2807
            ],
        exception_value="-1")
2808

2809
    def _handle_simple_method_object_append(self, node, function, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
2810 2811 2812
        """Optimistic optimisation as X.append() is almost always
        referring to a list.
        """
2813
        if len(args) != 2 or node.result_is_used:
2814 2815
            return node

2816 2817
        return ExprNodes.PythonCapiCallNode(
            node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type,
2818 2819 2820 2821 2822 2823
            args=args,
            may_return_none=False,
            is_temp=node.is_temp,
            result_is_used=False,
            utility_code=load_c_utility('append')
        )
2824

2825 2826 2827 2828 2829 2830 2831 2832 2833 2834 2835 2836 2837 2838 2839 2840 2841 2842 2843 2844 2845 2846 2847 2848 2849 2850 2851 2852 2853 2854 2855 2856 2857 2858 2859 2860 2861 2862 2863 2864 2865 2866 2867 2868 2869 2870 2871 2872 2873 2874 2875 2876 2877 2878 2879 2880
    def _handle_simple_method_list_extend(self, node, function, args, is_unbound_method):
        """Replace list.extend([...]) for short sequence literals values by sequential appends
        to avoid creating an intermediate sequence argument.
        """
        if len(args) != 2:
            return node
        obj, value = args
        if not value.is_sequence_constructor or value.mult_factor is not None:
            return node
        items = list(value.args)
        if len(items) > 4:
            # Appending wins for short sequences.
            # Ignorantly assume that this a good enough limit that avoids repeated resizing.
            return node
        wrapped_obj = self._wrap_self_arg(obj, function, is_unbound_method, 'extend')
        if not items:
            # Empty sequences are not likely to occur, but why waste a call to list.extend() for them?
            wrapped_obj.result_is_used = node.result_is_used
            return wrapped_obj
        cloned_obj = obj = wrapped_obj
        if len(items) > 1 and not obj.is_simple():
            cloned_obj = UtilNodes.LetRefNode(obj)
        # Use ListComp_Append() for all but the last item and finish with PyList_Append()
        # to shrink the list storage size at the very end if necessary.
        temps = []
        arg = items[-1]
        if not arg.is_simple():
            arg = UtilNodes.LetRefNode(arg)
            temps.append(arg)
        new_node = ExprNodes.PythonCapiCallNode(
            node.pos, "__Pyx_PyList_Append", self.PyObject_Append_func_type,
            args=[cloned_obj, arg],
            is_temp=True,
            utility_code=load_c_utility("ListAppend"))
        for arg in items[-2::-1]:
            if not arg.is_simple():
                arg = UtilNodes.LetRefNode(arg)
                temps.append(arg)
            new_node = ExprNodes.binop_node(
                node.pos, '|',
                ExprNodes.PythonCapiCallNode(
                    node.pos, "__Pyx_ListComp_Append", self.PyObject_Append_func_type,
                    args=[cloned_obj, arg], py_name="extend",
                    is_temp=True,
                    utility_code=load_c_utility("ListCompAppend")),
                new_node,
                type=PyrexTypes.c_returncode_type,
            )
        new_node.result_is_used = node.result_is_used
        if cloned_obj is not obj:
            temps.append(cloned_obj)
        for temp in temps:
            new_node = UtilNodes.EvalWithTempExprNode(temp, new_node)
            new_node.result_is_used = node.result_is_used
        return new_node

2881 2882 2883 2884 2885 2886 2887 2888 2889 2890 2891 2892 2893 2894 2895 2896 2897 2898 2899 2900 2901
    PyByteArray_Append_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_returncode_type, [
            PyrexTypes.CFuncTypeArg("bytearray", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("value", PyrexTypes.c_int_type, None),
            ],
        exception_value="-1")

    PyByteArray_AppendObject_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_returncode_type, [
            PyrexTypes.CFuncTypeArg("bytearray", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("value", PyrexTypes.py_object_type, None),
            ],
        exception_value="-1")

    def _handle_simple_method_bytearray_append(self, node, function, args, is_unbound_method):
        if len(args) != 2:
            return node
        func_name = "__Pyx_PyByteArray_Append"
        func_type = self.PyByteArray_Append_func_type

        value = unwrap_coerced_node(args[1])
2902
        if value.type.is_int or isinstance(value, ExprNodes.IntNode):
2903 2904 2905 2906 2907 2908 2909 2910 2911 2912 2913 2914 2915 2916 2917 2918 2919 2920 2921 2922 2923 2924 2925 2926 2927
            value = value.coerce_to(PyrexTypes.c_int_type, self.current_env())
            utility_code = UtilityCode.load_cached("ByteArrayAppend", "StringTools.c")
        elif value.is_string_literal:
            if not value.can_coerce_to_char_literal():
                return node
            value = value.coerce_to(PyrexTypes.c_char_type, self.current_env())
            utility_code = UtilityCode.load_cached("ByteArrayAppend", "StringTools.c")
        elif value.type.is_pyobject:
            func_name = "__Pyx_PyByteArray_AppendObject"
            func_type = self.PyByteArray_AppendObject_func_type
            utility_code = UtilityCode.load_cached("ByteArrayAppendObject", "StringTools.c")
        else:
            return node

        new_node = ExprNodes.PythonCapiCallNode(
            node.pos, func_name, func_type,
            args=[args[0], value],
            may_return_none=False,
            is_temp=node.is_temp,
            utility_code=utility_code,
        )
        if node.result_is_used:
            new_node = new_node.coerce_to(node.type, self.current_env())
        return new_node

Robert Bradshaw's avatar
Robert Bradshaw committed
2928 2929 2930 2931 2932 2933 2934 2935
    PyObject_Pop_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
            ])

    PyObject_PopIndex_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
Stefan Behnel's avatar
Stefan Behnel committed
2936 2937
            PyrexTypes.CFuncTypeArg("py_index", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("c_index", PyrexTypes.c_py_ssize_t_type, None),
2938 2939 2940
            PyrexTypes.CFuncTypeArg("is_signed", PyrexTypes.c_int_type, None),
        ],
        has_varargs=True)  # to fake the additional macro args that lack a proper C type
Robert Bradshaw's avatar
Robert Bradshaw committed
2941

2942 2943 2944 2945 2946
    def _handle_simple_method_list_pop(self, node, function, args, is_unbound_method):
        return self._handle_simple_method_object_pop(
            node, function, args, is_unbound_method, is_list=True)

    def _handle_simple_method_object_pop(self, node, function, args, is_unbound_method, is_list=False):
Stefan Behnel's avatar
Stefan Behnel committed
2947 2948 2949
        """Optimistic optimisation as X.pop([n]) is almost always
        referring to a list.
        """
2950 2951
        if not args:
            return node
2952
        obj = args[0]
2953 2954
        if is_list:
            type_name = 'List'
2955
            obj = obj.as_none_safe_node(
2956
                "'NoneType' object has no attribute '%.30s'",
2957 2958 2959 2960
                error="PyExc_AttributeError",
                format_args=['pop'])
        else:
            type_name = 'Object'
Robert Bradshaw's avatar
Robert Bradshaw committed
2961 2962
        if len(args) == 1:
            return ExprNodes.PythonCapiCallNode(
2963 2964
                node.pos, "__Pyx_Py%s_Pop" % type_name,
                self.PyObject_Pop_func_type,
2965
                args=[obj],
2966 2967 2968 2969
                may_return_none=True,
                is_temp=node.is_temp,
                utility_code=load_c_utility('pop'),
            )
Robert Bradshaw's avatar
Robert Bradshaw committed
2970
        elif len(args) == 2:
2971
            index = unwrap_coerced_node(args[1])
Stefan Behnel's avatar
Stefan Behnel committed
2972
            py_index = ExprNodes.NoneNode(index.pos)
2973 2974
            orig_index_type = index.type
            if not index.type.is_int:
Stefan Behnel's avatar
Stefan Behnel committed
2975 2976 2977 2978 2979 2980 2981
                if isinstance(index, ExprNodes.IntNode):
                    py_index = index.coerce_to_pyobject(self.current_env())
                    index = index.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
                elif is_list:
                    if index.type.is_pyobject:
                        py_index = index.coerce_to_simple(self.current_env())
                        index = ExprNodes.CloneNode(py_index)
2982 2983 2984 2985 2986
                    index = index.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
                else:
                    return node
            elif not PyrexTypes.numeric_type_fits(index.type, PyrexTypes.c_py_ssize_t_type):
                return node
Stefan Behnel's avatar
Stefan Behnel committed
2987 2988
            elif isinstance(index, ExprNodes.IntNode):
                py_index = index.coerce_to_pyobject(self.current_env())
2989 2990 2991 2992 2993 2994 2995 2996 2997 2998 2999
            # real type might still be larger at runtime
            if not orig_index_type.is_int:
                orig_index_type = index.type
            if not orig_index_type.create_to_py_utility_code(self.current_env()):
                return node
            convert_func = orig_index_type.to_py_function
            conversion_type = PyrexTypes.CFuncType(
                PyrexTypes.py_object_type, [PyrexTypes.CFuncTypeArg("intval", orig_index_type, None)])
            return ExprNodes.PythonCapiCallNode(
                node.pos, "__Pyx_Py%s_PopIndex" % type_name,
                self.PyObject_PopIndex_func_type,
Stefan Behnel's avatar
Stefan Behnel committed
3000
                args=[obj, py_index, index,
3001 3002 3003 3004
                      ExprNodes.IntNode(index.pos, value=str(orig_index_type.signed and 1 or 0),
                                        constant_result=orig_index_type.signed and 1 or 0,
                                        type=PyrexTypes.c_int_type),
                      ExprNodes.RawCNameExprNode(index.pos, PyrexTypes.c_void_type,
3005
                                                 orig_index_type.empty_declaration_code()),
3006
                      ExprNodes.RawCNameExprNode(index.pos, conversion_type, convert_func)],
3007 3008 3009 3010
                may_return_none=True,
                is_temp=node.is_temp,
                utility_code=load_c_utility("pop_index"),
            )
3011

Robert Bradshaw's avatar
Robert Bradshaw committed
3012 3013
        return node

3014
    single_param_func_type = PyrexTypes.CFuncType(
3015
        PyrexTypes.c_returncode_type, [
3016 3017 3018
            PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
            ],
        exception_value = "-1")
3019

3020
    def _handle_simple_method_list_sort(self, node, function, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
3021 3022
        """Call PyList_Sort() instead of the 0-argument l.sort().
        """
3023
        if len(args) != 1:
3024
            return node
3025
        return self._substitute_method_call(
3026
            node, function, "PyList_Sort", self.single_param_func_type,
3027
            'sort', is_unbound_method, args).coerce_to(node.type, self.current_env)
3028

3029 3030 3031 3032 3033
    Pyx_PyDict_GetItem_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
3034
            ])
3035

3036
    def _handle_simple_method_dict_get(self, node, function, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
3037 3038
        """Replace dict.get() by a call to PyDict_GetItem().
        """
3039 3040 3041 3042 3043 3044 3045
        if len(args) == 2:
            args.append(ExprNodes.NoneNode(node.pos))
        elif len(args) != 3:
            self._error_wrong_arg_count('dict.get', node, args, "2 or 3")
            return node

        return self._substitute_method_call(
3046 3047
            node, function,
            "__Pyx_PyDict_GetItemDefault", self.Pyx_PyDict_GetItem_func_type,
3048
            'get', is_unbound_method, args,
Stefan Behnel's avatar
Stefan Behnel committed
3049
            may_return_none = True,
3050
            utility_code = load_c_utility("dict_getitem_default"))
3051

3052 3053 3054 3055 3056
    Pyx_PyDict_SetDefault_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
3057
            PyrexTypes.CFuncTypeArg("is_safe_type", PyrexTypes.c_int_type, None),
3058 3059
            ])

3060
    def _handle_simple_method_dict_setdefault(self, node, function, args, is_unbound_method):
3061 3062 3063 3064 3065 3066 3067
        """Replace dict.setdefault() by calls to PyDict_GetItem() and PyDict_SetItem().
        """
        if len(args) == 2:
            args.append(ExprNodes.NoneNode(node.pos))
        elif len(args) != 3:
            self._error_wrong_arg_count('dict.setdefault', node, args, "2 or 3")
            return node
3068 3069 3070 3071 3072 3073 3074 3075 3076
        key_type = args[1].type
        if key_type.is_builtin_type:
            is_safe_type = int(key_type.name in
                               'str bytes unicode float int long bool')
        elif key_type is PyrexTypes.py_object_type:
            is_safe_type = -1  # don't know
        else:
            is_safe_type = 0   # definitely not
        args.append(ExprNodes.IntNode(
3077
            node.pos, value=str(is_safe_type), constant_result=is_safe_type))
3078 3079

        return self._substitute_method_call(
3080 3081
            node, function,
            "__Pyx_PyDict_SetDefault", self.Pyx_PyDict_SetDefault_func_type,
3082
            'setdefault', is_unbound_method, args,
3083 3084
            may_return_none=True,
            utility_code=load_c_utility('dict_setdefault'))
3085

3086 3087 3088 3089 3090 3091 3092 3093 3094 3095 3096 3097 3098 3099 3100 3101 3102 3103
    PyDict_Pop_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
            ])

    def _handle_simple_method_dict_pop(self, node, function, args, is_unbound_method):
        """Replace dict.pop() by a call to _PyDict_Pop().
        """
        if len(args) == 2:
            args.append(ExprNodes.NullNode(node.pos))
        elif len(args) != 3:
            self._error_wrong_arg_count('dict.pop', node, args, "2 or 3")
            return node

        return self._substitute_method_call(
            node, function,
3104 3105 3106
            "__Pyx_PyDict_Pop", self.PyDict_Pop_func_type,
            'pop', is_unbound_method, args,
            utility_code=load_c_utility('py_dict_pop'))
3107

3108
    Pyx_PyInt_BinopInt_func_type = PyrexTypes.CFuncType(
3109 3110 3111
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("op1", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("op2", PyrexTypes.py_object_type, None),
3112 3113 3114 3115 3116 3117 3118 3119 3120 3121
            PyrexTypes.CFuncTypeArg("intval", PyrexTypes.c_long_type, None),
            PyrexTypes.CFuncTypeArg("inplace", PyrexTypes.c_bint_type, None),
        ])

    Pyx_PyFloat_BinopInt_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("op1", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("op2", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("fval", PyrexTypes.c_double_type, None),
            PyrexTypes.CFuncTypeArg("inplace", PyrexTypes.c_bint_type, None),
3122 3123 3124
        ])

    def _handle_simple_method_object___add__(self, node, function, args, is_unbound_method):
3125
        return self._optimise_num_binop('Add', node, function, args, is_unbound_method)
3126 3127

    def _handle_simple_method_object___sub__(self, node, function, args, is_unbound_method):
3128 3129
        return self._optimise_num_binop('Subtract', node, function, args, is_unbound_method)

3130 3131 3132 3133 3134 3135
    def _handle_simple_method_object___eq__(self, node, function, args, is_unbound_method):
        return self._optimise_num_binop('Eq', node, function, args, is_unbound_method)

    def _handle_simple_method_object___neq__(self, node, function, args, is_unbound_method):
        return self._optimise_num_binop('Ne', node, function, args, is_unbound_method)

3136 3137 3138 3139 3140 3141 3142 3143 3144
    def _handle_simple_method_object___and__(self, node, function, args, is_unbound_method):
        return self._optimise_num_binop('And', node, function, args, is_unbound_method)

    def _handle_simple_method_object___or__(self, node, function, args, is_unbound_method):
        return self._optimise_num_binop('Or', node, function, args, is_unbound_method)

    def _handle_simple_method_object___xor__(self, node, function, args, is_unbound_method):
        return self._optimise_num_binop('Xor', node, function, args, is_unbound_method)

3145 3146 3147 3148 3149 3150 3151
    def _handle_simple_method_object___rshift__(self, node, function, args, is_unbound_method):
        if len(args) != 2 or not isinstance(args[1], ExprNodes.IntNode):
            return node
        if not args[1].has_constant_result() or not (1 <= args[1].constant_result <= 63):
            return node
        return self._optimise_num_binop('Rshift', node, function, args, is_unbound_method)

3152 3153 3154 3155 3156 3157 3158
    def _handle_simple_method_object___lshift__(self, node, function, args, is_unbound_method):
        if len(args) != 2 or not isinstance(args[1], ExprNodes.IntNode):
            return node
        if not args[1].has_constant_result() or not (1 <= args[1].constant_result <= 63):
            return node
        return self._optimise_num_binop('Lshift', node, function, args, is_unbound_method)

Stefan Behnel's avatar
Stefan Behnel committed
3159
    def _handle_simple_method_object___mod__(self, node, function, args, is_unbound_method):
3160
        return self._optimise_num_div('Remainder', node, function, args, is_unbound_method)
Stefan Behnel's avatar
Stefan Behnel committed
3161

3162 3163 3164 3165 3166 3167
    def _handle_simple_method_object___floordiv__(self, node, function, args, is_unbound_method):
        return self._optimise_num_div('FloorDivide', node, function, args, is_unbound_method)

    def _handle_simple_method_object___truediv__(self, node, function, args, is_unbound_method):
        return self._optimise_num_div('TrueDivide', node, function, args, is_unbound_method)

3168 3169 3170
    def _handle_simple_method_object___div__(self, node, function, args, is_unbound_method):
        return self._optimise_num_div('Divide', node, function, args, is_unbound_method)

3171 3172 3173 3174 3175 3176 3177 3178 3179 3180 3181 3182 3183
    def _optimise_num_div(self, operator, node, function, args, is_unbound_method):
        if len(args) != 2 or not args[1].has_constant_result() or args[1].constant_result == 0:
            return node
        if isinstance(args[1], ExprNodes.IntNode):
            if not (-2**30 <= args[1].constant_result <= 2**30):
                return node
        elif isinstance(args[1], ExprNodes.FloatNode):
            if not (-2**53 <= args[1].constant_result <= 2**53):
                return node
        else:
            return node
        return self._optimise_num_binop(operator, node, function, args, is_unbound_method)

3184 3185 3186 3187 3188
    def _handle_simple_method_float___add__(self, node, function, args, is_unbound_method):
        return self._optimise_num_binop('Add', node, function, args, is_unbound_method)

    def _handle_simple_method_float___sub__(self, node, function, args, is_unbound_method):
        return self._optimise_num_binop('Subtract', node, function, args, is_unbound_method)
3189

3190 3191 3192 3193 3194 3195
    def _handle_simple_method_float___truediv__(self, node, function, args, is_unbound_method):
        return self._optimise_num_binop('TrueDivide', node, function, args, is_unbound_method)

    def _handle_simple_method_float___div__(self, node, function, args, is_unbound_method):
        return self._optimise_num_binop('Divide', node, function, args, is_unbound_method)

3196 3197 3198
    def _handle_simple_method_float___mod__(self, node, function, args, is_unbound_method):
        return self._optimise_num_binop('Remainder', node, function, args, is_unbound_method)

3199 3200 3201 3202 3203 3204
    def _handle_simple_method_float___eq__(self, node, function, args, is_unbound_method):
        return self._optimise_num_binop('Eq', node, function, args, is_unbound_method)

    def _handle_simple_method_float___neq__(self, node, function, args, is_unbound_method):
        return self._optimise_num_binop('Ne', node, function, args, is_unbound_method)

3205
    def _optimise_num_binop(self, operator, node, function, args, is_unbound_method):
3206
        """
3207
        Optimise math operators for (likely) float or small integer operations.
3208 3209 3210 3211 3212 3213
        """
        if len(args) != 2:
            return node
        if not node.type.is_pyobject:
            return node

3214 3215
        # When adding IntNode/FloatNode to something else, assume other operand is also numeric.
        # Prefer constants on RHS as they allows better size control for some operators.
3216
        num_nodes = (ExprNodes.IntNode, ExprNodes.FloatNode)
3217
        if isinstance(args[1], num_nodes):
3218 3219
            if args[0].type is not PyrexTypes.py_object_type:
                return node
3220 3221
            numval = args[1]
            arg_order = 'ObjC'
3222 3223 3224 3225 3226
        elif isinstance(args[0], num_nodes):
            if args[1].type is not PyrexTypes.py_object_type:
                return node
            numval = args[0]
            arg_order = 'CObj'
3227 3228 3229
        else:
            return node

3230 3231 3232
        if not numval.has_constant_result():
            return node

3233
        is_float = isinstance(numval, ExprNodes.FloatNode)
3234
        if is_float:
3235
            if operator not in ('Add', 'Subtract', 'Remainder', 'TrueDivide', 'Divide', 'Eq', 'Ne'):
3236
                return node
3237 3238 3239
        elif operator == 'Divide':
            # mixed old-/new-style division is not currently optimised for integers
            return node
3240
        elif abs(numval.constant_result) > 2**30:
3241
            return node
3242

3243
        args = list(args)
3244 3245 3246
        args.append((ExprNodes.FloatNode if is_float else ExprNodes.IntNode)(
            numval.pos, value=numval.value, constant_result=numval.constant_result,
            type=PyrexTypes.c_double_type if is_float else PyrexTypes.c_long_type))
3247
        inplace = node.inplace if isinstance(node, ExprNodes.NumBinopNode) else False
3248
        args.append(ExprNodes.BoolNode(node.pos, value=inplace, constant_result=inplace))
3249 3250

        utility_code = TempitaUtilityCode.load_cached(
3251
            "PyFloatBinop" if is_float else "PyIntBinop", "Optimize.c",
3252 3253 3254
            context=dict(op=operator, order=arg_order))

        return self._substitute_method_call(
3255 3256
            node, function, "__Pyx_Py%s_%s%s" % ('Float' if is_float else 'Int', operator, arg_order),
            self.Pyx_PyFloat_BinopInt_func_type if is_float else self.Pyx_PyInt_BinopInt_func_type,
3257 3258 3259 3260
            '__%s__' % operator[:3].lower(), is_unbound_method, args,
            may_return_none=True,
            with_none_check=False,
            utility_code=utility_code)
3261 3262 3263

    ### unicode type methods

3264 3265
    PyUnicode_uchar_predicate_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_bint_type, [
3266
            PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None),
3267 3268
            ])

3269
    def _inject_unicode_predicate(self, node, function, args, is_unbound_method):
3270 3271 3272 3273
        if is_unbound_method or len(args) != 1:
            return node
        ustring = args[0]
        if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
Stefan Behnel's avatar
Stefan Behnel committed
3274
               not ustring.arg.type.is_unicode_char:
3275 3276
            return node
        uchar = ustring.arg
3277
        method_name = function.attribute
3278 3279
        if method_name == 'istitle':
            # istitle() doesn't directly map to Py_UNICODE_ISTITLE()
3280 3281
            utility_code = UtilityCode.load_cached(
                "py_unicode_istitle", "StringTools.c")
3282 3283 3284 3285 3286
            function_name = '__Pyx_Py_UNICODE_ISTITLE'
        else:
            utility_code = None
            function_name = 'Py_UNICODE_%s' % method_name.upper()
        func_call = self._substitute_method_call(
3287 3288
            node, function,
            function_name, self.PyUnicode_uchar_predicate_func_type,
3289 3290 3291 3292 3293 3294 3295 3296 3297 3298 3299 3300 3301 3302 3303 3304 3305
            method_name, is_unbound_method, [uchar],
            utility_code = utility_code)
        if node.type.is_pyobject:
            func_call = func_call.coerce_to_pyobject(self.current_env)
        return func_call

    _handle_simple_method_unicode_isalnum   = _inject_unicode_predicate
    _handle_simple_method_unicode_isalpha   = _inject_unicode_predicate
    _handle_simple_method_unicode_isdecimal = _inject_unicode_predicate
    _handle_simple_method_unicode_isdigit   = _inject_unicode_predicate
    _handle_simple_method_unicode_islower   = _inject_unicode_predicate
    _handle_simple_method_unicode_isnumeric = _inject_unicode_predicate
    _handle_simple_method_unicode_isspace   = _inject_unicode_predicate
    _handle_simple_method_unicode_istitle   = _inject_unicode_predicate
    _handle_simple_method_unicode_isupper   = _inject_unicode_predicate

    PyUnicode_uchar_conversion_func_type = PyrexTypes.CFuncType(
3306 3307
        PyrexTypes.c_py_ucs4_type, [
            PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None),
3308 3309
            ])

Stefan Behnel's avatar
Stefan Behnel committed
3310
    def _inject_unicode_character_conversion(self, node, function, args, is_unbound_method):
3311 3312 3313 3314
        if is_unbound_method or len(args) != 1:
            return node
        ustring = args[0]
        if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
Stefan Behnel's avatar
Stefan Behnel committed
3315
               not ustring.arg.type.is_unicode_char:
3316 3317
            return node
        uchar = ustring.arg
3318
        method_name = function.attribute
3319 3320
        function_name = 'Py_UNICODE_TO%s' % method_name.upper()
        func_call = self._substitute_method_call(
3321 3322
            node, function,
            function_name, self.PyUnicode_uchar_conversion_func_type,
3323 3324 3325 3326 3327 3328 3329 3330 3331
            method_name, is_unbound_method, [uchar])
        if node.type.is_pyobject:
            func_call = func_call.coerce_to_pyobject(self.current_env)
        return func_call

    _handle_simple_method_unicode_lower = _inject_unicode_character_conversion
    _handle_simple_method_unicode_upper = _inject_unicode_character_conversion
    _handle_simple_method_unicode_title = _inject_unicode_character_conversion

3332 3333 3334 3335 3336 3337
    PyUnicode_Splitlines_func_type = PyrexTypes.CFuncType(
        Builtin.list_type, [
            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
            PyrexTypes.CFuncTypeArg("keepends", PyrexTypes.c_bint_type, None),
            ])

Stefan Behnel's avatar
Stefan Behnel committed
3338
    def _handle_simple_method_unicode_splitlines(self, node, function, args, is_unbound_method):
3339 3340 3341 3342 3343 3344
        """Replace unicode.splitlines(...) by a direct call to the
        corresponding C-API function.
        """
        if len(args) not in (1,2):
            self._error_wrong_arg_count('unicode.splitlines', node, args, "1 or 2")
            return node
3345
        self._inject_bint_default_argument(node, args, 1, False)
3346 3347

        return self._substitute_method_call(
3348 3349
            node, function,
            "PyUnicode_Splitlines", self.PyUnicode_Splitlines_func_type,
3350 3351 3352 3353 3354 3355 3356 3357 3358 3359
            'splitlines', is_unbound_method, args)

    PyUnicode_Split_func_type = PyrexTypes.CFuncType(
        Builtin.list_type, [
            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
            PyrexTypes.CFuncTypeArg("sep", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("maxsplit", PyrexTypes.c_py_ssize_t_type, None),
            ]
        )

3360
    def _handle_simple_method_unicode_split(self, node, function, args, is_unbound_method):
3361 3362 3363 3364 3365 3366 3367 3368
        """Replace unicode.split(...) by a direct call to the
        corresponding C-API function.
        """
        if len(args) not in (1,2,3):
            self._error_wrong_arg_count('unicode.split', node, args, "1-3")
            return node
        if len(args) < 2:
            args.append(ExprNodes.NullNode(node.pos))
3369 3370
        self._inject_int_default_argument(
            node, args, 2, PyrexTypes.c_py_ssize_t_type, "-1")
3371 3372

        return self._substitute_method_call(
3373 3374
            node, function,
            "PyUnicode_Split", self.PyUnicode_Split_func_type,
3375 3376
            'split', is_unbound_method, args)

3377 3378 3379 3380 3381 3382 3383 3384 3385 3386 3387 3388 3389 3390 3391 3392 3393
    PyUnicode_Join_func_type = PyrexTypes.CFuncType(
        Builtin.unicode_type, [
            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
            PyrexTypes.CFuncTypeArg("seq", PyrexTypes.py_object_type, None),
            ])

    def _handle_simple_method_unicode_join(self, node, function, args, is_unbound_method):
        """
        unicode.join() builds a list first => see if we can do this more efficiently
        """
        if len(args) != 2:
            self._error_wrong_arg_count('unicode.join', node, args, "2")
            return node
        if isinstance(args[1], ExprNodes.GeneratorExpressionNode):
            gen_expr_node = args[1]
            loop_node = gen_expr_node.loop

3394 3395
            yield_statements = _find_yield_statements(loop_node)
            if yield_statements:
3396 3397 3398 3399
                inlined_genexpr = ExprNodes.InlinedGeneratorExpressionNode(
                    node.pos, gen_expr_node, orig_func='list',
                    comprehension_type=Builtin.list_type)

3400
                for yield_expression, yield_stat_node in yield_statements:
3401 3402 3403 3404 3405 3406
                    append_node = ExprNodes.ComprehensionAppendNode(
                        yield_expression.pos,
                        expr=yield_expression,
                        target=inlined_genexpr.target)

                    Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node)
3407 3408 3409 3410 3411 3412 3413 3414

                args[1] = inlined_genexpr

        return self._substitute_method_call(
            node, function,
            "PyUnicode_Join", self.PyUnicode_Join_func_type,
            'join', is_unbound_method, args)

3415
    PyString_Tailmatch_func_type = PyrexTypes.CFuncType(
3416
        PyrexTypes.c_bint_type, [
3417
            PyrexTypes.CFuncTypeArg("str", PyrexTypes.py_object_type, None),  # bytes/str/unicode
3418 3419 3420 3421 3422 3423 3424
            PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
            ],
        exception_value = '-1')

3425
    def _handle_simple_method_unicode_endswith(self, node, function, args, is_unbound_method):
3426
        return self._inject_tailmatch(
3427
            node, function, args, is_unbound_method, 'unicode', 'endswith',
3428
            unicode_tailmatch_utility_code, +1)
3429

3430
    def _handle_simple_method_unicode_startswith(self, node, function, args, is_unbound_method):
3431
        return self._inject_tailmatch(
3432
            node, function, args, is_unbound_method, 'unicode', 'startswith',
3433
            unicode_tailmatch_utility_code, -1)
3434

3435
    def _inject_tailmatch(self, node, function, args, is_unbound_method, type_name,
3436
                          method_name, utility_code, direction):
3437 3438 3439 3440
        """Replace unicode.startswith(...) and unicode.endswith(...)
        by a direct call to the corresponding C-API function.
        """
        if len(args) not in (2,3,4):
3441
            self._error_wrong_arg_count('%s.%s' % (type_name, method_name), node, args, "2-4")
3442
            return node
3443 3444 3445 3446
        self._inject_int_default_argument(
            node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
        self._inject_int_default_argument(
            node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
3447 3448 3449 3450
        args.append(ExprNodes.IntNode(
            node.pos, value=str(direction), type=PyrexTypes.c_int_type))

        method_call = self._substitute_method_call(
3451 3452
            node, function,
            "__Pyx_Py%s_Tailmatch" % type_name.capitalize(),
3453
            self.PyString_Tailmatch_func_type,
3454
            method_name, is_unbound_method, args,
3455
            utility_code = utility_code)
Stefan Behnel's avatar
Stefan Behnel committed
3456
        return method_call.coerce_to(Builtin.bool_type, self.current_env())
3457

3458 3459 3460 3461 3462 3463 3464 3465 3466 3467
    PyUnicode_Find_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_py_ssize_t_type, [
            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
            PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
            ],
        exception_value = '-2')

3468
    def _handle_simple_method_unicode_find(self, node, function, args, is_unbound_method):
3469
        return self._inject_unicode_find(
3470
            node, function, args, is_unbound_method, 'find', +1)
3471

3472
    def _handle_simple_method_unicode_rfind(self, node, function, args, is_unbound_method):
3473
        return self._inject_unicode_find(
3474
            node, function, args, is_unbound_method, 'rfind', -1)
3475

3476
    def _inject_unicode_find(self, node, function, args, is_unbound_method,
3477 3478 3479 3480 3481 3482 3483
                             method_name, direction):
        """Replace unicode.find(...) and unicode.rfind(...) by a
        direct call to the corresponding C-API function.
        """
        if len(args) not in (2,3,4):
            self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4")
            return node
3484 3485 3486 3487
        self._inject_int_default_argument(
            node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
        self._inject_int_default_argument(
            node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
3488 3489 3490 3491
        args.append(ExprNodes.IntNode(
            node.pos, value=str(direction), type=PyrexTypes.c_int_type))

        method_call = self._substitute_method_call(
3492
            node, function, "PyUnicode_Find", self.PyUnicode_Find_func_type,
3493
            method_name, is_unbound_method, args)
Stefan Behnel's avatar
Stefan Behnel committed
3494
        return method_call.coerce_to_pyobject(self.current_env())
3495

Stefan Behnel's avatar
Stefan Behnel committed
3496 3497 3498 3499 3500 3501 3502 3503 3504
    PyUnicode_Count_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_py_ssize_t_type, [
            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
            PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
            ],
        exception_value = '-1')

3505
    def _handle_simple_method_unicode_count(self, node, function, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
3506 3507 3508 3509 3510 3511
        """Replace unicode.count(...) by a direct call to the
        corresponding C-API function.
        """
        if len(args) not in (2,3,4):
            self._error_wrong_arg_count('unicode.count', node, args, "2-4")
            return node
3512 3513 3514 3515
        self._inject_int_default_argument(
            node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
        self._inject_int_default_argument(
            node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
Stefan Behnel's avatar
Stefan Behnel committed
3516 3517

        method_call = self._substitute_method_call(
3518
            node, function, "PyUnicode_Count", self.PyUnicode_Count_func_type,
Stefan Behnel's avatar
Stefan Behnel committed
3519
            'count', is_unbound_method, args)
Stefan Behnel's avatar
Stefan Behnel committed
3520
        return method_call.coerce_to_pyobject(self.current_env())
Stefan Behnel's avatar
Stefan Behnel committed
3521

Stefan Behnel's avatar
Stefan Behnel committed
3522 3523 3524 3525 3526 3527 3528 3529
    PyUnicode_Replace_func_type = PyrexTypes.CFuncType(
        Builtin.unicode_type, [
            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
            PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("replstr", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("maxcount", PyrexTypes.c_py_ssize_t_type, None),
            ])

3530
    def _handle_simple_method_unicode_replace(self, node, function, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
3531 3532 3533 3534 3535 3536
        """Replace unicode.replace(...) by a direct call to the
        corresponding C-API function.
        """
        if len(args) not in (3,4):
            self._error_wrong_arg_count('unicode.replace', node, args, "3-4")
            return node
3537 3538
        self._inject_int_default_argument(
            node, args, 3, PyrexTypes.c_py_ssize_t_type, "-1")
Stefan Behnel's avatar
Stefan Behnel committed
3539 3540

        return self._substitute_method_call(
3541
            node, function, "PyUnicode_Replace", self.PyUnicode_Replace_func_type,
Stefan Behnel's avatar
Stefan Behnel committed
3542 3543
            'replace', is_unbound_method, args)

3544 3545
    PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType(
        Builtin.bytes_type, [
3546
            PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
3547 3548
            PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None),
            PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None),
3549
            ])
3550 3551 3552

    PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType(
        Builtin.bytes_type, [
3553
            PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
3554
            ])
3555

3556
    _special_encodings = ['UTF8', 'UTF16', 'UTF-16LE', 'UTF-16BE', 'Latin1', 'ASCII',
3557 3558
                          'unicode_escape', 'raw_unicode_escape']

3559 3560
    _special_codecs = [ (name, codecs.getencoder(name))
                        for name in _special_encodings ]
3561

3562
    def _handle_simple_method_unicode_encode(self, node, function, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
3563 3564 3565
        """Replace unicode.encode(...) by a direct C-API call to the
        corresponding codec.
        """
3566
        if len(args) < 1 or len(args) > 3:
3567
            self._error_wrong_arg_count('unicode.encode', node, args, '1-3')
3568 3569 3570 3571 3572
            return node

        string_node = args[0]

        if len(args) == 1:
3573
            null_node = ExprNodes.NullNode(node.pos)
3574
            return self._substitute_method_call(
3575
                node, function, "PyUnicode_AsEncodedString",
3576 3577 3578
                self.PyUnicode_AsEncodedString_func_type,
                'encode', is_unbound_method, [string_node, null_node, null_node])

3579 3580 3581 3582 3583
        parameters = self._unpack_encoding_and_error_mode(node.pos, args)
        if parameters is None:
            return node
        encoding, encoding_node, error_handling, error_handling_node = parameters

3584
        if encoding and isinstance(string_node, ExprNodes.UnicodeNode):
3585 3586 3587 3588 3589 3590 3591
            # constant, so try to do the encoding at compile time
            try:
                value = string_node.value.encode(encoding, error_handling)
            except:
                # well, looks like we can't
                pass
            else:
3592 3593
                value = bytes_literal(value, encoding)
                return ExprNodes.BytesNode(string_node.pos, value=value, type=Builtin.bytes_type)
3594

3595
        if encoding and error_handling == 'strict':
3596 3597
            # try to find a specific encoder function
            codec_name = self._find_special_codec_name(encoding)
3598
            if codec_name is not None and '-' not in codec_name:
3599 3600
                encode_function = "PyUnicode_As%sString" % codec_name
                return self._substitute_method_call(
3601
                    node, function, encode_function,
3602 3603 3604 3605
                    self.PyUnicode_AsXyzString_func_type,
                    'encode', is_unbound_method, [string_node])

        return self._substitute_method_call(
3606
            node, function, "PyUnicode_AsEncodedString",
3607 3608 3609 3610
            self.PyUnicode_AsEncodedString_func_type,
            'encode', is_unbound_method,
            [string_node, encoding_node, error_handling_node])

3611
    PyUnicode_DecodeXyz_func_ptr_type = PyrexTypes.CPtrType(PyrexTypes.CFuncType(
3612
        Builtin.unicode_type, [
3613
            PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_const_char_ptr_type, None),
3614
            PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
3615 3616
            PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None),
        ]))
3617

3618
    _decode_c_string_func_type = PyrexTypes.CFuncType(
3619
        Builtin.unicode_type, [
3620
            PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_const_char_ptr_type, None),
3621 3622
            PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None),
3623 3624
            PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None),
            PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None),
3625
            PyrexTypes.CFuncTypeArg("decode_func", PyUnicode_DecodeXyz_func_ptr_type, None),
3626
        ])
3627

Stefan Behnel's avatar
Stefan Behnel committed
3628 3629 3630 3631 3632
    _decode_bytes_func_type = PyrexTypes.CFuncType(
        Builtin.unicode_type, [
            PyrexTypes.CFuncTypeArg("string", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None),
3633 3634
            PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None),
            PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None),
Stefan Behnel's avatar
Stefan Behnel committed
3635
            PyrexTypes.CFuncTypeArg("decode_func", PyUnicode_DecodeXyz_func_ptr_type, None),
3636
        ])
Stefan Behnel's avatar
Stefan Behnel committed
3637

Stefan Behnel's avatar
Stefan Behnel committed
3638
    _decode_cpp_string_func_type = None  # lazy init
3639

3640
    def _handle_simple_method_bytes_decode(self, node, function, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
3641
        """Replace char*.decode() by a direct C-API call to the
Stefan Behnel's avatar
Stefan Behnel committed
3642
        corresponding codec, possibly resolving a slice on the char*.
Stefan Behnel's avatar
Stefan Behnel committed
3643
        """
3644
        if not (1 <= len(args) <= 3):
3645 3646
            self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
            return node
3647 3648

        # normalise input nodes
Stefan Behnel's avatar
Stefan Behnel committed
3649 3650 3651 3652
        string_node = args[0]
        start = stop = None
        if isinstance(string_node, ExprNodes.SliceIndexNode):
            index_node = string_node
3653 3654 3655 3656
            string_node = index_node.base
            start, stop = index_node.start, index_node.stop
            if not start or start.constant_result == 0:
                start = None
Stefan Behnel's avatar
Stefan Behnel committed
3657 3658 3659 3660
        if isinstance(string_node, ExprNodes.CoerceToPyTypeNode):
            string_node = string_node.arg

        string_type = string_node.type
Stefan Behnel's avatar
Stefan Behnel committed
3661
        if string_type in (Builtin.bytes_type, Builtin.bytearray_type):
Stefan Behnel's avatar
Stefan Behnel committed
3662 3663 3664
            if is_unbound_method:
                string_node = string_node.as_none_safe_node(
                    "descriptor '%s' requires a '%s' object but received a 'NoneType'",
Stefan Behnel's avatar
Stefan Behnel committed
3665
                    format_args=['decode', string_type.name])
Stefan Behnel's avatar
Stefan Behnel committed
3666 3667
            else:
                string_node = string_node.as_none_safe_node(
3668
                    "'NoneType' object has no attribute '%.30s'",
Stefan Behnel's avatar
Stefan Behnel committed
3669 3670
                    error="PyExc_AttributeError",
                    format_args=['decode'])
Stefan Behnel's avatar
Stefan Behnel committed
3671
        elif not string_type.is_string and not string_type.is_cpp_string:
3672 3673
            # nothing to optimise here
            return node
3674 3675 3676 3677 3678 3679

        parameters = self._unpack_encoding_and_error_mode(node.pos, args)
        if parameters is None:
            return node
        encoding, encoding_node, error_handling, error_handling_node = parameters

3680 3681 3682 3683 3684 3685 3686
        if not start:
            start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
        elif not start.type.is_int:
            start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
        if stop and not stop.type.is_int:
            stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())

3687
        # try to find a specific encoder function
3688 3689 3690
        codec_name = None
        if encoding is not None:
            codec_name = self._find_special_codec_name(encoding)
3691
        if codec_name is not None:
3692 3693 3694 3695
            if codec_name in ('UTF16', 'UTF-16LE', 'UTF-16BE'):
                codec_cname = "__Pyx_PyUnicode_Decode%s" % codec_name.replace('-', '')
            else:
                codec_cname = "PyUnicode_Decode%s" % codec_name
3696
            decode_function = ExprNodes.RawCNameExprNode(
3697
                node.pos, type=self.PyUnicode_DecodeXyz_func_ptr_type, cname=codec_cname)
3698
            encoding_node = ExprNodes.NullNode(node.pos)
3699
        else:
3700 3701 3702 3703
            decode_function = ExprNodes.NullNode(node.pos)

        # build the helper function call
        temps = []
Stefan Behnel's avatar
Stefan Behnel committed
3704
        if string_type.is_string:
3705 3706 3707 3708 3709 3710 3711 3712
            # C string
            if not stop:
                # use strlen() to find the string length, just as CPython would
                if not string_node.is_name:
                    string_node = UtilNodes.LetRefNode(string_node) # used twice
                    temps.append(string_node)
                stop = ExprNodes.PythonCapiCallNode(
                    string_node.pos, "strlen", self.Pyx_strlen_func_type,
Stefan Behnel's avatar
Stefan Behnel committed
3713 3714 3715 3716
                    args=[string_node],
                    is_temp=False,
                    utility_code=UtilityCode.load_cached("IncludeStringH", "StringTools.c"),
                ).coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
3717 3718
            helper_func_type = self._decode_c_string_func_type
            utility_code_name = 'decode_c_string'
Stefan Behnel's avatar
Stefan Behnel committed
3719
        elif string_type.is_cpp_string:
3720 3721 3722 3723 3724 3725 3726 3727
            # C++ std::string
            if not stop:
                stop = ExprNodes.IntNode(node.pos, value='PY_SSIZE_T_MAX',
                                         constant_result=ExprNodes.not_a_constant)
            if self._decode_cpp_string_func_type is None:
                # lazy init to reuse the C++ string type
                self._decode_cpp_string_func_type = PyrexTypes.CFuncType(
                    Builtin.unicode_type, [
Stefan Behnel's avatar
Stefan Behnel committed
3728
                        PyrexTypes.CFuncTypeArg("string", string_type, None),
3729 3730
                        PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
                        PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None),
3731 3732
                        PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None),
                        PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None),
3733
                        PyrexTypes.CFuncTypeArg("decode_func", self.PyUnicode_DecodeXyz_func_ptr_type, None),
Stefan Behnel's avatar
Stefan Behnel committed
3734
                    ])
3735 3736
            helper_func_type = self._decode_cpp_string_func_type
            utility_code_name = 'decode_cpp_string'
Stefan Behnel's avatar
Stefan Behnel committed
3737
        else:
Stefan Behnel's avatar
Stefan Behnel committed
3738
            # Python bytes/bytearray object
Stefan Behnel's avatar
Stefan Behnel committed
3739 3740 3741 3742
            if not stop:
                stop = ExprNodes.IntNode(node.pos, value='PY_SSIZE_T_MAX',
                                         constant_result=ExprNodes.not_a_constant)
            helper_func_type = self._decode_bytes_func_type
Stefan Behnel's avatar
Stefan Behnel committed
3743 3744 3745 3746
            if string_type is Builtin.bytes_type:
                utility_code_name = 'decode_bytes'
            else:
                utility_code_name = 'decode_bytearray'
3747 3748 3749

        node = ExprNodes.PythonCapiCallNode(
            node.pos, '__Pyx_%s' % utility_code_name, helper_func_type,
Stefan Behnel's avatar
Stefan Behnel committed
3750 3751
            args=[string_node, start, stop, encoding_node, error_handling_node, decode_function],
            is_temp=node.is_temp,
3752
            utility_code=UtilityCode.load_cached(utility_code_name, 'StringTools.c'),
Stefan Behnel's avatar
Stefan Behnel committed
3753
        )
3754

3755 3756 3757
        for temp in temps[::-1]:
            node = UtilNodes.EvalWithTempExprNode(temp, node)
        return node
3758

Stefan Behnel's avatar
Stefan Behnel committed
3759 3760
    _handle_simple_method_bytearray_decode = _handle_simple_method_bytes_decode

3761 3762 3763
    def _find_special_codec_name(self, encoding):
        try:
            requested_codec = codecs.getencoder(encoding)
Stefan Behnel's avatar
Stefan Behnel committed
3764
        except LookupError:
3765 3766 3767 3768
            return None
        for name, codec in self._special_codecs:
            if codec == requested_codec:
                if '_' in name:
Stefan Behnel's avatar
Stefan Behnel committed
3769 3770
                    name = ''.join([s.capitalize()
                                    for s in name.split('_')])
3771 3772 3773 3774
                return name
        return None

    def _unpack_encoding_and_error_mode(self, pos, args):
3775 3776 3777
        null_node = ExprNodes.NullNode(pos)

        if len(args) >= 2:
3778 3779
            encoding, encoding_node = self._unpack_string_and_cstring_node(args[1])
            if encoding_node is None:
3780
                return None
3781
        else:
3782 3783
            encoding = None
            encoding_node = null_node
3784 3785

        if len(args) == 3:
3786 3787
            error_handling, error_handling_node = self._unpack_string_and_cstring_node(args[2])
            if error_handling_node is None:
3788
                return None
3789 3790
            if error_handling == 'strict':
                error_handling_node = null_node
3791 3792 3793 3794
        else:
            error_handling = 'strict'
            error_handling_node = null_node

3795
        return (encoding, encoding_node, error_handling, error_handling_node)
3796

3797 3798 3799 3800 3801 3802
    def _unpack_string_and_cstring_node(self, node):
        if isinstance(node, ExprNodes.CoerceToPyTypeNode):
            node = node.arg
        if isinstance(node, ExprNodes.UnicodeNode):
            encoding = node.value
            node = ExprNodes.BytesNode(
Stefan Behnel's avatar
Stefan Behnel committed
3803
                node.pos, value=encoding.as_utf8_string(), type=PyrexTypes.c_const_char_ptr_type)
3804 3805 3806
        elif isinstance(node, (ExprNodes.StringNode, ExprNodes.BytesNode)):
            encoding = node.value.decode('ISO-8859-1')
            node = ExprNodes.BytesNode(
3807
                node.pos, value=node.value, type=PyrexTypes.c_const_char_ptr_type)
3808 3809
        elif node.type is Builtin.bytes_type:
            encoding = None
3810
            node = node.coerce_to(PyrexTypes.c_const_char_ptr_type, self.current_env())
3811 3812 3813
        elif node.type.is_string:
            encoding = None
        else:
3814
            encoding = node = None
3815 3816
        return encoding, node

3817
    def _handle_simple_method_str_endswith(self, node, function, args, is_unbound_method):
3818
        return self._inject_tailmatch(
3819
            node, function, args, is_unbound_method, 'str', 'endswith',
3820
            str_tailmatch_utility_code, +1)
3821

3822
    def _handle_simple_method_str_startswith(self, node, function, args, is_unbound_method):
3823
        return self._inject_tailmatch(
3824
            node, function, args, is_unbound_method, 'str', 'startswith',
3825
            str_tailmatch_utility_code, -1)
3826

3827
    def _handle_simple_method_bytes_endswith(self, node, function, args, is_unbound_method):
3828
        return self._inject_tailmatch(
3829
            node, function, args, is_unbound_method, 'bytes', 'endswith',
3830
            bytes_tailmatch_utility_code, +1)
3831

3832
    def _handle_simple_method_bytes_startswith(self, node, function, args, is_unbound_method):
3833
        return self._inject_tailmatch(
3834
            node, function, args, is_unbound_method, 'bytes', 'startswith',
3835
            bytes_tailmatch_utility_code, -1)
3836

Stefan Behnel's avatar
Stefan Behnel committed
3837 3838 3839 3840 3841 3842 3843 3844 3845 3846 3847 3848
    '''   # disabled for now, enable when we consider it worth it (see StringTools.c)
    def _handle_simple_method_bytearray_endswith(self, node, function, args, is_unbound_method):
        return self._inject_tailmatch(
            node, function, args, is_unbound_method, 'bytearray', 'endswith',
            bytes_tailmatch_utility_code, +1)

    def _handle_simple_method_bytearray_startswith(self, node, function, args, is_unbound_method):
        return self._inject_tailmatch(
            node, function, args, is_unbound_method, 'bytearray', 'startswith',
            bytes_tailmatch_utility_code, -1)
    '''

3849 3850
    ### helpers

3851
    def _substitute_method_call(self, node, function, name, func_type,
3852
                                attr_name, is_unbound_method, args=(),
3853
                                utility_code=None, is_temp=None,
3854 3855
                                may_return_none=ExprNodes.PythonCapiCallNode.may_return_none,
                                with_none_check=True):
3856
        args = list(args)
3857 3858
        if with_none_check and args:
            args[0] = self._wrap_self_arg(args[0], function, is_unbound_method, attr_name)
3859 3860
        if is_temp is None:
            is_temp = node.is_temp
3861
        return ExprNodes.PythonCapiCallNode(
3862
            node.pos, name, func_type,
3863
            args = args,
3864
            is_temp = is_temp,
Stefan Behnel's avatar
Stefan Behnel committed
3865 3866
            utility_code = utility_code,
            may_return_none = may_return_none,
3867
            result_is_used = node.result_is_used,
3868 3869
            )

3870 3871 3872 3873 3874 3875 3876 3877 3878 3879 3880 3881 3882 3883
    def _wrap_self_arg(self, self_arg, function, is_unbound_method, attr_name):
        if self_arg.is_literal:
            return self_arg
        if is_unbound_method:
            self_arg = self_arg.as_none_safe_node(
                "descriptor '%s' requires a '%s' object but received a 'NoneType'",
                format_args=[attr_name, self_arg.type.name])
        else:
            self_arg = self_arg.as_none_safe_node(
                "'NoneType' object has no attribute '%{0}s'".format('.30' if len(attr_name) <= 30 else ''),
                error="PyExc_AttributeError",
                format_args=[attr_name])
        return self_arg

3884 3885 3886
    def _inject_int_default_argument(self, node, args, arg_index, type, default_value):
        assert len(args) >= arg_index
        if len(args) == arg_index:
3887 3888
            args.append(ExprNodes.IntNode(node.pos, value=str(default_value),
                                          type=type, constant_result=default_value))
3889
        else:
3890
            args[arg_index] = args[arg_index].coerce_to(type, self.current_env())
3891 3892 3893 3894

    def _inject_bint_default_argument(self, node, args, arg_index, default_value):
        assert len(args) >= arg_index
        if len(args) == arg_index:
3895 3896 3897
            default_value = bool(default_value)
            args.append(ExprNodes.BoolNode(node.pos, value=default_value,
                                           constant_result=default_value))
3898
        else:
3899
            args[arg_index] = args[arg_index].coerce_to_boolean(self.current_env())
3900

3901

3902 3903 3904
unicode_tailmatch_utility_code = UtilityCode.load_cached('unicode_tailmatch', 'StringTools.c')
bytes_tailmatch_utility_code = UtilityCode.load_cached('bytes_tailmatch', 'StringTools.c')
str_tailmatch_utility_code = UtilityCode.load_cached('str_tailmatch', 'StringTools.c')
3905

3906

3907 3908 3909 3910
class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
    """Calculate the result of constant expressions to store it in
    ``expr_node.constant_result``, and replace trivial cases by their
    constant result.
3911 3912 3913 3914 3915 3916 3917 3918 3919 3920 3921

    General rules:

    - We calculate float constants to make them available to the
      compiler, but we do not aggregate them into a single literal
      node to prevent any loss of precision.

    - We recursively calculate constants from non-literal nodes to
      make them available to the compiler, but we only aggregate
      literal nodes at each step.  Non-literal nodes are never merged
      into a single node.
3922
    """
3923

Mark Florisson's avatar
Mark Florisson committed
3924 3925 3926 3927 3928 3929 3930
    def __init__(self, reevaluate=False):
        """
        The reevaluate argument specifies whether constant values that were
        previously computed should be recomputed.
        """
        super(ConstantFolding, self).__init__()
        self.reevaluate = reevaluate
3931

3932
    def _calculate_const(self, node):
Mark Florisson's avatar
Mark Florisson committed
3933
        if (not self.reevaluate and
3934
                node.constant_result is not ExprNodes.constant_value_not_set):
3935 3936 3937 3938 3939 3940 3941 3942
            return

        # make sure we always set the value
        not_a_constant = ExprNodes.not_a_constant
        node.constant_result = not_a_constant

        # check if all children are constant
        children = self.visitchildren(node)
3943
        for child_result in children.values():
3944 3945
            if type(child_result) is list:
                for child in child_result:
Stefan Behnel's avatar
Stefan Behnel committed
3946
                    if getattr(child, 'constant_result', not_a_constant) is not_a_constant:
3947
                        return
Stefan Behnel's avatar
Stefan Behnel committed
3948
            elif getattr(child_result, 'constant_result', not_a_constant) is not_a_constant:
3949 3950 3951 3952 3953 3954 3955
                return

        # now try to calculate the real constant value
        try:
            node.calculate_constant_result()
#            if node.constant_result is not ExprNodes.not_a_constant:
#                print node.__class__.__name__, node.constant_result
3956
        except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
3957 3958 3959 3960 3961 3962 3963
            # ignore all 'normal' errors here => no constant result
            pass
        except Exception:
            # this looks like a real error
            import traceback, sys
            traceback.print_exc(file=sys.stdout)

3964 3965
    NODE_TYPE_ORDER = [ExprNodes.BoolNode, ExprNodes.CharNode,
                       ExprNodes.IntNode, ExprNodes.FloatNode]
3966 3967 3968 3969 3970 3971 3972 3973

    def _widest_node_class(self, *nodes):
        try:
            return self.NODE_TYPE_ORDER[
                max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))]
        except ValueError:
            return None

3974 3975 3976 3977
    def _bool_node(self, node, value):
        value = bool(value)
        return ExprNodes.BoolNode(node.pos, value=value, constant_result=value)

3978 3979 3980 3981
    def visit_ExprNode(self, node):
        self._calculate_const(node)
        return node

3982
    def visit_UnopNode(self, node):
3983
        self._calculate_const(node)
Stefan Behnel's avatar
Stefan Behnel committed
3984
        if not node.has_constant_result():
3985
            if node.operator == '!':
3986
                return self._handle_NotNode(node)
3987 3988 3989
            return node
        if not node.operand.is_literal:
            return node
Stefan Behnel's avatar
Stefan Behnel committed
3990
        if node.operator == '!':
3991
            return self._bool_node(node, node.constant_result)
3992
        elif isinstance(node.operand, ExprNodes.BoolNode):
Stefan Behnel's avatar
Stefan Behnel committed
3993 3994 3995
            return ExprNodes.IntNode(node.pos, value=str(int(node.constant_result)),
                                     type=PyrexTypes.c_int_type,
                                     constant_result=int(node.constant_result))
3996
        elif node.operator == '+':
3997 3998 3999 4000 4001
            return self._handle_UnaryPlusNode(node)
        elif node.operator == '-':
            return self._handle_UnaryMinusNode(node)
        return node

4002 4003 4004 4005 4006 4007 4008
    _negate_operator = {
        'in': 'not_in',
        'not_in': 'in',
        'is': 'is_not',
        'is_not': 'is'
    }.get

4009
    def _handle_NotNode(self, node):
4010 4011 4012 4013 4014 4015 4016
        operand = node.operand
        if isinstance(operand, ExprNodes.PrimaryCmpNode):
            operator = self._negate_operator(operand.operator)
            if operator:
                node = copy.copy(operand)
                node.operator = operator
                node = self.visit_PrimaryCmpNode(node)
4017 4018
        return node

4019
    def _handle_UnaryMinusNode(self, node):
4020 4021 4022 4023 4024 4025 4026
        def _negate(value):
            if value.startswith('-'):
                value = value[1:]
            else:
                value = '-' + value
            return value

4027
        node_type = node.operand.type
4028 4029
        if isinstance(node.operand, ExprNodes.FloatNode):
            # this is a safe operation
4030
            return ExprNodes.FloatNode(node.pos, value=_negate(node.operand.value),
4031
                                       type=node_type,
4032
                                       constant_result=node.constant_result)
4033
        if node_type.is_int and node_type.signed or \
4034 4035 4036 4037 4038
                isinstance(node.operand, ExprNodes.IntNode) and node_type.is_pyobject:
            return ExprNodes.IntNode(node.pos, value=_negate(node.operand.value),
                                     type=node_type,
                                     longness=node.operand.longness,
                                     constant_result=node.constant_result)
4039 4040
        return node

4041
    def _handle_UnaryPlusNode(self, node):
4042 4043
        if (node.operand.has_constant_result() and
                    node.constant_result == node.operand.constant_result):
4044 4045 4046
            return node.operand
        return node

4047 4048
    def visit_BoolBinopNode(self, node):
        self._calculate_const(node)
Stefan Behnel's avatar
Stefan Behnel committed
4049
        if not node.operand1.has_constant_result():
4050
            return node
Stefan Behnel's avatar
Stefan Behnel committed
4051
        if node.operand1.constant_result:
4052 4053 4054 4055
            if node.operator == 'and':
                return node.operand2
            else:
                return node.operand1
4056
        else:
4057 4058 4059 4060
            if node.operator == 'and':
                return node.operand1
            else:
                return node.operand2
4061

4062 4063 4064 4065
    def visit_BinopNode(self, node):
        self._calculate_const(node)
        if node.constant_result is ExprNodes.not_a_constant:
            return node
4066 4067
        if isinstance(node.constant_result, float):
            return node
4068 4069
        operand1, operand2 = node.operand1, node.operand2
        if not operand1.is_literal or not operand2.is_literal:
4070 4071 4072
            return node

        # now inject a new constant node with the calculated value
4073
        try:
4074
            type1, type2 = operand1.type, operand2.type
4075
            if type1 is None or type2 is None:
4076 4077 4078 4079
                return node
        except AttributeError:
            return node

4080
        if type1.is_numeric and type2.is_numeric:
4081
            widest_type = PyrexTypes.widest_numeric_type(type1, type2)
4082 4083
        else:
            widest_type = PyrexTypes.py_object_type
4084

4085
        target_class = self._widest_node_class(operand1, operand2)
4086 4087
        if target_class is None:
            return node
Stefan Behnel's avatar
Stefan Behnel committed
4088
        elif target_class is ExprNodes.BoolNode and node.operator in '+-//<<%**>>':
4089 4090
            # C arithmetic results in at least an int type
            target_class = ExprNodes.IntNode
Stefan Behnel's avatar
Stefan Behnel committed
4091
        elif target_class is ExprNodes.CharNode and node.operator in '+-//<<%**>>&|^':
4092 4093 4094 4095
            # C arithmetic results in at least an int type
            target_class = ExprNodes.IntNode

        if target_class is ExprNodes.IntNode:
4096 4097 4098 4099
            unsigned = getattr(operand1, 'unsigned', '') and \
                       getattr(operand2, 'unsigned', '')
            longness = "LL"[:max(len(getattr(operand1, 'longness', '')),
                                 len(getattr(operand2, 'longness', '')))]
4100
            new_node = ExprNodes.IntNode(pos=node.pos,
4101 4102 4103
                                         unsigned=unsigned, longness=longness,
                                         value=str(int(node.constant_result)),
                                         constant_result=int(node.constant_result))
4104 4105 4106 4107
            # IntNode is smart about the type it chooses, so we just
            # make sure we were not smarter this time
            if widest_type.is_pyobject or new_node.type.is_pyobject:
                new_node.type = PyrexTypes.py_object_type
4108
            else:
4109
                new_node.type = PyrexTypes.widest_numeric_type(widest_type, new_node.type)
4110
        else:
4111
            if target_class is ExprNodes.BoolNode:
4112 4113 4114 4115 4116 4117
                node_value = node.constant_result
            else:
                node_value = str(node.constant_result)
            new_node = target_class(pos=node.pos, type = widest_type,
                                    value = node_value,
                                    constant_result = node.constant_result)
4118 4119
        return new_node

4120 4121 4122 4123 4124 4125 4126 4127 4128 4129 4130
    def visit_AddNode(self, node):
        self._calculate_const(node)
        if node.constant_result is ExprNodes.not_a_constant:
            return node
        if node.operand1.is_string_literal and node.operand2.is_string_literal:
            # some people combine string literals with a '+'
            str1, str2 = node.operand1, node.operand2
            if isinstance(str1, ExprNodes.UnicodeNode) and isinstance(str2, ExprNodes.UnicodeNode):
                bytes_value = None
                if str1.bytes_value is not None and str2.bytes_value is not None:
                    if str1.bytes_value.encoding == str2.bytes_value.encoding:
4131 4132 4133
                        bytes_value = bytes_literal(
                            str1.bytes_value + str2.bytes_value,
                            str1.bytes_value.encoding)
4134 4135 4136 4137 4138
                string_value = EncodedString(node.constant_result)
                return ExprNodes.UnicodeNode(
                    str1.pos, value=string_value, constant_result=node.constant_result, bytes_value=bytes_value)
            elif isinstance(str1, ExprNodes.BytesNode) and isinstance(str2, ExprNodes.BytesNode):
                if str1.value.encoding == str2.value.encoding:
4139
                    bytes_value = bytes_literal(node.constant_result, str1.value.encoding)
4140 4141 4142 4143 4144
                    return ExprNodes.BytesNode(str1.pos, value=bytes_value, constant_result=node.constant_result)
            # all other combinations are rather complicated
            # to get right in Py2/3: encodings, unicode escapes, ...
        return self.visit_BinopNode(node)

4145
    def visit_MulNode(self, node):
4146
        self._calculate_const(node)
4147 4148
        if node.operand1.is_sequence_constructor:
            return self._calculate_constant_seq(node, node.operand1, node.operand2)
4149
        if isinstance(node.operand1, ExprNodes.IntNode) and \
4150 4151
                node.operand2.is_sequence_constructor:
            return self._calculate_constant_seq(node, node.operand2, node.operand1)
4152 4153
        return self.visit_BinopNode(node)

4154
    def _calculate_constant_seq(self, node, sequence_node, factor):
4155
        if factor.constant_result != 1 and sequence_node.args:
4156
            if isinstance(factor.constant_result, _py_int_types) and factor.constant_result <= 0:
4157
                del sequence_node.args[:]
4158 4159
                sequence_node.mult_factor = None
            elif sequence_node.mult_factor is not None:
4160 4161
                if (isinstance(factor.constant_result, _py_int_types) and
                        isinstance(sequence_node.mult_factor.constant_result, _py_int_types)):
4162 4163 4164 4165 4166 4167 4168 4169 4170
                    value = sequence_node.mult_factor.constant_result * factor.constant_result
                    sequence_node.mult_factor = ExprNodes.IntNode(
                        sequence_node.mult_factor.pos,
                        value=str(value), constant_result=value)
                else:
                    # don't know if we can combine the factors, so don't
                    return self.visit_BinopNode(node)
            else:
                sequence_node.mult_factor = factor
4171 4172
        return sequence_node

4173 4174
    def visit_FormattedValueNode(self, node):
        self.visitchildren(node)
4175
        conversion_char = node.conversion_char or 's'
4176 4177
        if isinstance(node.format_spec, ExprNodes.UnicodeNode) and not node.format_spec.value:
            node.format_spec = None
4178 4179 4180 4181 4182 4183 4184 4185 4186 4187 4188 4189
        if node.format_spec is None and isinstance(node.value, ExprNodes.IntNode):
            value = EncodedString(node.value.value)
            if value.isdigit():
                return ExprNodes.UnicodeNode(node.value.pos, value=value, constant_result=value)
        if node.format_spec is None and conversion_char == 's':
            value = None
            if isinstance(node.value, ExprNodes.UnicodeNode):
                value = node.value.value
            elif isinstance(node.value, ExprNodes.StringNode):
                value = node.value.unicode_value
            if value is not None:
                return ExprNodes.UnicodeNode(node.value.pos, value=value, constant_result=value)
4190 4191
        return node

4192 4193 4194 4195 4196 4197 4198 4199
    def visit_JoinedStrNode(self, node):
        """
        Clean up after the parser by discarding empty Unicode strings and merging
        substring sequences.  Empty or single-value join lists are not uncommon
        because f-string format specs are always parsed into JoinedStrNodes.
        """
        self.visitchildren(node)
        unicode_node = ExprNodes.UnicodeNode
4200

4201 4202 4203 4204 4205 4206
        values = []
        for is_unode_group, substrings in itertools.groupby(node.values, lambda v: isinstance(v, unicode_node)):
            if is_unode_group:
                substrings = list(substrings)
                unode = substrings[0]
                if len(substrings) > 1:
4207 4208
                    value = EncodedString(u''.join(value.value for value in substrings))
                    unode = ExprNodes.UnicodeNode(unode.pos, value=value, constant_result=value)
4209 4210 4211 4212
                # ignore empty Unicode strings
                if unode.value:
                    values.append(unode)
            else:
4213
                values.extend(substrings)
4214 4215

        if not values:
4216 4217
            value = EncodedString('')
            node = ExprNodes.UnicodeNode(node.pos, value=value, constant_result=value)
4218 4219
        elif len(values) == 1:
            node = values[0]
4220 4221 4222
        elif len(values) == 2:
            # reduce to string concatenation
            node = ExprNodes.binop_node(node.pos, '+', *values)
4223 4224 4225 4226
        else:
            node.values = values
        return node

4227 4228 4229 4230 4231 4232 4233 4234 4235 4236 4237 4238 4239 4240 4241 4242 4243 4244 4245 4246 4247 4248 4249 4250 4251 4252 4253 4254 4255 4256 4257 4258 4259 4260 4261 4262 4263 4264 4265 4266 4267 4268 4269 4270 4271 4272 4273 4274 4275 4276 4277 4278 4279 4280 4281 4282 4283 4284 4285 4286 4287 4288 4289 4290 4291 4292 4293 4294 4295 4296 4297 4298 4299 4300 4301 4302 4303 4304 4305 4306 4307 4308 4309 4310 4311 4312 4313
    def visit_MergedDictNode(self, node):
        """Unpack **args in place if we can."""
        self.visitchildren(node)
        args = []
        items = []

        def add(arg):
            if arg.is_dict_literal:
                if items:
                    items[0].key_value_pairs.extend(arg.key_value_pairs)
                else:
                    items.append(arg)
            elif isinstance(arg, ExprNodes.MergedDictNode):
                for child_arg in arg.keyword_args:
                    add(child_arg)
            else:
                if items:
                    args.append(items[0])
                    del items[:]
                args.append(arg)

        for arg in node.keyword_args:
            add(arg)
        if items:
            args.append(items[0])

        if len(args) == 1:
            arg = args[0]
            if arg.is_dict_literal or isinstance(arg, ExprNodes.MergedDictNode):
                return arg
        node.keyword_args[:] = args
        self._calculate_const(node)
        return node

    def visit_MergedSequenceNode(self, node):
        """Unpack *args in place if we can."""
        self.visitchildren(node)

        is_set = node.type is Builtin.set_type
        args = []
        values = []

        def add(arg):
            if (is_set and arg.is_set_literal) or (arg.is_sequence_constructor and not arg.mult_factor):
                if values:
                    values[0].args.extend(arg.args)
                else:
                    values.append(arg)
            elif isinstance(arg, ExprNodes.MergedSequenceNode):
                for child_arg in arg.args:
                    add(child_arg)
            else:
                if values:
                    args.append(values[0])
                    del values[:]
                args.append(arg)

        for arg in node.args:
            add(arg)
        if values:
            args.append(values[0])

        if len(args) == 1:
            arg = args[0]
            if ((is_set and arg.is_set_literal) or
                    (arg.is_sequence_constructor and arg.type is node.type) or
                    isinstance(arg, ExprNodes.MergedSequenceNode)):
                return arg
        node.args[:] = args
        self._calculate_const(node)
        return node

    def visit_SequenceNode(self, node):
        """Unpack *args in place if we can."""
        self.visitchildren(node)
        args = []
        for arg in node.args:
            if not arg.is_starred:
                args.append(arg)
            elif arg.target.is_sequence_constructor and not arg.target.mult_factor:
                args.extend(arg.target.args)
            else:
                args.append(arg)
        node.args[:] = args
        self._calculate_const(node)
        return node

4314
    def visit_PrimaryCmpNode(self, node):
4315
        # calculate constant partial results in the comparison cascade
4316
        self.visitchildren(node, ['operand1'])
4317 4318 4319
        left_node = node.operand1
        cmp_node = node
        while cmp_node is not None:
4320
            self.visitchildren(cmp_node, ['operand2'])
4321 4322 4323 4324 4325 4326 4327 4328 4329 4330
            right_node = cmp_node.operand2
            cmp_node.constant_result = not_a_constant
            if left_node.has_constant_result() and right_node.has_constant_result():
                try:
                    cmp_node.calculate_cascaded_constant_result(left_node.constant_result)
                except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
                    pass  # ignore all 'normal' errors here => no constant result
            left_node = right_node
            cmp_node = cmp_node.cascade

4331 4332 4333 4334 4335
        if not node.cascade:
            if node.has_constant_result():
                return self._bool_node(node, node.constant_result)
            return node

4336 4337
        # collect partial cascades: [[value, CmpNode...], [value, CmpNode, ...], ...]
        cascades = [[node.operand1]]
4338
        final_false_result = []
4339 4340 4341 4342 4343

        def split_cascades(cmp_node):
            if cmp_node.has_constant_result():
                if not cmp_node.constant_result:
                    # False => short-circuit
4344
                    final_false_result.append(self._bool_node(cmp_node, False))
4345 4346 4347 4348 4349 4350 4351 4352 4353 4354 4355 4356 4357 4358 4359 4360 4361 4362 4363 4364 4365 4366 4367 4368 4369 4370 4371 4372 4373 4374 4375
                    return
                else:
                    # True => discard and start new cascade
                    cascades.append([cmp_node.operand2])
            else:
                # not constant => append to current cascade
                cascades[-1].append(cmp_node)
            if cmp_node.cascade:
                split_cascades(cmp_node.cascade)

        split_cascades(node)

        cmp_nodes = []
        for cascade in cascades:
            if len(cascade) < 2:
                continue
            cmp_node = cascade[1]
            pcmp_node = ExprNodes.PrimaryCmpNode(
                cmp_node.pos,
                operand1=cascade[0],
                operator=cmp_node.operator,
                operand2=cmp_node.operand2,
                constant_result=not_a_constant)
            cmp_nodes.append(pcmp_node)

            last_cmp_node = pcmp_node
            for cmp_node in cascade[2:]:
                last_cmp_node.cascade = cmp_node
                last_cmp_node = cmp_node
            last_cmp_node.cascade = None

4376
        if final_false_result:
4377
            # last cascade was constant False
4378
            cmp_nodes.append(final_false_result[0])
4379
        elif not cmp_nodes:
4380 4381 4382 4383 4384 4385 4386 4387 4388 4389 4390 4391 4392 4393
            # only constants, but no False result
            return self._bool_node(node, True)
        node = cmp_nodes[0]
        if len(cmp_nodes) == 1:
            if node.has_constant_result():
                return self._bool_node(node, node.constant_result)
        else:
            for cmp_node in cmp_nodes[1:]:
                node = ExprNodes.BoolBinopNode(
                    node.pos,
                    operand1=node,
                    operator='and',
                    operand2=cmp_node,
                    constant_result=not_a_constant)
4394
        return node
4395

4396 4397
    def visit_CondExprNode(self, node):
        self._calculate_const(node)
4398
        if not node.test.has_constant_result():
4399 4400 4401 4402 4403 4404
            return node
        if node.test.constant_result:
            return node.true_val
        else:
            return node.false_val

4405 4406 4407 4408 4409
    def visit_IfStatNode(self, node):
        self.visitchildren(node)
        # eliminate dead code based on constant condition results
        if_clauses = []
        for if_clause in node.if_clauses:
Stefan Behnel's avatar
Stefan Behnel committed
4410 4411 4412 4413 4414 4415 4416 4417
            condition = if_clause.condition
            if condition.has_constant_result():
                if condition.constant_result:
                    # always true => subsequent clauses can safely be dropped
                    node.else_clause = if_clause.body
                    break
                # else: false => drop clause
            else:
4418
                # unknown result => normal runtime evaluation
4419
                if_clauses.append(if_clause)
4420 4421 4422 4423
        if if_clauses:
            node.if_clauses = if_clauses
            return node
        elif node.else_clause:
4424
            return node.else_clause
4425 4426
        else:
            return Nodes.StatListNode(node.pos, stats=[])
4427

4428 4429 4430
    def visit_SliceIndexNode(self, node):
        self._calculate_const(node)
        # normalise start/stop values
4431 4432 4433 4434 4435 4436 4437 4438 4439 4440 4441
        if node.start is None or node.start.constant_result is None:
            start = node.start = None
        else:
            start = node.start.constant_result
        if node.stop is None or node.stop.constant_result is None:
            stop = node.stop = None
        else:
            stop = node.stop.constant_result
        # cut down sliced constant sequences
        if node.constant_result is not not_a_constant:
            base = node.base
4442
            if base.is_sequence_constructor and base.mult_factor is None:
4443 4444 4445
                base.args = base.args[start:stop]
                return base
            elif base.is_string_literal:
4446 4447 4448
                base = base.as_sliced_node(start, stop)
                if base is not None:
                    return base
4449 4450
        return node

4451 4452 4453 4454
    def visit_ComprehensionNode(self, node):
        self.visitchildren(node)
        if isinstance(node.loop, Nodes.StatListNode) and not node.loop.stats:
            # loop was pruned already => transform into literal
4455
            if node.type is Builtin.list_type:
4456 4457
                return ExprNodes.ListNode(
                    node.pos, args=[], constant_result=[])
4458
            elif node.type is Builtin.set_type:
4459 4460
                return ExprNodes.SetNode(
                    node.pos, args=[], constant_result=set())
4461
            elif node.type is Builtin.dict_type:
4462 4463
                return ExprNodes.DictNode(
                    node.pos, key_value_pairs=[], constant_result={})
4464 4465
        return node

4466 4467 4468
    def visit_ForInStatNode(self, node):
        self.visitchildren(node)
        sequence = node.iterator.sequence
4469 4470 4471 4472 4473 4474 4475
        if isinstance(sequence, ExprNodes.SequenceNode):
            if not sequence.args:
                if node.else_clause:
                    return node.else_clause
                else:
                    # don't break list comprehensions
                    return Nodes.StatListNode(node.pos, stats=[])
Stefan Behnel's avatar
Stefan Behnel committed
4476 4477 4478
            # iterating over a list literal? => tuples are more efficient
            if isinstance(sequence, ExprNodes.ListNode):
                node.iterator.sequence = sequence.as_tuple()
4479 4480
        return node

4481 4482
    def visit_WhileStatNode(self, node):
        self.visitchildren(node)
Stefan Behnel's avatar
Stefan Behnel committed
4483
        if node.condition and node.condition.has_constant_result():
4484
            if node.condition.constant_result:
4485
                node.condition = None
4486 4487 4488 4489 4490
                node.else_clause = None
            else:
                return node.else_clause
        return node

4491 4492
    def visit_ExprStatNode(self, node):
        self.visitchildren(node)
Stefan Behnel's avatar
Stefan Behnel committed
4493 4494 4495
        if not isinstance(node.expr, ExprNodes.ExprNode):
            # ParallelRangeTransform does this ...
            return node
4496 4497 4498 4499 4500
        # drop unused constant expressions
        if node.expr.has_constant_result():
            return None
        return node

4501 4502
    # in the future, other nodes can have their own handler method here
    # that can replace them with a constant result node
Stefan Behnel's avatar
Stefan Behnel committed
4503

4504
    visit_Node = Visitor.VisitorTransform.recurse_to_children
4505 4506


4507
class FinalOptimizePhase(Visitor.CythonTransform, Visitor.NodeRefCleanupMixin):
4508 4509
    """
    This visitor handles several commuting optimizations, and is run
4510 4511 4512 4513
    just before the C code generation phase.

    The optimizations currently implemented in this class are:
        - eliminate None assignment and refcounting for first assignment.
4514
        - isinstance -> typecheck for cdef types
Stefan Behnel's avatar
Stefan Behnel committed
4515
        - eliminate checks for None and/or types that became redundant after tree changes
4516
        - replace Python function calls that look like method calls by a faster PyMethodCallNode
4517
    """
4518
    def visit_SingleAssignmentNode(self, node):
4519 4520 4521 4522
        """Avoid redundant initialisation of local variables before their
        first assignment.
        """
        self.visitchildren(node)
4523 4524
        if node.first:
            lhs = node.lhs
4525
            lhs.lhs_of_first_assignment = True
4526
        return node
4527

4528
    def visit_SimpleCallNode(self, node):
4529 4530 4531
        """
        Replace generic calls to isinstance(x, type) by a more efficient type check.
        Replace likely Python method calls by a specialised PyMethodCallNode.
4532
        """
4533
        self.visitchildren(node)
4534 4535 4536
        function = node.function
        if function.type.is_cfunction and function.is_name:
            if function.name == 'isinstance' and len(node.args) == 2:
4537 4538
                type_arg = node.args[1]
                if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
4539
                    cython_scope = self.context.cython_scope
4540 4541
                    function.entry = cython_scope.lookup('PyObject_TypeCheck')
                    function.type = function.entry.type
4542
                    PyTypeObjectPtr = PyrexTypes.CPtrType(cython_scope.lookup('PyTypeObject').type)
4543
                    node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
4544 4545
        elif (self.current_directives.get("optimize.unpack_method_calls")
                and node.is_temp and function.type.is_pyobject):
4546 4547 4548 4549 4550
            # optimise simple Python methods calls
            if isinstance(node.arg_tuple, ExprNodes.TupleNode) and not (
                    node.arg_tuple.mult_factor or (node.arg_tuple.is_literal and node.arg_tuple.args)):
                # simple call, now exclude calls to objects that are definitely not methods
                may_be_a_method = True
4551 4552
                if function.type is Builtin.type_type:
                    may_be_a_method = False
4553
                elif function.is_attribute:
Stefan Behnel's avatar
Stefan Behnel committed
4554
                    if function.entry and function.entry.type.is_cfunction:
4555 4556
                        # optimised builtin method
                        may_be_a_method = False
4557
                elif function.is_name:
4558 4559
                    entry = function.entry
                    if entry.is_builtin or entry.type.is_cfunction:
4560
                        may_be_a_method = False
4561
                    elif entry.cf_assignments:
4562 4563
                        # local functions/classes are definitely not methods
                        non_method_nodes = (ExprNodes.PyCFunctionNode, ExprNodes.ClassNode, ExprNodes.Py3ClassNode)
4564
                        may_be_a_method = any(
4565
                            assignment.rhs and not isinstance(assignment.rhs, non_method_nodes)
4566
                            for assignment in entry.cf_assignments)
4567
                if may_be_a_method:
4568 4569 4570 4571
                    if (node.self and function.is_attribute and
                            isinstance(function.obj, ExprNodes.CloneNode) and function.obj.arg is node.self):
                        # function self object was moved into a CloneNode => undo
                        function.obj = function.obj.arg
4572 4573
                    node = self.replace(node, ExprNodes.PyMethodCallNode.from_node(
                        node, function=function, arg_tuple=node.arg_tuple, type=node.type))
4574
        return node
Stefan Behnel's avatar
Stefan Behnel committed
4575 4576 4577 4578 4579 4580 4581 4582 4583 4584 4585

    def visit_PyTypeTestNode(self, node):
        """Remove tests for alternatively allowed None values from
        type tests when we know that the argument cannot be None
        anyway.
        """
        self.visitchildren(node)
        if not node.notnone:
            if not node.arg.may_be_none():
                node.notnone = True
        return node
4586 4587 4588 4589 4590 4591 4592 4593 4594

    def visit_NoneCheckNode(self, node):
        """Remove None checks from expressions that definitely do not
        carry a None value.
        """
        self.visitchildren(node)
        if not node.arg.may_be_none():
            return node.arg
        return node
4595 4596 4597 4598 4599 4600 4601 4602 4603

class ConsolidateOverflowCheck(Visitor.CythonTransform):
    """
    This class facilitates the sharing of overflow checking among all nodes
    of a nested arithmetic expression.  For example, given the expression
    a*b + c, where a, b, and x are all possibly overflowing ints, the entire
    sequence will be evaluated and the overflow bit checked only at the end.
    """
    overflow_bit_node = None
4604

4605 4606 4607 4608 4609 4610 4611 4612 4613
    def visit_Node(self, node):
        if self.overflow_bit_node is not None:
            saved = self.overflow_bit_node
            self.overflow_bit_node = None
            self.visitchildren(node)
            self.overflow_bit_node = saved
        else:
            self.visitchildren(node)
        return node
4614

4615
    def visit_NumBinopNode(self, node):
4616
        if node.overflow_check and node.overflow_fold:
4617 4618 4619 4620 4621 4622 4623 4624 4625 4626 4627 4628
            top_level_overflow = self.overflow_bit_node is None
            if top_level_overflow:
                self.overflow_bit_node = node
            else:
                node.overflow_bit_node = self.overflow_bit_node
                node.overflow_check = False
            self.visitchildren(node)
            if top_level_overflow:
                self.overflow_bit_node = None
        else:
            self.visitchildren(node)
        return node