Optimize.py 77.9 KB
Newer Older
1 2
import Nodes
import ExprNodes
3
import PyrexTypes
4
import Visitor
5 6 7 8
import Builtin
import UtilNodes
import TypeSlots
import Symtab
9
import Options
10
import Naming
11

12
from Code import UtilityCode
13
from StringEncoding import EncodedString, BytesLiteral
14
from Errors import error
15 16
from ParseTreeTransforms import SkipDeclarations

17 18
import codecs

19 20 21 22 23
try:
    reduce
except NameError:
    from functools import reduce

24 25 26 27 28
try:
    set
except NameError:
    from sets import Set as set

29 30 31 32
class FakePythonEnv(object):
    "A fake environment for creating type test nodes etc."
    nogil = False

33
def unwrap_node(node):
34 35
    while isinstance(node, UtilNodes.ResultRefNode):
        node = node.expression
36
    return node
37 38

def is_common_value(a, b):
39 40
    a = unwrap_node(a)
    b = unwrap_node(b)
41 42 43
    if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
        return a.name == b.name
    if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
44
        return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
45 46
    return False

47 48 49 50
class IterationTransform(Visitor.VisitorTransform):
    """Transform some common for-in loop patterns into efficient C loops:

    - for-in-dict loop becomes a while loop calling PyDict_Next()
Stefan Behnel's avatar
Stefan Behnel committed
51
    - for-in-enumerate is replaced by an external counter variable
52
    - for-in-range loop becomes a plain C for loop
53 54 55 56 57 58 59 60 61 62 63 64 65 66
    """
    PyDict_Next_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_bint_type, [
            PyrexTypes.CFuncTypeArg("dict",  PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("pos",   PyrexTypes.c_py_ssize_t_ptr_type, None),
            PyrexTypes.CFuncTypeArg("key",   PyrexTypes.CPtrType(PyrexTypes.py_object_type), None),
            PyrexTypes.CFuncTypeArg("value", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None)
            ])

    PyDict_Next_name = EncodedString("PyDict_Next")

    PyDict_Next_entry = Symtab.Entry(
        PyDict_Next_name, PyDict_Next_name, PyDict_Next_func_type)

67
    visit_Node = Visitor.VisitorTransform.recurse_to_children
Stefan Behnel's avatar
Stefan Behnel committed
68

69 70 71 72 73 74 75 76 77 78 79 80
    def visit_ModuleNode(self, node):
        self.current_scope = node.scope
        self.visitchildren(node)
        return node

    def visit_DefNode(self, node):
        oldscope = self.current_scope
        self.current_scope = node.entry.scope
        self.visitchildren(node)
        self.current_scope = oldscope
        return node

81 82
    def visit_ForInStatNode(self, node):
        self.visitchildren(node)
83 84 85
        return self._optimise_for_loop(node)

    def _optimise_for_loop(self, node):
86
        iterator = node.iterator.sequence
87 88
        if iterator.type is Builtin.dict_type:
            # like iterating over dict.keys()
Stefan Behnel's avatar
Stefan Behnel committed
89 90
            return self._transform_dict_iteration(
                node, dict_obj=iterator, keys=True, values=False)
91

92
        # C array (slice) iteration?
93 94 95
        if isinstance(iterator, ExprNodes.SliceIndexNode) and \
               (iterator.base.type.is_array or iterator.base.type.is_ptr):
            return self._transform_carray_iteration(node, iterator)
96 97
        elif iterator.type.is_array:
            return self._transform_carray_iteration(node, iterator)
98
        elif not isinstance(iterator, ExprNodes.SimpleCallNode):
Stefan Behnel's avatar
Stefan Behnel committed
99 100 101
            return node

        function = iterator.function
102
        # dict iteration?
Stefan Behnel's avatar
Stefan Behnel committed
103 104
        if isinstance(function, ExprNodes.AttributeNode) and \
                function.obj.type == Builtin.dict_type:
105 106 107 108 109 110 111 112 113 114 115 116
            dict_obj = function.obj
            method = function.attribute

            keys = values = False
            if method == 'iterkeys':
                keys = True
            elif method == 'itervalues':
                values = True
            elif method == 'iteritems':
                keys = values = True
            else:
                return node
Stefan Behnel's avatar
Stefan Behnel committed
117 118
            return self._transform_dict_iteration(
                node, dict_obj, keys, values)
119

120
        # enumerate() ?
Stefan Behnel's avatar
Stefan Behnel committed
121
        if iterator.self is None and function.is_name and \
122
               function.entry and function.entry.is_builtin and \
123 124 125
               function.name == 'enumerate':
            return self._transform_enumerate_iteration(node, iterator)

126 127
        # range() iteration?
        if Options.convert_range and node.target.type.is_int:
Stefan Behnel's avatar
Stefan Behnel committed
128 129 130
            if iterator.self is None and function.is_name and \
                   function.entry and function.entry.is_builtin and \
                   function.name in ('range', 'xrange'):
Stefan Behnel's avatar
Stefan Behnel committed
131
                return self._transform_range_iteration(node, iterator)
132

Stefan Behnel's avatar
Stefan Behnel committed
133
        return node
134

135
    def _transform_carray_iteration(self, node, slice_node):
136 137 138 139 140 141 142 143 144 145 146 147 148 149
        if isinstance(slice_node, ExprNodes.SliceIndexNode):
            slice_base = slice_node.base
            start = slice_node.start
            stop = slice_node.stop
            step = None
            if not stop:
                return node
        elif slice_node.type.is_array and slice_node.type.size is not None:
            slice_base = slice_node
            start = None
            stop = ExprNodes.IntNode(
                slice_node.pos, value=str(slice_node.type.size))
            step = None
        else:
150 151
            return node

152 153 154 155
        ptr_type = slice_base.type
        if ptr_type.is_array:
            ptr_type = ptr_type.element_ptr_type()
        carray_ptr = slice_base.coerce_to_simple(self.current_scope)
156

157
        if start and start.constant_result != 0:
158 159 160 161 162
            start_ptr_node = ExprNodes.AddNode(
                start.pos,
                operand1=carray_ptr,
                operator='+',
                operand2=start,
163
                type=ptr_type)
164
        else:
165
            start_ptr_node = carray_ptr
166

167 168 169 170 171
        stop_ptr_node = ExprNodes.AddNode(
            stop.pos,
            operand1=carray_ptr,
            operator='+',
            operand2=stop,
172
            type=ptr_type
173
            ).coerce_to_simple(self.current_scope)
174

175
        counter = UtilNodes.TempHandle(ptr_type)
176 177
        counter_temp = counter.ref(node.target.pos)

178
        if slice_base.type.is_string and node.target.type.is_pyobject:
179
            # special case: char* -> bytes
180 181
            target_value = ExprNodes.SliceIndexNode(
                node.target.pos,
182 183 184 185 186 187 188
                start=ExprNodes.IntNode(node.target.pos, value='0',
                                        constant_result=0,
                                        type=PyrexTypes.c_int_type),
                stop=ExprNodes.IntNode(node.target.pos, value='1',
                                       constant_result=1,
                                       type=PyrexTypes.c_int_type),
                base=counter_temp,
189 190 191 192 193
                type=Builtin.bytes_type,
                is_temp=1)
        else:
            target_value = ExprNodes.IndexNode(
                node.target.pos,
194 195 196 197
                index=ExprNodes.IntNode(node.target.pos, value='0',
                                        constant_result=0,
                                        type=PyrexTypes.c_int_type),
                base=counter_temp,
198
                is_buffer_access=False,
199
                type=ptr_type.base_type)
200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215

        if target_value.type != node.target.type:
            target_value = target_value.coerce_to(node.target.type,
                                                  self.current_scope)

        target_assign = Nodes.SingleAssignmentNode(
            pos = node.target.pos,
            lhs = node.target,
            rhs = target_value)

        body = Nodes.StatListNode(
            node.pos,
            stats = [target_assign, node.body])

        for_node = Nodes.ForFromStatNode(
            node.pos,
216
            bound1=start_ptr_node, relation1='<=',
217
            target=counter_temp,
218
            relation2='<', bound2=stop_ptr_node,
219 220 221 222 223 224 225 226
            step=step, body=body,
            else_clause=node.else_clause,
            from_range=True)

        return UtilNodes.TempsBlockNode(
            node.pos, temps=[counter],
            body=for_node)

227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
    def _transform_enumerate_iteration(self, node, enumerate_function):
        args = enumerate_function.arg_tuple.args
        if len(args) == 0:
            error(enumerate_function.pos,
                  "enumerate() requires an iterable argument")
            return node
        elif len(args) > 1:
            error(enumerate_function.pos,
                  "enumerate() takes at most 1 argument")
            return node

        if not node.target.is_sequence_constructor:
            # leave this untouched for now
            return node
        targets = node.target.args
        if len(targets) != 2:
            # leave this untouched for now
            return node
        if not isinstance(targets[0], ExprNodes.NameNode):
            # leave this untouched for now
            return node

        enumerate_target, iterable_target = targets
        counter_type = enumerate_target.type

        if not counter_type.is_pyobject and not counter_type.is_int:
            # nothing we can do here, I guess
            return node

256 257 258 259
        temp = UtilNodes.LetRefNode(ExprNodes.IntNode(enumerate_function.pos,
                                                      value='0',
                                                      type=counter_type,
                                                      constant_result=0))
260 261
        inc_expression = ExprNodes.AddNode(
            enumerate_function.pos,
262
            operand1 = temp,
263
            operand2 = ExprNodes.IntNode(node.pos, value='1',
264 265
                                         type=counter_type,
                                         constant_result=1),
266 267 268 269 270
            operator = '+',
            type = counter_type,
            is_temp = counter_type.is_pyobject
            )

271 272 273 274
        loop_body = [
            Nodes.SingleAssignmentNode(
                pos = enumerate_target.pos,
                lhs = enumerate_target,
275
                rhs = temp),
276 277
            Nodes.SingleAssignmentNode(
                pos = enumerate_target.pos,
278
                lhs = temp,
279 280
                rhs = inc_expression)
            ]
281

282 283 284 285 286 287 288
        if isinstance(node.body, Nodes.StatListNode):
            node.body.stats = loop_body + node.body.stats
        else:
            loop_body.append(node.body)
            node.body = Nodes.StatListNode(
                node.body.pos,
                stats = loop_body)
289 290

        node.target = iterable_target
291
        node.item = node.item.coerce_to(iterable_target.type, self.current_scope)
292 293 294
        node.iterator.sequence = enumerate_function.arg_tuple.args[0]

        # recurse into loop to check for further optimisations
295
        return UtilNodes.LetNode(temp, self._optimise_for_loop(node))
296

297 298 299 300 301
    def _transform_range_iteration(self, node, range_function):
        args = range_function.arg_tuple.args
        if len(args) < 3:
            step_pos = range_function.pos
            step_value = 1
302 303
            step = ExprNodes.IntNode(step_pos, value='1',
                                     constant_result=1)
304 305 306
        else:
            step = args[2]
            step_pos = step.pos
307
            if not isinstance(step.constant_result, (int, long)):
308 309
                # cannot determine step direction
                return node
310 311 312
            step_value = step.constant_result
            if step_value == 0:
                # will lead to an error elsewhere
313 314
                return node
            if not isinstance(step, ExprNodes.IntNode):
315 316
                step = ExprNodes.IntNode(step_pos, value=str(step_value),
                                         constant_result=step_value)
317

318
        if step_value < 0:
319
            step.value = str(-step_value)
320 321 322
            relation1 = '>='
            relation2 = '>'
        else:
323 324
            relation1 = '<='
            relation2 = '<'
325 326

        if len(args) == 1:
327 328
            bound1 = ExprNodes.IntNode(range_function.pos, value='0',
                                       constant_result=0)
329
            bound2 = args[0].coerce_to_integer(self.current_scope)
330
        else:
331 332 333
            bound1 = args[0].coerce_to_integer(self.current_scope)
            bound2 = args[1].coerce_to_integer(self.current_scope)
        step = step.coerce_to_integer(self.current_scope)
334

335
        if not bound2.is_literal:
336 337 338 339 340 341
            # stop bound must be immutable => keep it in a temp var
            bound2_is_temp = True
            bound2 = UtilNodes.LetRefNode(bound2)
        else:
            bound2_is_temp = False

342 343 344 345 346 347 348
        for_node = Nodes.ForFromStatNode(
            node.pos,
            target=node.target,
            bound1=bound1, relation1=relation1,
            relation2=relation2, bound2=bound2,
            step=step, body=node.body,
            else_clause=node.else_clause,
Magnus Lie Hetland's avatar
Magnus Lie Hetland committed
349
            from_range=True)
350 351 352 353

        if bound2_is_temp:
            for_node = UtilNodes.LetNode(bound2, for_node)

354 355
        return for_node

Stefan Behnel's avatar
Stefan Behnel committed
356
    def _transform_dict_iteration(self, node, dict_obj, keys, values):
357 358 359
        py_object_ptr = PyrexTypes.c_void_ptr_type

        temps = []
360 361 362
        temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
        temps.append(temp)
        dict_temp = temp.ref(dict_obj.pos)
363 364
        temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
        temps.append(temp)
365
        pos_temp = temp.ref(node.pos)
366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396
        pos_temp_addr = ExprNodes.AmpersandNode(
            node.pos, operand=pos_temp,
            type=PyrexTypes.c_ptr_type(PyrexTypes.c_py_ssize_t_type))
        if keys:
            temp = UtilNodes.TempHandle(py_object_ptr)
            temps.append(temp)
            key_temp = temp.ref(node.target.pos)
            key_temp_addr = ExprNodes.AmpersandNode(
                node.target.pos, operand=key_temp,
                type=PyrexTypes.c_ptr_type(py_object_ptr))
        else:
            key_temp_addr = key_temp = ExprNodes.NullNode(
                pos=node.target.pos)
        if values:
            temp = UtilNodes.TempHandle(py_object_ptr)
            temps.append(temp)
            value_temp = temp.ref(node.target.pos)
            value_temp_addr = ExprNodes.AmpersandNode(
                node.target.pos, operand=value_temp,
                type=PyrexTypes.c_ptr_type(py_object_ptr))
        else:
            value_temp_addr = value_temp = ExprNodes.NullNode(
                pos=node.target.pos)

        key_target = value_target = node.target
        tuple_target = None
        if keys and values:
            if node.target.is_sequence_constructor:
                if len(node.target.args) == 2:
                    key_target, value_target = node.target.args
                else:
Stefan Behnel's avatar
Stefan Behnel committed
397
                    # unusual case that may or may not lead to an error
398 399 400 401
                    return node
            else:
                tuple_target = node.target

402 403
        def coerce_object_to(obj_node, dest_type):
            if dest_type.is_pyobject:
404 405 406
                if dest_type != obj_node.type:
                    if dest_type.is_extension_type or dest_type.is_builtin_type:
                        obj_node = ExprNodes.PyTypeTestNode(
407
                            obj_node, dest_type, self.current_scope, notnone=True)
408 409 410 411
                result = ExprNodes.TypecastNode(
                    obj_node.pos,
                    operand = obj_node,
                    type = dest_type)
412
                return (result, None)
413 414 415 416 417 418 419 420 421
            else:
                temp = UtilNodes.TempHandle(dest_type)
                temps.append(temp)
                temp_result = temp.ref(obj_node.pos)
                class CoercedTempNode(ExprNodes.CoerceFromPyTypeNode):
                    def result(self):
                        return temp_result.result()
                    def generate_execution_code(self, code):
                        self.generate_result_code(code)
422
                return (temp_result, CoercedTempNode(dest_type, obj_node, self.current_scope))
423 424 425 426 427 428 429 430

        if isinstance(node.body, Nodes.StatListNode):
            body = node.body
        else:
            body = Nodes.StatListNode(pos = node.body.pos,
                                      stats = [node.body])

        if tuple_target:
431
            tuple_result = ExprNodes.TupleNode(
432
                pos = tuple_target.pos,
433
                args = [key_temp, value_temp],
434 435
                is_temp = 1,
                type = Builtin.tuple_type,
436
                )
437
            body.stats.insert(
438 439 440 441
                0, Nodes.SingleAssignmentNode(
                    pos = tuple_target.pos,
                    lhs = tuple_target,
                    rhs = tuple_result))
442
        else:
443 444 445
            # execute all coercions before the assignments
            coercion_stats = []
            assign_stats = []
446
            if keys:
447 448 449 450 451 452 453
                temp_result, coercion = coerce_object_to(
                    key_temp, key_target.type)
                if coercion:
                    coercion_stats.append(coercion)
                assign_stats.append(
                    Nodes.SingleAssignmentNode(
                        pos = key_temp.pos,
454 455
                        lhs = key_target,
                        rhs = temp_result))
456 457 458 459 460 461 462 463
            if values:
                temp_result, coercion = coerce_object_to(
                    value_temp, value_target.type)
                if coercion:
                    coercion_stats.append(coercion)
                assign_stats.append(
                    Nodes.SingleAssignmentNode(
                        pos = value_temp.pos,
464 465
                        lhs = value_target,
                        rhs = temp_result))
466
            body.stats[0:0] = coercion_stats + assign_stats
467 468

        result_code = [
469 470 471 472
            Nodes.SingleAssignmentNode(
                pos = dict_obj.pos,
                lhs = dict_temp,
                rhs = dict_obj),
473 474 475
            Nodes.SingleAssignmentNode(
                pos = node.pos,
                lhs = pos_temp,
476 477
                rhs = ExprNodes.IntNode(node.pos, value='0',
                                        constant_result=0)),
478 479 480 481 482 483
            Nodes.WhileStatNode(
                pos = node.pos,
                condition = ExprNodes.SimpleCallNode(
                    pos = dict_obj.pos,
                    type = PyrexTypes.c_bint_type,
                    function = ExprNodes.NameNode(
Stefan Behnel's avatar
Stefan Behnel committed
484 485
                        pos = dict_obj.pos,
                        name = self.PyDict_Next_name,
486 487
                        type = self.PyDict_Next_func_type,
                        entry = self.PyDict_Next_entry),
488
                    args = [dict_temp, pos_temp_addr,
489 490 491 492 493 494 495 496 497 498
                            key_temp_addr, value_temp_addr]
                    ),
                body = body,
                else_clause = node.else_clause
                )
            ]

        return UtilNodes.TempsBlockNode(
            node.pos, temps=temps,
            body=Nodes.StatListNode(
499
                node.pos,
500 501 502 503
                stats = result_code
                ))


504 505 506 507
class SwitchTransform(Visitor.VisitorTransform):
    """
    This transformation tries to turn long if statements into C switch statements. 
    The requirement is that every clause be an (or of) var == value, where the var
Robert Bradshaw's avatar
Robert Bradshaw committed
508
    is common among all clauses and both var and value are ints. 
509
    """
510
    def extract_conditions(self, cond):
511 512 513 514 515 516 517 518 519 520
        while True:
            if isinstance(cond, ExprNodes.CoerceToTempNode):
                cond = cond.arg
            elif isinstance(cond, UtilNodes.EvalWithTempExprNode):
                # this is what we get from the FlattenInListTransform
                cond = cond.subexpression
            elif isinstance(cond, ExprNodes.TypecastNode):
                cond = cond.operand
            else:
                break
521

522 523 524 525 526
        if (isinstance(cond, ExprNodes.PrimaryCmpNode) 
                and cond.cascade is None 
                and cond.operator == '=='
                and not cond.is_python_comparison()):
            if is_common_value(cond.operand1, cond.operand1):
527
                if cond.operand2.is_literal:
528
                    return cond.operand1, [cond.operand2]
529
                elif hasattr(cond.operand2, 'entry') and cond.operand2.entry and cond.operand2.entry.is_const:
530 531
                    return cond.operand1, [cond.operand2]
            if is_common_value(cond.operand2, cond.operand2):
532
                if cond.operand1.is_literal:
533
                    return cond.operand2, [cond.operand1]
534
                elif hasattr(cond.operand1, 'entry') and cond.operand1.entry and cond.operand1.entry.is_const:
535 536 537 538 539 540 541 542 543 544
                    return cond.operand2, [cond.operand1]
        elif (isinstance(cond, ExprNodes.BoolBinopNode) 
                and cond.operator == 'or'):
            t1, c1 = self.extract_conditions(cond.operand1)
            t2, c2 = self.extract_conditions(cond.operand2)
            if is_common_value(t1, t2):
                return t1, c1+c2
        return None, None
        
    def visit_IfStatNode(self, node):
545
        self.visitchildren(node)
546
        common_var = None
547
        case_count = 0
548 549 550 551 552
        cases = []
        for if_clause in node.if_clauses:
            var, conditions = self.extract_conditions(if_clause.condition)
            if var is None:
                return node
553
            elif common_var is not None and not is_common_value(var, common_var):
554
                return node
Robert Bradshaw's avatar
Robert Bradshaw committed
555 556
            elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
                return node
557 558
            else:
                common_var = var
559
                case_count += len(conditions)
560 561 562
                cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
                                                  conditions = conditions,
                                                  body = if_clause.body))
563 564
        if case_count < 2:
            return node
Robert Bradshaw's avatar
Robert Bradshaw committed
565 566
        
        common_var = unwrap_node(common_var)
567 568 569 570
        return Nodes.SwitchStatNode(pos = node.pos,
                                    test = common_var,
                                    cases = cases,
                                    else_clause = node.else_clause)
571

572
    visit_Node = Visitor.VisitorTransform.recurse_to_children
573
                              
574

575
class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592
    """
    This transformation flattens "x in [val1, ..., valn]" into a sequential list
    of comparisons. 
    """
    
    def visit_PrimaryCmpNode(self, node):
        self.visitchildren(node)
        if node.cascade is not None:
            return node
        elif node.operator == 'in':
            conjunction = 'or'
            eq_or_neq = '=='
        elif node.operator == 'not_in':
            conjunction = 'and'
            eq_or_neq = '!='
        else:
            return node
593

Stefan Behnel's avatar
Stefan Behnel committed
594 595
        if not isinstance(node.operand2, (ExprNodes.TupleNode, ExprNodes.ListNode)):
            return node
596

Stefan Behnel's avatar
Stefan Behnel committed
597 598 599
        args = node.operand2.args
        if len(args) == 0:
            return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in')
600

601
        lhs = UtilNodes.ResultRefNode(node.operand1)
Stefan Behnel's avatar
Stefan Behnel committed
602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621

        conds = []
        for arg in args:
            cond = ExprNodes.PrimaryCmpNode(
                                pos = node.pos,
                                operand1 = lhs,
                                operator = eq_or_neq,
                                operand2 = arg,
                                cascade = None)
            conds.append(ExprNodes.TypecastNode(
                                pos = node.pos, 
                                operand = cond,
                                type = PyrexTypes.c_bint_type))
        def concat(left, right):
            return ExprNodes.BoolBinopNode(
                                pos = node.pos, 
                                operator = conjunction,
                                operand1 = left,
                                operand2 = right)

622
        condition = reduce(concat, conds)
Stefan Behnel's avatar
Stefan Behnel committed
623
        return UtilNodes.EvalWithTempExprNode(lhs, condition)
624

625
    visit_Node = Visitor.VisitorTransform.recurse_to_children
626 627


628 629 630 631 632 633
class DropRefcountingTransform(Visitor.VisitorTransform):
    """Drop ref-counting in safe places.
    """
    visit_Node = Visitor.VisitorTransform.recurse_to_children

    def visit_ParallelAssignmentNode(self, node):
Stefan Behnel's avatar
Stefan Behnel committed
634 635 636
        """
        Parallel swap assignments like 'a,b = b,a' are safe.
        """
637 638 639 640
        left_names, right_names = [], []
        left_indices, right_indices = [], []
        temps = []

641 642
        for stat in node.stats:
            if isinstance(stat, Nodes.SingleAssignmentNode):
643 644
                if not self._extract_operand(stat.lhs, left_names,
                                             left_indices, temps):
645
                    return node
646 647
                if not self._extract_operand(stat.rhs, right_names,
                                             right_indices, temps):
648
                    return node
649 650 651
            elif isinstance(stat, Nodes.CascadedAssignmentNode):
                # FIXME
                return node
652 653 654
            else:
                return node

655 656
        if left_names or right_names:
            # lhs/rhs names must be a non-redundant permutation
657 658
            lnames = [ path for path, n in left_names ]
            rnames = [ path for path, n in right_names ]
659 660 661
            if set(lnames) != set(rnames):
                return node
            if len(set(lnames)) != len(right_names):
662 663
                return node

664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686
        if left_indices or right_indices:
            # base name and index of index nodes must be a
            # non-redundant permutation
            lindices = []
            for lhs_node in left_indices:
                index_id = self._extract_index_id(lhs_node)
                if not index_id:
                    return node
                lindices.append(index_id)
            rindices = []
            for rhs_node in right_indices:
                index_id = self._extract_index_id(rhs_node)
                if not index_id:
                    return node
                rindices.append(index_id)
            
            if set(lindices) != set(rindices):
                return node
            if len(set(lindices)) != len(right_indices):
                return node

            # really supporting IndexNode requires support in
            # __Pyx_GetItemInt(), so let's stop short for now
687 688
            return node

689 690 691 692
        temp_args = [t.arg for t in temps]
        for temp in temps:
            temp.use_managed_ref = False

693
        for _, name_node in left_names + right_names:
694 695 696 697 698
            if name_node not in temp_args:
                name_node.use_managed_ref = False

        for index_node in left_indices + right_indices:
            index_node.use_managed_ref = False
699 700 701

        return node

702 703 704 705 706 707 708
    def _extract_operand(self, node, names, indices, temps):
        node = unwrap_node(node)
        if not node.type.is_pyobject:
            return False
        if isinstance(node, ExprNodes.CoerceToTempNode):
            temps.append(node)
            node = node.arg
709 710 711 712
        name_path = []
        obj_node = node
        while isinstance(obj_node, ExprNodes.AttributeNode):
            if obj_node.is_py_attr:
713
                return False
714 715 716 717 718
            name_path.append(obj_node.member)
            obj_node = obj_node.obj
        if isinstance(obj_node, ExprNodes.NameNode):
            name_path.append(obj_node.name)
            names.append( ('.'.join(name_path[::-1]), node) )
719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742
        elif isinstance(node, ExprNodes.IndexNode):
            if node.base.type != Builtin.list_type:
                return False
            if not node.index.type.is_int:
                return False
            if not isinstance(node.base, ExprNodes.NameNode):
                return False
            indices.append(node)
        else:
            return False
        return True

    def _extract_index_id(self, index_node):
        base = index_node.base
        index = index_node.index
        if isinstance(index, ExprNodes.NameNode):
            index_val = index.name
        elif isinstance(index, ExprNodes.ConstNode):
            # FIXME:
            return None
        else:
            return None
        return (base.name, index_val)

743

744 745 746 747 748 749 750
class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
    """Optimize some common calls to builtin types *before* the type
    analysis phase and *after* the declarations analysis phase.

    This transform cannot make use of any argument types, but it can
    restructure the tree in a way that the type analysis phase can
    respond to.
Stefan Behnel's avatar
Stefan Behnel committed
751 752 753 754

    Introducing C function calls here may not be a good idea.  Move
    them to the OptimizeBuiltinCalls transform instead, which runs
    after type analyis.
755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867
    """
    # only intercept on call nodes
    visit_Node = Visitor.VisitorTransform.recurse_to_children

    def visit_SimpleCallNode(self, node):
        self.visitchildren(node)
        function = node.function
        if not self._function_is_builtin_name(function):
            return node
        return self._dispatch_to_handler(node, function, node.args)

    def visit_GeneralCallNode(self, node):
        self.visitchildren(node)
        function = node.function
        if not self._function_is_builtin_name(function):
            return node
        arg_tuple = node.positional_args
        if not isinstance(arg_tuple, ExprNodes.TupleNode):
            return node
        args = arg_tuple.args
        return self._dispatch_to_handler(
            node, function, args, node.keyword_args)

    def _function_is_builtin_name(self, function):
        if not function.is_name:
            return False
        entry = self.env_stack[-1].lookup(function.name)
        if not entry or getattr(entry, 'scope', None) is not Builtin.builtin_scope:
            return False
        return True

    def _dispatch_to_handler(self, node, function, args, kwargs=None):
        if kwargs is None:
            handler_name = '_handle_simple_function_%s' % function.name
        else:
            handler_name = '_handle_general_function_%s' % function.name
        handle_call = getattr(self, handler_name, None)
        if handle_call is not None:
            if kwargs is None:
                return handle_call(node, args)
            else:
                return handle_call(node, args, kwargs)
        return node

    def _inject_capi_function(self, node, cname, func_type, utility_code=None):
        node.function = ExprNodes.PythonCapiFunctionNode(
            node.function.pos, node.function.name, cname, func_type,
            utility_code = utility_code)

    def _error_wrong_arg_count(self, function_name, node, args, expected=None):
        if not expected: # None or 0
            arg_str = ''
        elif isinstance(expected, basestring) or expected > 1:
            arg_str = '...'
        elif expected == 1:
            arg_str = 'x'
        else:
            arg_str = ''
        if expected is not None:
            expected_str = 'expected %s, ' % expected
        else:
            expected_str = ''
        error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
            function_name, arg_str, expected_str, len(args)))

    # specific handlers for simple call nodes

    def _handle_simple_function_set(self, node, pos_args):
        """Replace set([a,b,...]) by a literal set {a,b,...} and
        set([ x for ... ]) by a literal { x for ... }.
        """
        arg_count = len(pos_args)
        if arg_count == 0:
            return ExprNodes.SetNode(node.pos, args=[],
                                     type=Builtin.set_type)
        if arg_count > 1:
            return node
        iterable = pos_args[0]
        if isinstance(iterable, (ExprNodes.ListNode, ExprNodes.TupleNode)):
            return ExprNodes.SetNode(node.pos, args=iterable.args)
        elif isinstance(iterable, ExprNodes.ComprehensionNode) and \
                 isinstance(iterable.target, (ExprNodes.ListNode,
                                              ExprNodes.SetNode)):
            iterable.target = ExprNodes.SetNode(node.pos, args=[])
            iterable.pos = node.pos
            return iterable
        else:
            return node

    def _handle_simple_function_dict(self, node, pos_args):
        """Replace dict([ (a,b) for ... ]) by a literal { a:b for ... }.
        """
        if len(pos_args) != 1:
            return node
        arg = pos_args[0]
        if isinstance(arg, ExprNodes.ComprehensionNode) and \
               isinstance(arg.target, (ExprNodes.ListNode,
                                       ExprNodes.SetNode)):
            append_node = arg.append
            if isinstance(append_node.expr, (ExprNodes.TupleNode, ExprNodes.ListNode)) and \
                   len(append_node.expr.args) == 2:
                key_node, value_node = append_node.expr.args
                target_node = ExprNodes.DictNode(
                    pos=arg.target.pos, key_value_pairs=[])
                new_append_node = ExprNodes.DictComprehensionAppendNode(
                    append_node.pos, target=target_node,
                    key_expr=key_node, value_expr=value_node)
                arg.target = target_node
                arg.type = target_node.type
                replace_in = Visitor.RecursiveNodeReplacer(append_node, new_append_node)
                return replace_in(arg)
        return node

868 869 870 871 872 873 874
    def _handle_simple_function_float(self, node, pos_args):
        if len(pos_args) == 0:
            return ExprNodes.FloatNode(node.pos, value='0.0')
        if len(pos_args) > 1:
            self._error_wrong_arg_count('float', node, pos_args, 1)
        return node

875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890
    # specific handlers for general call nodes

    def _handle_general_function_dict(self, node, pos_args, kwargs):
        """Replace dict(a=b,c=d,...) by the underlying keyword dict
        construction which is done anyway.
        """
        if len(pos_args) > 0:
            return node
        if not isinstance(kwargs, ExprNodes.DictNode):
            return node
        if node.starstar_arg:
            # we could optimize this by updating the kw dict instead
            return node
        return kwargs


891
class OptimizeBuiltinCalls(Visitor.EnvTransform):
Stefan Behnel's avatar
Stefan Behnel committed
892
    """Optimize some common methods calls and instantiation patterns
893 894 895 896 897
    for builtin types *after* the type analysis phase.

    Running after type analysis, this transform can only perform
    function replacements that do not alter the function return type
    in a way that was not anticipated by the type analysis.
898
    """
899 900
    # only intercept on call nodes
    visit_Node = Visitor.VisitorTransform.recurse_to_children
901

902
    def visit_GeneralCallNode(self, node):
903
        self.visitchildren(node)
904 905 906 907 908 909
        function = node.function
        if not function.type.is_pyobject:
            return node
        arg_tuple = node.positional_args
        if not isinstance(arg_tuple, ExprNodes.TupleNode):
            return node
910
        args = arg_tuple.args
911
        return self._dispatch_to_handler(
912
            node, function, args, node.keyword_args)
913 914 915

    def visit_SimpleCallNode(self, node):
        self.visitchildren(node)
916
        function = node.function
917 918 919 920 921 922 923
        if function.type.is_pyobject:
            arg_tuple = node.arg_tuple
            if not isinstance(arg_tuple, ExprNodes.TupleNode):
                return node
            args = arg_tuple.args
        else:
            args = node.args
924
        return self._dispatch_to_handler(
925
            node, function, args)
926

927 928
    ### cleanup to avoid redundant coercions to/from Python types

929 930 931
    def _visit_PyTypeTestNode(self, node):
        # disabled - appears to break assignments in some cases, and
        # also drops a None check, which might still be required
932 933 934 935 936 937 938 939
        """Flatten redundant type checks after tree changes.
        """
        old_arg = node.arg
        self.visitchildren(node)
        if old_arg is node.arg or node.arg.type != node.type:
            return node
        return node.arg

940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977
    def visit_CoerceFromPyTypeNode(self, node):
        """Drop redundant conversion nodes after tree changes.

        Also, optimise away calls to Python's builtin int() and
        float() if the result is going to be coerced back into a C
        type anyway.
        """
        self.visitchildren(node)
        arg = node.arg
        if not arg.type.is_pyobject:
            # no Python conversion left at all, just do a C coercion instead
            if node.type == arg.type:
                return arg
            else:
                return arg.coerce_to(node.type, self.env_stack[-1])
        if not isinstance(arg, ExprNodes.SimpleCallNode):
            return node
        if not (node.type.is_int or node.type.is_float):
            return node
        function = arg.function
        if not isinstance(function, ExprNodes.NameNode) \
               or not function.type.is_builtin_type \
               or not isinstance(arg.arg_tuple, ExprNodes.TupleNode):
            return node
        args = arg.arg_tuple.args
        if len(args) != 1:
            return node
        func_arg = args[0]
        if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
            func_arg = func_arg.arg
        elif func_arg.type.is_pyobject:
            # play safe: Python conversion might work on all sorts of things
            return node
        if function.name == 'int':
            if func_arg.type.is_int or node.type.is_int:
                if func_arg.type == node.type:
                    return func_arg
                elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
978 979
                    return ExprNodes.TypecastNode(
                        node.pos, operand=func_arg, type=node.type)
980 981 982 983 984
        elif function.name == 'float':
            if func_arg.type.is_float or node.type.is_float:
                if func_arg.type == node.type:
                    return func_arg
                elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
985 986
                    return ExprNodes.TypecastNode(
                        node.pos, operand=func_arg, type=node.type)
987 988 989 990
        return node

    ### dispatch to specific optimisers

991 992 993 994 995 996 997
    def _find_handler(self, match_name, has_kwargs):
        call_type = has_kwargs and 'general' or 'simple'
        handler = getattr(self, '_handle_%s_%s' % (call_type, match_name), None)
        if handler is None:
            handler = getattr(self, '_handle_any_%s' % match_name, None)
        return handler

998
    def _dispatch_to_handler(self, node, function, arg_list, kwargs=None):
999
        if function.is_name:
1000 1001 1002
            # we only consider functions that are either builtin
            # Python functions or builtins that were already replaced
            # into a C function call (defined in the builtin scope)
1003 1004
            if not function.entry:
                return node
1005 1006 1007 1008
            is_builtin = function.entry.is_builtin \
                         or getattr(function.entry, 'scope', None) is Builtin.builtin_scope
            if not is_builtin:
                return node
1009 1010 1011 1012 1013
            function_handler = self._find_handler(
                "function_%s" % function.name, kwargs)
            if function_handler is None:
                return node
            if kwargs:
1014
                return function_handler(node, arg_list, kwargs)
1015
            else:
1016 1017
                return function_handler(node, arg_list)
        elif function.is_attribute and function.type.is_pyobject:
Stefan Behnel's avatar
Stefan Behnel committed
1018
            attr_name = function.attribute
1019 1020
            self_arg = function.obj
            obj_type = self_arg.type
1021
            is_unbound_method = False
1022 1023 1024 1025 1026 1027 1028
            if obj_type.is_builtin_type:
                if obj_type is Builtin.type_type and arg_list and \
                         arg_list[0].type.is_pyobject:
                    # calling an unbound method like 'list.append(L,x)'
                    # (ignoring 'type.mro()' here ...)
                    type_name = function.obj.name
                    self_arg = None
1029
                    is_unbound_method = True
1030 1031 1032
                else:
                    type_name = obj_type.name
            else:
1033
                type_name = "object" # safety measure
1034
            method_handler = self._find_handler(
Stefan Behnel's avatar
Stefan Behnel committed
1035
                "method_%s_%s" % (type_name, attr_name), kwargs)
1036
            if method_handler is None:
Stefan Behnel's avatar
Stefan Behnel committed
1037 1038 1039 1040
                if attr_name in TypeSlots.method_name_to_slot \
                       or attr_name == '__new__':
                    method_handler = self._find_handler(
                        "slot%s" % attr_name, kwargs)
1041 1042
                if method_handler is None:
                    return node
1043 1044 1045
            if self_arg is not None:
                arg_list = [self_arg] + list(arg_list)
            if kwargs:
1046
                return method_handler(node, arg_list, kwargs, is_unbound_method)
1047
            else:
1048
                return method_handler(node, arg_list, is_unbound_method)
1049
        else:
1050
            return node
1051

1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067
    def _error_wrong_arg_count(self, function_name, node, args, expected=None):
        if not expected: # None or 0
            arg_str = ''
        elif isinstance(expected, basestring) or expected > 1:
            arg_str = '...'
        elif expected == 1:
            arg_str = 'x'
        else:
            arg_str = ''
        if expected is not None:
            expected_str = 'expected %s, ' % expected
        else:
            expected_str = ''
        error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
            function_name, arg_str, expected_str, len(args)))

1068 1069
    ### builtin types

1070 1071 1072 1073 1074 1075
    PyDict_Copy_func_type = PyrexTypes.CFuncType(
        Builtin.dict_type, [
            PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None)
            ])

    def _handle_simple_function_dict(self, node, pos_args):
1076
        """Replace dict(some_dict) by PyDict_Copy(some_dict).
1077
        """
1078
        if len(pos_args) != 1:
1079
            return node
1080
        arg = pos_args[0]
1081 1082 1083 1084 1085
        if arg.type is Builtin.dict_type:
            arg = ExprNodes.NoneCheckNode(
                arg, "PyExc_TypeError", "'NoneType' is not iterable")
            return ExprNodes.PythonCapiCallNode(
                node.pos, "PyDict_Copy", self.PyDict_Copy_func_type,
1086
                args = [arg],
1087 1088 1089
                is_temp = node.is_temp
                )
        return node
1090

1091 1092 1093 1094 1095
    PyList_AsTuple_func_type = PyrexTypes.CFuncType(
        Builtin.tuple_type, [
            PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None)
            ])

1096
    def _handle_simple_function_tuple(self, node, pos_args):
1097 1098
        """Replace tuple([...]) by a call to PyList_AsTuple.
        """
1099
        if len(pos_args) != 1:
1100
            return node
1101
        list_arg = pos_args[0]
1102 1103 1104 1105
        if list_arg.type is not Builtin.list_type:
            return node
        if not isinstance(list_arg, (ExprNodes.ComprehensionNode,
                                     ExprNodes.ListNode)):
1106
            pos_args[0] = ExprNodes.NoneCheckNode(
1107 1108
                list_arg, "PyExc_TypeError",
                "'NoneType' object is not iterable")
1109

1110 1111
        return ExprNodes.PythonCapiCallNode(
            node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type,
1112
            args = pos_args,
1113 1114 1115
            is_temp = node.is_temp
            )

1116 1117 1118 1119 1120 1121 1122 1123
    PyObject_AsDouble_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_double_type, [
            PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
            ],
        exception_value = "((double)-1)",
        exception_check = True)

    def _handle_simple_function_float(self, node, pos_args):
Stefan Behnel's avatar
Stefan Behnel committed
1124 1125 1126
        """Transform float() into either a C type cast or a faster C
        function call.
        """
1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137
        # Note: this requires the float() function to be typed as
        # returning a C 'double'
        if len(pos_args) != 1:
            self._error_wrong_arg_count('float', node, pos_args, 1)
            return node
        func_arg = pos_args[0]
        if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
            func_arg = func_arg.arg
        if func_arg.type is PyrexTypes.c_double_type:
            return func_arg
        elif node.type.assignable_from(func_arg.type) or func_arg.type.is_numeric:
1138 1139
            return ExprNodes.TypecastNode(
                node.pos, operand=func_arg, type=node.type)
1140 1141 1142 1143 1144 1145 1146 1147
        return ExprNodes.PythonCapiCallNode(
            node.pos, "__Pyx_PyObject_AsDouble",
            self.PyObject_AsDouble_func_type,
            args = pos_args,
            is_temp = node.is_temp,
            utility_code = pyobject_as_double_utility_code,
            py_name = "float")

1148 1149
    ### builtin functions

Stefan Behnel's avatar
Stefan Behnel committed
1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163
    PyObject_GetAttr2_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
            ])

    PyObject_GetAttr3_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
            ])

    def _handle_simple_function_getattr(self, node, pos_args):
Stefan Behnel's avatar
Stefan Behnel committed
1164 1165
        """Replace 2/3 argument forms of getattr() by C-API calls.
        """
Stefan Behnel's avatar
Stefan Behnel committed
1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180
        if len(pos_args) == 2:
            return ExprNodes.PythonCapiCallNode(
                node.pos, "PyObject_GetAttr", self.PyObject_GetAttr2_func_type,
                args = pos_args,
                is_temp = node.is_temp)
        elif len(pos_args) == 3:
            return ExprNodes.PythonCapiCallNode(
                node.pos, "__Pyx_GetAttr3", self.PyObject_GetAttr3_func_type,
                args = pos_args,
                is_temp = node.is_temp,
                utility_code = Builtin.getattr3_utility_code)
        else:
            self._error_wrong_arg_count('getattr', node, pos_args, '2 or 3')
        return node

1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192
    PyObject_GetIter_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
            ])

    PyCallIter_New_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("sentinel", PyrexTypes.py_object_type, None),
            ])

    def _handle_simple_function_iter(self, node, pos_args):
Stefan Behnel's avatar
Stefan Behnel committed
1193 1194
        """Replace 1/2 argument forms of iter() by C-API calls.
        """
1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208
        if len(pos_args) == 1:
            return ExprNodes.PythonCapiCallNode(
                node.pos, "PyObject_GetIter", self.PyObject_GetIter_func_type,
                args = pos_args,
                is_temp = node.is_temp)
        elif len(pos_args) == 2:
            return ExprNodes.PythonCapiCallNode(
                node.pos, "PyCallIter_New", self.PyCallIter_New_func_type,
                args = pos_args,
                is_temp = node.is_temp)
        else:
            self._error_wrong_arg_count('iter', node, pos_args, '1 or 2')
        return node

1209 1210 1211 1212 1213 1214
    Pyx_strlen_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_size_t_type, [
            PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_char_ptr_type, None)
            ])

    def _handle_simple_function_len(self, node, pos_args):
Stefan Behnel's avatar
Stefan Behnel committed
1215 1216
        """Replace len(char*) by the equivalent call to strlen().
        """
1217 1218 1219 1220 1221 1222 1223 1224
        if len(pos_args) != 1:
            self._error_wrong_arg_count('len', node, pos_args, 1)
            return node
        arg = pos_args[0]
        if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
            arg = arg.arg
        if not arg.type.is_string:
            return node
Stefan Behnel's avatar
Stefan Behnel committed
1225 1226 1227 1228 1229 1230 1231
        if not node.type.is_numeric:
            # this optimisation only works when we already replaced
            # len() by PyObject_Length() which returns a Py_ssize_t
            # instead of a Python object, so we can return a plain
            # size_t instead without caring about Python object
            # conversion etc.
            return node
1232 1233
        node = ExprNodes.PythonCapiCallNode(
            node.pos, "strlen", self.Pyx_strlen_func_type,
1234 1235 1236 1237
            args = [arg],
            is_temp = node.is_temp,
            utility_code = include_string_h_utility_code
            )
1238 1239
        return node

1240 1241 1242 1243 1244 1245
    Pyx_Type_func_type = PyrexTypes.CFuncType(
        Builtin.type_type, [
            PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
            ])

    def _handle_simple_function_type(self, node, pos_args):
Stefan Behnel's avatar
Stefan Behnel committed
1246 1247
        """Replace type(o) by a macro call to Py_TYPE(o).
        """
1248 1249 1250 1251 1252 1253 1254 1255
        if len(pos_args) != 1:
            return node
        node = ExprNodes.PythonCapiCallNode(
            node.pos, "Py_TYPE", self.Pyx_Type_func_type,
            args = pos_args,
            is_temp = False)
        return ExprNodes.CastNode(node, PyrexTypes.py_object_type)

1256 1257
    ### special methods

1258 1259 1260 1261 1262
    Pyx_tp_new_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("type", Builtin.type_type, None)
            ])

Stefan Behnel's avatar
Stefan Behnel committed
1263
    def _handle_simple_slot__new__(self, node, args, is_unbound_method):
1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281
        """Replace 'exttype.__new__(exttype)' by a call to exttype->tp_new()
        """
        obj = node.function.obj
        if not is_unbound_method or len(args) != 1:
            return node
        type_arg = args[0]
        if not obj.is_name or not type_arg.is_name:
            # play safe
            return node
        if obj.type != Builtin.type_type or type_arg.type != Builtin.type_type:
            # not a known type, play safe
            return node
        if not type_arg.type_entry or not obj.type_entry:
            if obj.name != type_arg.name:
                return node
            # otherwise, we know it's a type and we know it's the same
            # type for both - that should do
        elif type_arg.type_entry != obj.type_entry:
1282
            # different types - may or may not lead to an error at runtime
1283 1284
            return node

Stefan Behnel's avatar
Stefan Behnel committed
1285 1286 1287 1288
        # FIXME: we could potentially look up the actual tp_new C
        # method of the extension type and call that instead of the
        # generic slot. That would also allow us to pass parameters
        # efficiently.
1289

1290 1291 1292 1293 1294 1295
        if not type_arg.type_entry:
            # arbitrary variable, needs a None check for safety
            type_arg = ExprNodes.NoneCheckNode(
                type_arg, "PyExc_TypeError",
                "object.__new__(X): X is not a type object (NoneType)")

1296
        return ExprNodes.PythonCapiCallNode(
1297
            node.pos, "__Pyx_tp_new", self.Pyx_tp_new_func_type,
1298
            args = [type_arg],
1299 1300 1301 1302
            utility_code = tpnew_utility_code,
            is_temp = node.is_temp
            )

1303 1304 1305 1306 1307 1308 1309 1310
    ### methods of builtin types

    PyObject_Append_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
            ])

1311
    def _handle_simple_method_object_append(self, node, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
1312 1313 1314
        """Optimistic optimisation as X.append() is almost always
        referring to a list.
        """
1315
        if len(args) != 2:
1316 1317
            return node

1318 1319 1320 1321
        return ExprNodes.PythonCapiCallNode(
            node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type,
            args = args,
            is_temp = node.is_temp,
1322
            utility_code = append_utility_code
1323 1324
            )

Robert Bradshaw's avatar
Robert Bradshaw committed
1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336
    PyObject_Pop_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
            ])

    PyObject_PopIndex_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_long_type, None),
            ])

    def _handle_simple_method_object_pop(self, node, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
1337 1338 1339
        """Optimistic optimisation as X.pop([n]) is almost always
        referring to a list.
        """
Robert Bradshaw's avatar
Robert Bradshaw committed
1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360
        if len(args) == 1:
            return ExprNodes.PythonCapiCallNode(
                node.pos, "__Pyx_PyObject_Pop", self.PyObject_Pop_func_type,
                args = args,
                is_temp = node.is_temp,
                utility_code = pop_utility_code
                )
        elif len(args) == 2:
            if isinstance(args[1], ExprNodes.CoerceToPyTypeNode) and args[1].arg.type.is_int:
                original_type = args[1].arg.type
                if PyrexTypes.widest_numeric_type(original_type, PyrexTypes.c_py_ssize_t_type) == PyrexTypes.c_py_ssize_t_type:
                    args[1] = args[1].arg
                    return ExprNodes.PythonCapiCallNode(
                        node.pos, "__Pyx_PyObject_PopIndex", self.PyObject_PopIndex_func_type,
                        args = args,
                        is_temp = node.is_temp,
                        utility_code = pop_index_utility_code
                        )
                
        return node

1361 1362 1363 1364 1365 1366 1367
    PyList_Append_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_int_type, [
            PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
            ],
        exception_value = "-1")

1368
    def _handle_simple_method_list_append(self, node, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
1369 1370
        """Call PyList_Append() instead of l.append().
        """
1371
        if len(args) != 2:
1372
            self._error_wrong_arg_count('list.append', node, args, 2)
1373
            return node
1374
        return self._substitute_method_call(
1375 1376
            node, "PyList_Append", self.PyList_Append_func_type,
            'append', is_unbound_method, args)
1377

1378 1379 1380 1381 1382
    single_param_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_int_type, [
            PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
            ],
        exception_value = "-1")
1383

1384
    def _handle_simple_method_list_sort(self, node, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
1385 1386
        """Call PyList_Sort() instead of the 0-argument l.sort().
        """
1387
        if len(args) != 1:
1388
            return node
1389
        return self._substitute_method_call(
1390 1391
            node, "PyList_Sort", self.single_param_func_type,
            'sort', is_unbound_method, args)
1392

1393
    def _handle_simple_method_list_reverse(self, node, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
1394 1395
        """Call PyList_Reverse() instead of l.reverse().
        """
1396
        if len(args) != 1:
1397
            self._error_wrong_arg_count('list.reverse', node, args, 1)
1398
            return node
1399
        return self._substitute_method_call(
1400 1401
            node, "PyList_Reverse", self.single_param_func_type,
            'reverse', is_unbound_method, args)
1402

1403 1404 1405 1406 1407
    Pyx_PyDict_GetItem_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
1408
            ])
1409 1410

    def _handle_simple_method_dict_get(self, node, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
1411 1412
        """Replace dict.get() by a call to PyDict_GetItem().
        """
1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423
        if len(args) == 2:
            args.append(ExprNodes.NoneNode(node.pos))
        elif len(args) != 3:
            self._error_wrong_arg_count('dict.get', node, args, "2 or 3")
            return node

        return self._substitute_method_call(
            node, "__Pyx_PyDict_GetItemDefault", self.Pyx_PyDict_GetItem_func_type,
            'get', is_unbound_method, args,
            utility_code = dict_getitem_default_utility_code)

1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474
    PyUnicode_Splitlines_func_type = PyrexTypes.CFuncType(
        Builtin.list_type, [
            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
            PyrexTypes.CFuncTypeArg("keepends", PyrexTypes.c_bint_type, None),
            ])

    def _handle_simple_method_unicode_splitlines(self, node, args, is_unbound_method):
        """Replace unicode.splitlines(...) by a direct call to the
        corresponding C-API function.
        """
        if len(args) not in (1,2):
            self._error_wrong_arg_count('unicode.splitlines', node, args, "1 or 2")
            return node
        if len(args) < 2:
            args.append(ExprNodes.BoolNode(node.pos, value=False))
        else:
            args[1] = args[1].coerce_to(PyrexTypes.c_bint_type,
                                        self.env_stack[-1])

        return self._substitute_method_call(
            node, "PyUnicode_Splitlines", self.PyUnicode_Splitlines_func_type,
            'splitlines', is_unbound_method, args)

    PyUnicode_Split_func_type = PyrexTypes.CFuncType(
        Builtin.list_type, [
            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
            PyrexTypes.CFuncTypeArg("sep", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("maxsplit", PyrexTypes.c_py_ssize_t_type, None),
            ]
        )

    def _handle_simple_method_unicode_split(self, node, args, is_unbound_method):
        """Replace unicode.split(...) by a direct call to the
        corresponding C-API function.
        """
        if len(args) not in (1,2,3):
            self._error_wrong_arg_count('unicode.split', node, args, "1-3")
            return node
        if len(args) < 2:
            args.append(ExprNodes.NullNode(node.pos))
        if len(args) < 3:
            args.append(ExprNodes.IntNode(
                node.pos, value="-1", type=PyrexTypes.c_py_ssize_t_type))
        else:
            args[2] = args[2].coerce_to(PyrexTypes.c_py_ssize_t_type,
                                        self.env_stack[-1])

        return self._substitute_method_call(
            node, "PyUnicode_Split", self.PyUnicode_Split_func_type,
            'split', is_unbound_method, args)

1475 1476
    PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType(
        Builtin.bytes_type, [
1477
            PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
1478 1479
            PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
            PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
1480
            ])
1481 1482 1483

    PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType(
        Builtin.bytes_type, [
1484
            PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
1485
            ])
1486 1487 1488 1489

    _special_encodings = ['UTF8', 'UTF16', 'Latin1', 'ASCII',
                          'unicode_escape', 'raw_unicode_escape']

1490 1491
    _special_codecs = [ (name, codecs.getencoder(name))
                        for name in _special_encodings ]
1492 1493

    def _handle_simple_method_unicode_encode(self, node, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
1494 1495 1496
        """Replace unicode.encode(...) by a direct C-API call to the
        corresponding codec.
        """
1497
        if len(args) < 1 or len(args) > 3:
1498
            self._error_wrong_arg_count('unicode.encode', node, args, '1-3')
1499 1500 1501 1502 1503
            return node

        string_node = args[0]

        if len(args) == 1:
1504
            null_node = ExprNodes.NullNode(node.pos)
1505 1506 1507 1508 1509
            return self._substitute_method_call(
                node, "PyUnicode_AsEncodedString",
                self.PyUnicode_AsEncodedString_func_type,
                'encode', is_unbound_method, [string_node, null_node, null_node])

1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548
        parameters = self._unpack_encoding_and_error_mode(node.pos, args)
        if parameters is None:
            return node
        encoding, encoding_node, error_handling, error_handling_node = parameters

        if isinstance(string_node, ExprNodes.UnicodeNode):
            # constant, so try to do the encoding at compile time
            try:
                value = string_node.value.encode(encoding, error_handling)
            except:
                # well, looks like we can't
                pass
            else:
                value = BytesLiteral(value)
                value.encoding = encoding
                return ExprNodes.BytesNode(
                    string_node.pos, value=value, type=Builtin.bytes_type)

        if error_handling == 'strict':
            # try to find a specific encoder function
            codec_name = self._find_special_codec_name(encoding)
            if codec_name is not None:
                encode_function = "PyUnicode_As%sString" % codec_name
                return self._substitute_method_call(
                    node, encode_function,
                    self.PyUnicode_AsXyzString_func_type,
                    'encode', is_unbound_method, [string_node])

        return self._substitute_method_call(
            node, "PyUnicode_AsEncodedString",
            self.PyUnicode_AsEncodedString_func_type,
            'encode', is_unbound_method,
            [string_node, encoding_node, error_handling_node])

    PyUnicode_DecodeXyz_func_type = PyrexTypes.CFuncType(
        Builtin.unicode_type, [
            PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
            PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
1549
            ])
1550 1551 1552 1553 1554 1555 1556

    PyUnicode_Decode_func_type = PyrexTypes.CFuncType(
        Builtin.unicode_type, [
            PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
            PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
            PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
1557
            ])
1558 1559

    def _handle_simple_method_bytes_decode(self, node, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
1560 1561 1562
        """Replace char*.decode() by a direct C-API call to the
        corresponding codec, possibly resoving a slice on the char*.
        """
1563 1564 1565
        if len(args) < 1 or len(args) > 3:
            self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
            return node
1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578
        temps = []
        if isinstance(args[0], ExprNodes.SliceIndexNode):
            index_node = args[0]
            string_node = index_node.base
            if not string_node.type.is_string:
                # nothing to optimise here
                return node
            start, stop = index_node.start, index_node.stop
            if not start or start.constant_result == 0:
                start = None
            else:
                if start.type.is_pyobject:
                    start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.env_stack[-1])
1579
                if stop:
1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597
                    start = UtilNodes.LetRefNode(start)
                    temps.append(start)
                string_node = ExprNodes.AddNode(pos=start.pos,
                                                operand1=string_node,
                                                operator='+',
                                                operand2=start,
                                                is_temp=False,
                                                type=string_node.type
                                                )
            if stop and stop.type.is_pyobject:
                stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.env_stack[-1])
        elif isinstance(args[0], ExprNodes.CoerceToPyTypeNode) \
                 and args[0].arg.type.is_string:
            # use strlen() to find the string length, just as CPython would
            start = stop = None
            string_node = args[0].arg
        else:
            # let Python do its job
1598
            return node
1599

1600
        if not stop:
1601
            if start or not string_node.is_name:
1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618
                string_node = UtilNodes.LetRefNode(string_node)
                temps.append(string_node)
            stop = ExprNodes.PythonCapiCallNode(
                string_node.pos, "strlen", self.Pyx_strlen_func_type,
                    args = [string_node],
                    is_temp = False,
                    utility_code = include_string_h_utility_code,
                    ).coerce_to(PyrexTypes.c_py_ssize_t_type, self.env_stack[-1])
        elif start:
            stop = ExprNodes.SubNode(
                pos = stop.pos,
                operand1 = stop,
                operator = '-',
                operand2 = start,
                is_temp = False,
                type = PyrexTypes.c_py_ssize_t_type
                )
1619 1620 1621 1622 1623 1624 1625

        parameters = self._unpack_encoding_and_error_mode(node.pos, args)
        if parameters is None:
            return node
        encoding, encoding_node, error_handling, error_handling_node = parameters

        # try to find a specific encoder function
1626 1627 1628
        codec_name = None
        if encoding is not None:
            codec_name = self._find_special_codec_name(encoding)
1629 1630
        if codec_name is not None:
            decode_function = "PyUnicode_Decode%s" % codec_name
1631
            node = ExprNodes.PythonCapiCallNode(
1632 1633 1634 1635 1636
                node.pos, decode_function,
                self.PyUnicode_DecodeXyz_func_type,
                args = [string_node, stop, error_handling_node],
                is_temp = node.is_temp,
                )
1637 1638 1639 1640 1641 1642 1643
        else:
            node = ExprNodes.PythonCapiCallNode(
                node.pos, "PyUnicode_Decode",
                self.PyUnicode_Decode_func_type,
                args = [string_node, stop, encoding_node, error_handling_node],
                is_temp = node.is_temp,
                )
1644

1645 1646 1647
        for temp in temps[::-1]:
            node = UtilNodes.EvalWithTempExprNode(temp, node)
        return node
1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662

    def _find_special_codec_name(self, encoding):
        try:
            requested_codec = codecs.getencoder(encoding)
        except:
            return None
        for name, codec in self._special_codecs:
            if codec == requested_codec:
                if '_' in name:
                    name = ''.join([ s.capitalize()
                                     for s in name.split('_')])
                return name
        return None

    def _unpack_encoding_and_error_mode(self, pos, args):
1663 1664 1665
        encoding_node = args[1]
        if isinstance(encoding_node, ExprNodes.CoerceToPyTypeNode):
            encoding_node = encoding_node.arg
1666 1667 1668 1669 1670 1671 1672 1673
        if isinstance(encoding_node, (ExprNodes.UnicodeNode, ExprNodes.StringNode,
                                      ExprNodes.BytesNode)):
            encoding = encoding_node.value
            encoding_node = ExprNodes.BytesNode(encoding_node.pos, value=encoding,
                                                 type=PyrexTypes.c_char_ptr_type)
        elif encoding_node.type.is_string:
            encoding = None
        else:
1674
            return None
1675

1676
        null_node = ExprNodes.NullNode(pos)
1677 1678 1679 1680
        if len(args) == 3:
            error_handling_node = args[2]
            if isinstance(error_handling_node, ExprNodes.CoerceToPyTypeNode):
                error_handling_node = error_handling_node.arg
1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692
            if isinstance(error_handling_node,
                          (ExprNodes.UnicodeNode, ExprNodes.StringNode,
                           ExprNodes.BytesNode)):
                error_handling = error_handling_node.value
                if error_handling == 'strict':
                    error_handling_node = null_node
                else:
                    error_handling_node = ExprNodes.BytesNode(
                        error_handling_node.pos, value=error_handling,
                        type=PyrexTypes.c_char_ptr_type)
            elif error_handling_node.type.is_string:
                error_handling = None
1693
            else:
1694
                return None
1695 1696 1697 1698
        else:
            error_handling = 'strict'
            error_handling_node = null_node

1699
        return (encoding, encoding_node, error_handling, error_handling_node)
1700

1701
    def _substitute_method_call(self, node, name, func_type,
1702 1703
                                attr_name, is_unbound_method, args=(),
                                utility_code=None):
1704
        args = list(args)
1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716
        if args:
            self_arg = args[0]
            if is_unbound_method:
                self_arg = ExprNodes.NoneCheckNode(
                    self_arg, "PyExc_TypeError",
                    "descriptor '%s' requires a '%s' object but received a 'NoneType'" % (
                    attr_name, node.function.obj.name))
            else:
                self_arg = ExprNodes.NoneCheckNode(
                    self_arg, "PyExc_AttributeError",
                    "'NoneType' object has no attribute '%s'" % attr_name)
            args[0] = self_arg
1717
        return ExprNodes.PythonCapiCallNode(
1718
            node.pos, name, func_type,
1719
            args = args,
1720 1721
            is_temp = node.is_temp,
            utility_code = utility_code
1722 1723 1724
            )


1725 1726
dict_getitem_default_utility_code = UtilityCode(
proto = '''
1727
static PyObject* __Pyx_PyDict_GetItemDefault(PyObject* d, PyObject* key, PyObject* default_value) {
1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748
    PyObject* value;
#if PY_MAJOR_VERSION >= 3
    value = PyDict_GetItemWithError(d, key);
    if (unlikely(!value)) {
        if (unlikely(PyErr_Occurred()))
            return NULL;
        value = default_value;
    }
    Py_INCREF(value);
#else
    if (PyString_CheckExact(key) || PyUnicode_CheckExact(key) || PyInt_CheckExact(key)) {
        /* these presumably have safe hash functions */
        value = PyDict_GetItem(d, key);
        if (unlikely(!value)) {
            value = default_value;
        }
        Py_INCREF(value);
    } else {
        PyObject *m;
        m = __Pyx_GetAttrString(d, "get");
        if (!m) return NULL;
1749 1750
        value = PyObject_CallFunctionObjArgs(m, key,
            (default_value == Py_None) ? NULL : default_value, NULL);
1751 1752 1753 1754 1755 1756 1757 1758 1759
        Py_DECREF(m);
    }
#endif
    return value;
}
''',
impl = ""
)

1760 1761
append_utility_code = UtilityCode(
proto = """
1762
static CYTHON_INLINE PyObject* __Pyx_PyObject_Append(PyObject* L, PyObject* x) {
1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779
    if (likely(PyList_CheckExact(L))) {
        if (PyList_Append(L, x) < 0) return NULL;
        Py_INCREF(Py_None);
        return Py_None; /* this is just to have an accurate signature */
    }
    else {
        PyObject *r, *m;
        m = __Pyx_GetAttrString(L, "append");
        if (!m) return NULL;
        r = PyObject_CallFunctionObjArgs(m, x, NULL);
        Py_DECREF(m);
        return r;
    }
}
""",
impl = ""
)
1780 1781


Robert Bradshaw's avatar
Robert Bradshaw committed
1782 1783
pop_utility_code = UtilityCode(
proto = """
1784
static CYTHON_INLINE PyObject* __Pyx_PyObject_Pop(PyObject* L) {
1785
#if PY_VERSION_HEX >= 0x02040000
Robert Bradshaw's avatar
Robert Bradshaw committed
1786 1787 1788 1789 1790 1791
    if (likely(PyList_CheckExact(L))
            /* Check that both the size is positive and no reallocation shrinking needs to be done. */
            && likely(PyList_GET_SIZE(L) > (((PyListObject*)L)->allocated >> 1))) {
        Py_SIZE(L) -= 1;
        return PyList_GET_ITEM(L, PyList_GET_SIZE(L));
    }
1792 1793 1794 1795 1796 1797 1798
#endif
    PyObject *r, *m;
    m = __Pyx_GetAttrString(L, "pop");
    if (!m) return NULL;
    r = PyObject_CallObject(m, NULL);
    Py_DECREF(m);
    return r;
Robert Bradshaw's avatar
Robert Bradshaw committed
1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810
}
""",
impl = ""
)

pop_index_utility_code = UtilityCode(
proto = """
static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix);
""",
impl = """
static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix) {
    PyObject *r, *m, *t, *py_ix;
1811
#if PY_VERSION_HEX >= 0x02040000
Robert Bradshaw's avatar
Robert Bradshaw committed
1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829
    if (likely(PyList_CheckExact(L))) {
        Py_ssize_t size = PyList_GET_SIZE(L);
        if (likely(size > (((PyListObject*)L)->allocated >> 1))) {
            if (ix < 0) {
                ix += size;
            }
            if (likely(0 <= ix && ix < size)) {
                Py_ssize_t i;
                PyObject* v = PyList_GET_ITEM(L, ix);
                Py_SIZE(L) -= 1;
                size -= 1;
                for(i=ix; i<size; i++) {
                    PyList_SET_ITEM(L, i, PyList_GET_ITEM(L, i+1));
                }
                return v;
            }
        }
    }
1830
#endif
Robert Bradshaw's avatar
Robert Bradshaw committed
1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853
    py_ix = t = NULL;
    m = __Pyx_GetAttrString(L, "pop");
    if (!m) goto bad;
    py_ix = PyInt_FromSsize_t(ix);
    if (!py_ix) goto bad;
    t = PyTuple_New(1);
    if (!t) goto bad;
    PyTuple_SET_ITEM(t, 0, py_ix);
    py_ix = NULL;
    r = PyObject_CallObject(m, t);
    Py_DECREF(m);
    Py_DECREF(t);
    return r;
bad:
    Py_XDECREF(m);
    Py_XDECREF(t);
    Py_XDECREF(py_ix);
    return NULL;
}
"""
)


1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866
pyobject_as_double_utility_code = UtilityCode(
proto = '''
static double __Pyx__PyObject_AsDouble(PyObject* obj); /* proto */

#define __Pyx_PyObject_AsDouble(obj) \\
    ((likely(PyFloat_CheckExact(obj))) ? \\
     PyFloat_AS_DOUBLE(obj) : __Pyx__PyObject_AsDouble(obj))
''',
impl='''
static double __Pyx__PyObject_AsDouble(PyObject* obj) {
    PyObject* float_value;
    if (Py_TYPE(obj)->tp_as_number && Py_TYPE(obj)->tp_as_number->nb_float) {
        return PyFloat_AsDouble(obj);
1867
    } else if (PyUnicode_CheckExact(obj) || PyBytes_CheckExact(obj)) {
1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892
#if PY_MAJOR_VERSION >= 3
        float_value = PyFloat_FromString(obj);
#else
        float_value = PyFloat_FromString(obj, 0);
#endif
    } else {
        PyObject* args = PyTuple_New(1);
        if (unlikely(!args)) goto bad;
        PyTuple_SET_ITEM(args, 0, obj);
        float_value = PyObject_Call((PyObject*)&PyFloat_Type, args, 0);
        PyTuple_SET_ITEM(args, 0, 0);
        Py_DECREF(args);
    }
    if (likely(float_value)) {
        double value = PyFloat_AS_DOUBLE(float_value);
        Py_DECREF(float_value);
        return value;
    }
bad:
    return (double)-1;
}
'''
)


1893 1894 1895 1896 1897 1898 1899
include_string_h_utility_code = UtilityCode(
proto = """
#include <string.h>
"""
)


1900 1901
tpnew_utility_code = UtilityCode(
proto = """
1902
static CYTHON_INLINE PyObject* __Pyx_tp_new(PyObject* type_obj) {
1903 1904 1905 1906 1907 1908 1909
    return (PyObject*) (((PyTypeObject*)(type_obj))->tp_new(
        (PyTypeObject*)(type_obj), %(TUPLE)s, NULL));
}
""" % {'TUPLE' : Naming.empty_tuple}
)


1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937
class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
    """Calculate the result of constant expressions to store it in
    ``expr_node.constant_result``, and replace trivial cases by their
    constant result.
    """
    def _calculate_const(self, node):
        if node.constant_result is not ExprNodes.constant_value_not_set:
            return

        # make sure we always set the value
        not_a_constant = ExprNodes.not_a_constant
        node.constant_result = not_a_constant

        # check if all children are constant
        children = self.visitchildren(node)
        for child_result in children.itervalues():
            if type(child_result) is list:
                for child in child_result:
                    if child.constant_result is not_a_constant:
                        return
            elif child_result.constant_result is not_a_constant:
                return

        # now try to calculate the real constant value
        try:
            node.calculate_constant_result()
#            if node.constant_result is not ExprNodes.not_a_constant:
#                print node.__class__.__name__, node.constant_result
1938
        except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
1939 1940 1941 1942 1943 1944 1945
            # ignore all 'normal' errors here => no constant result
            pass
        except Exception:
            # this looks like a real error
            import traceback, sys
            traceback.print_exc(file=sys.stdout)

1946 1947 1948 1949 1950 1951 1952 1953 1954 1955
    NODE_TYPE_ORDER = (ExprNodes.CharNode, ExprNodes.IntNode,
                       ExprNodes.LongNode, ExprNodes.FloatNode)

    def _widest_node_class(self, *nodes):
        try:
            return self.NODE_TYPE_ORDER[
                max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))]
        except ValueError:
            return None

1956 1957 1958 1959
    def visit_ExprNode(self, node):
        self._calculate_const(node)
        return node

1960 1961 1962 1963
    def visit_BinopNode(self, node):
        self._calculate_const(node)
        if node.constant_result is ExprNodes.not_a_constant:
            return node
1964 1965 1966 1967 1968
        if isinstance(node.constant_result, float):
            # We calculate float constants to make them available to
            # the compiler, but we do not aggregate them into a
            # constant node to prevent any loss of precision.
            return node
1969
        if not node.operand1.is_literal or not node.operand2.is_literal:
1970 1971 1972 1973 1974 1975
            # We calculate other constants to make them available to
            # the compiler, but we only aggregate constant nodes
            # recursively, so non-const nodes are straight out.
            return node

        # now inject a new constant node with the calculated value
1976
        try:
1977 1978
            type1, type2 = node.operand1.type, node.operand2.type
            if type1 is None or type2 is None:
1979 1980 1981 1982
                return node
        except AttributeError:
            return node

1983 1984 1985 1986 1987 1988 1989 1990
        if type1 is type2:
            new_node = node.operand1
        else:
            widest_type = PyrexTypes.widest_numeric_type(type1, type2)
            if type(node.operand1) is type(node.operand2):
                new_node = node.operand1
                new_node.type = widest_type
            elif type1 is widest_type:
1991
                new_node = node.operand1
1992 1993
            elif type2 is widest_type:
                new_node = node.operand2
1994
            else:
1995 1996 1997 1998 1999
                target_class = self._widest_node_class(
                    node.operand1, node.operand2)
                if target_class is None:
                    return node
                new_node = target_class(pos=node.pos, type = widest_type)
2000 2001 2002 2003

        new_node.constant_result = node.constant_result
        new_node.value = str(node.constant_result)
        #new_node = new_node.coerce_to(node.type, self.current_scope)
2004 2005
        return new_node

2006 2007
    # in the future, other nodes can have their own handler method here
    # that can replace them with a constant result node
Stefan Behnel's avatar
Stefan Behnel committed
2008

2009
    visit_Node = Visitor.VisitorTransform.recurse_to_children
2010 2011


2012 2013 2014 2015 2016 2017
class FinalOptimizePhase(Visitor.CythonTransform):
    """
    This visitor handles several commuting optimizations, and is run
    just before the C code generation phase. 
    
    The optimizations currently implemented in this class are: 
2018 2019
        - Eliminate None assignment and refcounting for first assignment. 
        - isinstance -> typecheck for cdef types
2020
    """
2021
    def visit_SingleAssignmentNode(self, node):
2022 2023 2024 2025
        """Avoid redundant initialisation of local variables before their
        first assignment.
        """
        self.visitchildren(node)
2026 2027
        if node.first:
            lhs = node.lhs
2028
            lhs.lhs_of_first_assignment = True
2029 2030 2031 2032 2033
            if isinstance(lhs, ExprNodes.NameNode) and lhs.entry.type.is_pyobject:
                # Have variable initialized to 0 rather than None
                lhs.entry.init_to_none = False
                lhs.entry.init = 0
        return node
2034

2035 2036 2037
    def visit_SimpleCallNode(self, node):
        """Replace generic calls to isinstance(x, type) by a more efficient
        type check.
2038
        """
2039 2040 2041 2042 2043 2044 2045 2046 2047 2048
        self.visitchildren(node)
        if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode):
            if node.function.name == 'isinstance':
                type_arg = node.args[1]
                if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
                    from CythonScope import utility_scope
                    node.function.entry = utility_scope.lookup('PyObject_TypeCheck')
                    node.function.type = node.function.entry.type
                    PyTypeObjectPtr = PyrexTypes.CPtrType(utility_scope.lookup('PyTypeObject').type)
                    node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
2049
        return node