Commit 9b19cce7 authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Buffers: Better error messages + bugfix

parent 59cfae38
...@@ -520,8 +520,7 @@ def create_typestringchecker(protocode, defcode, name, dtype): ...@@ -520,8 +520,7 @@ def create_typestringchecker(protocode, defcode, name, dtype):
def put_assert(cond, msg): def put_assert(cond, msg):
defcode.putln("if (!(%s)) {" % cond) defcode.putln("if (!(%s)) {" % cond)
msg += ", got '%s'" defcode.putln('PyErr_Format(PyExc_ValueError, "Buffer dtype mismatch (%s)", __Pyx_DescribeTokenInFormatString(ts));' % msg)
defcode.putln('PyErr_Format(PyExc_ValueError, "%s", ts);' % msg)
defcode.putln("return NULL;") defcode.putln("return NULL;")
defcode.putln("}") defcode.putln("}")
...@@ -583,14 +582,7 @@ def create_typestringchecker(protocode, defcode, name, dtype): ...@@ -583,14 +582,7 @@ def create_typestringchecker(protocode, defcode, name, dtype):
(char, ctype, against, ctype)) (char, ctype, against, ctype))
defcode.putln("default: ok = 0;") defcode.putln("default: ok = 0;")
defcode.putln("}") defcode.putln("}")
defcode.putln("if (!ok) {") put_assert("ok", "expected %s, got %%s" % dtype)
if dtype.typestring is not None:
errmsg = "Buffer datatype mismatch (expected '%s', got '%%s')" % dtype.typestring
else:
errmsg = "Buffer datatype mismatch (rejecting on '%s')"
defcode.putln('PyErr_Format(PyExc_ValueError, "%s", ts);' % errmsg)
defcode.putln("return NULL;");
defcode.putln("}")
defcode.putln("++ts;") defcode.putln("++ts;")
elif complex_possible: elif complex_possible:
# Could be a struct representing a complex number, so allow # Could be a struct representing a complex number, so allow
...@@ -605,9 +597,7 @@ def create_typestringchecker(protocode, defcode, name, dtype): ...@@ -605,9 +597,7 @@ def create_typestringchecker(protocode, defcode, name, dtype):
real_t.declaration_code(""), real_t.declaration_code(""),
imag_t.declaration_code(""))) imag_t.declaration_code("")))
defcode.putln('PyErr_SetString(PyExc_ValueError, "Cannot store complex number in \'%s\' as \'%s\' differs from \'%s\' in size.");' % ( defcode.putln('PyErr_SetString(PyExc_ValueError, "Cannot store complex number in \'%s\' as \'%s\' differs from \'%s\' in size.");' % (
dtype.declaration_code("", for_display=True), dtype, real_t, imag_t))
real_t.declaration_code("", for_display=True),
imag_t.declaration_code("", for_display=True)))
defcode.putln("return NULL;") defcode.putln("return NULL;")
defcode.putln("}") defcode.putln("}")
check_real, check_imag = [x[2] for x in field_blocks] check_real, check_imag = [x[2] for x in field_blocks]
...@@ -624,21 +614,23 @@ def create_typestringchecker(protocode, defcode, name, dtype): ...@@ -624,21 +614,23 @@ def create_typestringchecker(protocode, defcode, name, dtype):
defcode.putln("int n, count;") defcode.putln("int n, count;")
defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;") defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
for n, type, checker in field_blocks: next_types = [x[1] for x in field_blocks[1:]] + ["end"]
for (n, type, checker), next_type in zip(field_blocks, next_types):
if n == 1: if n == 1:
defcode.putln("if (*ts == '1') ++ts;") defcode.putln("if (*ts == '1') ++ts;")
else: else:
defcode.putln("n = %d;" % n); defcode.putln("n = %d;" % n);
defcode.putln("do {") defcode.putln("do {")
defcode.putln("ts = __Pyx_ParseTypestringRepeat(ts, &count); n -= count;") defcode.putln("ts = __Pyx_ParseTypestringRepeat(ts, &count); n -= count;")
put_assert("n >= 0", "expected %s, got %%s" % next_type)
simple = type.is_simple_buffer_dtype() simple = type.is_simple_buffer_dtype()
if not simple: if not simple:
put_assert("*ts == 'T' && *(ts+1) == '{'", "Expected start of %s" % type.declaration_code("", for_display=True)) put_assert("*ts == 'T' && *(ts+1) == '{'", "expected %s, got %%s" % type)
defcode.putln("ts += 2;") defcode.putln("ts += 2;")
defcode.putln("ts = %s(ts); if (!ts) return NULL;" % checker) defcode.putln("ts = %s(ts); if (!ts) return NULL;" % checker)
if not simple: if not simple:
put_assert("*ts == '}'", "Expected end of '%s'" % type.declaration_code("", for_display=True)) put_assert("*ts == '}'", "expected end of %s struct, got %%s" % type)
defcode.putln("++ts;") defcode.putln("++ts;")
if n > 1: if n > 1:
...@@ -689,7 +681,8 @@ def get_getbuffer_code(dtype, code): ...@@ -689,7 +681,8 @@ def get_getbuffer_code(dtype, code):
if (!ts) goto fail; if (!ts) goto fail;
if (*ts != 0) { if (*ts != 0) {
PyErr_Format(PyExc_ValueError, PyErr_Format(PyExc_ValueError,
"Buffer format string specifies more data than '%(dtype_name)s' can hold (expected end, got '%%s')", ts); "Buffer dtype mismatch (expected end, got %%s)",
__Pyx_DescribeTokenInFormatString(ts));
goto fail; goto fail;
} }
} else { } else {
...@@ -822,6 +815,7 @@ static INLINE void __Pyx_SafeReleaseBuffer(Py_buffer* info); ...@@ -822,6 +815,7 @@ static INLINE void __Pyx_SafeReleaseBuffer(Py_buffer* info);
static INLINE void __Pyx_ZeroBuffer(Py_buffer* buf); /*proto*/ static INLINE void __Pyx_ZeroBuffer(Py_buffer* buf); /*proto*/
static INLINE const char* __Pyx_ConsumeWhitespace(const char* ts); /*proto*/ static INLINE const char* __Pyx_ConsumeWhitespace(const char* ts); /*proto*/
static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim); /*proto*/ static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim); /*proto*/
static const char* __Pyx_DescribeTokenInFormatString(const char* ts); /*proto*/
""", """ """, """
static INLINE void __Pyx_SafeReleaseBuffer(Py_buffer* info) { static INLINE void __Pyx_SafeReleaseBuffer(Py_buffer* info) {
if (info->buf == NULL) return; if (info->buf == NULL) return;
...@@ -864,6 +858,34 @@ static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim) { ...@@ -864,6 +858,34 @@ static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim) {
expected_ndim, buffer->ndim); expected_ndim, buffer->ndim);
} }
static const char* __Pyx_DescribeTokenInFormatString(const char* ts) {
switch (*ts) {
case 'b': return "char";
case 'B': return "unsigned char";
case 'h': return "short";
case 'H': return "unsigned short";
case 'i': return "int";
case 'I': return "unsigned int";
case 'l': return "long";
case 'L': return "unsigned long";
case 'q': return "long long";
case 'Q': return "unsigned long long";
case 'f': return "float";
case 'd': return "double";
case 'g': return "long double";
case 'Z': switch (*(ts+1)) {
case 'f': return "complex float";
case 'd': return "complex double";
case 'g': return "complex long double";
default: return "unparseable format string";
}
case 'T': return "a struct";
case 'O': return "Python object";
case 'P': return "a pointer";
default: return "unparseable format string";
}
}
"""] """]
......
...@@ -373,14 +373,14 @@ def alignment_string(object[int] buf): ...@@ -373,14 +373,14 @@ def alignment_string(object[int] buf):
@testcase @testcase
def wrong_string(object[int] buf): def wrong_string(object[int] buf):
""" """
>>> wrong_string(IntMockBuffer(None, [1,2], format="iasdf")) >>> wrong_string(IntMockBuffer(None, [1,2], format="if"))
Traceback (most recent call last): Traceback (most recent call last):
... ...
ValueError: Buffer format string specifies more data than 'int' can hold (expected end, got 'asdf') ValueError: Buffer dtype mismatch (expected end, got float)
>>> wrong_string(IntMockBuffer(None, [1,2], format="$$")) >>> wrong_string(IntMockBuffer(None, [1,2], format="$$"))
Traceback (most recent call last): Traceback (most recent call last):
... ...
ValueError: Buffer datatype mismatch (expected 'i', got '$$') ValueError: Buffer dtype mismatch (expected int, got unparseable format string)
""" """
print buf[1] print buf[1]
...@@ -532,7 +532,7 @@ def fmtst1(buf): ...@@ -532,7 +532,7 @@ def fmtst1(buf):
>>> fmtst1(IntMockBuffer("A", range(3))) >>> fmtst1(IntMockBuffer("A", range(3)))
Traceback (most recent call last): Traceback (most recent call last):
... ...
ValueError: Buffer datatype mismatch (expected 'f', got 'i') ValueError: Buffer dtype mismatch (expected float, got int)
""" """
cdef object[float] a = buf cdef object[float] a = buf
...@@ -542,7 +542,7 @@ def fmtst2(object[int] buf): ...@@ -542,7 +542,7 @@ def fmtst2(object[int] buf):
>>> fmtst2(FloatMockBuffer("A", range(3))) >>> fmtst2(FloatMockBuffer("A", range(3)))
Traceback (most recent call last): Traceback (most recent call last):
... ...
ValueError: Buffer datatype mismatch (expected 'i', got 'f') ValueError: Buffer dtype mismatch (expected int, got float)
""" """
@testcase @testcase
...@@ -849,7 +849,7 @@ def printbuf_td_cy_int(object[td_cy_int] buf, shape): ...@@ -849,7 +849,7 @@ def printbuf_td_cy_int(object[td_cy_int] buf, shape):
>>> printbuf_td_cy_int(ShortMockBuffer(None, range(3)), (3,)) >>> printbuf_td_cy_int(ShortMockBuffer(None, range(3)), (3,))
Traceback (most recent call last): Traceback (most recent call last):
... ...
ValueError: Buffer datatype mismatch (rejecting on 'h') ValueError: Buffer dtype mismatch (expected bufaccess.td_cy_int, got short)
""" """
cdef int i cdef int i
...@@ -865,7 +865,7 @@ def printbuf_td_h_short(object[td_h_short] buf, shape): ...@@ -865,7 +865,7 @@ def printbuf_td_h_short(object[td_h_short] buf, shape):
>>> printbuf_td_h_short(IntMockBuffer(None, range(3)), (3,)) >>> printbuf_td_h_short(IntMockBuffer(None, range(3)), (3,))
Traceback (most recent call last): Traceback (most recent call last):
... ...
ValueError: Buffer datatype mismatch (rejecting on 'i') ValueError: Buffer dtype mismatch (expected bufaccess.td_h_short, got int)
""" """
cdef int i cdef int i
for i in range(shape[0]): for i in range(shape[0]):
...@@ -880,7 +880,7 @@ def printbuf_td_h_cy_short(object[td_h_cy_short] buf, shape): ...@@ -880,7 +880,7 @@ def printbuf_td_h_cy_short(object[td_h_cy_short] buf, shape):
>>> printbuf_td_h_cy_short(IntMockBuffer(None, range(3)), (3,)) >>> printbuf_td_h_cy_short(IntMockBuffer(None, range(3)), (3,))
Traceback (most recent call last): Traceback (most recent call last):
... ...
ValueError: Buffer datatype mismatch (rejecting on 'i') ValueError: Buffer dtype mismatch (expected bufaccess.td_h_cy_short, got int)
""" """
cdef int i cdef int i
for i in range(shape[0]): for i in range(shape[0]):
...@@ -895,7 +895,7 @@ def printbuf_td_h_ushort(object[td_h_ushort] buf, shape): ...@@ -895,7 +895,7 @@ def printbuf_td_h_ushort(object[td_h_ushort] buf, shape):
>>> printbuf_td_h_ushort(ShortMockBuffer(None, range(3)), (3,)) >>> printbuf_td_h_ushort(ShortMockBuffer(None, range(3)), (3,))
Traceback (most recent call last): Traceback (most recent call last):
... ...
ValueError: Buffer datatype mismatch (rejecting on 'h') ValueError: Buffer dtype mismatch (expected bufaccess.td_h_ushort, got short)
""" """
cdef int i cdef int i
for i in range(shape[0]): for i in range(shape[0]):
...@@ -910,7 +910,7 @@ def printbuf_td_h_double(object[td_h_double] buf, shape): ...@@ -910,7 +910,7 @@ def printbuf_td_h_double(object[td_h_double] buf, shape):
>>> printbuf_td_h_double(FloatMockBuffer(None, [0.25, 1, 3.125]), (3,)) >>> printbuf_td_h_double(FloatMockBuffer(None, [0.25, 1, 3.125]), (3,))
Traceback (most recent call last): Traceback (most recent call last):
... ...
ValueError: Buffer datatype mismatch (rejecting on 'f') ValueError: Buffer dtype mismatch (expected bufaccess.td_h_double, got float)
""" """
cdef int i cdef int i
for i in range(shape[0]): for i in range(shape[0]):
...@@ -1328,10 +1328,14 @@ def basic_struct(object[MyStruct] buf): ...@@ -1328,10 +1328,14 @@ def basic_struct(object[MyStruct] buf):
1 2 3 4 5 1 2 3 4 5
>>> basic_struct(MyStructMockBuffer(None, [(1, 2, 3, 4, 5)], format="bbqii")) >>> basic_struct(MyStructMockBuffer(None, [(1, 2, 3, 4, 5)], format="bbqii"))
1 2 3 4 5 1 2 3 4 5
>>> basic_struct(MyStructMockBuffer(None, [(1, 2, 3, 4, 5)], format="23bqii"))
Traceback (most recent call last):
...
ValueError: Buffer dtype mismatch (expected long long, got char)
>>> basic_struct(MyStructMockBuffer(None, [(1, 2, 3, 4, 5)], format="i")) >>> basic_struct(MyStructMockBuffer(None, [(1, 2, 3, 4, 5)], format="i"))
Traceback (most recent call last): Traceback (most recent call last):
... ...
ValueError: Buffer datatype mismatch (expected 'b', got 'i') ValueError: Buffer dtype mismatch (expected char, got int)
""" """
print buf[0].a, buf[0].b, buf[0].c, buf[0].d, buf[0].e print buf[0].a, buf[0].b, buf[0].c, buf[0].d, buf[0].e
...@@ -1345,7 +1349,11 @@ def nested_struct(object[NestedStruct] buf): ...@@ -1345,7 +1349,11 @@ def nested_struct(object[NestedStruct] buf):
>>> nested_struct(NestedStructMockBuffer(None, [(1, 2, 3, 4, 5)], format="iiiii")) >>> nested_struct(NestedStructMockBuffer(None, [(1, 2, 3, 4, 5)], format="iiiii"))
Traceback (most recent call last): Traceback (most recent call last):
... ...
ValueError: Expected start of SmallStruct, got 'iiiii' ValueError: Buffer dtype mismatch (expected SmallStruct, got int)
>>> nested_struct(NestedStructMockBuffer(None, [(1, 2, 3, 4, 5)], format="T{iii}T{ii}i"))
Traceback (most recent call last):
...
ValueError: Buffer dtype mismatch (expected end of SmallStruct struct, got int)
""" """
print buf[0].x.a, buf[0].x.b, buf[0].y.a, buf[0].y.b, buf[0].z print buf[0].x.a, buf[0].x.b, buf[0].y.a, buf[0].y.b, buf[0].z
......
...@@ -144,7 +144,7 @@ try: ...@@ -144,7 +144,7 @@ try:
]))) ])))
Traceback (most recent call last): Traceback (most recent call last):
... ...
ValueError: Buffer datatype mismatch (expected 'i', got 'f}T{ii}') ValueError: Buffer dtype mismatch (expected int, got float)
>>> test_good_cast() >>> test_good_cast()
True True
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment