Commit bd9d0283 authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Buffers: Initial support for structs. Inplace operators broken.

parent 111f3d42
...@@ -385,7 +385,7 @@ def put_buffer_lookup_code(entry, index_signeds, index_cnames, options, pos, cod ...@@ -385,7 +385,7 @@ def put_buffer_lookup_code(entry, index_signeds, index_cnames, options, pos, cod
funcgen = buf_lookup_strided_code funcgen = buf_lookup_strided_code
# Make sure the utility code is available # Make sure the utility code is available
code.globalstate.use_generated_code(funcgen, name=funcname, nd=nd) code.globalstate.use_code_from(funcgen, name=funcname, nd=nd)
ptrcode = "%s(%s.buf, %s)" % (funcname, bufstruct, ", ".join(params)) ptrcode = "%s(%s.buf, %s)" % (funcname, bufstruct, ", ".join(params))
return entry.type.buffer_ptr_type.cast_code(ptrcode) return entry.type.buffer_ptr_type.cast_code(ptrcode)
...@@ -446,14 +446,14 @@ def mangle_dtype_name(dtype): ...@@ -446,14 +446,14 @@ def mangle_dtype_name(dtype):
def get_ts_check_item(dtype, writer): def get_ts_check_item(dtype, writer):
# See if we can consume one (unnamed) dtype as next item # See if we can consume one (unnamed) dtype as next item
# Put native types and structs in seperate namespaces (as one could create a struct named unsigned_int...) # Put native and custom types in seperate namespaces (as one could create a type named unsigned_int...)
name = "__Pyx_BufferTypestringCheck_item_%s" % mangle_dtype_name(dtype) name = "__Pyx_CheckTypestringItem_%s" % mangle_dtype_name(dtype)
if not writer.globalstate.has_utility_code(name): if not writer.globalstate.has_code(name):
char = dtype.typestring char = dtype.typestring
if char is not None: if char is not None:
assert len(char) == 1
# Can use direct comparison # Can use direct comparison
code = dedent("""\ code = dedent("""\
if (*ts == '1') ++ts;
if (*ts != '%s') { if (*ts != '%s') {
PyErr_Format(PyExc_ValueError, "Buffer datatype mismatch (expected '%s', got '%%s')", ts); PyErr_Format(PyExc_ValueError, "Buffer datatype mismatch (expected '%s', got '%%s')", ts);
return NULL; return NULL;
...@@ -465,7 +465,6 @@ def get_ts_check_item(dtype, writer): ...@@ -465,7 +465,6 @@ def get_ts_check_item(dtype, writer):
ctype = dtype.declaration_code("") ctype = dtype.declaration_code("")
code = dedent("""\ code = dedent("""\
int ok; int ok;
if (*ts == '1') ++ts;
switch (*ts) {""", 2) switch (*ts) {""", 2)
if dtype.is_int: if dtype.is_int:
types = [ types = [
...@@ -475,8 +474,7 @@ def get_ts_check_item(dtype, writer): ...@@ -475,8 +474,7 @@ def get_ts_check_item(dtype, writer):
elif dtype.is_float: elif dtype.is_float:
types = [('f', 'float'), ('d', 'double'), ('g', 'long double')] types = [('f', 'float'), ('d', 'double'), ('g', 'long double')]
else: else:
assert dtype.is_error assert False
return name
if dtype.signed == 0: if dtype.signed == 0:
code += "".join(["\n case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 > 0); break;" % code += "".join(["\n case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 > 0); break;" %
(char.upper(), ctype, against, ctype) for char, against in types]) (char.upper(), ctype, against, ctype) for char, against in types])
...@@ -503,6 +501,51 @@ def get_ts_check_item(dtype, writer): ...@@ -503,6 +501,51 @@ def get_ts_check_item(dtype, writer):
return name return name
def create_typestringchecker(protocode, defcode, name, dtype):
if dtype.is_error: return
simple = dtype.is_int or dtype.is_float or dtype.is_pyobject or dtype.is_extension_type or dtype.is_ptr
complex_possible = dtype.is_struct_or_union and dtype.can_be_complex()
# Cannot add utility code recursively...
if simple:
itemchecker = get_ts_check_item(dtype, protocode)
else:
protocode.globalstate.use_utility_code(parse_typestring_repeat_code)
fields = dtype.scope.var_entries
field_checkers = [get_ts_check_item(x.type, protocode) for x in fields]
protocode.putln("static const char* %s(const char* ts); /*proto*/" % name)
defcode.putln("static const char* %s(const char* ts) {" % name)
if simple:
defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
defcode.putln("if (*ts == '1') ++ts;")
defcode.putln("ts = %s(ts); if (!ts) return NULL;" % itemchecker)
else:
defcode.putln("int repeat; char type;")
defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
if complex_possible:
# Could be a struct representing a complex number, so allow
# for parsing a "Zf" spec.
real_t, imag_t = [x.type.declaration_code("") for x in fields]
defcode.putln("if (*ts == 'Z' && sizeof(%s) == sizeof(%s)) {" % (real_t, imag_t))
defcode.putln("ts = %s(ts + 1); if (!ts) return NULL;" % field_checkers[0])
defcode.putln("} else {")
defcode.putln('PyErr_SetString(PyExc_ValueError, "Struct buffer dtypes not implemented yet!");')
defcode.putln('return NULL;')
# Code for parsing as a struct.
# for field, checker in zip(fields, field_checkers):
# defcode.put(dedent("""\
# if (repeat == 0) {
# ts = __Pyx_ParseTypestringRepeat(ts, &repeat); if (!ts) return NULL;
# ts = %s(ts); if (!ts) return NULL;
# }
# """) % checker)
if complex_possible:
defcode.putln("}")
defcode.putln("return ts;")
defcode.putln("}")
def get_getbuffer_code(dtype, code): def get_getbuffer_code(dtype, code):
""" """
Generate a utility function for getting a buffer for the given dtype. Generate a utility function for getting a buffer for the given dtype.
...@@ -514,9 +557,15 @@ def get_getbuffer_code(dtype, code): ...@@ -514,9 +557,15 @@ def get_getbuffer_code(dtype, code):
""" """
name = "__Pyx_GetBuffer_%s" % mangle_dtype_name(dtype) name = "__Pyx_GetBuffer_%s" % mangle_dtype_name(dtype)
if not code.globalstate.has_utility_code(name): if not code.globalstate.has_code(name):
code.globalstate.use_utility_code(acquire_utility_code) code.globalstate.use_utility_code(acquire_utility_code)
itemchecker = get_ts_check_item(dtype, code)
typestringchecker = "__Pyx_CheckTypestring_%s" % mangle_dtype_name(dtype)
code.globalstate.use_code_from(create_typestringchecker,
typestringchecker,
dtype=dtype)
dtype_name = str(dtype)
utilcode = [dedent(""" utilcode = [dedent("""
static int %s(PyObject* obj, Py_buffer* buf, int flags, int nd); /*proto*/ static int %s(PyObject* obj, Py_buffer* buf, int flags, int nd); /*proto*/
""") % name, dedent(""" """) % name, dedent("""
...@@ -533,15 +582,11 @@ def get_getbuffer_code(dtype, code): ...@@ -533,15 +582,11 @@ def get_getbuffer_code(dtype, code):
goto fail; goto fail;
} }
ts = buf->format; ts = buf->format;
ts = %(typestringchecker)s(ts); if (!ts) goto fail;
ts = __Pyx_ConsumeWhitespace(ts); ts = __Pyx_ConsumeWhitespace(ts);
if (!ts) goto fail;
ts = %(itemchecker)s(ts);
if (!ts) goto fail;
ts = __Pyx_ConsumeWhitespace(ts);
if (!ts) goto fail;
if (*ts != 0) { if (*ts != 0) {
PyErr_Format(PyExc_ValueError, PyErr_Format(PyExc_ValueError,
"Expected non-struct buffer data type (expected end, got '%%s')", ts); "Buffer format string specifies more data than '%(dtype_name)s' can hold (expected end, got '%%s')", ts);
goto fail; goto fail;
} }
if (buf->suboffsets == NULL) buf->suboffsets = __Pyx_minusones; if (buf->suboffsets == NULL) buf->suboffsets = __Pyx_minusones;
...@@ -711,6 +756,26 @@ static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim) { ...@@ -711,6 +756,26 @@ static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim) {
"""] """]
parse_typestring_repeat_code = ["""
static INLINE const char* __Pyx_ParseTypestringRepeat(const char* ts, int* out_count); /*proto*/
""","""
static INLINE const char* __Pyx_ParseTypestringRepeat(const char* ts, int* out_count) {
int count;
if (*ts < '0' || *ts > '9') {
count = 1;
} else {
count = *ts++ - '0';
while (*ts >= '0' && *ts < '9') {
count *= 10;
count += *ts++ - '0';
}
}
*out_count = count;
return ts;
}
"""]
raise_buffer_fallback_code = [""" raise_buffer_fallback_code = ["""
static void __Pyx_RaiseBufferFallbackError(void); /*proto*/ static void __Pyx_RaiseBufferFallbackError(void); /*proto*/
""",""" ""","""
......
...@@ -168,6 +168,7 @@ class GlobalState(object): ...@@ -168,6 +168,7 @@ class GlobalState(object):
self.used_utility_code = set() self.used_utility_code = set()
self.declared_cnames = {} self.declared_cnames = {}
self.pystring_table_needed = False self.pystring_table_needed = False
self.in_utility_code_generation = False
def initwriters(self, rootwriter): def initwriters(self, rootwriter):
self.utilprotowriter = rootwriter.new_writer() self.utilprotowriter = rootwriter.new_writer()
...@@ -344,10 +345,10 @@ class GlobalState(object): ...@@ -344,10 +345,10 @@ class GlobalState(object):
self.utilprotowriter.put(proto) self.utilprotowriter.put(proto)
self.utildefwriter.put(_def) self.utildefwriter.put(_def)
def has_utility_code(self, name): def has_code(self, name):
return name in self.used_utility_code return name in self.used_utility_code
def use_generated_code(self, func, name, *args, **kw): def use_code_from(self, func, name, *args, **kw):
""" """
Requests that the utility code that func can generate is used in the C Requests that the utility code that func can generate is used in the C
file. func is called like this: file. func is called like this:
......
...@@ -1412,6 +1412,7 @@ class IndexNode(ExprNode): ...@@ -1412,6 +1412,7 @@ class IndexNode(ExprNode):
# we only need a temp because result_code isn't refactored to # we only need a temp because result_code isn't refactored to
# generation time, but this seems an ok shortcut to take # generation time, but this seems an ok shortcut to take
self.is_temp = True self.is_temp = True
self.result_ctype = PyrexTypes.c_ptr_type(self.type)
if setting: if setting:
if not self.base.entry.type.writable: if not self.base.entry.type.writable:
error(self.pos, "Writing to readonly buffer") error(self.pos, "Writing to readonly buffer")
......
...@@ -99,6 +99,7 @@ class PyrexType(BaseType): ...@@ -99,6 +99,7 @@ class PyrexType(BaseType):
default_value = "" default_value = ""
parsetuple_format = "" parsetuple_format = ""
pymemberdef_typecode = None pymemberdef_typecode = None
typestring = None
def resolve(self): def resolve(self):
# If a typedef, returns the base type. # If a typedef, returns the base type.
...@@ -138,7 +139,6 @@ class PyrexType(BaseType): ...@@ -138,7 +139,6 @@ class PyrexType(BaseType):
# a struct whose attributes are not defined, etc. # a struct whose attributes are not defined, etc.
return 1 return 1
class CTypedefType(BaseType): class CTypedefType(BaseType):
# #
# Pseudo-type defined with a ctypedef statement in a # Pseudo-type defined with a ctypedef statement in a
...@@ -955,6 +955,11 @@ class CStructOrUnionType(CType): ...@@ -955,6 +955,11 @@ class CStructOrUnionType(CType):
def attributes_known(self): def attributes_known(self):
return self.is_complete() return self.is_complete()
def can_be_complex(self):
# Does the struct consist of exactly two floats?
fields = self.scope.var_entries
return len(fields) == 2 and fields[0].type.is_float and fields[1].type.is_float
class CEnumType(CType): class CEnumType(CType):
# name string # name string
......
...@@ -55,20 +55,23 @@ cdef extern from "numpy/arrayobject.h": ...@@ -55,20 +55,23 @@ cdef extern from "numpy/arrayobject.h":
# made available from this pxd file yet. # made available from this pxd file yet.
cdef int t = PyArray_TYPE(self) cdef int t = PyArray_TYPE(self)
cdef char* f = NULL cdef char* f = NULL
if t == NPY_BYTE: f = "b" if t == NPY_BYTE: f = "b"
elif t == NPY_UBYTE: f = "B" elif t == NPY_UBYTE: f = "B"
elif t == NPY_SHORT: f = "h" elif t == NPY_SHORT: f = "h"
elif t == NPY_USHORT: f = "H" elif t == NPY_USHORT: f = "H"
elif t == NPY_INT: f = "i" elif t == NPY_INT: f = "i"
elif t == NPY_UINT: f = "I" elif t == NPY_UINT: f = "I"
elif t == NPY_LONG: f = "l" elif t == NPY_LONG: f = "l"
elif t == NPY_ULONG: f = "L" elif t == NPY_ULONG: f = "L"
elif t == NPY_LONGLONG: f = "q" elif t == NPY_LONGLONG: f = "q"
elif t == NPY_ULONGLONG: f = "Q" elif t == NPY_ULONGLONG: f = "Q"
elif t == NPY_FLOAT: f = "f" elif t == NPY_FLOAT: f = "f"
elif t == NPY_DOUBLE: f = "d" elif t == NPY_DOUBLE: f = "d"
elif t == NPY_LONGDOUBLE: f = "g" elif t == NPY_LONGDOUBLE: f = "g"
elif t == NPY_OBJECT: f = "O" elif t == NPY_CFLOAT: f = "Zf"
elif t == NPY_CDOUBLE: f = "Zd"
elif t == NPY_CLONGDOUBLE: f = "Zg"
elif t == NPY_OBJECT: f = "O"
if f == NULL: if f == NULL:
raise ValueError("only objects, int and float dtypes supported for ndarray buffer access so far (dtype is %d)" % t) raise ValueError("only objects, int and float dtypes supported for ndarray buffer access so far (dtype is %d)" % t)
......
...@@ -358,6 +358,20 @@ def alignment_string(object[int] buf): ...@@ -358,6 +358,20 @@ def alignment_string(object[int] buf):
""" """
print buf[1] print buf[1]
@testcase
def wrong_string(object[int] buf):
"""
>>> wrong_string(IntMockBuffer(None, [1,2], format="iasdf"))
Traceback (most recent call last):
...
ValueError: Buffer format string specifies more data than 'int' can hold (expected end, got 'asdf')
>>> wrong_string(IntMockBuffer(None, [1,2], format="$$"))
Traceback (most recent call last):
...
ValueError: Buffer datatype mismatch (expected 'i', got '$$')
"""
print buf[1]
# #
# Getting items and index bounds checking # Getting items and index bounds checking
# #
...@@ -1056,7 +1070,6 @@ cdef class DoubleMockBuffer(MockBuffer): ...@@ -1056,7 +1070,6 @@ cdef class DoubleMockBuffer(MockBuffer):
cdef get_itemsize(self): return sizeof(double) cdef get_itemsize(self): return sizeof(double)
cdef get_default_format(self): return b"d" cdef get_default_format(self): return b"d"
cdef extern from *: cdef extern from *:
void* addr_of_pyobject "(void*)"(object) void* addr_of_pyobject "(void*)"(object)
...@@ -1135,3 +1148,69 @@ def bufdefaults1(IntStridedMockBuffer[int, ndim=1] buf): ...@@ -1135,3 +1148,69 @@ def bufdefaults1(IntStridedMockBuffer[int, ndim=1] buf):
pass pass
#
# Structs
#
cdef struct MyStruct:
char a
char b
long long int c
int d
int e
cdef class MyStructMockBuffer(MockBuffer):
cdef int write(self, char* buf, object value) except -1:
cdef MyStruct* s
s = <MyStruct*>buf;
s.a, s.b, s.c, s.d, s.e = value
return 0
cdef get_itemsize(self): return sizeof(MyStruct)
cdef get_default_format(self): return b"2bq2i"
@testcase
def basic_struct(object[MyStruct] buf):
"""
>>> basic_struct(MyStructMockBuffer(None, [(1, 2, 3, 4, 5)]))
Traceback (most recent call last):
...
ValueError: Struct buffer dtypes not implemented yet!
# 1 2 3 4 5
"""
print buf[0].a, buf[0].b, buf[0].c, buf[0].d, buf[0].e
cdef struct LongComplex:
long double real
long double imag
cdef class LongComplexMockBuffer(MockBuffer):
cdef int write(self, char* buf, object value) except -1:
cdef LongComplex* s
s = <LongComplex*>buf;
s.real, s.imag = value
return 0
cdef get_itemsize(self): return sizeof(LongComplex)
cdef get_default_format(self): return b"Zg"
@testcase
def complex_struct_dtype(object[LongComplex] buf):
"""
Note that the format string is "Zg" rather than "2g"...
>>> complex_struct_dtype(LongComplexMockBuffer(None, [(0, -1)]))
0.0 -1.0
"""
print buf[0].real, buf[0].imag
@testcase
def complex_struct_inplace(object[LongComplex] buf):
"""
>>> complex_struct_dtype(LongComplexMockBuffer(None, [(0, -1)]))
1.0 1.0
"""
buf[0].real += 1
buf[0].imag += 2
print buf[0].real, buf[0].imag
...@@ -91,6 +91,9 @@ try: ...@@ -91,6 +91,9 @@ try:
>>> test_dtype('d', inc1_double) >>> test_dtype('d', inc1_double)
>>> test_dtype('g', inc1_longdouble) >>> test_dtype('g', inc1_longdouble)
>>> test_dtype('O', inc1_object) >>> test_dtype('O', inc1_object)
>>> test_dtype('F', inc1_cfloat) # numpy format codes differ from buffer ones here
>>> test_dtype('D', inc1_cdouble)
>>> test_dtype('G', inc1_clongdouble)
>>> test_dtype(np.int, inc1_int_t) >>> test_dtype(np.int, inc1_int_t)
>>> test_dtype(np.long, inc1_long_t) >>> test_dtype(np.long, inc1_long_t)
...@@ -103,11 +106,6 @@ try: ...@@ -103,11 +106,6 @@ try:
>>> test_dtype(np.float64, inc1_float64_t) >>> test_dtype(np.float64, inc1_float64_t)
Unsupported types: Unsupported types:
>>> test_dtype(np.complex, inc1_byte)
Traceback (most recent call last):
...
ValueError: only objects, int and float dtypes supported for ndarray buffer access so far (dtype is 15)
>>> a = np.zeros((10,), dtype=np.dtype('i4,i4')) >>> a = np.zeros((10,), dtype=np.dtype('i4,i4'))
>>> inc1_byte(a) >>> inc1_byte(a)
Traceback (most recent call last): Traceback (most recent call last):
...@@ -154,7 +152,19 @@ def put_range_long_1d(np.ndarray[long] arr): ...@@ -154,7 +152,19 @@ def put_range_long_1d(np.ndarray[long] arr):
value += 1 value += 1
# Exhaustive dtype tests -- increments element [1] by 1 for all dtypes cdef struct cfloat:
float real
float imag
cdef struct cdouble:
double real
double imag
cdef struct clongdouble:
long double real
long double imag
# Exhaustive dtype tests -- increments element [1] by 1 (or 1+1j) for all dtypes
def inc1_byte(np.ndarray[char] arr): arr[1] += 1 def inc1_byte(np.ndarray[char] arr): arr[1] += 1
def inc1_ubyte(np.ndarray[unsigned char] arr): arr[1] += 1 def inc1_ubyte(np.ndarray[unsigned char] arr): arr[1] += 1
def inc1_short(np.ndarray[short] arr): arr[1] += 1 def inc1_short(np.ndarray[short] arr): arr[1] += 1
...@@ -170,6 +180,23 @@ def inc1_float(np.ndarray[float] arr): arr[1] += 1 ...@@ -170,6 +180,23 @@ def inc1_float(np.ndarray[float] arr): arr[1] += 1
def inc1_double(np.ndarray[double] arr): arr[1] += 1 def inc1_double(np.ndarray[double] arr): arr[1] += 1
def inc1_longdouble(np.ndarray[long double] arr): arr[1] += 1 def inc1_longdouble(np.ndarray[long double] arr): arr[1] += 1
def inc1_cfloat(np.ndarray[cfloat] arr):
arr[1].real += 1
arr[1].imag += 1
def inc1_cdouble(np.ndarray[cdouble] arr):
arr[1].real += 1
arr[1].imag += 1
def inc1_clongdouble(np.ndarray[clongdouble] arr):
print arr[1].real
print arr[1].imag
cdef long double x
x = arr[1].real + 1
arr[1].real = x
arr[1].imag = arr[1].imag + 1
print arr[1].real
print arr[1].imag
def inc1_object(np.ndarray[object] arr): def inc1_object(np.ndarray[object] arr):
o = arr[1] o = arr[1]
...@@ -189,6 +216,11 @@ def inc1_float64_t(np.ndarray[np.float64_t] arr): arr[1] += 1 ...@@ -189,6 +216,11 @@ def inc1_float64_t(np.ndarray[np.float64_t] arr): arr[1] += 1
def test_dtype(dtype, inc1): def test_dtype(dtype, inc1):
a = np.array([0, 10], dtype=dtype) if dtype in ('F', 'D', 'G'):
inc1(a) a = np.array([0, 10+10j], dtype=dtype)
if a[1] != 11: print "failed!" inc1(a)
if a[1] != (11 + 11j): print "failed!", a[1]
else:
a = np.array([0, 10], dtype=dtype)
inc1(a)
if a[1] != 11: print "failed!"
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment