Commit 72d54fb4 authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

PS: non-working state. Buffer access able to run fully in some very restricted cases

parent 6f0bc35a
...@@ -1275,36 +1275,59 @@ class IndexNode(ExprNode): ...@@ -1275,36 +1275,59 @@ class IndexNode(ExprNode):
self.analyse_base_and_index_types(env, setting = 1) self.analyse_base_and_index_types(env, setting = 1)
def analyse_base_and_index_types(self, env, getting = 0, setting = 0): def analyse_base_and_index_types(self, env, getting = 0, setting = 0):
self.is_buffer_access = False
self.base.analyse_types(env) self.base.analyse_types(env)
self.index.analyse_types(env)
if self.base.type.is_pyobject: if self.base.type.buffer_options is not None:
if self.index.type.is_int: if isinstance(self.index, TupleNode):
self.original_index_type = self.index.type indices = self.index.args
self.index = self.index.coerce_to(PyrexTypes.c_py_ssize_t_type, env).coerce_to_simple(env) # is_int_indices = 0 == sum([1 for i in self.index.args if not i.type.is_int])
if getting:
env.use_utility_code(getitem_int_utility_code)
if setting:
env.use_utility_code(setitem_int_utility_code)
else: else:
self.index = self.index.coerce_to_pyobject(env) # is_int_indices = self.index.type.is_int
self.type = py_object_type indices = [self.index]
self.gil_check(env) all_ints = True
self.is_temp = 1 for index in indices:
else: index.analyse_types(env)
if self.base.type.is_ptr or self.base.type.is_array: if not index.type.is_int:
self.type = self.base.type.base_type all_ints = False
if all_ints:
self.indices = indices
self.index = None
self.type = self.base.type.buffer_options.dtype
self.is_temp = 1
self.is_buffer_access = True
if not self.is_buffer_access:
self.index.analyse_types(env) # ok to analyse as tuple
if self.base.type.is_pyobject:
if self.index.type.is_int:
self.original_index_type = self.index.type
self.index = self.index.coerce_to(PyrexTypes.c_py_ssize_t_type, env).coerce_to_simple(env)
if getting:
env.use_utility_code(getitem_int_utility_code)
if setting:
env.use_utility_code(setitem_int_utility_code)
else:
self.index = self.index.coerce_to_pyobject(env)
self.type = py_object_type
self.gil_check(env)
self.is_temp = 1
else: else:
error(self.pos, if self.base.type.is_ptr or self.base.type.is_array:
"Attempting to index non-array type '%s'" % self.type = self.base.type.base_type
self.base.type) else:
self.type = PyrexTypes.error_type error(self.pos,
if self.index.type.is_pyobject: "Attempting to index non-array type '%s'" %
self.index = self.index.coerce_to( self.base.type)
PyrexTypes.c_py_ssize_t_type, env) self.type = PyrexTypes.error_type
if not self.index.type.is_int: if self.index.type.is_pyobject:
error(self.pos, self.index = self.index.coerce_to(
"Invalid index type '%s'" % PyrexTypes.c_py_ssize_t_type, env)
self.index.type) if not self.index.type.is_int:
error(self.pos,
"Invalid index type '%s'" %
self.index.type)
gil_message = "Indexing Python object" gil_message = "Indexing Python object"
...@@ -1330,11 +1353,17 @@ class IndexNode(ExprNode): ...@@ -1330,11 +1353,17 @@ class IndexNode(ExprNode):
def generate_subexpr_evaluation_code(self, code): def generate_subexpr_evaluation_code(self, code):
self.base.generate_evaluation_code(code) self.base.generate_evaluation_code(code)
self.index.generate_evaluation_code(code) if self.index is not None:
self.index.generate_evaluation_code(code)
else:
for i in self.indices: i.generate_evaluation_code(code)
def generate_subexpr_disposal_code(self, code): def generate_subexpr_disposal_code(self, code):
self.base.generate_disposal_code(code) self.base.generate_disposal_code(code)
self.index.generate_disposal_code(code) if self.index is not None:
self.index.generate_disposal_code(code)
else:
for i in self.indices: i.generate_disposal_code(code)
def generate_result_code(self, code): def generate_result_code(self, code):
if self.type.is_pyobject: if self.type.is_pyobject:
......
...@@ -354,7 +354,7 @@ def create_generate_code(context, options, result): ...@@ -354,7 +354,7 @@ def create_generate_code(context, options, result):
return generate_code return generate_code
def create_default_pipeline(context, options, result): def create_default_pipeline(context, options, result):
from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse, BufferTransform
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor
from ModuleNode import check_c_classes from ModuleNode import check_c_classes
...@@ -367,6 +367,7 @@ def create_default_pipeline(context, options, result): ...@@ -367,6 +367,7 @@ def create_default_pipeline(context, options, result):
AnalyseDeclarationsTransform(context), AnalyseDeclarationsTransform(context),
check_c_classes, check_c_classes,
AnalyseExpressionsTransform(context), AnalyseExpressionsTransform(context),
BufferTransform(context),
# CreateClosureClasses(context), # CreateClosureClasses(context),
create_generate_code(context, options, result) create_generate_code(context, options, result)
] ]
......
...@@ -259,6 +259,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -259,6 +259,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
self.generate_module_cleanup_func(env, code) self.generate_module_cleanup_func(env, code)
self.generate_filename_table(code) self.generate_filename_table(code)
self.generate_utility_functions(env, code) self.generate_utility_functions(env, code)
self.generate_buffer_compatability_functions(env, code)
self.generate_declarations_for_modules(env, modules, code.h) self.generate_declarations_for_modules(env, modules, code.h)
...@@ -438,6 +439,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -438,6 +439,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln(" #define PyBUF_F_CONTIGUOUS (0x0040 | PyBUF_STRIDES)") code.putln(" #define PyBUF_F_CONTIGUOUS (0x0040 | PyBUF_STRIDES)")
code.putln(" #define PyBUF_ANY_CONTIGUOUS (0x0080 | PyBUF_STRIDES)") code.putln(" #define PyBUF_ANY_CONTIGUOUS (0x0080 | PyBUF_STRIDES)")
code.putln(" #define PyBUF_INDIRECT (0x0100 | PyBUF_STRIDES)") code.putln(" #define PyBUF_INDIRECT (0x0100 | PyBUF_STRIDES)")
code.putln("")
code.putln(" static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags);")
code.putln(" static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view);")
code.putln("#endif") code.putln("#endif")
code.put(builtin_module_name_utility_code[0]) code.put(builtin_module_name_utility_code[0])
...@@ -1945,6 +1949,60 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1945,6 +1949,60 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.h.put(utility_code[0]) code.h.put(utility_code[0])
code.put(utility_code[1]) code.put(utility_code[1])
code.put(PyrexTypes.type_conversion_functions) code.put(PyrexTypes.type_conversion_functions)
code.putln("")
def generate_buffer_compatability_functions(self, env, code):
# will be refactored
code.put("""
static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
/* This function is always called after a type-check */
PyArrayObject *arr = (PyArrayObject*)obj;
PyArray_Descr *type = (PyArray_Descr*)arr->descr;
view->buf = arr->data;
view->readonly = 0; /*fixme*/
view->format = "B"; /*fixme*/
view->ndim = arr->nd;
view->strides = arr->strides;
view->shape = arr->dimensions;
view->suboffsets = 0;
view->itemsize = type->elsize;
view->internal = 0;
return 0;
}
static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
}
""")
# For now, hard-code numpy imported as "numpy"
ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type
types = [
(ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer")
]
# typeptr_cname = ndarrtype.typeptr_cname
code.putln("static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {")
clause = "if"
for t, get, release in types:
code.putln("%s (__Pyx_TypeTest(obj, %s)) return %s(obj, view, flags);" % (clause, t, get))
clause = "else if"
code.putln("else {")
code.putln("PyErr_Format(PyExc_TypeError, \"'%100s' does not have the buffer interface\", Py_TYPE(obj)->tp_name);")
code.putln("return -1;")
code.putln("}")
code.putln("}")
code.putln("")
code.putln("static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view) {")
clause = "if"
for t, get, release in types:
code.putln("%s (__Pyx_TypeTest(obj, %s)) %s(obj, view);" % (clause, t, release))
clause = "else if"
code.putln("}")
code.putln("")
#------------------------------------------------------------------------------------ #------------------------------------------------------------------------------------
# #
......
from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform
from Cython.Compiler.ModuleNode import ModuleNode from Cython.Compiler.ModuleNode import ModuleNode
from Cython.Compiler.Nodes import * from Cython.Compiler.Nodes import *
...@@ -137,12 +138,177 @@ class PostParse(CythonTransform): ...@@ -137,12 +138,177 @@ class PostParse(CythonTransform):
if ndim_value < 0: if ndim_value < 0:
raise PostParseError(ndimnode.pos, ERR_BUF_NONNEG % 'ndim') raise PostParseError(ndimnode.pos, ERR_BUF_NONNEG % 'ndim')
node.ndim = int(ndimnode.value) node.ndim = int(ndimnode.value)
else:
node.ndim = 1
# We're done with the parse tree args # We're done with the parse tree args
node.positional_args = None node.positional_args = None
node.keyword_args = None node.keyword_args = None
return node return node
class BufferTransform(CythonTransform):
"""
Run after type analysis. Takes care of the buffer functionality.
"""
scope = None
def __call__(self, node):
cymod = self.context.modules[u'__cython__']
self.buffer_type = cymod.entries[u'Py_buffer'].type
return super(BufferTransform, self).__call__(node)
def handle_scope(self, node, scope):
# For all buffers, insert extra variables in the scope.
# The variables are also accessible from the buffer_info
# on the buffer entry
bufvars = [(name, entry) for name, entry
in scope.entries.iteritems()
if entry.type.buffer_options is not None]
for name, entry in bufvars:
# Variable has buffer opts, declare auxiliary vars
bufopts = entry.type.buffer_options
bufinfo = scope.declare_var(temp_name_handle(u"%s_bufinfo" % name),
self.buffer_type, node.pos)
temp_var = scope.declare_var(temp_name_handle(u"%s_tmp" % name),
entry.type, node.pos)
stridevars = []
shapevars = []
for idx in range(bufopts.ndim):
# stride
varname = temp_name_handle(u"%s_%s%d" % (name, "stride", idx))
var = scope.declare_var(varname, PyrexTypes.c_int_type, node.pos, is_cdef=True)
stridevars.append(var)
# shape
varname = temp_name_handle(u"%s_%s%d" % (name, "shape", idx))
var = scope.declare_var(varname, PyrexTypes.c_uint_type, node.pos, is_cdef=True)
shapevars.append(var)
entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars,
shapevars)
entry.buffer_aux.temp_var = temp_var
self.scope = scope
def visit_ModuleNode(self, node):
self.handle_scope(node, node.scope)
self.visitchildren(node)
return node
def visit_FuncDefNode(self, node):
self.handle_scope(node, node.local_scope)
self.visitchildren(node)
return node
acquire_buffer_fragment = TreeFragment(u"""
TMP = LHS
if TMP is not None:
__cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>TMP, &BUFINFO)
TMP = RHS
__cython__.PyObject_GetBuffer(<__cython__.PyObject*>TMP, &BUFINFO, 0)
ASSIGN_AUX
LHS = TMP
""")
fetch_strides = TreeFragment(u"""
TARGET = BUFINFO.strides[IDX]
""")
fetch_shape = TreeFragment(u"""
TARGET = BUFINFO.shape[IDX]
""")
# ass = SingleAssignmentNode(pos=node.pos,
# lhs=NameNode(node.pos, name=entry.name),
# rhs=IndexNode(node.pos,
# base=AttributeNode(node.pos,
# obj=NameNode(node.pos, name=bufaux.buffer_info_var.name),
# attribute=EncodedString("strides")),
# index=IntNode(node.pos, value=EncodedString(idx))))
# print ass.dump()
def visit_SingleAssignmentNode(self, node):
self.visitchildren(node)
bufaux = node.lhs.entry.buffer_aux
if bufaux is not None:
auxass = []
for idx, entry in enumerate(bufaux.stridevars):
entry.used = True
ass = self.fetch_strides.substitute({
u"TARGET": NameNode(node.pos, name=entry.name),
u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
u"IDX": IntNode(node.pos, value=EncodedString(idx))
})
auxass.append(ass)
for idx, entry in enumerate(bufaux.shapevars):
entry.used = True
ass = self.fetch_shape.substitute({
u"TARGET": NameNode(node.pos, name=entry.name),
u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
u"IDX": IntNode(node.pos, value=EncodedString(idx))
})
auxass.append(ass)
bufaux.buffer_info_var.used = True
acq = self.acquire_buffer_fragment.substitute({
u"TMP" : NameNode(pos=node.pos, name=bufaux.temp_var.name),
u"LHS" : node.lhs,
u"RHS": node.rhs,
u"ASSIGN_AUX": StatListNode(node.pos, stats=auxass),
u"BUFINFO": NameNode(pos=node.pos, name=bufaux.buffer_info_var.name)
}, pos=node.pos)
# Note: The below should probably be refactored into something
# like fragment.substitute(..., context=self.context), with
# TreeFragment getting context.pipeline_until_now() and
# applying it on the fragment.
acq.analyse_declarations(self.scope)
acq.analyse_expressions(self.scope)
stats = acq.stats
# stats += [node] # Do assignment after successful buffer acquisition
# print acq.dump()
return stats
else:
return node
buffer_access = TreeFragment(u"""
(<unsigned char*>(BUF.buf + OFFSET))[0]
""")
def visit_IndexNode(self, node):
if node.is_buffer_access:
assert node.index is None
assert node.indices is not None
bufaux = node.base.entry.buffer_aux
assert bufaux is not None
to_sum = [ IntBinopNode(node.pos, operator='*', operand1=index,
operand2=NameNode(node.pos, name=stride.name))
for index, stride in zip(node.indices, bufaux.stridevars)]
print to_sum
indices = node.indices
# reduce * on indices
expr = to_sum[0]
for next in to_sum[1:]:
expr = IntBinopNode(node.pos, operator='+', operand1=expr, operand2=next)
tmp= self.buffer_access.substitute({
'BUF': NameNode(node.pos, name=bufaux.buffer_info_var.name),
'OFFSET': expr
})
tmp.analyse_expressions(self.scope)
return tmp.stats[0].expr
else:
return node
def visit_CallNode(self, node):
### print node.dump()
return node
# def visit_FuncDefNode(self, node):
# print node.dump()
class WithTransform(CythonTransform): class WithTransform(CythonTransform):
# EXCINFO is manually set to a variable that contains # EXCINFO is manually set to a variable that contains
......
...@@ -6,6 +6,22 @@ from Cython import Utils ...@@ -6,6 +6,22 @@ from Cython import Utils
import Naming import Naming
import copy import copy
class BufferOptions:
# dtype PyrexType
# ndim int
def __init__(self, dtype, ndim):
self.dtype = dtype
self.ndim = ndim
def create_buffer_type(base_type, buffer_options):
# Make a shallow copy of base_type and then annotate it
# with the buffer information
result = copy.copy(base_type)
result.buffer_options = buffer_options
return result
class BaseType: class BaseType:
# #
# Base class for all Pyrex types including pseudo-types. # Base class for all Pyrex types including pseudo-types.
...@@ -93,6 +109,7 @@ class PyrexType(BaseType): ...@@ -93,6 +109,7 @@ class PyrexType(BaseType):
default_value = "" default_value = ""
parsetuple_format = "" parsetuple_format = ""
pymemberdef_typecode = None pymemberdef_typecode = None
buffer_options = None # can contain a BufferOptions instance
def resolve(self): def resolve(self):
# If a typedef, returns the base type. # If a typedef, returns the base type.
...@@ -184,21 +201,6 @@ class CTypedefType(BaseType): ...@@ -184,21 +201,6 @@ class CTypedefType(BaseType):
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.typedef_base_type, name) return getattr(self.typedef_base_type, name)
class BufferOptions:
# dtype PyrexType
# ndim int
def __init__(self, dtype, ndim):
self.dtype = dtype
self.ndim = ndim
def create_buffer_type(base_type, buffer_options):
# Make a shallow copy of base_type and then annotate it
# with the buffer information
result = copy.copy(base_type)
result.buffer_options = buffer_options
return result
class PyObjectType(PyrexType): class PyObjectType(PyrexType):
# #
# Base class for all Python object types (reference-counted). # Base class for all Python object types (reference-counted).
...@@ -208,7 +210,6 @@ class PyObjectType(PyrexType): ...@@ -208,7 +210,6 @@ class PyObjectType(PyrexType):
default_value = "0" default_value = "0"
parsetuple_format = "O" parsetuple_format = "O"
pymemberdef_typecode = "T_OBJECT" pymemberdef_typecode = "T_OBJECT"
buffer_options = None # can contain a BufferOptions instance
def __str__(self): def __str__(self):
return "Python object" return "Python object"
......
...@@ -19,6 +19,14 @@ import __builtin__ ...@@ -19,6 +19,14 @@ import __builtin__
possible_identifier = re.compile(ur"(?![0-9])\w+$", re.U).match possible_identifier = re.compile(ur"(?![0-9])\w+$", re.U).match
nice_identifier = re.compile('^[a-zA-Z0-0_]+$').match nice_identifier = re.compile('^[a-zA-Z0-0_]+$').match
class BufferAux:
def __init__(self, buffer_info_var, stridevars, shapevars):
self.buffer_info_var = buffer_info_var
self.stridevars = stridevars
self.shapevars = shapevars
def __repr__(self):
return "<BufferAux %r>" % self.__dict__
class Entry: class Entry:
# A symbol table entry in a Scope or ModuleNamespace. # A symbol table entry in a Scope or ModuleNamespace.
# #
...@@ -76,6 +84,8 @@ class Entry: ...@@ -76,6 +84,8 @@ class Entry:
# defined_in_pxd boolean Is defined in a .pxd file (not just declared) # defined_in_pxd boolean Is defined in a .pxd file (not just declared)
# api boolean Generate C API for C class or function # api boolean Generate C API for C class or function
# utility_code string Utility code needed when this entry is used # utility_code string Utility code needed when this entry is used
#
# buffer_aux BufferAux or None Extra information needed for buffer variables
borrowed = 0 borrowed = 0
init = "" init = ""
...@@ -117,6 +127,7 @@ class Entry: ...@@ -117,6 +127,7 @@ class Entry:
api = 0 api = 0
utility_code = None utility_code = None
is_overridable = 0 is_overridable = 0
buffer_aux = None
def __init__(self, name, cname, type, pos = None, init = None): def __init__(self, name, cname, type, pos = None, init = None):
self.name = name self.name = name
......
cdef extern from "Python.h":
ctypedef struct PyObject
ctypedef struct Py_buffer:
void *buf
Py_ssize_t len
int readonly
char *format
int ndim
Py_ssize_t *shape
Py_ssize_t *strides
Py_ssize_t *suboffsets
Py_ssize_t itemsize
void *internal
int PyObject_GetBuffer(PyObject* obj, Py_buffer* view, int flags) except -1
void PyObject_ReleaseBuffer(PyObject* obj, Py_buffer* view)
# int PyObject_GetBuffer(PyObject *obj, Py_buffer *view,
# int flags)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment