Commit 4fce6db0 authored by Kurt Smith's avatar Kurt Smith Committed by Mark Florisson

cleanup in MemoryViewSliceType

parent b297ae99
...@@ -302,7 +302,42 @@ no_fail: ...@@ -302,7 +302,42 @@ no_fail:
} }
''' '''
def get_copy_contents_code(from_mvs, to_mvs, cfunc_name): def memoryviewslice_get_copy_func(from_memview, to_memview, mode, scope):
from PyrexTypes import CFuncType, CFuncTypeArg
if mode == 'c':
cython_name = "copy"
copy_name = '__Pyx_BufferNew_C_From_'+from_memview.specialization_suffix()
contig_flag = 'PyBUF_C_CONTIGUOUS'
elif mode == 'fortran':
cython_name = "copy_fortran"
copy_name = "__Pyx_BufferNew_F_From_"+from_memview.specialization_suffix()
contig_flag = 'PyBUF_F_CONTIGUOUS'
else:
assert False
copy_contents_name = get_copy_contents_name(from_memview, to_memview)
scope.declare_cfunction(cython_name,
CFuncType(from_memview,
[CFuncTypeArg("memviewslice", from_memview, None)]),
pos = None,
defining = 1,
cname = copy_name)
copy_impl = copy_template % dict(
copy_name=copy_name,
mode=mode,
sizeof_dtype="sizeof(%s)" % from_memview.dtype.declaration_code(''),
contig_flag=contig_flag,
copy_contents_name=copy_contents_name)
copy_decl = ("static __Pyx_memviewslice "
"%s(const __Pyx_memviewslice); /* proto */\n" % (copy_name,))
return (copy_decl, copy_impl)
def get_copy_contents_func(from_mvs, to_mvs, cfunc_name):
assert from_mvs.dtype == to_mvs.dtype assert from_mvs.dtype == to_mvs.dtype
assert len(from_mvs.axes) == len(to_mvs.axes) assert len(from_mvs.axes) == len(to_mvs.axes)
...@@ -313,7 +348,10 @@ def get_copy_contents_code(from_mvs, to_mvs, cfunc_name): ...@@ -313,7 +348,10 @@ def get_copy_contents_code(from_mvs, to_mvs, cfunc_name):
if access != 'direct': if access != 'direct':
raise NotImplementedError("only direct access supported currently.") raise NotImplementedError("only direct access supported currently.")
code = ''' code_decl = ("static int %(cfunc_name)s(const __Pyx_memviewslice *from_mvs,"
"__Pyx_memviewslice *to_mvs); /* proto */" % {'cfunc_name' : cfunc_name})
code_impl = '''
static int %(cfunc_name)s(const __Pyx_memviewslice *from_mvs, __Pyx_memviewslice *to_mvs) { static int %(cfunc_name)s(const __Pyx_memviewslice *from_mvs, __Pyx_memviewslice *to_mvs) {
...@@ -338,44 +376,44 @@ static int %(cfunc_name)s(const __Pyx_memviewslice *from_mvs, __Pyx_memviewslice ...@@ -338,44 +376,44 @@ static int %(cfunc_name)s(const __Pyx_memviewslice *from_mvs, __Pyx_memviewslice
# 'i' always goes up from zero to ndim-1. # 'i' always goes up from zero to ndim-1.
# 'idx' is the same as 'i' for c_contig, and goes from ndim-1 to 0 for f_contig. # 'idx' is the same as 'i' for c_contig, and goes from ndim-1 to 0 for f_contig.
# this makes the loop code below identical in both cases. # this makes the loop code below identical in both cases.
code += INDENT+"Py_ssize_t i%d = 0, idx%d = 0;\n" % (i,i) code_impl += INDENT+"Py_ssize_t i%d = 0, idx%d = 0;\n" % (i,i)
code += INDENT+"Py_ssize_t stride%(i)d = from_mvs->diminfo[%(idx)d].strides;\n" % {'i':i, 'idx':idx} code_impl += INDENT+"Py_ssize_t stride%(i)d = from_mvs->diminfo[%(idx)d].strides;\n" % {'i':i, 'idx':idx}
code += INDENT+"Py_ssize_t shape%(i)d = from_mvs->diminfo[%(idx)d].shape;\n" % {'i':i, 'idx':idx} code_impl += INDENT+"Py_ssize_t shape%(i)d = from_mvs->diminfo[%(idx)d].shape;\n" % {'i':i, 'idx':idx}
code += "\n" code_impl += "\n"
# put down the nested for-loop. # put down the nested for-loop.
for k in range(ndim): for k in range(ndim):
code += INDENT*(k+1) + "for(i%(k)d=0; i%(k)d<shape%(k)d; i%(k)d++) {\n" % {'k' : k} code_impl += INDENT*(k+1) + "for(i%(k)d=0; i%(k)d<shape%(k)d; i%(k)d++) {\n" % {'k' : k}
if k >= 1: if k >= 1:
code += INDENT*(k+2) + "idx%(k)d = i%(k)d * stride%(k)d + idx%(km1)d;\n" % {'k' : k, 'km1' : k-1} code_impl += INDENT*(k+2) + "idx%(k)d = i%(k)d * stride%(k)d + idx%(km1)d;\n" % {'k' : k, 'km1' : k-1}
else: else:
code += INDENT*(k+2) + "idx%(k)d = i%(k)d * stride%(k)d;\n" % {'k' : k} code_impl += INDENT*(k+2) + "idx%(k)d = i%(k)d * stride%(k)d;\n" % {'k' : k}
# the inner part of the loop. # the inner part of the loop.
dtype_decl = from_mvs.dtype.declaration_code("") dtype_decl = from_mvs.dtype.declaration_code("")
last_idx = ndim-1 last_idx = ndim-1
code += INDENT*ndim+"memcpy(to_buf, from_buf+idx%(last_idx)d, sizeof(%(dtype_decl)s));\n" % locals() code_impl += INDENT*ndim+"memcpy(to_buf, from_buf+idx%(last_idx)d, sizeof(%(dtype_decl)s));\n" % locals()
code += INDENT*ndim+"to_buf += sizeof(%(dtype_decl)s);\n" % locals() code_impl += INDENT*ndim+"to_buf += sizeof(%(dtype_decl)s);\n" % locals()
# for-loop closing braces # for-loop closing braces
for k in range(ndim-1, -1, -1): for k in range(ndim-1, -1, -1):
code += INDENT*(k+1)+"}\n" code_impl += INDENT*(k+1)+"}\n"
# init to_mvs->data and to_mvs->diminfo. # init to_mvs->data and to_mvs->diminfo.
code += INDENT+"temp_memview = to_mvs->memview;\n" code_impl += INDENT+"temp_memview = to_mvs->memview;\n"
code += INDENT+"temp_data = to_mvs->data;\n" code_impl += INDENT+"temp_data = to_mvs->data;\n"
code += INDENT+"to_mvs->memview = 0; to_mvs->data = 0;\n" code_impl += INDENT+"to_mvs->memview = 0; to_mvs->data = 0;\n"
code += INDENT+"if(unlikely(-1 == __Pyx_init_memviewslice(temp_memview, %d, to_mvs))) {\n" % (ndim,) code_impl += INDENT+"if(unlikely(-1 == __Pyx_init_memviewslice(temp_memview, %d, to_mvs))) {\n" % (ndim,)
code += INDENT*2+"return -1;\n" code_impl += INDENT*2+"return -1;\n"
code += INDENT+"}\n" code_impl += INDENT+"}\n"
code += INDENT + "return 0;\n" code_impl += INDENT + "return 0;\n"
code += '}\n' code_impl += '}\n'
return code return code_decl, code_impl
def get_axes_specs(env, axes): def get_axes_specs(env, axes):
''' '''
......
...@@ -406,71 +406,42 @@ class MemoryViewSliceType(PyrexType): ...@@ -406,71 +406,42 @@ class MemoryViewSliceType(PyrexType):
to_memview_c = MemoryViewSliceType(self.dtype, to_axes_c, self.env) to_memview_c = MemoryViewSliceType(self.dtype, to_axes_c, self.env)
to_memview_f = MemoryViewSliceType(self.dtype, to_axes_f, self.env) to_memview_f = MemoryViewSliceType(self.dtype, to_axes_f, self.env)
cython_name_c = 'copy' copy_contents_name_c =\
cython_name_f = 'copy_fortran' MemoryView.get_copy_contents_name(self, to_memview_c)
copy_contents_name_f =\
copy_name_c = '__Pyx_BufferNew_C_From_'+self.specialization_suffix() MemoryView.get_copy_contents_name(self, to_memview_f)
copy_name_f = '__Pyx_BufferNew_F_From_'+self.specialization_suffix()
c_copy_decl, c_copy_impl = \
c_copy_util_code = UtilityCode() MemoryView.memoryviewslice_get_copy_func(self, to_memview_c, 'c', self.scope)
f_copy_util_code = UtilityCode() f_copy_decl, f_copy_impl = \
MemoryView.memoryviewslice_get_copy_func(self, to_memview_f, 'fortran', self.scope)
for (to_memview, copy_name, cython_name, mode, contig_flag, util_code) in (
(to_memview_c, copy_name_c, cython_name_c, 'c', 'PyBUF_C_CONTIGUOUS', c_copy_util_code), c_copy_contents_decl, c_copy_contents_impl = \
(to_memview_f, copy_name_f, cython_name_f, 'fortran', 'PyBUF_F_CONTIGUOUS', f_copy_util_code)): MemoryView.get_copy_contents_func(
self, to_memview_c, copy_contents_name_c)
copy_contents_name = MemoryView.get_copy_contents_name(self, to_memview) f_copy_contents_decl, f_copy_contents_impl = \
MemoryView.get_copy_contents_func(
scope.declare_cfunction(cython_name, self, to_memview_f, copy_contents_name_f)
CFuncType(self,
[CFuncTypeArg("memviewslice", self, None)]), c_util_code = UtilityCode(
pos = None, proto = "%s%s" % (c_copy_decl, c_copy_contents_decl),
defining = 1, impl = "%s%s" % (c_copy_impl, c_copy_contents_impl))
cname = copy_name) f_util_code = UtilityCode(
proto = f_copy_decl,
copy_impl = MemoryView.copy_template %\ impl = f_copy_impl)
dict(copy_name=copy_name,
mode=mode,
sizeof_dtype="sizeof(%s)" % self.dtype.declaration_code(''),
contig_flag=contig_flag,
copy_contents_name=copy_contents_name)
copy_decl = '''\
static __Pyx_memviewslice %s(const __Pyx_memviewslice); /* proto */
''' % (copy_name,)
util_code.proto = copy_decl
util_code.impl = copy_impl
copy_contents_name_c = MemoryView.get_copy_contents_name(self, to_memview_c)
copy_contents_name_f = MemoryView.get_copy_contents_name(self, to_memview_f)
c_copy_util_code.proto += ('static int %s'
'(const __Pyx_memviewslice *,'
' __Pyx_memviewslice *); /* proto */\n' %
(copy_contents_name_c,))
c_copy_util_code.impl += \
MemoryView.get_copy_contents_code(self, to_memview_c, copy_contents_name_c)
if copy_contents_name_c != copy_contents_name_f: if copy_contents_name_c != copy_contents_name_f:
f_util_code.proto += f_copy_contents_decl
f_copy_util_code.proto += ('static int %s' f_util_code.impl += f_copy_contents_impl
'(const __Pyx_memviewslice *,'
' __Pyx_memviewslice *); /* proto */\n' %
(copy_contents_name_f,))
f_copy_util_code.impl += \
MemoryView.get_copy_contents_code(self, to_memview_f, copy_contents_name_f)
c_copy_used = [1 for util_code in self.env.global_scope().utility_code_list if c_copy_util_code.proto == util_code.proto] c_copy_used = [1 for util_code in self.env.global_scope().utility_code_list if c_util_code.proto == util_code.proto]
f_copy_used = [1 for util_code in self.env.global_scope().utility_code_list if f_copy_util_code.proto == util_code.proto] f_copy_used = [1 for util_code in self.env.global_scope().utility_code_list if f_util_code.proto == util_code.proto]
if not c_copy_used: if not c_copy_used:
self.env.use_utility_code(c_copy_util_code) self.env.use_utility_code(c_util_code)
if not f_copy_used: if not f_copy_used:
self.env.use_utility_code(f_copy_util_code) self.env.use_utility_code(f_util_code)
# is_c_contiguous and is_f_contiguous functions # is_c_contiguous and is_f_contiguous functions
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment