MemoryView.py 34.7 KB
Newer Older
1
from Errors import CompileError, error
2
import ExprNodes
3
from ExprNodes import IntNode, NameNode, AttributeNode
4
import Options
5
from Code import UtilityCode, TempitaUtilityCode
6
from UtilityCode import CythonUtilityCode
7
import Buffer
8
import PyrexTypes
9

10 11 12
START_ERR = "Start must not be given."
STOP_ERR = "Axis specification only allowed in the 'step' slot."
STEP_ERR = "Step must be omitted, 1, or a valid specifier."
13 14
BOTH_CF_ERR = "Cannot specify an array that is both C and Fortran contiguous."
INVALID_ERR = "Invalid axis specification."
15
NOT_CIMPORTED_ERR = "Variable was not cimported from cython.view"
16
EXPR_ERR = "no expressions allowed in axis spec, only names and literals."
17
CF_ERR = "Invalid axis specification for a C/Fortran contiguous array."
18 19 20 21 22 23
ERR_UNINITIALIZED = ("Cannot check if memoryview %s is initialized without the "
                     "GIL, consider using initializedcheck(False)")

def err_if_nogil_initialized_check(pos, env, name='variable'):
    if env.nogil and env.directives['initializedcheck']:
        error(pos, ERR_UNINITIALIZED % name)
24

Mark Florisson's avatar
Mark Florisson committed
25 26 27 28 29
def concat_flags(*flags):
    return "(%s)" % "|".join(flags)

format_flag = "PyBUF_FORMAT"

30 31 32
memview_c_contiguous = "(PyBUF_C_CONTIGUOUS | PyBUF_FORMAT | PyBUF_WRITABLE)"
memview_f_contiguous = "(PyBUF_F_CONTIGUOUS | PyBUF_FORMAT | PyBUF_WRITABLE)"
memview_any_contiguous = "(PyBUF_ANY_CONTIGUOUS | PyBUF_FORMAT | PyBUF_WRITABLE)"
33
memview_full_access = "PyBUF_FULL"
Mark Florisson's avatar
Mark Florisson committed
34 35
#memview_strided_access = "PyBUF_STRIDED"
memview_strided_access = "PyBUF_RECORDS"
36

37 38 39 40 41 42
MEMVIEW_DIRECT = '__Pyx_MEMVIEW_DIRECT'
MEMVIEW_PTR    = '__Pyx_MEMVIEW_PTR'
MEMVIEW_FULL   = '__Pyx_MEMVIEW_FULL'
MEMVIEW_CONTIG = '__Pyx_MEMVIEW_CONTIG'
MEMVIEW_STRIDED= '__Pyx_MEMVIEW_STRIDED'
MEMVIEW_FOLLOW = '__Pyx_MEMVIEW_FOLLOW'
43 44

_spec_to_const = {
45 46 47
        'direct' : MEMVIEW_DIRECT,
        'ptr'    : MEMVIEW_PTR,
        'full'   : MEMVIEW_FULL,
48 49 50 51 52
        'contig' : MEMVIEW_CONTIG,
        'strided': MEMVIEW_STRIDED,
        'follow' : MEMVIEW_FOLLOW,
        }

53 54 55 56 57 58 59 60 61
_spec_to_abbrev = {
    'direct'  : 'd',
    'ptr'     : 'p',
    'full'    : 'f',
    'contig'  : 'c',
    'strided' : 's',
    'follow'  : '_',
}

Mark Florisson's avatar
Mark Florisson committed
62 63
memslice_entry_init = "{ 0, 0, { 0 }, { 0 }, { 0 } }"

64 65 66 67
memview_name = u'memoryview'
memview_typeptr_cname = '__pyx_memoryview_type'
memview_objstruct_cname = '__pyx_memoryview_obj'
memviewslice_cname = u'__Pyx_memviewslice'
Mark Florisson's avatar
Mark Florisson committed
68

69 70
def put_init_entry(mv_cname, code):
    code.putln("%s.data = NULL;" % mv_cname)
71 72
    code.putln("%s.memview = NULL;" % mv_cname)

73 74 75 76 77
def mangle_dtype_name(dtype):
    # a dumb wrapper for now; move Buffer.mangle_dtype_name in here later?
    import Buffer
    return Buffer.mangle_dtype_name(dtype)

78 79
#def axes_to_str(axes):
#    return "".join([access[0].upper()+packing[0] for (access, packing) in axes])
80

81
def put_acquire_memoryviewslice(lhs_cname, lhs_type, lhs_pos, rhs, code,
82
                                incref_rhs=False, have_gil=False):
83 84 85 86 87 88 89 90 91
    assert rhs.type.is_memoryviewslice

    pretty_rhs = isinstance(rhs, NameNode) or rhs.result_in_temp()
    if pretty_rhs:
        rhstmp = rhs.result()
    else:
        rhstmp = code.funcstate.allocate_temp(lhs_type, manage_ref=False)
        code.putln("%s = %s;" % (rhstmp, rhs.result_as(lhs_type)))

92 93
    # Allow uninitialized assignment
    #code.putln(code.put_error_if_unbound(lhs_pos, rhs.entry))
Mark Florisson's avatar
Mark Florisson committed
94 95
    put_assign_to_memviewslice(lhs_cname, rhstmp, lhs_type, code, incref_rhs,
                               have_gil=have_gil)
96 97 98

    if not pretty_rhs:
        code.funcstate.release_temp(rhstmp)
99

100
def put_assign_to_memviewslice(lhs_cname, rhs_cname, memviewslicetype, code,
Mark Florisson's avatar
Mark Florisson committed
101 102
                               incref_rhs=False, have_gil=False):
    code.put_xdecref_memoryviewslice(lhs_cname, have_gil=have_gil)
103
    if incref_rhs:
Mark Florisson's avatar
Mark Florisson committed
104
        code.put_incref_memoryviewslice(rhs_cname, have_gil=have_gil)
105

106 107 108 109 110 111 112 113 114
    code.putln("%s = %s;" % (lhs_cname, rhs_cname))

    #code.putln("%s.memview = %s.memview;" % (lhs_cname, rhs_cname))
    #code.putln("%s.data = %s.data;" % (lhs_cname, rhs_cname))
    #for i in range(memviewslicetype.ndim):
    #    tup = (lhs_cname, i, rhs_cname, i)
    #    code.putln("%s.shape[%d] = %s.shape[%d];" % tup)
    #    code.putln("%s.strides[%d] = %s.strides[%d];" % tup)
    #    code.putln("%s.suboffsets[%d] = %s.suboffsets[%d];" % tup)
115

116
def get_buf_flags(specs):
117 118 119 120 121 122 123 124 125 126 127 128 129 130
    is_c_contig, is_f_contig = is_cf_contig(specs)

    if is_c_contig:
        return memview_c_contiguous
    elif is_f_contig:
        return memview_f_contiguous

    access, packing = zip(*specs)

    if 'full' in access or 'ptr' in access:
        return memview_full_access
    else:
        return memview_strided_access

131

132 133 134 135 136 137 138 139 140
def src_conforms_to_dst(src, dst):
    '''
    returns True if src conforms to dst, False otherwise.

    If conformable, the types are the same, the ndims are equal, and each axis spec is conformable.

    Any packing/access spec is conformable to itself.

    'direct' and 'ptr' are conformable to 'full'.
141
    'contig' and 'follow' are conformable to 'strided'.
142 143 144 145 146 147 148 149 150 151 152
    Any other combo is not conformable.
    '''

    if src.dtype != dst.dtype:
        return False
    if len(src.axes) != len(dst.axes):
        return False

    for src_spec, dst_spec in zip(src.axes, dst.axes):
        src_access, src_packing = src_spec
        dst_access, dst_packing = dst_spec
153
        if src_access != dst_access and dst_access != 'full':
154
            return False
155
        if src_packing != dst_packing and dst_packing != 'strided':
156 157 158 159
            return False

    return True

160 161 162
def valid_memslice_dtype(dtype):
    """
    Return whether type dtype can be used as the base type of a
163 164 165
    memoryview slice.

    We support structs, numeric types and objects
166 167 168 169
    """
    if dtype.is_complex and dtype.real_type.is_int:
        return False

170 171 172 173 174 175 176
    if dtype.is_struct and dtype.kind == 'struct':
        for member in dtype.scope.var_entries:
            if not valid_memslice_dtype(member.type):
                return False

        return True

177 178
    return (
        dtype.is_error or
179 180
        # Pointers are not valid (yet)
        # (dtype.is_ptr and valid_memslice_dtype(dtype.base_type)) or
181 182
        dtype.is_numeric or
        dtype.is_pyobject or
183
        dtype.is_fused or # accept this as it will be replaced by specializations later
184 185 186 187 188
        (dtype.is_typedef and valid_memslice_dtype(dtype.typedef_base_type))
    )

def validate_memslice_dtype(pos, dtype):
    if not valid_memslice_dtype(dtype):
189
        error(pos, "Invalid base type for memoryview slice: %s" % dtype)
190

191

192 193 194 195 196 197 198 199 200 201 202 203 204
class MemoryViewSliceBufferEntry(Buffer.BufferEntry):
    def __init__(self, entry):
        self.entry = entry
        self.type = entry.type
        self.cname = entry.cname
        self.buf_ptr = "%s.data" % self.cname

        dtype = self.entry.type.dtype
        dtype = PyrexTypes.CPtrType(dtype)

        self.buf_ptr_type = dtype

    def get_buf_suboffsetvars(self):
Mark Florisson's avatar
Mark Florisson committed
205
        return self._for_all_ndim("%s.suboffsets[%d]")
206 207 208 209 210 211 212

    def get_buf_stridevars(self):
        return self._for_all_ndim("%s.strides[%d]")

    def get_buf_shapevars(self):
        return self._for_all_ndim("%s.shape[%d]")

213
    def generate_buffer_lookup_code(self, code, index_cnames):
214 215 216 217 218
        axes = [(dim, index_cnames[dim], access, packing)
                    for dim, (access, packing) in enumerate(self.type.axes)]
        return self._generate_buffer_lookup_code(code, axes)

    def _generate_buffer_lookup_code(self, code, axes, cast_result=True):
219 220 221
        bufp = self.buf_ptr
        type_decl = self.type.dtype.declaration_code("")

222
        for dim, index, access, packing in axes:
223 224 225 226
            shape = "%s.shape[%d]" % (self.cname, dim)
            stride = "%s.strides[%d]" % (self.cname, dim)
            suboffset = "%s.suboffsets[%d]" % (self.cname, dim)

227 228
            flag = get_memoryview_flag(access, packing)

229 230 231 232 233
            if flag in ("generic", "generic_contiguous"):
                # Note: we cannot do cast tricks to avoid stride multiplication
                #       for generic_contiguous, as we may have to do (dtype *)
                #       or (dtype **) arithmetic, we won't know which unless
                #       we check suboffsets
234
                code.globalstate.use_utility_code(memviewslice_index_helpers)
235
                bufp = ('__pyx_memviewslice_index_full(%s, %s, %s, %s)' %
236 237
                                            (bufp, index, stride, suboffset))

238
            elif flag == "indirect":
239 240
                bufp = "(%s + %s * %s)" % (bufp, index, stride)
                bufp = ("(*((char **) %s) + %s)" % (bufp, suboffset))
241

242
            elif flag == "indirect_contiguous":
243 244
                # Note: we do char ** arithmetic
                bufp = "(*((char **) %s + %s) + %s)" % (bufp, index, suboffset)
245

246
            elif flag == "strided":
247 248 249
                bufp = "(%s + %s * %s)" % (bufp, index, stride)

            else:
250
                assert flag == 'contiguous', flag
251 252 253 254
                bufp = '((char *) (((%s *) %s) + %s))' % (type_decl, bufp, index)

            bufp = '( /* dim=%d */ %s )' % (dim, bufp)

255 256 257 258 259
        if cast_result:
            return "((%s *) %s)" % (type_decl, bufp)

        return bufp

260 261 262 263 264 265 266 267 268 269
    def generate_buffer_slice_code(self, code, indices, dst, have_gil):
        """
        Slice a memoryviewslice.

        indices     - list of index nodes. If not a SliceNode, then it must be
                      coercible to Py_ssize_t

        Simply call __pyx_memoryview_slice_memviewslice with the right
        arguments.
        """
270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297
        slicefunc = "__pyx_memoryview_slice_memviewslice"
        new_ndim = 0
        cname = self.cname

        suboffset_dim = code.funcstate.allocate_temp(PyrexTypes.c_int_type,
                                                     False)

        index_code = ("%(slicefunc)s(&%(cname)s, &%(dst)s, %(have_gil)d, "
                                    "%(dim)d, %(new_ndim)d, &%(suboffset_dim)s, "
                                    "%(idx)s, 0, 0, 0, 0, 0, 0)")

        slice_code = ("%(slicefunc)s(&%(cname)s, &%(dst)s, %(have_gil)d, "
                                    "/* dim */ %(dim)d, "
                                    "/* new_ndim */ %(new_ndim)d, "
                                    "/* suboffset_dim */ &%(suboffset_dim)s, "
                                    "/* start */ %(start)s, "
                                    "/* stop */ %(stop)s, "
                                    "/* step */ %(step)s, "
                                    "/* have_start */ %(have_start)d, "
                                    "/* have_stop */ %(have_stop)d, "
                                    "/* have_step */ %(have_step)d, "
                                    "/* is_slice */ 1)")

        def generate_slice_call(expr):
            pos = index.pos

            if have_gil:
                code.putln(code.error_goto_if(expr, pos))
298
            else:
299 300
                code.putln("{")
                code.putln(    "const char *__pyx_t_result = %s;" % expr)
301

302 303 304
                code.putln(    "if (unlikely(__pyx_t_result)) {")
                code.put_ensure_gil()
                code.putln(        "PyErr_Format(PyExc_IndexError, "
305
                                                "__pyx_t_result, %d);" % dim)
306
                code.put_release_ensured_gil()
307
                code.putln(code.error_goto(pos))
308
                code.putln(    "}")
309 310 311

                code.putln("}")

312 313 314
        code.putln("%s = -1;" % suboffset_dim)
        code.putln("%(dst)s.data = %(cname)s.data;" % locals())
        code.putln("%(dst)s.memview = %(cname)s.memview;" % locals())
315
        code.put_incref_memoryviewslice(dst)
316

317 318 319 320 321 322 323 324 325 326 327 328 329
        for dim, index in enumerate(indices):
            if not isinstance(index, ExprNodes.SliceNode):
                idx = index.result()
                generate_slice_call(index_code % locals())
            else:
                d = {}
                for s in "start stop step".split():
                    idx = getattr(index, s)
                    have_idx = d['have_' + s] = not idx.is_none
                    if have_idx:
                        d[s] = idx.result()
                    else:
                        d[s] = "0"
330

331 332 333
                d.update(locals())
                generate_slice_call(slice_code % d)
                new_ndim += 1
334

335
        code.funcstate.release_temp(suboffset_dim)
336

337

338 339 340 341 342
def empty_slice(pos):
    none = ExprNodes.NoneNode(pos)
    return ExprNodes.SliceNode(pos, start=none,
                               stop=none, step=none)

343 344 345
def unellipsify(indices, ndim):
    result = []
    seen_ellipsis = False
346
    have_slices = False
347 348 349

    for index in indices:
        if isinstance(index, ExprNodes.EllipsisNode):
350
            have_slices = True
351
            full_slice = empty_slice(index.pos)
352

353 354 355 356 357 358 359
            if seen_ellipsis:
                result.append(full_slice)
            else:
                nslices = ndim - len(indices) + 1
                result.extend([full_slice] * nslices)
                seen_ellipsis = True
        else:
360
            have_slices = have_slices or isinstance(index, ExprNodes.SliceNode)
361 362
            result.append(index)

363
    if len(result) < ndim:
364
        have_slices = True
365 366 367
        nslices = ndim - len(result)
        result.extend([empty_slice(indices[-1].pos)] * nslices)

368
    return have_slices, result
369

370 371 372 373 374 375 376 377 378 379 380 381 382 383 384
def get_memoryview_flag(access, packing):
    if access == 'full' and packing in ('strided', 'follow'):
        return 'generic'
    elif access == 'full' and packing == 'contig':
        return 'generic_contiguous'
    elif access == 'ptr' and packing in ('strided', 'follow'):
        return 'indirect'
    elif access == 'ptr' and packing == 'contig':
        return 'indirect_contiguous'
    elif access == 'direct' and packing in ('strided', 'follow'):
        return 'strided'
    else:
        assert (access, packing) == ('direct', 'contig'), (access, packing)
        return 'contiguous'

385
def get_copy_func_name(to_memview):
386
    base = "__Pyx_BufferNew_%s_From_%s"
387
    if to_memview.is_c_contig:
388
        return base % ('C', to_memview.specialization_suffix())
389
    else:
390
        return base % ('F', to_memview.specialization_suffix())
391

392
def get_copy_contents_name(from_mvs, to_mvs):
393 394 395 396
    assert from_mvs.dtype == to_mvs.dtype
    return '__Pyx_BufferCopyContents_%s_to_%s' % (from_mvs.specialization_suffix(),
                                                  to_mvs.specialization_suffix())

397

398
class IsContigFuncUtilCode(object):
399

400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424
    requires = None

    def __init__(self, c_or_f):
        self.c_or_f = c_or_f

        self.is_contig_func_name = get_is_contig_func_name(self.c_or_f)

    def __eq__(self, other):
        if not isinstance(other, IsContigFuncUtilCode):
            return False
        return self.is_contig_func_name == other.is_contig_func_name

    def __hash__(self):
        return hash(self.is_contig_func_name)

    def get_tree(self): pass

    def put_code(self, output):
        code = output['utility_code_def']
        proto = output['utility_code_proto']

        func_decl, func_impl = get_is_contiguous_func(self.c_or_f)

        proto.put(func_decl)
        code.put(func_impl)
425

426 427 428 429 430 431 432 433 434 435 436 437 438 439
def get_is_contig_func_name(c_or_f):
    return "__Pyx_Buffer_is_%s_contiguous" % c_or_f

def get_is_contiguous_func(c_or_f):

    func_name = get_is_contig_func_name(c_or_f)
    decl = "static int %s(const __Pyx_memviewslice); /* proto */\n" % func_name

    impl = """
static int %s(const __Pyx_memviewslice mvs) {
    /* returns 1 if mvs is the right contiguity, 0 otherwise */

    int i, ndim = mvs.memview->view.ndim;
    Py_ssize_t itemsize = mvs.memview->view.itemsize;
440
    long size = 0;
441 442 443 444 445 446 447 448 449 450 451 452 453 454
""" % func_name

    if c_or_f == 'fortran':
        for_loop = "for(i=0; i<ndim; i++)"
    elif c_or_f == 'c':
        for_loop = "for(i=ndim-1; i>-1; i--)"
    else:
        assert False

    impl += """
    size = 1;
    %(for_loop)s {

#ifdef DEBUG
Mark Florisson's avatar
Mark Florisson committed
455
        printf("mvs.suboffsets[i] %%d\\n", mvs.suboffsets[i]);
456 457
        printf("mvs.strides[i] %%d\\n", mvs.strides[i]);
        printf("mvs.shape[i] %%d\\n", mvs.shape[i]);
458 459 460 461 462
        printf("size %%d\\n", size);
        printf("ndim %%d\\n", ndim);
#endif
#undef DEBUG

Mark Florisson's avatar
Mark Florisson committed
463
        if(mvs.suboffsets[i] >= 0) {
464 465
            return 0;
        }
466
        if(size * itemsize != mvs.strides[i]) {
467 468
            return 0;
        }
469
        size *= mvs.shape[i];
470 471 472 473 474 475 476
    }
    return 1;

}""" % {'for_loop' : for_loop}

    return decl, impl

477 478 479 480 481 482 483 484
copy_to_template = '''
static int %(copy_to_name)s(const __Pyx_memviewslice from_mvs, __Pyx_memviewslice to_mvs) {

    /* ensure from_mvs & to_mvs have the same shape & dtype */

}
'''

485
class CopyContentsFuncUtilCode(object):
486

487
    requires = None
488

489 490 491 492
    def __init__(self, from_memview, to_memview):
        self.from_memview = from_memview
        self.to_memview = to_memview
        self.copy_contents_name = get_copy_contents_name(from_memview, to_memview)
493

494 495 496 497
    def __eq__(self, other):
        if not isinstance(other, CopyContentsFuncUtilCode):
            return False
        return other.copy_contents_name == self.copy_contents_name
498

499 500
    def __hash__(self):
        return hash(self.copy_contents_name)
501

502
    def get_tree(self): pass
503

504 505 506
    def put_code(self, output):
        code = output['utility_code_def']
        proto = output['utility_code_proto']
507

508 509
        func_decl, func_impl = \
                get_copy_contents_func(self.from_memview, self.to_memview, self.copy_contents_name)
510

511 512
        proto.put(func_decl)
        code.put(func_impl)
513

514
class CopyFuncUtilCode(object):
515

516
    requires = None
517

518 519 520 521 522 523 524 525 526 527
    def __init__(self, from_memview, to_memview):
        if from_memview.dtype != to_memview.dtype:
            raise ValueError("dtypes must be the same!")
        if len(from_memview.axes) != len(to_memview.axes):
            raise ValueError("number of dimensions must be same")
        if not (to_memview.is_c_contig or to_memview.is_f_contig):
            raise ValueError("to_memview must be c or f contiguous.")
        for (access, packing) in from_memview.axes:
            if access != 'direct':
                raise NotImplementedError("cannot handle 'full' or 'ptr' access at this time.")
528

529 530 531
        self.from_memview = from_memview
        self.to_memview = to_memview
        self.copy_func_name = get_copy_func_name(to_memview)
532

533
        self.requires = [CopyContentsFuncUtilCode(from_memview, to_memview)]
534

535 536 537 538 539 540 541 542 543
    def __eq__(self, other):
        if not isinstance(other, CopyFuncUtilCode):
            return False
        return other.copy_func_name == self.copy_func_name

    def __hash__(self):
        return hash(self.copy_func_name)

    def get_tree(self): pass
Kurt Smith's avatar
Kurt Smith committed
544

545 546 547
    def put_code(self, output):
        code = output['utility_code_def']
        proto = output['utility_code_proto']
Kurt Smith's avatar
Kurt Smith committed
548

549 550 551
        proto.put(Buffer.dedent("""\
                static __Pyx_memviewslice %s(const __Pyx_memviewslice from_mvs); /* proto */
        """ % self.copy_func_name))
Kurt Smith's avatar
Kurt Smith committed
552

553
        copy_contents_name = get_copy_contents_name(self.from_memview, self.to_memview)
Kurt Smith's avatar
Kurt Smith committed
554

555 556
        if self.to_memview.is_c_contig:
            mode = 'c'
557
            contig_flag = memview_c_contiguous
558 559
        elif self.to_memview.is_f_contig:
            mode = 'fortran'
560
            contig_flag = memview_f_contiguous
561

Mark Florisson's avatar
Mark Florisson committed
562 563
        C = dict(
            context,
Mark Florisson's avatar
Mark Florisson committed
564 565 566 567 568 569 570
            copy_name=self.copy_func_name,
            mode=mode,
            sizeof_dtype="sizeof(%s)" % self.from_memview.dtype.declaration_code(''),
            contig_flag=contig_flag,
            copy_contents_name=copy_contents_name
        )

571 572 573 574
        _, copy_code = TempitaUtilityCode.load_as_string(
                    "MemviewSliceCopyTemplate",
                    from_file="MemoryView_C.c",
                    context=C)
Mark Florisson's avatar
Mark Florisson committed
575
        code.put(copy_code)
Kurt Smith's avatar
Kurt Smith committed
576 577 578


def get_copy_contents_func(from_mvs, to_mvs, cfunc_name):
579 580 581 582 583 584 585 586
    assert from_mvs.dtype == to_mvs.dtype
    assert len(from_mvs.axes) == len(to_mvs.axes)

    ndim = len(from_mvs.axes)

    # XXX: we only support direct access for now.
    for (access, packing) in from_mvs.axes:
        if access != 'direct':
587
            raise NotImplementedError("currently only direct access is supported.")
588

Kurt Smith's avatar
Kurt Smith committed
589 590 591 592
    code_decl = ("static int %(cfunc_name)s(const __Pyx_memviewslice *from_mvs,"
                "__Pyx_memviewslice *to_mvs); /* proto */" % {'cfunc_name' : cfunc_name})

    code_impl = '''
593 594 595 596 597

static int %(cfunc_name)s(const __Pyx_memviewslice *from_mvs, __Pyx_memviewslice *to_mvs) {

    char *to_buf = (char *)to_mvs->data;
    char *from_buf = (char *)from_mvs->data;
598
    struct __pyx_memoryview_obj *temp_memview = 0;
599 600
    char *temp_data = 0;

601
    int ndim_idx = 0;
602

603
    for(ndim_idx=0; ndim_idx<%(ndim)d; ndim_idx++) {
604
        if(from_mvs->shape[ndim_idx] != to_mvs->shape[ndim_idx]) {
605 606 607 608 609 610 611 612 613
            PyErr_Format(PyExc_ValueError,
                "memoryview shapes not the same in dimension %%d", ndim_idx);
            return -1;
        }
    }

''' % {'cfunc_name' : cfunc_name, 'ndim' : ndim}

    # raise NotImplementedError("put in shape checking code here!!!")
614 615

    INDENT = "    "
616 617
    dtype_decl = from_mvs.dtype.declaration_code("")
    last_idx = ndim-1
618

619 620 621 622 623
    if to_mvs.is_c_contig or to_mvs.is_f_contig:
        if to_mvs.is_c_contig:
            start, stop, step = 0, ndim, 1
        elif to_mvs.is_f_contig:
            start, stop, step = ndim-1, -1, -1
624 625


626 627 628 629 630 631
        for i, idx in enumerate(range(start, stop, step)):
            # the crazy indexing is to account for the fortran indexing.
            # 'i' always goes up from zero to ndim-1.
            # 'idx' is the same as 'i' for c_contig, and goes from ndim-1 to 0 for f_contig.
            # this makes the loop code below identical in both cases.
            code_impl += INDENT+"Py_ssize_t i%d = 0, idx%d = 0;\n" % (i,i)
632 633
            code_impl += INDENT+"Py_ssize_t stride%(i)d = from_mvs->strides[%(idx)d];\n" % {'i':i, 'idx':idx}
            code_impl += INDENT+"Py_ssize_t shape%(i)d = from_mvs->shape[%(idx)d];\n" % {'i':i, 'idx':idx}
634

635
        code_impl += "\n"
636

637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655
        # put down the nested for-loop.
        for k in range(ndim):

            code_impl += INDENT*(k+1) + "for(i%(k)d=0; i%(k)d<shape%(k)d; i%(k)d++) {\n" % {'k' : k}
            if k >= 1:
                code_impl += INDENT*(k+2) + "idx%(k)d = i%(k)d * stride%(k)d + idx%(km1)d;\n" % {'k' : k, 'km1' : k-1}
            else:
                code_impl += INDENT*(k+2) + "idx%(k)d = i%(k)d * stride%(k)d;\n" % {'k' : k}

        # the inner part of the loop.
        code_impl += INDENT*(ndim+1)+"memcpy(to_buf, from_buf+idx%(last_idx)d, sizeof(%(dtype_decl)s));\n" % locals()
        code_impl += INDENT*(ndim+1)+"to_buf += sizeof(%(dtype_decl)s);\n" % locals()


    else:

        code_impl += INDENT+"/* 'f' prefix is for the 'from' memview, 't' prefix is for the 'to' memview */\n"
        for i in range(ndim):
            code_impl += INDENT+"char *fi%d = 0, *ti%d = 0, *end%d = 0;\n" % (i,i,i)
656 657 658 659
            code_impl += INDENT+"Py_ssize_t fstride%(i)d = from_mvs->strides[%(i)d];\n" % {'i':i}
            code_impl += INDENT+"Py_ssize_t fshape%(i)d = from_mvs->shape[%(i)d];\n" % {'i':i}
            code_impl += INDENT+"Py_ssize_t tstride%(i)d = to_mvs->strides[%(i)d];\n" % {'i':i}
            # code_impl += INDENT+"Py_ssize_t tshape%(i)d = to_mvs->shape[%(i)d];\n" % {'i':i}
660 661 662 663 664 665 666 667

        code_impl += INDENT+"end0 = fshape0 * fstride0 + from_mvs->data;\n"
        code_impl += INDENT+"for(fi0=from_buf, ti0=to_buf; fi0 < end0; fi0 += fstride0, ti0 += tstride0) {\n"
        for i in range(1, ndim):
            code_impl += INDENT*(i+1)+"end%(i)d = fshape%(i)d * fstride%(i)d + fi%(im1)d;\n" % {'i' : i, 'im1' : i-1}
            code_impl += INDENT*(i+1)+"for(fi%(i)d=fi%(im1)d, ti%(i)d=ti%(im1)d; fi%(i)d < end%(i)d; fi%(i)d += fstride%(i)d, ti%(i)d += tstride%(i)d) {\n" % {'i':i, 'im1':i-1}

        code_impl += INDENT*(ndim+1)+"*(%(dtype_decl)s*)(ti%(last_idx)d) = *(%(dtype_decl)s*)(fi%(last_idx)d);\n" % locals()
668 669 670

    # for-loop closing braces
    for k in range(ndim-1, -1, -1):
Kurt Smith's avatar
Kurt Smith committed
671
        code_impl += INDENT*(k+1)+"}\n"
672

673
    # init to_mvs->data and to_mvs shape/strides/suboffsets arrays.
Kurt Smith's avatar
Kurt Smith committed
674 675 676 677 678 679
    code_impl += INDENT+"temp_memview = to_mvs->memview;\n"
    code_impl += INDENT+"temp_data = to_mvs->data;\n"
    code_impl += INDENT+"to_mvs->memview = 0; to_mvs->data = 0;\n"
    code_impl += INDENT+"if(unlikely(-1 == __Pyx_init_memviewslice(temp_memview, %d, to_mvs))) {\n" % (ndim,)
    code_impl += INDENT*2+"return -1;\n"
    code_impl +=   INDENT+"}\n"
680

Kurt Smith's avatar
Kurt Smith committed
681
    code_impl += INDENT + "return 0;\n"
682

Kurt Smith's avatar
Kurt Smith committed
683
    code_impl += '}\n'
684

Kurt Smith's avatar
Kurt Smith committed
685
    return code_decl, code_impl
686

687 688 689 690 691 692 693 694
def get_axes_specs(env, axes):
    '''
    get_axes_specs(env, axes) -> list of (access, packing) specs for each axis.

    access is one of 'full', 'ptr' or 'direct'
    packing is one of 'contig', 'strided' or 'follow'
    '''

695
    cythonscope = env.global_scope().context.cython_scope
696
    cythonscope.load_cythonscope()
697 698 699 700 701 702 703 704 705 706 707 708 709 710
    viewscope = cythonscope.viewscope

    access_specs = tuple([viewscope.lookup(name)
                    for name in ('full', 'direct', 'ptr')])
    packing_specs = tuple([viewscope.lookup(name)
                    for name in ('contig', 'strided', 'follow')])

    is_f_contig, is_c_contig = False, False
    default_access, default_packing = 'direct', 'strided'
    cf_access, cf_packing = default_access, 'follow'

    axes_specs = []
    # analyse all axes.
    for idx, axis in enumerate(axes):
711
        if not axis.start.is_none:
712 713
            raise CompileError(axis.start.pos,  START_ERR)

714
        if not axis.stop.is_none:
715 716
            raise CompileError(axis.stop.pos, STOP_ERR)

717 718
        if axis.step.is_none:
            axes_specs.append((default_access, default_packing))
719 720

        elif isinstance(axis.step, IntNode):
721
            # the packing for the ::1 axis is contiguous,
722
            # all others are cf_packing.
723 724
            if axis.step.compile_time_value(env) != 1:
                raise CompileError(axis.step.pos, STEP_ERR)
725

726
            axes_specs.append((cf_access, 'cfcontig'))
727

728
        elif isinstance(axis.step, (NameNode, AttributeNode)):
729 730 731
            entry = _get_resolved_spec(env, axis.step)
            if entry.name in view_constant_to_access_packing:
                axes_specs.append(view_constant_to_access_packing[entry.name])
732
            else:
733
                raise CompilerError(axis.step.pos, INVALID_ERR)
734 735 736 737

        else:
            raise CompileError(axis.step.pos, INVALID_ERR)

738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797
    # First, find out if we have a ::1 somewhere
    contig_dim = 0
    is_contig = False
    for idx, (access, packing) in enumerate(axes_specs):
        if packing == 'cfcontig':
            if is_contig:
                raise CompileError(axis.step.pos, BOTH_CF_ERR)

            contig_dim = idx
            axes_specs[idx] = (access, 'contig')
            is_contig = True

    if is_contig:
        # We have a ::1 somewhere, see if we're C or Fortran contiguous
        if contig_dim == len(axes) - 1:
            is_c_contig = True
        else:
            is_f_contig = True

            if contig_dim and not axes_specs[contig_dim - 1][0] in ('full', 'ptr'):
                raise CompileError(axes[contig_dim].pos,
                                   "Fortran contiguous specifier must follow an indirect dimension")

        if is_c_contig:
            # Contiguous in the last dimension, find the last indirect dimension
            contig_dim = -1
            for idx, (access, packing) in enumerate(reversed(axes_specs)):
                if access in ('ptr', 'full'):
                    contig_dim = len(axes) - idx - 1

        # Replace 'strided' with 'follow' for any dimension following the last
        # indirect dimension, the first dimension or the dimension following
        # the ::1.
        #               int[::indirect, ::1, :, :]
        #                                    ^  ^
        #               int[::indirect, :, :, ::1]
        #                               ^  ^
        start = contig_dim + 1
        stop = len(axes) - is_c_contig
        for idx, (access, packing) in enumerate(axes_specs[start:stop]):
            idx = contig_dim + 1 + idx
            if access != 'direct':
                raise CompileError(axes[idx].pos,
                                   "Indirect dimension may not follow "
                                   "Fortran contiguous dimension")
            if packing == 'contig':
                raise CompileError(axes[idx].pos,
                                   "Dimension may not be contiguous")
            axes_specs[idx] = (access, cf_packing)

        if is_c_contig:
            # For C contiguity, we need to fix the 'contig' dimension
            # after the loop
            a, p = axes_specs[-1]
            axes_specs[-1] = a, 'contig'

    validate_axes_specs([axis.start.pos for axis in axes],
                        axes_specs,
                        is_c_contig,
                        is_f_contig)
798

799 800
    return axes_specs

Mark Florisson's avatar
Mark Florisson committed
801 802 803 804 805 806
def all(it):
    for item in it:
        if not item:
            return False
    return True

807
def is_cf_contig(specs):
808 809
    is_c_contig = is_f_contig = False

810 811
    if (len(specs) == 1 and specs == [('direct', 'contig')]):
        is_c_contig = True
812

Mark Florisson's avatar
Mark Florisson committed
813 814
    elif (specs[-1] == ('direct','contig') and
          all([axis == ('direct','follow') for axis in specs[:-1]])):
815 816 817
        # c_contiguous: 'follow', 'follow', ..., 'follow', 'contig'
        is_c_contig = True

818
    elif (len(specs) > 1 and
Mark Florisson's avatar
Mark Florisson committed
819 820
        specs[0] == ('direct','contig') and
        all([axis == ('direct','follow') for axis in specs[1:]])):
821 822 823
        # f_contiguous: 'contig', 'follow', 'follow', ..., 'follow'
        is_f_contig = True

824 825
    return is_c_contig, is_f_contig

826 827 828 829 830 831 832 833 834 835 836 837 838 839
def get_mode(specs):
    is_c_contig, is_f_contig = is_cf_contig(specs)

    if is_c_contig:
        return 'c'
    elif is_f_contig:
        return 'fortran'

    for access, packing in specs:
        if access in ('ptr', 'full'):
            return 'full'

    return 'strided'

840 841 842 843 844 845 846 847 848
view_constant_to_access_packing = {
    'generic':              ('full',   'strided'),
    'strided':              ('direct', 'strided'),
    'indirect':             ('ptr',    'strided'),
    'generic_contiguous':   ('full',   'contig'),
    'contiguous':           ('direct', 'contig'),
    'indirect_contiguous':  ('ptr',    'contig'),
}

849
def validate_axes_specs(positions, specs, is_c_contig, is_f_contig):
850 851 852 853

    packing_specs = ('contig', 'strided', 'follow')
    access_specs = ('direct', 'ptr', 'full')

854
    # is_c_contig, is_f_contig = is_cf_contig(specs)
855

856
    has_contig = has_follow = has_strided = has_generic_contig = False
857

858 859 860 861 862 863
    last_indirect_dimension = -1
    for idx, (access, packing) in enumerate(specs):
        if access == 'ptr':
            last_indirect_dimension = idx

    for idx, pos, (access, packing) in zip(xrange(len(specs)), positions, specs):
864 865 866 867 868 869 870 871 872

        if not (access in access_specs and
                packing in packing_specs):
            raise CompileError(pos, "Invalid axes specification.")

        if packing == 'strided':
            has_strided = True
        elif packing == 'contig':
            if has_contig:
873 874 875 876 877 878 879
                raise CompileError(pos, "Only one direct contiguous "
                                        "axis may be specified.")

            valid_contig_dims = last_indirect_dimension + 1, len(specs) - 1
            if idx not in valid_contig_dims and access != 'ptr':
                if last_indirect_dimension + 1 != len(specs) - 1:
                    dims = "dimensions %d and %d" % valid_contig_dims
880
                else:
881 882 883 884
                    dims = "dimension %d" % valid_contig_dims[0]

                raise CompileError(pos, "Only %s may be contiguous and direct" % dims)

885
            has_contig = access != 'ptr'
886 887 888 889 890 891
        elif packing == 'follow':
            if has_strided:
                raise CompileError(pos, "A memoryview cannot have both follow and strided axis specifiers.")
            if not (is_c_contig or is_f_contig):
                raise CompileError(pos, "Invalid use of the follow specifier.")

892 893 894
        if access in ('ptr', 'full'):
            has_strided = False

895 896 897 898 899 900 901 902 903 904 905 906 907 908
def _get_resolved_spec(env, spec):
    # spec must be a NameNode or an AttributeNode
    if isinstance(spec, NameNode):
        return _resolve_NameNode(env, spec)
    elif isinstance(spec, AttributeNode):
        return _resolve_AttributeNode(env, spec)
    else:
        raise CompileError(spec.pos, INVALID_ERR)

def _resolve_NameNode(env, node):
    try:
        resolved_name = env.lookup(node.name).name
    except AttributeError:
        raise CompileError(node.pos, INVALID_ERR)
909

910
    viewscope = env.global_scope().context.cython_scope.viewscope
911 912 913 914 915
    entry = viewscope.lookup(resolved_name)
    if entry is None:
        raise CompileError(node.pos, NOT_CIMPORTED_ERR)

    return entry
916 917 918 919 920 921 922 923 924 925 926 927 928

def _resolve_AttributeNode(env, node):
    path = []
    while isinstance(node, AttributeNode):
        path.insert(0, node.attribute)
        node = node.obj
    if isinstance(node, NameNode):
        path.insert(0, node.name)
    else:
        raise CompileError(node.pos, EXPR_ERR)
    modnames = path[:-1]
    # must be at least 1 module name, o/w not an AttributeNode.
    assert modnames
929 930 931 932

    scope = env
    for modname in modnames:
        mod = scope.lookup(modname)
933
        if not mod or not mod.as_module:
934 935 936 937
            raise CompileError(
                    node.pos, "undeclared name not builtin: %s" % modname)
        scope = mod.as_module

938 939 940 941 942
    entry = scope.lookup(path[-1])
    if not entry:
        raise CompileError(node.pos, "No such attribute '%s'" % path[-1])

    return entry
943

944 945 946
def load_memview_cy_utility(util_code_name, context=None, **kwargs):
    return CythonUtilityCode.load(util_code_name, "MemoryView.pyx",
                                  context=context, **kwargs)
947

948
def load_memview_c_utility(util_code_name, context=None, **kwargs):
949 950 951 952 953
    if context is None:
        return UtilityCode.load(util_code_name, "MemoryView_C.c", **kwargs)
    else:
        return TempitaUtilityCode.load(util_code_name, "MemoryView_C.c",
                                       context=context, **kwargs)
954

955 956 957 958
def use_cython_array_utility_code(env):
    env.global_scope().context.cython_scope.lookup('array_cwrapper').used = True
    env.use_utility_code(cython_array_utility_code)

Mark Florisson's avatar
Mark Florisson committed
959
context = {
960 961 962
    'memview_struct_name': memview_objstruct_cname,
    'max_dims': Options.buffer_max_dims,
    'memviewslice_name': memviewslice_cname,
Mark Florisson's avatar
Mark Florisson committed
963
    'memslice_init': memslice_entry_init,
964
}
965
memviewslice_declare_code = load_memview_c_utility(
966
        "MemviewSliceStruct",
Mark Florisson's avatar
Mark Florisson committed
967
        proto_block='utility_code_proto_before_types',
Mark Florisson's avatar
Mark Florisson committed
968
        context=context)
969

970
memviewslice_init_code = load_memview_c_utility(
971
    "MemviewSliceInit",
972
    context=dict(context, BUF_MAX_NDIMS=Options.buffer_max_dims),
973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988
    requires=[memviewslice_declare_code,
              Buffer.acquire_utility_code],
)

memviewslice_index_helpers = load_memview_c_utility("MemviewSliceIndex")

typeinfo_to_format_code = load_memview_cy_utility(
        "BufferFormatFromTypeInfo", requires=[Buffer._typeinfo_to_format_code])

view_utility_code = load_memview_cy_utility(
        "View.MemoryView",
        context=context,
        requires=[Buffer.GetAndReleaseBufferUtilityCode(),
                  Buffer.buffer_struct_declare_code,
                  Buffer.empty_bufstruct_utility,
                  memviewslice_init_code],
989
)
990

991 992 993 994 995
cython_array_utility_code = load_memview_cy_utility(
        "CythonArray",
        context=context,
        requires=[view_utility_code])

996 997 998 999 1000
# memview_fromslice_utility_code = load_memview_cy_utility(
        # "MemviewFromSlice",
        # context=context,
        # requires=[view_utility_code],
# )