Commit 33e7929b authored by Robert Bradshaw's avatar Robert Bradshaw

Merge branch 'overflow'

parents 8d2a7185 c93a2314
......@@ -68,10 +68,18 @@ class UtilityCodeBase(object):
Code sections in the file can be specified as follows:
##### MyUtility.proto #####
[proto declarations]
##### MyUtility.init #####
[code run at module initialization]
##### MyUtility #####
#@requires: MyOtherUtility
#@substitute: naming
[definitions]
for prototypes and implementation respectively. For non-python or
-cython files backslashes should be used instead. 5 to 30 comment
......@@ -374,10 +382,13 @@ class UtilityCode(UtilityCodeBase):
output['utility_code_def'].put(self.format_code(self.impl))
if self.init:
writer = output['init_globals']
writer.putln("/* %s.init */" % self.name)
if isinstance(self.init, basestring):
writer.put(self.format_code(self.init))
else:
self.init(writer, output.module_pos)
writer.putln(writer.error_goto_if_PyErr(output.module_pos))
writer.putln()
if self.cleanup and Options.generate_cleanup_code:
writer = output['cleanup_globals']
if isinstance(self.cleanup, basestring):
......@@ -400,13 +411,14 @@ def sub_tempita(s, context, file=None, name=None):
return sub(s, **context)
class TempitaUtilityCode(UtilityCode):
def __init__(self, name=None, proto=None, impl=None, file=None, context=None, **kwargs):
def __init__(self, name=None, proto=None, impl=None, init=None, file=None, context=None, **kwargs):
if context is None:
context = {}
proto = sub_tempita(proto, context, file, name)
impl = sub_tempita(impl, context, file, name)
init = sub_tempita(init, context, file, name)
super(TempitaUtilityCode, self).__init__(
proto, impl, name=name, file=file, **kwargs)
proto, impl, init=init, name=name, file=file, **kwargs)
def none_or_sub(self, s, context):
"""
......
......@@ -7341,6 +7341,9 @@ class TypecastNode(ExprNode):
if self.type is None:
base_type = self.base_type.analyse(env)
_, self.type = self.declarator.analyse(base_type, env)
if self.operand.has_constant_result():
# Must be done after self.type is resolved.
self.calculate_constant_result()
if self.type.is_cfunction:
error(self.pos,
"Cannot cast to a function type")
......@@ -7400,11 +7403,11 @@ class TypecastNode(ExprNode):
return self.operand.check_const()
def calculate_constant_result(self):
# we usually do not know the result of a type cast at code
# generation time
pass
self.constant_result = self.calculate_result_code(self.operand.constant_result)
def calculate_result_code(self):
def calculate_result_code(self, operand_result = None):
if operand_result is None:
operand_result = self.operand.result()
if self.type.is_complex:
operand_result = self.operand.result()
if self.operand.type.is_complex:
......@@ -7418,7 +7421,7 @@ class TypecastNode(ExprNode):
real_part,
imag_part)
else:
return self.type.cast_code(self.operand.result())
return self.type.cast_code(operand_result)
def get_constant_c_result_code(self):
operand_result = self.operand.get_constant_c_result_code()
......@@ -7997,6 +8000,7 @@ class NumBinopNode(BinopNode):
# Binary operation taking numeric arguments.
infix = True
overflow_check = False
def analyse_c_operation(self, env):
type1 = self.operand1.type
......@@ -8007,6 +8011,13 @@ class NumBinopNode(BinopNode):
return
if self.type.is_complex:
self.infix = False
if self.type.is_int and env.directives['overflowcheck'] and self.operator in self.overflow_op_names:
self.overflow_check = True
self.func = self.type.overflow_check_binop(
self.overflow_op_names[self.operator],
env,
const_rhs = self.operand2.has_constant_result())
self.is_temp = True
if not self.infix or (type1.is_numeric and type2.is_numeric):
self.operand1 = self.operand1.coerce_to(self.type, env)
self.operand2 = self.operand2.coerce_to(self.type, env)
......@@ -8048,8 +8059,26 @@ class NumBinopNode(BinopNode):
return (type1.is_numeric or type1.is_enum) \
and (type2.is_numeric or type2.is_enum)
def generate_result_code(self, code):
super(NumBinopNode, self).generate_result_code(code)
if self.overflow_check:
self.overflow_bit = code.funcstate.allocate_temp(PyrexTypes.c_int_type, manage_ref=False)
code.putln("%s = 0;" % self.overflow_bit);
code.putln("%s = %s;" % (self.result(), self.calculate_result_code()))
code.putln("if (unlikely(%s)) {" % self.overflow_bit)
code.putln('PyErr_Format(PyExc_OverflowError, "value too large");')
code.putln(code.error_goto(self.pos))
code.putln("}")
code.funcstate.release_temp(self.overflow_bit)
def calculate_result_code(self):
if self.infix:
if self.overflow_check:
return "%s(%s, %s, &%s)" % (
self.func,
self.operand1.result(),
self.operand2.result(),
self.overflow_bit)
elif self.infix:
return "(%s %s %s)" % (
self.operand1.result(),
self.operator,
......@@ -8088,6 +8117,13 @@ class NumBinopNode(BinopNode):
"%": "PyNumber_Remainder",
"**": "PyNumber_Power"
}
overflow_op_names = {
"+": "add",
"-": "sub",
"*": "mul",
"<<": "lshift",
}
class IntBinopNode(NumBinopNode):
# Binary operation taking integer arguments.
......
......@@ -81,6 +81,7 @@ directive_defaults = {
'auto_cpdef': False,
'cdivision': False, # was True before 0.12
'cdivision_warnings': False,
'overflowcheck': False,
'always_allow_keywords': False,
'allow_none_for_extension_args': True,
'wraparound' : True,
......
......@@ -25,7 +25,7 @@ class BaseType(object):
# This is not entirely robust.
safe = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_0123456789'
all = []
for c in self.declaration_code("").replace(" ", "__"):
for c in self.declaration_code("").replace("unsigned ", "unsigned_").replace("long long", "long_long").replace(" ", "__"):
if c in safe:
all.append(c)
else:
......@@ -402,6 +402,26 @@ class CTypedefType(BaseType):
# delegation
return self.typedef_base_type.create_from_py_utility_code(env)
def overflow_check_binop(self, binop, env, const_rhs=False):
env.use_utility_code(UtilityCode.load("Common", "Overflow.c"))
type = self.declaration_code("")
name = self.specialization_name()
if binop == "lshift":
env.use_utility_code(TempitaUtilityCode.load(
"LeftShift", "Overflow.c",
context={'TYPE': type, 'NAME': name, 'SIGNED': self.signed}))
else:
if const_rhs:
binop += "_const"
_load_overflow_base(env)
env.use_utility_code(TempitaUtilityCode.load(
"SizeCheck", "Overflow.c",
context={'TYPE': type, 'NAME': name}))
env.use_utility_code(TempitaUtilityCode.load(
"Binop", "Overflow.c",
context={'TYPE': type, 'NAME': name, 'BINOP': binop}))
return "__Pyx_%s_%s_checking_overflow" % (binop, name)
def error_condition(self, result_code):
if self.typedef_is_external:
if self.exception_value:
......@@ -1546,7 +1566,51 @@ class CIntType(CNumericType):
# We do not really know the size of the type, so return
# a 32-bit literal and rely on casting to final type. It will
# be negative for signed ints, which is good.
return "0xbad0bad0";
return "0xbad0bad0"
def overflow_check_binop(self, binop, env, const_rhs=False):
env.use_utility_code(UtilityCode.load("Common", "Overflow.c"))
type = self.declaration_code("")
name = self.specialization_name()
if binop == "lshift":
env.use_utility_code(TempitaUtilityCode.load(
"LeftShift", "Overflow.c",
context={'TYPE': type, 'NAME': name, 'SIGNED': not self.signed}))
else:
if const_rhs:
binop += "_const"
if type in ('int', 'long', 'long long'):
env.use_utility_code(TempitaUtilityCode.load(
"BaseCaseSigned", "Overflow.c",
context={'INT': type, 'NAME': name}))
elif type in ('unsigned int', 'unsigned long', 'unsigned long long'):
env.use_utility_code(TempitaUtilityCode.load(
"BaseCaseUnsigned", "Overflow.c",
context={'UINT': type, 'NAME': name}))
elif self.rank <= 1:
# sizeof(short) < sizeof(int)
return "__Pyx_%s_%s_no_overflow" % (binop, name)
else:
_load_overflow_base(env)
env.use_utility_code(TempitaUtilityCode.load(
"SizeCheck", "Overflow.c",
context={'TYPE': type, 'NAME': name}))
env.use_utility_code(TempitaUtilityCode.load(
"Binop", "Overflow.c",
context={'TYPE': type, 'NAME': name, 'BINOP': binop}))
return "__Pyx_%s_%s_checking_overflow" % (binop, name)
def _load_overflow_base(env):
env.use_utility_code(UtilityCode.load("Common", "Overflow.c"))
for type in ('int', 'long', 'long long'):
env.use_utility_code(TempitaUtilityCode.load(
"BaseCaseSigned", "Overflow.c",
context={'INT': type, 'NAME': type.replace(' ', '_')}))
for type in ('unsigned int', 'unsigned long', 'unsigned long long'):
env.use_utility_code(TempitaUtilityCode.load(
"BaseCaseUnsigned", "Overflow.c",
context={'UINT': type, 'NAME': type.replace(' ', '_')}))
class CAnonEnumType(CIntType):
......
This diff is collapsed.
......@@ -148,6 +148,11 @@ Cython code. Here is the list of currently supported directives:
set to ``None``. Otherwise a check is inserted and the
appropriate exception is raised. This is off by default for
performance reasons. Default is False.
``overflowcheck`` (True / False)
If set to True, raise errors on overflowing C integer arithmetic
operations. Incurs a slight runtime penalty, but much faster than
using Python ints. Default is False.
``embedsignature`` (True / False)
If set to True, Cython will embed a textual copy of the call
......
cimport cython
cdef object two = 2
cdef int size_in_bits = sizeof(INT) * 8
cdef bint is_signed_ = (<INT>-1 < 0)
cdef INT max_value_ = <INT>(two ** (size_in_bits - is_signed_) - 1)
cdef INT min_value_ = ~max_value_
cdef INT half_ = max_value_ // 2
# Python visible.
is_signed = is_signed_
max_value = max_value_
min_value = min_value_
half = half_
import operator
from libc.math cimport sqrt
cpdef check(func, op, a, b):
cdef INT res, op_res
cdef bint func_overflow = False
cdef bint assign_overflow = False
try:
res = func(a, b)
except OverflowError:
func_overflow = True
try:
op_res = op(a, b)
except OverflowError:
assign_overflow = True
assert func_overflow == assign_overflow, "Inconsistant overflow: %s(%s, %s)" % (func, a, b)
if not func_overflow:
assert res == op_res, "Inconsistant values: %s(%s, %s) == %s != %s" % (func, a, b, res, op_res)
medium_values = (max_value_ / 2, max_value_ / 3, min_value_ / 2, <INT>sqrt(max_value_) - 1, <INT>sqrt(max_value_) + 1)
def run_test(func, op):
cdef INT offset, b
check(func, op, 300, 200)
check(func, op, max_value_, max_value_)
check(func, op, max_value_, min_value_)
if not is_signed_ or not func is test_sub:
check(func, op, min_value_, min_value_)
for offset in range(5):
check(func, op, max_value_ - 1, offset)
check(func, op, min_value_ + 1, offset)
if is_signed_:
check(func, op, max_value_ - 1, 2 - offset)
check(func, op, min_value_ + 1, 2 - offset)
for offset in range(9):
check(func, op, max_value_ / 2, offset)
check(func, op, min_value_ / 3, offset)
check(func, op, max_value_ / 4, offset)
check(func, op, min_value_ / 5, offset)
if is_signed_:
check(func, op, max_value_ / 2, 4 - offset)
check(func, op, min_value_ / 3, 4 - offset)
check(func, op, max_value_ / -4, 3 - offset)
check(func, op, min_value_ / -5, 3 - offset)
for offset in range(-3, 4):
for a in medium_values:
for b in medium_values:
check(func, op, a, b + offset)
@cython.overflowcheck(True)
def test_add(INT a, INT b):
"""
>>> test_add(1, 2)
3
>>> test_add(max_value, max_value) #doctest: +ELLIPSIS
Traceback (most recent call last):
...
OverflowError: value too large
>>> run_test(test_add, operator.add)
"""
return int(a + b)
@cython.overflowcheck(True)
def test_sub(INT a, INT b):
"""
>>> test_sub(10, 1)
9
>>> test_sub(min_value, 1) #doctest: +ELLIPSIS
Traceback (most recent call last):
...
OverflowError: value too large
>>> run_test(test_sub, operator.sub)
"""
return int(a - b)
@cython.overflowcheck(True)
def test_mul(INT a, INT b):
"""
>>> test_mul(11, 13)
143
>>> test_mul(max_value / 2, max_value / 2) #doctest: +ELLIPSIS
Traceback (most recent call last):
...
OverflowError: value too large
>>> run_test(test_mul, operator.mul)
"""
return int(a * b)
@cython.overflowcheck(True)
def test_nested(INT a, INT b, INT c):
"""
>>> test_nested(1, 2, 3)
6
>>> expect_overflow(test_nested, half + 1, half + 1, half + 1)
>>> expect_overflow(test_nested, half - 1, half - 1, half - 1)
"""
return int(a + b + c)
def expect_overflow(func, *args):
try:
res = func(*args)
except OverflowError:
return
assert False, "Expected OverflowError, got %s" % res
cpdef format(INT value):
"""
>>> format(1)
'1'
>>> format(half - 1)
'half - 1'
>>> format(half)
'half'
>>> format(half + 2)
'half + 2'
>>> format(half + half - 3)
'half + half - 3'
>>> format(max_value)
'max_value'
"""
if value == max_value_:
return "max_value"
elif value == half_:
return "half"
elif max_value_ - value <= max_value_ // 4:
return "half + half - %s" % (half_ + half_ - value)
elif max_value_ - value <= half_:
return "half + %s" % (value - half_)
elif max_value_ - value <= half_ + max_value_ // 4:
return "half - %s" % (half_ - value)
else:
return "%s" % value
cdef INT called(INT value):
print("called(%s)" % format(value))
return value
@cython.overflowcheck(True)
def test_nested_func(INT a, INT b, INT c):
"""
>>> test_nested_func(1, 2, 3)
called(5)
6
>>> expect_overflow(test_nested_func, half + 1, half + 1, half + 1)
>>> expect_overflow(test_nested_func, half - 1, half - 1, half - 1)
called(half + half - 2)
>>> print(format(test_nested_func(1, half - 1, half - 1)))
called(half + half - 2)
half + half - 1
>>>
"""
return int(a + called(b + c))
@cython.overflowcheck(True)
def test_add_const(INT a):
"""
>>> test_add_const(1)
101
>>> expect_overflow(test_add_const, max_value)
>>> expect_overflow(test_add_const , max_value - 99)
>>> test_add_const(max_value - 100) == max_value
True
"""
return int(a + <INT>100)
@cython.overflowcheck(True)
def test_sub_const(INT a):
"""
>>> test_sub_const(101)
1
>>> expect_overflow(test_sub_const, min_value)
>>> expect_overflow(test_sub_const, min_value + 99)
>>> test_sub_const(min_value + 100) == min_value
True
"""
return int(a - <INT>100)
@cython.overflowcheck(True)
def test_mul_const(INT a):
"""
>>> test_mul_const(2)
200
>>> expect_overflow(test_mul_const, max_value)
>>> expect_overflow(test_mul_const, max_value // 99)
>>> test_mul_const(max_value // 100) == max_value - max_value % 100
True
"""
return int(a * <INT>100)
@cython.overflowcheck(True)
def test_lshift(INT a, int b):
"""
>>> test_lshift(1, 10)
1024
>>> expect_overflow(test_lshift, 1, 100)
>>> expect_overflow(test_lshift, max_value, 1)
>>> test_lshift(max_value, 0) == max_value
True
>>> check(test_lshift, operator.lshift, 10, 15)
>>> check(test_lshift, operator.lshift, 10, 30)
>>> check(test_lshift, operator.lshift, 100, 60)
"""
return int(a << b)
ctypedef int INT
include "overflow_check.pxi"
ctypedef long long INT
include "overflow_check.pxi"
ctypedef unsigned int INT
include "overflow_check.pxi"
ctypedef unsigned long long INT
include "overflow_check.pxi"
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment