Commit 1f978500 authored by Michael Droettboom's avatar Michael Droettboom Committed by GitHub

Merge pull request #287 from mdboom/numpy-arrays-to-typed-arrays

Convert Numpy arrays to TypedArrays when possible
parents 75ed8c05 dabc8ecc
......@@ -63,7 +63,7 @@ all: build/pyodide.asm.js \
build/pyodide.asm.js: src/main.bc src/jsimport.bc src/jsproxy.bc src/js2python.bc \
src/pyimport.bc src/pyproxy.bc src/python2js.bc \
src/pyimport.bc src/pyproxy.bc src/python2js.bc src/python2js_buffer.bc \
src/runpython.bc src/hiwire.bc
[ -d build ] || mkdir build
$(CXX) -s EXPORT_NAME="'pyodide'" -o build/pyodide.asm.html $(filter %.bc,$^) \
......
......@@ -96,6 +96,46 @@ EM_JS(int, hiwire_bytes, (int ptr, int len), {
return Module.hiwire_new_value(bytes);
});
EM_JS(int, hiwire_int8array, (int ptr, int len), {
var array = new Int8Array(Module.HEAPU8.buffer, ptr, len);
return Module.hiwire_new_value(array);
})
EM_JS(int, hiwire_uint8array, (int ptr, int len), {
var array = new Uint8Array(Module.HEAPU8.buffer, ptr, len);
return Module.hiwire_new_value(array);
})
EM_JS(int, hiwire_int16array, (int ptr, int len), {
var array = new Int16Array(Module.HEAPU8.buffer, ptr, len);
return Module.hiwire_new_value(array);
})
EM_JS(int, hiwire_uint16array, (int ptr, int len), {
var array = new Uint16Array(Module.HEAPU8.buffer, ptr, len);
return Module.hiwire_new_value(array);
})
EM_JS(int, hiwire_int32array, (int ptr, int len), {
var array = new Int32Array(Module.HEAPU8.buffer, ptr, len);
return Module.hiwire_new_value(array);
})
EM_JS(int, hiwire_uint32array, (int ptr, int len), {
var array = new Uint32Array(Module.HEAPU8.buffer, ptr, len);
return Module.hiwire_new_value(array);
})
EM_JS(int, hiwire_float32array, (int ptr, int len), {
var array = new Float32Array(Module.HEAPU8.buffer, ptr, len);
return Module.hiwire_new_value(array);
})
EM_JS(int, hiwire_float64array, (int ptr, int len), {
var array = new Float64Array(Module.HEAPU8.buffer, ptr, len);
return Module.hiwire_new_value(array);
})
int
hiwire_undefined()
{
......@@ -345,3 +385,9 @@ EM_JS(int, hiwire_get_dtype, (int idobj), {
}
return dtype;
});
EM_JS(int, hiwire_subarray, (int idarr, int start, int end), {
var jsarr = Module.hiwire_get_value(idarr);
var jssub = jsarr.subarray(start, end);
return Module.hiwire_new_value(jssub);
});
......@@ -109,6 +109,94 @@ hiwire_string_ascii(int ptr);
int
hiwire_bytes(int ptr, int len);
/**
* Create a new Javascript Int8Array, given a pointer to a buffer and a
* length, in bytes.
*
* The array's data is not copied.
*
* Returns: New reference
*/
int
hiwire_int8array(int ptr, int len);
/**
* Create a new Javascript Uint8Array, given a pointer to a buffer and a
* length, in bytes.
*
* The array's data is not copied.
*
* Returns: New reference
*/
int
hiwire_uint8array(int ptr, int len);
/**
* Create a new Javascript Int16Array, given a pointer to a buffer and a
* length, in bytes.
*
* The array's data is not copied.
*
* Returns: New reference
*/
int
hiwire_int16array(int ptr, int len);
/**
* Create a new Javascript Uint16Array, given a pointer to a buffer and a
* length, in bytes.
*
* The array's data is not copied.
*
* Returns: New reference
*/
int
hiwire_uint16array(int ptr, int len);
/**
* Create a new Javascript Int32Array, given a pointer to a buffer and a
* length, in bytes.
*
* The array's data is not copied.
*
* Returns: New reference
*/
int
hiwire_int32array(int ptr, int len);
/**
* Create a new Javascript Uint32Array, given a pointer to a buffer and a
* length, in bytes.
*
* The array's data is not copied.
*
* Returns: New reference
*/
int
hiwire_uint32array(int ptr, int len);
/**
* Create a new Javascript Float32Array, given a pointer to a buffer and a
* length, in bytes.
*
* The array's data is not copied.
*
* Returns: New reference
*/
int
hiwire_float32array(int ptr, int len);
/**
* Create a new Javascript Float64Array, given a pointer to a buffer and a
* length, in bytes.
*
* The array's data is not copied.
*
* Returns: New reference
*/
int
hiwire_float64array(int ptr, int len);
/**
* Create a new Javascript undefined value.
*
......@@ -435,4 +523,10 @@ hiwire_copy_to_ptr(int idobj, int ptr);
int
hiwire_get_dtype(int idobj);
/**
* Get a subarray from a TypedArray
*/
int
hiwire_subarray(int idarr, int start, int end);
#endif /* HIWIRE_H */
......@@ -2,13 +2,12 @@
#include <emscripten.h>
#include <endian.h>
#include <stdint.h>
#include "hiwire.h"
#include "jsproxy.h"
#include "pyproxy.h"
#include "python2js_buffer.h"
static PyObject* tbmod = NULL;
static int
......@@ -234,315 +233,6 @@ _python2js_dict(PyObject* x, PyObject* map)
return jsdict;
}
typedef int(scalar_converter)(char*);
static int
_convert_bool(char* data)
{
char v = *((char*)data);
if (v) {
return hiwire_true();
} else {
return hiwire_false();
}
}
static int
_convert_int8(char* data)
{
int8_t v = *((int8_t*)data);
return hiwire_int(v);
}
static int
_convert_uint8(char* data)
{
uint8_t v = *((uint8_t*)data);
return hiwire_int(v);
}
static int
_convert_int16(char* data)
{
int16_t v = *((int16_t*)data);
return hiwire_int(v);
}
static int
_convert_int16_swap(char* data)
{
int16_t v = *((int16_t*)data);
return hiwire_int(be16toh(v));
}
static int
_convert_uint16(char* data)
{
uint16_t v = *((uint16_t*)data);
return hiwire_int(v);
}
static int
_convert_uint16_swap(char* data)
{
uint16_t v = *((uint16_t*)data);
return hiwire_int(be16toh(v));
}
static int
_convert_int32(char* data)
{
int32_t v = *((int32_t*)data);
return hiwire_int(v);
}
static int
_convert_int32_swap(char* data)
{
int32_t v = *((int32_t*)data);
return hiwire_int(be32toh(v));
}
static int
_convert_uint32(char* data)
{
uint32_t v = *((uint32_t*)data);
return hiwire_int(v);
}
static int
_convert_uint32_swap(char* data)
{
uint32_t v = *((uint32_t*)data);
return hiwire_int(be32toh(v));
}
static int
_convert_int64(char* data)
{
int64_t v = *((int64_t*)data);
return hiwire_int(v);
}
static int
_convert_int64_swap(char* data)
{
int64_t v = *((int64_t*)data);
return hiwire_int(be64toh(v));
}
static int
_convert_uint64(char* data)
{
uint64_t v = *((uint64_t*)data);
return hiwire_int(v);
}
static int
_convert_uint64_swap(char* data)
{
uint64_t v = *((uint64_t*)data);
return hiwire_int(be64toh(v));
}
static int
_convert_float32(char* data)
{
float v = *((float*)data);
return hiwire_double(v);
}
static int
_convert_float32_swap(char* data)
{
union float32_t
{
uint32_t i;
float f;
} v;
v.f = *((float*)data);
v.i = be32toh(v.i);
return hiwire_double(v.f);
}
static int
_convert_float64(char* data)
{
double v = *((double*)data);
return hiwire_double(v);
}
static int
_convert_float64_swap(char* data)
{
union float64_t
{
uint64_t i;
double f;
} v;
v.f = *((double*)data);
v.i = be64toh(v.i);
return hiwire_double(v.f);
}
static scalar_converter*
_python2js_buffer_get_converter(Py_buffer* buff)
{
// Uses Python's struct typecodes as defined here:
// https://docs.python.org/3.7/library/array.html
char format;
char swap;
if (buff->format == NULL) {
swap = 0;
format = 'B';
} else {
switch (buff->format[0]) {
case '>':
case '!':
swap = 1;
format = buff->format[1];
break;
case '=':
case '<':
case '@':
swap = 0;
format = buff->format[1];
break;
default:
swap = 0;
format = buff->format[0];
}
}
switch (format) {
case 'c':
case 'b':
return _convert_int8;
case 'B':
return _convert_uint8;
case '?':
return _convert_bool;
case 'h':
if (swap) {
return _convert_int16_swap;
} else {
return _convert_int16;
}
case 'H':
if (swap) {
return _convert_uint16_swap;
} else {
return _convert_uint16;
}
case 'i':
case 'l':
case 'n':
if (swap) {
return _convert_int32_swap;
} else {
return _convert_int32;
}
case 'I':
case 'L':
case 'N':
if (swap) {
return _convert_uint32_swap;
} else {
return _convert_uint32;
}
case 'q':
if (swap) {
return _convert_int64_swap;
} else {
return _convert_int64;
}
case 'Q':
if (swap) {
return _convert_uint64_swap;
} else {
return _convert_uint64;
}
case 'f':
if (swap) {
return _convert_float32_swap;
} else {
return _convert_float32;
}
case 'd':
if (swap) {
return _convert_float64_swap;
} else {
return _convert_float64;
}
default:
return NULL;
}
}
static int
_python2js_buffer_recursive(Py_buffer* buff,
char* ptr,
int dim,
scalar_converter* convert)
{
// This function is basically a manual conversion of `recursive_tolist` in
// Numpy to use the Python buffer interface and output Javascript.
Py_ssize_t i, n, stride;
int jsarray, jsitem;
if (dim >= buff->ndim) {
return convert(ptr);
}
n = buff->shape[dim];
stride = buff->strides[dim];
jsarray = hiwire_array();
for (i = 0; i < n; ++i) {
jsitem = _python2js_buffer_recursive(buff, ptr, dim + 1, convert);
if (jsitem == HW_ERROR) {
hiwire_decref(jsarray);
return HW_ERROR;
}
hiwire_push_array(jsarray, jsitem);
hiwire_decref(jsitem);
ptr += stride;
}
return jsarray;
}
static int
_python2js_buffer(PyObject* x)
{
PyObject* memoryview = PyMemoryView_FromObject(x);
if (memoryview == NULL) {
PyErr_Clear();
return HW_ERROR;
}
Py_buffer* buff;
buff = PyMemoryView_GET_BUFFER(memoryview);
scalar_converter* convert = _python2js_buffer_get_converter(buff);
if (convert == NULL) {
Py_DECREF(memoryview);
return HW_ERROR;
}
int result = _python2js_buffer_recursive(buff, buff->buf, 0, convert);
Py_DECREF(memoryview);
return result;
}
static int
_python2js(PyObject* x, PyObject* map)
{
......
This diff is collapsed.
#ifndef PYTHON2JS_BUFFER_H
#define PYTHON2JS_BUFFER_H
/** Utilities to convert Python buffer objects to Javascript.
*/
#include <Python.h>
/** Convert a Python buffer object to a Javascript object.
*
* \param The Python object
* \return The Javascript object -- might be an Error object in the case of an
* exception.
*/
int
_python2js_buffer(PyObject* x);
#endif /* PYTHON2JS_BUFFER_H */
def test_numpy(selenium):
selenium.load_package("numpy")
selenium.run("import numpy")
x = selenium.run("numpy.zeros((32, 64))")
assert len(x) == 32
assert all(len(y) == 64 for y in x)
for y in x:
assert all(z == 0 for z in y)
selenium.run("x = numpy.ones((32, 64))")
assert selenium.run_js("return pyodide.pyimport('x').length == 32")
for i in range(32):
assert selenium.run_js(
f"return pyodide.pyimport('x')[{i}].length == 64"
)
for j in range(64):
assert selenium.run_js(
f"return pyodide.pyimport('x')[{i}][{j}] == 1"
)
def test_typed_arrays(selenium):
......
......@@ -76,6 +76,16 @@ def test_python2js_numpy_dtype(selenium_standalone):
expected_result = [[[0, 1], [2, 3]],
[[4, 5], [6, 7]]]
def assert_equal():
# We have to do this an element at a time, since the Selenium driver
# for Firefox does not convert TypedArrays to Python correctly
for i in range(2):
for j in range(2):
for k in range(2):
assert selenium.run_js(
f"return pyodide.pyimport('x')[{i}][{j}][{k}]"
) == expected_result[i][j][k]
for order in ('C', 'F'):
for dtype in (
'int8',
......@@ -89,19 +99,41 @@ def test_python2js_numpy_dtype(selenium_standalone):
'float32',
'float64'
):
assert selenium.run(
selenium.run(
f"""
x = np.arange(8, dtype=np.{dtype})
x = x.reshape((2, 2, 2))
x = x.copy({order!r})
x
"""
) == expected_result
assert selenium.run(
)
assert_equal()
classname = selenium.run_js(
"return pyodide.pyimport('x')[0][0].constructor.name"
)
if order == 'C' and dtype not in ('uint64', 'int64'):
# Here we expect a TypedArray subclass, such as Uint8Array, but
# not a plain-old Array
assert classname.endswith('Array')
assert classname != 'Array'
else:
assert classname == 'Array'
selenium.run(
"""
x.byteswap().newbyteorder()
x = x.byteswap().newbyteorder()
"""
) == expected_result
)
assert_equal()
classname = selenium.run_js(
"return pyodide.pyimport('x')[0][0].constructor.name"
)
if order == 'C' and dtype in ('int8', 'uint8'):
# Here we expect a TypedArray subclass, such as Uint8Array, but
# not a plain-old Array -- but only for single byte types where
# endianness doesn't matter
assert classname.endswith('Array')
assert classname != 'Array'
else:
assert classname == 'Array'
assert selenium.run("np.array([True, False])") == [True, False]
......@@ -532,13 +564,16 @@ def test_recursive_dict(selenium_standalone):
def test_runpythonasync(selenium_standalone):
output = selenium_standalone.run_async(
selenium_standalone.run_async(
"""
import numpy as np
np.zeros(5)
x = np.zeros(5)
"""
)
assert list(output) == [0, 0, 0, 0, 0]
for i in range(5):
assert selenium_standalone.run_js(
f"return pyodide.pyimport('x')[{i}] == 0"
)
def test_runpythonasync_different_package_name(selenium_standalone):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment