Commit a5e1aea2 authored by Robert Bradshaw's avatar Robert Bradshaw

merge

parents 123003db 5ecfc44d
...@@ -113,8 +113,8 @@ class IntroduceBufferAuxiliaryVars(CythonTransform): ...@@ -113,8 +113,8 @@ class IntroduceBufferAuxiliaryVars(CythonTransform):
# #
# Analysis # Analysis
# #
buffer_options = ("dtype", "ndim", "mode") # ordered! buffer_options = ("dtype", "ndim", "mode", "negative_indices", "cast") # ordered!
buffer_defaults = {"ndim": 1, "mode": "full"} buffer_defaults = {"ndim": 1, "mode": "full", "negative_indices": True, "cast": False}
buffer_positional_options_count = 1 # anything beyond this needs keyword argument buffer_positional_options_count = 1 # anything beyond this needs keyword argument
ERR_BUF_OPTION_UNKNOWN = '"%s" is not a buffer option' ERR_BUF_OPTION_UNKNOWN = '"%s" is not a buffer option'
...@@ -124,6 +124,7 @@ ERR_BUF_MISSING = '"%s" missing' ...@@ -124,6 +124,7 @@ ERR_BUF_MISSING = '"%s" missing'
ERR_BUF_MODE = 'Only allowed buffer modes are: "c", "fortran", "full", "strided" (as a compile-time string)' ERR_BUF_MODE = 'Only allowed buffer modes are: "c", "fortran", "full", "strided" (as a compile-time string)'
ERR_BUF_NDIM = 'ndim must be a non-negative integer' ERR_BUF_NDIM = 'ndim must be a non-negative integer'
ERR_BUF_DTYPE = 'dtype must be "object", numeric type or a struct' ERR_BUF_DTYPE = 'dtype must be "object", numeric type or a struct'
ERR_BUF_BOOL = '"%s" must be a boolean'
def analyse_buffer_options(globalpos, env, posargs, dictargs, defaults=None, need_complete=True): def analyse_buffer_options(globalpos, env, posargs, dictargs, defaults=None, need_complete=True):
""" """
...@@ -178,6 +179,14 @@ def analyse_buffer_options(globalpos, env, posargs, dictargs, defaults=None, nee ...@@ -178,6 +179,14 @@ def analyse_buffer_options(globalpos, env, posargs, dictargs, defaults=None, nee
if mode and not (mode in ('full', 'strided', 'c', 'fortran')): if mode and not (mode in ('full', 'strided', 'c', 'fortran')):
raise CompileError(globalpos, ERR_BUF_MODE) raise CompileError(globalpos, ERR_BUF_MODE)
def assert_bool(name):
x = options.get(name)
if not isinstance(x, bool):
raise CompileError(globalpos, ERR_BUF_BOOL % name)
assert_bool('negative_indices')
assert_bool('cast')
return options return options
...@@ -229,13 +238,15 @@ def put_acquire_arg_buffer(entry, code, pos): ...@@ -229,13 +238,15 @@ def put_acquire_arg_buffer(entry, code, pos):
code.globalstate.use_utility_code(acquire_utility_code) code.globalstate.use_utility_code(acquire_utility_code)
buffer_aux = entry.buffer_aux buffer_aux = entry.buffer_aux
getbuffer_cname = get_getbuffer_code(entry.type.dtype, code) getbuffer_cname = get_getbuffer_code(entry.type.dtype, code)
# Acquire any new buffer # Acquire any new buffer
code.putln(code.error_goto_if("%s((PyObject*)%s, &%s, %s, %d) == -1" % ( code.putln(code.error_goto_if("%s((PyObject*)%s, &%s, %s, %d, %d) == -1" % (
getbuffer_cname, getbuffer_cname,
entry.cname, entry.cname,
entry.buffer_aux.buffer_info_var.cname, entry.buffer_aux.buffer_info_var.cname,
get_flags(buffer_aux, entry.type), get_flags(buffer_aux, entry.type),
entry.type.ndim), pos)) entry.type.ndim,
int(entry.type.cast)), pos))
# An exception raised in arg parsing cannot be catched, so no # An exception raised in arg parsing cannot be catched, so no
# need to care about the buffer then. # need to care about the buffer then.
put_unpack_buffer_aux_into_scope(buffer_aux, entry.type.mode, code) put_unpack_buffer_aux_into_scope(buffer_aux, entry.type.mode, code)
...@@ -269,11 +280,12 @@ def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type, ...@@ -269,11 +280,12 @@ def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type,
bufstruct = buffer_aux.buffer_info_var.cname bufstruct = buffer_aux.buffer_info_var.cname
flags = get_flags(buffer_aux, buffer_type) flags = get_flags(buffer_aux, buffer_type)
getbuffer = "%s((PyObject*)%%s, &%s, %s, %d)" % (get_getbuffer_code(buffer_type.dtype, code), getbuffer = "%s((PyObject*)%%s, &%s, %s, %d, %d)" % (get_getbuffer_code(buffer_type.dtype, code),
# note: object is filled in later (%%s) # note: object is filled in later (%%s)
bufstruct, bufstruct,
flags, flags,
buffer_type.ndim) buffer_type.ndim,
int(buffer_type.cast))
if is_initialized: if is_initialized:
# Release any existing buffer # Release any existing buffer
...@@ -336,6 +348,7 @@ def put_buffer_lookup_code(entry, index_signeds, index_cnames, options, pos, cod ...@@ -336,6 +348,7 @@ def put_buffer_lookup_code(entry, index_signeds, index_cnames, options, pos, cod
""" """
bufaux = entry.buffer_aux bufaux = entry.buffer_aux
bufstruct = bufaux.buffer_info_var.cname bufstruct = bufaux.buffer_info_var.cname
negative_indices = entry.type.negative_indices
if options['boundscheck']: if options['boundscheck']:
# Check bounds and fix negative indices. # Check bounds and fix negative indices.
...@@ -349,9 +362,12 @@ def put_buffer_lookup_code(entry, index_signeds, index_cnames, options, pos, cod ...@@ -349,9 +362,12 @@ def put_buffer_lookup_code(entry, index_signeds, index_cnames, options, pos, cod
if signed != 0: if signed != 0:
# not unsigned, deal with negative index # not unsigned, deal with negative index
code.putln("if (%s < 0) {" % cname) code.putln("if (%s < 0) {" % cname)
if negative_indices:
code.putln("%s += %s;" % (cname, shape.cname)) code.putln("%s += %s;" % (cname, shape.cname))
code.putln("if (%s) %s = %d;" % ( code.putln("if (%s) %s = %d;" % (
code.unlikely("%s < 0" % cname), tmp_cname, dim)) code.unlikely("%s < 0" % cname), tmp_cname, dim))
else:
code.putln("%s = %d;" % (tmp_cname, dim))
code.put("} else ") code.put("} else ")
# check bounds in positive direction # check bounds in positive direction
code.putln("if (%s) %s = %d;" % ( code.putln("if (%s) %s = %d;" % (
...@@ -364,7 +380,7 @@ def put_buffer_lookup_code(entry, index_signeds, index_cnames, options, pos, cod ...@@ -364,7 +380,7 @@ def put_buffer_lookup_code(entry, index_signeds, index_cnames, options, pos, cod
code.putln(code.error_goto(pos)) code.putln(code.error_goto(pos))
code.end_block() code.end_block()
code.funcstate.release_temp(tmp_cname) code.funcstate.release_temp(tmp_cname)
else: elif negative_indices:
# Only fix negative indices. # Only fix negative indices.
for signed, cname, shape in zip(index_signeds, index_cnames, for signed, cname, shape in zip(index_signeds, index_cnames,
bufaux.shapevars): bufaux.shapevars):
...@@ -563,10 +579,11 @@ def get_getbuffer_code(dtype, code): ...@@ -563,10 +579,11 @@ def get_getbuffer_code(dtype, code):
if not code.globalstate.has_utility_code(name): if not code.globalstate.has_utility_code(name):
code.globalstate.use_utility_code(acquire_utility_code) code.globalstate.use_utility_code(acquire_utility_code)
itemchecker = get_ts_check_item(dtype, code) itemchecker = get_ts_check_item(dtype, code)
dtype_cname = dtype.declaration_code("")
utilcode = [dedent(""" utilcode = [dedent("""
static int %s(PyObject* obj, Py_buffer* buf, int flags, int nd); /*proto*/ static int %s(PyObject* obj, Py_buffer* buf, int flags, int nd, int cast); /*proto*/
""") % name, dedent(""" """) % name, dedent("""
static int %(name)s(PyObject* obj, Py_buffer* buf, int flags, int nd) { static int %(name)s(PyObject* obj, Py_buffer* buf, int flags, int nd, int cast) {
const char* ts; const char* ts;
if (obj == Py_None) { if (obj == Py_None) {
__Pyx_ZeroBuffer(buf); __Pyx_ZeroBuffer(buf);
...@@ -578,6 +595,7 @@ def get_getbuffer_code(dtype, code): ...@@ -578,6 +595,7 @@ def get_getbuffer_code(dtype, code):
__Pyx_BufferNdimError(buf, nd); __Pyx_BufferNdimError(buf, nd);
goto fail; goto fail;
} }
if (!cast) {
ts = buf->format; ts = buf->format;
ts = __Pyx_ConsumeWhitespace(ts); ts = __Pyx_ConsumeWhitespace(ts);
if (!ts) goto fail; if (!ts) goto fail;
...@@ -590,6 +608,13 @@ def get_getbuffer_code(dtype, code): ...@@ -590,6 +608,13 @@ def get_getbuffer_code(dtype, code):
"Expected non-struct buffer data type (expected end, got '%%s')", ts); "Expected non-struct buffer data type (expected end, got '%%s')", ts);
goto fail; goto fail;
} }
} else {
if (buf->itemsize != sizeof(%(dtype_cname)s)) {
PyErr_SetString(PyExc_ValueError,
"Attempted cast of buffer to datatype of different size.");
goto fail;
}
}
if (buf->suboffsets == NULL) buf->suboffsets = __Pyx_minusones; if (buf->suboffsets == NULL) buf->suboffsets = __Pyx_minusones;
return 0; return 0;
fail:; fail:;
......
...@@ -200,17 +200,21 @@ class BufferType(BaseType): ...@@ -200,17 +200,21 @@ class BufferType(BaseType):
# dtype PyrexType # dtype PyrexType
# ndim int # ndim int
# mode str # mode str
# is_buffer boolean # negative_indices bool
# writable boolean # cast bool
# is_buffer bool
# writable bool
is_buffer = 1 is_buffer = 1
writable = True writable = True
def __init__(self, base, dtype, ndim, mode): def __init__(self, base, dtype, ndim, mode, negative_indices, cast):
self.base = base self.base = base
self.dtype = dtype self.dtype = dtype
self.ndim = ndim self.ndim = ndim
self.buffer_ptr_type = CPtrType(dtype) self.buffer_ptr_type = CPtrType(dtype)
self.mode = mode self.mode = mode
self.negative_indices = negative_indices
self.cast = cast
def as_argument_type(self): def as_argument_type(self):
return self return self
......
...@@ -473,6 +473,25 @@ def list_comprehension(object[int] buf, len): ...@@ -473,6 +473,25 @@ def list_comprehension(object[int] buf, len):
cdef int i cdef int i
print u"|".join([unicode(buf[i]) for i in range(len)]) print u"|".join([unicode(buf[i]) for i in range(len)])
#
# The negative_indices buffer option
#
@testcase
def no_negative_indices(object[int, negative_indices=False] buf, int idx):
"""
The most interesting thing here is to inspect the C source and
make sure optimal code is produced.
>>> A = IntMockBuffer(None, range(6))
>>> no_negative_indices(A, 3)
3
>>> no_negative_indices(A, -1)
Traceback (most recent call last):
...
IndexError: Out of bounds on buffer access (axis 0)
"""
return buf[idx]
# #
# Buffer type mismatch examples. Varying the type and access # Buffer type mismatch examples. Varying the type and access
# method simultaneously, the odds of an interaction is virtually # method simultaneously, the odds of an interaction is virtually
...@@ -635,7 +654,7 @@ def safe_get(object[int] buf, int idx): ...@@ -635,7 +654,7 @@ def safe_get(object[int] buf, int idx):
return buf[idx] return buf[idx]
@testcase @testcase
@cython.boundscheck(False) @cython.boundscheck(False) # outer decorators should take precedence
@cython.boundscheck(True) @cython.boundscheck(True)
def unsafe_get(object[int] buf, int idx): def unsafe_get(object[int] buf, int idx):
""" """
...@@ -650,6 +669,18 @@ def unsafe_get(object[int] buf, int idx): ...@@ -650,6 +669,18 @@ def unsafe_get(object[int] buf, int idx):
""" """
return buf[idx] return buf[idx]
@testcase
@cython.boundscheck(False)
def unsafe_get_nonegative(object[int, negative_indices=False] buf, int idx):
"""
Also inspect the C source to see that it is optimal...
>>> A = IntMockBuffer(None, range(10), shape=(3,), offset=5)
>>> unsafe_get_nonegative(A, -2)
3
"""
return buf[idx]
@testcase @testcase
def mixed_get(object[int] buf, int unsafe_idx, int safe_idx): def mixed_get(object[int] buf, int unsafe_idx, int safe_idx):
""" """
...@@ -912,7 +943,32 @@ def assign_to_object(object[object] buf, int idx, obj): ...@@ -912,7 +943,32 @@ def assign_to_object(object[object] buf, int idx, obj):
""" """
buf[idx] = obj buf[idx] = obj
#
# cast option
#
@testcase
def buffer_cast(object[unsigned int, cast=True] buf, int idx):
"""
Round-trip a signed int through unsigned int buffer access.
>>> A = IntMockBuffer(None, [-100])
>>> buffer_cast(A, 0)
-100
"""
cdef unsigned int data = buf[idx]
return <int>data
@testcase
def buffer_cast_fails(object[char, cast=True] buf):
"""
Cannot cast between datatype of different sizes.
>>> buffer_cast_fails(IntMockBuffer(None, [0]))
Traceback (most recent call last):
...
ValueError: Attempted cast of buffer to datatype of different size.
"""
return buf[0]
# #
...@@ -1070,6 +1126,13 @@ cdef class MockBuffer: ...@@ -1070,6 +1126,13 @@ cdef class MockBuffer:
cdef get_default_format(self): cdef get_default_format(self):
print "ERROR, not subclassed", self.__class__ print "ERROR, not subclassed", self.__class__
cdef class CharMockBuffer(MockBuffer):
cdef int write(self, char* buf, object value) except -1:
(<char*>buf)[0] = <int>value
return 0
cdef get_itemsize(self): return sizeof(char)
cdef get_default_format(self): return b"@b"
cdef class IntMockBuffer(MockBuffer): cdef class IntMockBuffer(MockBuffer):
cdef int write(self, char* buf, object value) except -1: cdef int write(self, char* buf, object value) except -1:
(<int*>buf)[0] = <int>value (<int*>buf)[0] = <int>value
...@@ -1077,6 +1140,13 @@ cdef class IntMockBuffer(MockBuffer): ...@@ -1077,6 +1140,13 @@ cdef class IntMockBuffer(MockBuffer):
cdef get_itemsize(self): return sizeof(int) cdef get_itemsize(self): return sizeof(int)
cdef get_default_format(self): return b"@i" cdef get_default_format(self): return b"@i"
cdef class UnsignedIntMockBuffer(MockBuffer):
cdef int write(self, char* buf, object value) except -1:
(<unsigned int*>buf)[0] = <unsigned int>value
return 0
cdef get_itemsize(self): return sizeof(unsigned int)
cdef get_default_format(self): return b"@I"
cdef class ShortMockBuffer(MockBuffer): cdef class ShortMockBuffer(MockBuffer):
cdef int write(self, char* buf, object value) except -1: cdef int write(self, char* buf, object value) except -1:
(<short*>buf)[0] = <short>value (<short*>buf)[0] = <short>value
......
...@@ -138,6 +138,13 @@ try: ...@@ -138,6 +138,13 @@ try:
... ...
ValueError: only objects, int and float dtypes supported for ndarray buffer access so far (dtype is 20) ValueError: only objects, int and float dtypes supported for ndarray buffer access so far (dtype is 20)
>>> test_good_cast()
True
>>> test_bad_cast()
Traceback (most recent call last):
...
ValueError: Attempted cast of buffer to datatype of different size.
""" """
except: except:
__doc__ = "" __doc__ = ""
...@@ -225,3 +232,15 @@ def test_dtype(dtype, inc1): ...@@ -225,3 +232,15 @@ def test_dtype(dtype, inc1):
a = np.array([0, 10], dtype=dtype) a = np.array([0, 10], dtype=dtype)
inc1(a) inc1(a)
if a[1] != 11: print "failed!" if a[1] != 11: print "failed!"
def test_good_cast():
# Check that a signed int can round-trip through casted unsigned int access
cdef np.ndarray[unsigned int, cast=True] arr = np.array([-100], dtype='i')
cdef unsigned int data = arr[0]
return -100 == <int>data
def test_bad_cast():
# This should raise an exception
cdef np.ndarray[long, cast=True] arr = np.array([1], dtype='b')
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment