Commit 118a0239 authored by Stefan Behnel's avatar Stefan Behnel

speed up adding/subtracting small integer constants

parent 054db41a
......@@ -15,6 +15,8 @@ Features added
* Tracing is supported in ``nogil`` functions/sections and module init code.
* Adding/subtracting small constant Python integers is faster.
Bugs fixed
----------
......
......@@ -15,7 +15,7 @@ from . import Builtin
from . import UtilNodes
from . import Options
from .Code import UtilityCode
from .Code import UtilityCode, TempitaUtilityCode
from .StringEncoding import EncodedString, BytesLiteral
from .Errors import error
from .ParseTreeTransforms import SkipDeclarations
......@@ -2780,6 +2780,60 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
may_return_none=True,
utility_code=load_c_utility('dict_setdefault'))
Pyx_PyNumber_BinopInt_func_type = PyrexTypes.CFuncType(
PyrexTypes.py_object_type, [
PyrexTypes.CFuncTypeArg("op1", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("op2", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("int_op", PyrexTypes.c_long_type, None),
PyrexTypes.CFuncTypeArg("inplace", PyrexTypes.c_int_type, None),
])
def _handle_simple_method_object___add__(self, node, function, args, is_unbound_method):
return self._optimise_int_binop('Add', node, function, args, is_unbound_method)
def _handle_simple_method_object___sub__(self, node, function, args, is_unbound_method):
return self._optimise_int_binop('Subtract', node, function, args, is_unbound_method)
def _optimise_int_binop(self, operator, node, function, args, is_unbound_method):
"""
Optimise '+' / '-' operator for (likely) small integer operations.
"""
if len(args) != 2:
return node
if not node.type.is_pyobject:
return node
# when adding IntNode to something else, assume other operand is also numeric
if isinstance(args[0], ExprNodes.IntNode):
if args[1].type is not PyrexTypes.py_object_type:
return node
intval = args[0]
arg_order = 'IntObj'
elif isinstance(args[1], ExprNodes.IntNode):
if args[0].type is not PyrexTypes.py_object_type:
return node
intval = args[1]
arg_order = 'ObjInt'
else:
return node
if not intval.has_constant_result() or abs(intval.constant_result) > 2**30:
return node
args = list(args)
self._inject_int_default_argument(intval, args, len(args), PyrexTypes.c_long_type, intval.constant_result)
self._inject_int_default_argument(node, args, len(args), PyrexTypes.c_long_type, int(node.inplace))
utility_code = TempitaUtilityCode.load_cached(
"PyNumberBinopWithInt", "Optimize.c",
context=dict(op=operator, order=arg_order))
return self._substitute_method_call(
node, function, "__Pyx_PyNumber_%s%s" % (operator, arg_order),
self.Pyx_PyNumber_BinopInt_func_type,
'__%s__' % operator[:3].lower(), is_unbound_method, args,
may_return_none=True,
with_none_check=False,
utility_code=utility_code)
### unicode type methods
......@@ -3335,9 +3389,10 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
def _substitute_method_call(self, node, function, name, func_type,
attr_name, is_unbound_method, args=(),
utility_code=None, is_temp=None,
may_return_none=ExprNodes.PythonCapiCallNode.may_return_none):
may_return_none=ExprNodes.PythonCapiCallNode.may_return_none,
with_none_check=True):
args = list(args)
if args and not args[0].is_literal:
if with_none_check and args and not args[0].is_literal:
self_arg = args[0]
if is_unbound_method:
self_arg = self_arg.as_none_safe_node(
......
......@@ -477,3 +477,51 @@ fallback:
#endif
return (inplace ? PyNumber_InPlacePower : PyNumber_Power)(two, exp, none);
}
/////////////// PyNumberBinopWithInt.proto ///////////////
static PyObject* __Pyx_PyNumber_{{op}}{{order}}(PyObject *op1, PyObject *op2, long intval, int inplace); /*proto*/
/////////////// PyNumberBinopWithInt ///////////////
//@requires: TypeConversion.c::PyLongInternals
{{py: pyval, ival = ('op2', 'b') if order == 'IntObj' else ('op1', 'a') }}
static PyObject* __Pyx_PyNumber_{{op}}{{order}}(PyObject *op1, PyObject *op2, CYTHON_UNUSED long intval, int inplace) {
#if CYTHON_COMPILING_IN_CPYTHON
const long {{'a' if order == 'IntObj' else 'b'}} = intval;
#if PY_MAJOR_VERSION < 3
if (likely(PyInt_CheckExact({{pyval}}))) {
long x, {{ival}};
{{ival}} = PyInt_AS_LONG({{pyval}});
// copied from intobject.c in Py2.7:
// casts in the line below avoid undefined behaviour on overflow
x = (long)((unsigned long)a {{ '+' if op == 'Add' else '-' }} b);
if ((x^a) >= 0 || (x^{{ '~' if op == 'Subtract' else '' }}b) >= 0)
return PyInt_FromLong(x);
return PyLong_Type.tp_as_number->nb_{{op.lower()}}(op1, op2);
}
#endif
#if PY_MAJOR_VERSION >= 3 && CYTHON_USE_PYLONG_INTERNALS
if (likely(PyLong_CheckExact({{pyval}}))) {
long {{ival}};
switch (Py_SIZE({{pyval}})) {
case -1: {{ival}} = -(sdigit)((PyLongObject*){{pyval}})->ob_digit[0]; break;
case 0: {{ival}} = 0; break;
case 1: {{ival}} = ((PyLongObject*){{pyval}})->ob_digit[0]; break;
default: return PyLong_Type.tp_as_number->nb_{{op.lower()}}(op1, op2);
}
return PyLong_FromLong(a {{ '+' if op == 'Add' else '-' }} b);
}
#endif
if (PyFloat_CheckExact({{pyval}})) {
double {{ival}} = PyFloat_AS_DOUBLE({{pyval}});
return PyFloat_FromDouble(((double)a) {{ '+' if op == 'Add' else '-' }} (double)b);
}
#endif
return (inplace ? PyNumber_InPlace{{op}} : PyNumber_{{op}})(op1, op2);
}
def f():
cimport cython
def mixed_test():
"""
>>> f()
>>> mixed_test()
(30, 22)
"""
cdef int int1, int2, int3
......@@ -15,3 +18,91 @@ def f():
ptr1 = int2 + ptr3
obj1 = obj2 + int3
return int1, obj1
@cython.test_fail_if_path_exists('//AddNode')
def add_x_1(x):
"""
>>> add_x_1(0)
1
>>> add_x_1(1)
2
>>> add_x_1(-1)
0
>>> add_x_1(1.5)
2.5
>>> add_x_1(-1.5)
-0.5
>>> try: add_x_1("abc")
... except TypeError: pass
"""
return x + 1
@cython.test_fail_if_path_exists('//AddNode')
def add_x_large(x):
"""
>>> add_x_large(0)
1073741824
>>> add_x_large(1)
1073741825
>>> add_x_large(-1)
1073741823
>>> add_x_large(1.5)
1073741825.5
>>> add_x_large(-2.0**31)
-1073741824.0
>>> add_x_large(2**30 + 1)
2147483649
>>> 2**31 + 2**30
3221225472
>>> add_x_large(2**31)
3221225472
>>> print(2**66 + 2**30)
73786976295911948288
>>> print(add_x_large(2**66))
73786976295911948288
>>> try: add_x_large("abc")
... except TypeError: pass
"""
return x + 2**30
@cython.test_fail_if_path_exists('//AddNode')
def add_1_x(x):
"""
>>> add_1_x(0)
1
>>> add_1_x(1)
2
>>> add_1_x(-1)
0
>>> add_1_x(1.5)
2.5
>>> add_1_x(-1.5)
-0.5
>>> try: add_1_x("abc")
... except TypeError: pass
"""
return 1 + x
@cython.test_fail_if_path_exists('//AddNode')
def add_large_x(x):
"""
>>> add_large_x(0)
1073741824
>>> add_large_x(1)
1073741825
>>> add_large_x(-1)
1073741823
>>> add_large_x(1.5)
1073741825.5
>>> add_large_x(-2.0**30)
0.0
>>> add_large_x(-2.0**31)
-1073741824.0
>>> try: add_large_x("abc")
... except TypeError: pass
"""
return 2**30 + x
def f():
cimport cython
def mixed_test():
"""
>>> f()
>>> mixed_test()
(-1, -1)
"""
cdef int int1, int2, int3
......@@ -14,9 +17,10 @@ def f():
obj1 = obj2 - int3
return int1, obj1
def p():
def pointer_test():
"""
>>> p()
>>> pointer_test()
0
"""
cdef int int1, int2, int3
......@@ -29,3 +33,95 @@ def p():
ptr1 = ptr2 - int3
int1 = ptr2 - ptr3
return int1
@cython.test_fail_if_path_exists('//SubNode')
def sub_x_1(x):
"""
>>> sub_x_1(0)
-1
>>> sub_x_1(1)
0
>>> sub_x_1(-1)
-2
>>> sub_x_1(1.5)
0.5
>>> sub_x_1(-1.5)
-2.5
>>> try: sub_x_1("abc")
... except TypeError: pass
"""
return x - 1
@cython.test_fail_if_path_exists('//SubNode')
def sub_x_large(x):
"""
>>> sub_x_large(0)
-1073741824
>>> sub_x_large(1)
-1073741823
>>> sub_x_large(-1)
-1073741825
>>> sub_x_large(2.0**30)
0.0
>>> sub_x_large(2.0**30 + 1)
1.0
>>> sub_x_large(2.0**30 - 1)
-1.0
>>> 2.0 ** 31 - 2**30
1073741824.0
>>> sub_x_large(2.0**31)
1073741824.0
>>> try: sub_x_large("abc")
... except TypeError: pass
"""
return x - 2**30
@cython.test_fail_if_path_exists('//SubNode')
def sub_1_x(x):
"""
>>> sub_1_x(0)
1
>>> sub_1_x(-1)
2
>>> sub_1_x(1)
0
>>> sub_1_x(1.5)
-0.5
>>> sub_1_x(-1.5)
2.5
>>> try: sub_1_x("abc")
... except TypeError: pass
"""
return 1 - x
@cython.test_fail_if_path_exists('//SubNode')
def sub_large_x(x):
"""
>>> sub_large_x(0)
1073741824
>>> sub_large_x(-1)
1073741825
>>> sub_large_x(1)
1073741823
>>> sub_large_x(2**30)
0
>>> 2**30 - 2**31
-1073741824
>>> sub_large_x(2**31)
-1073741824
>>> sub_large_x(2.0**30)
0.0
>>> sub_large_x(2.0**31)
-1073741824.0
>>> sub_large_x(2.0**30 + 1)
-1.0
>>> sub_large_x(2.0**30 - 1)
1.0
>>> try: sub_large_x("abc")
... except TypeError: pass
"""
return 2**30 - x
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment