Commit 3d7b79ed authored by Michael Droettboom's avatar Michael Droettboom

Properly handle Numpy scalar types in python2js

parent 8f38cc27
......@@ -2,6 +2,9 @@
#include <emscripten.h>
#include <endian.h>
#include <stdint.h>
#include "hiwire.h"
#include "jsproxy.h"
#include "pyproxy.h"
......@@ -90,25 +93,6 @@ exit:
return HW_ERROR;
}
static int
is_type_name(PyObject* x, const char* name)
{
PyObject* x_type = PyObject_Type(x);
if (x_type == NULL) {
// If we can't get a type, that's probably ok in this case...
PyErr_Clear();
return 0;
}
PyObject* x_type_name = PyObject_Repr(x_type);
Py_DECREF(x_type);
int result = (PyUnicode_CompareWithASCIIString(x_type_name, name) == 0);
Py_DECREF(x_type_name);
return result;
}
int
_python2js_add_to_cache(PyObject* map, PyObject* pyparent, int jsparent);
......@@ -250,6 +234,315 @@ _python2js_dict(PyObject* x, PyObject* map)
return jsdict;
}
typedef int(scalar_converter)(char*);
static int
_convert_bool(char* data)
{
char v = *((char*)data);
if (v) {
return hiwire_true();
} else {
return hiwire_false();
}
}
static int
_convert_int8(char* data)
{
int8_t v = *((int8_t*)data);
return hiwire_int(v);
}
static int
_convert_uint8(char* data)
{
uint8_t v = *((uint8_t*)data);
return hiwire_int(v);
}
static int
_convert_int16(char* data)
{
int16_t v = *((int16_t*)data);
return hiwire_int(v);
}
static int
_convert_int16_swap(char* data)
{
int16_t v = *((int16_t*)data);
return hiwire_int(be16toh(v));
}
static int
_convert_uint16(char* data)
{
uint16_t v = *((uint16_t*)data);
return hiwire_int(v);
}
static int
_convert_uint16_swap(char* data)
{
uint16_t v = *((uint16_t*)data);
return hiwire_int(be16toh(v));
}
static int
_convert_int32(char* data)
{
int32_t v = *((int32_t*)data);
return hiwire_int(v);
}
static int
_convert_int32_swap(char* data)
{
int32_t v = *((int32_t*)data);
return hiwire_int(be32toh(v));
}
static int
_convert_uint32(char* data)
{
uint32_t v = *((uint32_t*)data);
return hiwire_int(v);
}
static int
_convert_uint32_swap(char* data)
{
uint32_t v = *((uint32_t*)data);
return hiwire_int(be32toh(v));
}
static int
_convert_int64(char* data)
{
int64_t v = *((int64_t*)data);
return hiwire_int(v);
}
static int
_convert_int64_swap(char* data)
{
int64_t v = *((int64_t*)data);
return hiwire_int(be64toh(v));
}
static int
_convert_uint64(char* data)
{
uint64_t v = *((uint64_t*)data);
return hiwire_int(v);
}
static int
_convert_uint64_swap(char* data)
{
uint64_t v = *((uint64_t*)data);
return hiwire_int(be64toh(v));
}
static int
_convert_float32(char* data)
{
float v = *((float*)data);
return hiwire_double(v);
}
static int
_convert_float32_swap(char* data)
{
union float32_t
{
uint32_t i;
float f;
} v;
v.f = *((float*)data);
v.i = be32toh(v.i);
return hiwire_double(v.f);
}
static int
_convert_float64(char* data)
{
double v = *((double*)data);
return hiwire_double(v);
}
static int
_convert_float64_swap(char* data)
{
union float64_t
{
uint64_t i;
double f;
} v;
v.f = *((double*)data);
v.i = be64toh(v.i);
return hiwire_double(v.f);
}
static scalar_converter*
_python2js_buffer_get_converter(Py_buffer* buff)
{
// Uses Python's struct typecodes as defined here:
// https://docs.python.org/3.7/library/array.html
char format;
char swap;
if (buff->format == NULL) {
swap = 0;
format = 'B';
} else {
switch (buff->format[0]) {
case '>':
case '!':
swap = 1;
format = buff->format[1];
break;
case '=':
case '<':
case '@':
swap = 0;
format = buff->format[1];
break;
default:
swap = 0;
format = buff->format[0];
}
}
switch (format) {
case 'c':
case 'b':
return _convert_int8;
case 'B':
return _convert_uint8;
case '?':
return _convert_bool;
case 'h':
if (swap) {
return _convert_int16_swap;
} else {
return _convert_int16;
}
case 'H':
if (swap) {
return _convert_uint16_swap;
} else {
return _convert_uint16;
}
case 'i':
case 'l':
case 'n':
if (swap) {
return _convert_int32_swap;
} else {
return _convert_int32;
}
case 'I':
case 'L':
case 'N':
if (swap) {
return _convert_uint32_swap;
} else {
return _convert_uint32;
}
case 'q':
if (swap) {
return _convert_int64_swap;
} else {
return _convert_int64;
}
case 'Q':
if (swap) {
return _convert_uint64_swap;
} else {
return _convert_uint64;
}
case 'f':
if (swap) {
return _convert_float32_swap;
} else {
return _convert_float32;
}
case 'd':
if (swap) {
return _convert_float64_swap;
} else {
return _convert_float64;
}
default:
return NULL;
}
}
static int
_python2js_buffer_recursive(Py_buffer* buff,
char* ptr,
int dim,
scalar_converter* convert)
{
// This function is basically a manual conversion of `recursive_tolist` in
// Numpy to use the Python buffer interface and output Javascript.
Py_ssize_t i, n, stride;
int jsarray, jsitem;
if (dim >= buff->ndim) {
return convert(ptr);
}
n = buff->shape[dim];
stride = buff->strides[dim];
jsarray = hiwire_array();
for (i = 0; i < n; ++i) {
jsitem = _python2js_buffer_recursive(buff, ptr, dim + 1, convert);
if (jsitem == HW_ERROR) {
hiwire_decref(jsarray);
return HW_ERROR;
}
hiwire_push_array(jsarray, jsitem);
hiwire_decref(jsitem);
ptr += stride;
}
return jsarray;
}
static int
_python2js_buffer(PyObject* x)
{
PyObject* memoryview = PyMemoryView_FromObject(x);
if (memoryview == NULL) {
PyErr_Clear();
return HW_ERROR;
}
Py_buffer* buff;
buff = PyMemoryView_GET_BUFFER(memoryview);
scalar_converter* convert = _python2js_buffer_get_converter(buff);
if (convert == NULL) {
Py_DECREF(memoryview);
return HW_ERROR;
}
int result = _python2js_buffer_recursive(buff, buff->buf, 0, convert);
Py_DECREF(memoryview);
return result;
}
static int
_python2js(PyObject* x, PyObject* map)
{
......@@ -269,12 +562,18 @@ _python2js(PyObject* x, PyObject* map)
return _python2js_bytes(x);
} else if (JsProxy_Check(x)) {
return JsProxy_AsJs(x);
} else if (PyList_Check(x) || PyTuple_Check(x) ||
is_type_name(x, "<class 'numpy.ndarray'>")) {
} else if (PyList_Check(x) || PyTuple_Check(x)) {
return _python2js_sequence(x, map);
} else if (PyDict_Check(x)) {
return _python2js_dict(x, map);
} else {
int jsbuff = _python2js_buffer(x);
if (jsbuff != HW_ERROR) {
return jsbuff;
}
if (PySequence_Check(x)) {
return _python2js_sequence(x, map);
}
Py_INCREF(x);
return pyproxy_new((int)x);
}
......
......@@ -67,6 +67,52 @@ def test_python2js_long_ints(selenium):
assert selenium.run('-2**31') == -2**31
def test_python2js_numpy_dtype(selenium_standalone):
selenium = selenium_standalone
selenium.load_package('numpy')
selenium.run("import numpy as np")
expected_result = [[[0, 1], [2, 3]],
[[4, 5], [6, 7]]]
for order in ('C', 'F'):
for dtype in (
'int8',
'uint8',
'int16',
'uint16',
'int32',
'uint32',
'int64',
'uint64',
'float32',
'float64'
):
assert selenium.run(
f"""
x = np.arange(8, dtype=np.{dtype})
x = x.reshape((2, 2, 2))
x = x.copy({order!r})
x
"""
) == expected_result
assert selenium.run(
"""
x.byteswap().newbyteorder()
"""
) == expected_result
assert selenium.run("np.array([True, False])") == [True, False]
selenium.run(
"x = np.array([['string1', 'string2'], ['string3', 'string4']])"
)
assert selenium.run_js("return pyodide.pyimport('x').length") == 2
assert selenium.run_js("return pyodide.pyimport('x')[0][0]") == 'string1'
assert selenium.run_js("return pyodide.pyimport('x')[1][1]") == 'string4'
def test_pythonexc2js(selenium):
try:
selenium.run_js('return pyodide.runPython("5 / 0")')
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment