Commit 19bf0665 authored by Marius Wachtler's avatar Marius Wachtler

str: use tp_as_sequence instead of tp_as_number

string is special in that it is a c++ type which has tp_as_number and tp_as_sequence.
This causes problems because when we fixup the slot dispatcher we will set the tp_as_number fields but not the
tp_as_sequence because setting both can cause problems.
Some extensions (e.g. numpy) require that we use the sq_* functions instead of nb_*.
Therefore clear the tp_as_number fields (except nb_remainder which cpython has set too because it is not part of
tp_as_sequence).
parent a6fca70a
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
#include "Python.h" #include "Python.h"
#include <stddef.h>
#include "stringlib/stringdefs.h" #include "stringlib/stringdefs.h"
#include "stringlib/fastsearch.h" #include "stringlib/fastsearch.h"
...@@ -14,6 +16,8 @@ ...@@ -14,6 +16,8 @@
#include "stringlib/string_format.h" #include "stringlib/string_format.h"
#define PyStringObject_SIZE (offsetof(PyStringObject, ob_sval) + 1)
// do_string_format needs to be declared as a static function, since it's used by both stringobject.c // do_string_format needs to be declared as a static function, since it's used by both stringobject.c
// and unicodeobject.c. We want to access it from str.cpp, though, so just use this little forwarding // and unicodeobject.c. We want to access it from str.cpp, though, so just use this little forwarding
// function. // function.
...@@ -958,3 +962,55 @@ PyObject* string_replace(PyStringObject *self, PyObject *from, PyObject* to, PyO ...@@ -958,3 +962,55 @@ PyObject* string_replace(PyStringObject *self, PyObject *from, PyObject* to, PyO
from_s, from_len, from_s, from_len,
to_s, to_len, count); to_s, to_len, count);
} }
PyObject* string_repeat(register PyStringObject *a, register Py_ssize_t n)
{
register Py_ssize_t i;
register Py_ssize_t j;
register Py_ssize_t size;
register PyStringObject *op;
size_t nbytes;
if (n < 0)
n = 0;
/* watch out for overflows: the size can overflow int,
* and the # of bytes needed can overflow size_t
*/
size = Py_SIZE(a) * n;
if (n && size / n != Py_SIZE(a)) {
PyErr_SetString(PyExc_OverflowError,
"repeated string is too long");
return NULL;
}
if (size == Py_SIZE(a) && PyString_CheckExact(a)) {
Py_INCREF(a);
return (PyObject *)a;
}
nbytes = (size_t)size;
if (nbytes + PyStringObject_SIZE <= nbytes) {
PyErr_SetString(PyExc_OverflowError,
"repeated string is too long");
return NULL;
}
op = (PyStringObject *)PyObject_MALLOC(PyStringObject_SIZE + nbytes);
if (op == NULL)
return PyErr_NoMemory();
PyObject_INIT_VAR(op, &PyString_Type, size);
op->ob_shash = -1;
op->ob_sstate = SSTATE_NOT_INTERNED;
op->ob_sval[size] = '\0';
if (Py_SIZE(a) == 1 && n > 0) {
memset(op->ob_sval, a->ob_sval[0] , n);
return (PyObject *) op;
}
i = 0;
if (i < size) {
Py_MEMCPY(op->ob_sval, a->ob_sval, Py_SIZE(a));
i = Py_SIZE(a);
}
while (i < size) {
j = (i <= size-i) ? i : size-i;
Py_MEMCPY(op->ob_sval+i, op->ob_sval, j);
i += j;
}
return (PyObject *) op;
}
...@@ -42,6 +42,7 @@ extern "C" PyObject* string_find(PyStringObject* self, PyObject* args) noexcept; ...@@ -42,6 +42,7 @@ extern "C" PyObject* string_find(PyStringObject* self, PyObject* args) noexcept;
extern "C" PyObject* string_index(PyStringObject* self, PyObject* args) noexcept; extern "C" PyObject* string_index(PyStringObject* self, PyObject* args) noexcept;
extern "C" PyObject* string_rindex(PyStringObject* self, PyObject* args) noexcept; extern "C" PyObject* string_rindex(PyStringObject* self, PyObject* args) noexcept;
extern "C" PyObject* string_rfind(PyStringObject* self, PyObject* args) noexcept; extern "C" PyObject* string_rfind(PyStringObject* self, PyObject* args) noexcept;
extern "C" PyObject* string_repeat(PyStringObject* a, Py_ssize_t n) noexcept;
extern "C" PyObject* string_replace(PyStringObject* self, PyObject* args) noexcept; extern "C" PyObject* string_replace(PyStringObject* self, PyObject* args) noexcept;
extern "C" PyObject* string_splitlines(PyStringObject* self, PyObject* args) noexcept; extern "C" PyObject* string_splitlines(PyStringObject* self, PyObject* args) noexcept;
extern "C" PyObject* string__format__(PyObject* self, PyObject* args) noexcept; extern "C" PyObject* string__format__(PyObject* self, PyObject* args) noexcept;
...@@ -336,12 +337,12 @@ extern "C" PyObject* PyString_FromFormat(const char* format, ...) noexcept { ...@@ -336,12 +337,12 @@ extern "C" PyObject* PyString_FromFormat(const char* format, ...) noexcept {
return ret; return ret;
} }
extern "C" Box* strAdd(BoxedString* lhs, Box* _rhs) { template <ExceptionStyle S> Box* strAdd(BoxedString* lhs, Box* _rhs) noexcept(S == CAPI) {
assert(PyString_Check(lhs)); assert(PyString_Check(lhs));
if (isSubclass(_rhs->cls, unicode_cls)) { if (PyUnicode_Check(_rhs)) {
Box* rtn = PyUnicode_Concat(lhs, _rhs); Box* rtn = PyUnicode_Concat(lhs, _rhs);
if (!rtn) if (!rtn && S == CXX)
throwCAPIException(); throwCAPIException();
return rtn; return rtn;
} }
...@@ -349,14 +350,16 @@ extern "C" Box* strAdd(BoxedString* lhs, Box* _rhs) { ...@@ -349,14 +350,16 @@ extern "C" Box* strAdd(BoxedString* lhs, Box* _rhs) {
if (!PyString_Check(_rhs)) { if (!PyString_Check(_rhs)) {
if (PyByteArray_Check(_rhs)) { if (PyByteArray_Check(_rhs)) {
Box* rtn = PyByteArray_Concat(lhs, _rhs); Box* rtn = PyByteArray_Concat(lhs, _rhs);
if (!rtn) if (!rtn && S == CXX)
throwCAPIException(); throwCAPIException();
return rtn; return rtn;
} else { } else {
// This is a compatibility break with CPython, which has their sq_concat method if (S == CXX)
// directly throw a TypeError. Since we're not implementing this as a sq_concat, raiseExcHelper(TypeError, "cannot concatenate 'str' and '%.200s' objects", _rhs->cls->tp_name);
// Give NotImplemented for now. else {
return incref(NotImplemented); PyErr_Format(PyExc_TypeError, "cannot concatenate 'str' and '%.200s' objects", _rhs->cls->tp_name);
return NULL;
}
} }
} }
...@@ -1171,6 +1174,14 @@ Box* strRMod(BoxedString* lhs, Box* rhs) { ...@@ -1171,6 +1174,14 @@ Box* strRMod(BoxedString* lhs, Box* rhs) {
return rtn; return rtn;
} }
static PyObject* str_mod(PyObject* v, PyObject* w) noexcept {
if (!PyString_Check(v)) {
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
return PyString_Format(v, w);
}
extern "C" Box* strMul(BoxedString* lhs, Box* rhs) { extern "C" Box* strMul(BoxedString* lhs, Box* rhs) {
assert(PyString_Check(lhs)); assert(PyString_Check(lhs));
...@@ -1194,13 +1205,10 @@ extern "C" Box* strMul(BoxedString* lhs, Box* rhs) { ...@@ -1194,13 +1205,10 @@ extern "C" Box* strMul(BoxedString* lhs, Box* rhs) {
if (n <= 0) if (n <= 0)
return incref(EmptyString); return incref(EmptyString);
// TODO: use createUninitializedString and getWriteableStringContents Box* rtn = string_repeat((PyStringObject*)lhs, n);
int sz = lhs->size(); if (!rtn)
std::string buf(sz * n, '\0'); throwCAPIException();
for (int i = 0; i < n; i++) { return rtn;
memcpy(&buf[sz * i], lhs->data(), sz);
}
return boxString(buf);
} }
Box* str_richcompare(Box* lhs, Box* rhs, int op) { Box* str_richcompare(Box* lhs, Box* rhs, int op) {
...@@ -2574,22 +2582,17 @@ extern "C" int _PyString_Resize(PyObject** pv, Py_ssize_t newsize) noexcept { ...@@ -2574,22 +2582,17 @@ extern "C" int _PyString_Resize(PyObject** pv, Py_ssize_t newsize) noexcept {
} }
extern "C" void PyString_Concat(register PyObject** pv, register PyObject* w) noexcept { extern "C" void PyString_Concat(register PyObject** pv, register PyObject* w) noexcept {
try { if (*pv == NULL)
if (*pv == NULL) return;
return;
AUTO_DECREF(*pv); AUTO_DECREF(*pv);
if (w == NULL || !PyString_Check(*pv)) { if (w == NULL || !PyString_Check(*pv)) {
*pv = NULL;
return;
}
*pv = strAdd((BoxedString*)*pv, w);
} catch (ExcInfo e) {
setCAPIException(e);
*pv = NULL; *pv = NULL;
return;
} }
*pv = strAdd<CAPI>((BoxedString*)*pv, w);
} }
extern "C" void PyString_ConcatAndDel(register PyObject** pv, register PyObject* w) noexcept { extern "C" void PyString_ConcatAndDel(register PyObject** pv, register PyObject* w) noexcept {
...@@ -2931,7 +2934,7 @@ void setupStr() { ...@@ -2931,7 +2934,7 @@ void setupStr() {
str_cls->giveAttr("format", new BoxedFunction(FunctionMetadata::create((void*)strFormat, UNKNOWN, 1, true, true))); str_cls->giveAttr("format", new BoxedFunction(FunctionMetadata::create((void*)strFormat, UNKNOWN, 1, true, true)));
str_cls->giveAttr("__add__", new BoxedFunction(FunctionMetadata::create((void*)strAdd, UNKNOWN, 2))); str_cls->giveAttr("__add__", new BoxedFunction(FunctionMetadata::create((void*)strAdd<CXX>, UNKNOWN, 2)));
str_cls->giveAttr("__mod__", new BoxedFunction(FunctionMetadata::create((void*)strMod, UNKNOWN, 2))); str_cls->giveAttr("__mod__", new BoxedFunction(FunctionMetadata::create((void*)strMod, UNKNOWN, 2)));
str_cls->giveAttr("__rmod__", new BoxedFunction(FunctionMetadata::create((void*)strRMod, UNKNOWN, 2))); str_cls->giveAttr("__rmod__", new BoxedFunction(FunctionMetadata::create((void*)strRMod, UNKNOWN, 2)));
str_cls->giveAttr("__mul__", new BoxedFunction(FunctionMetadata::create((void*)strMul, UNKNOWN, 2))); str_cls->giveAttr("__mul__", new BoxedFunction(FunctionMetadata::create((void*)strMul, UNKNOWN, 2)));
...@@ -2970,15 +2973,27 @@ void setupStr() { ...@@ -2970,15 +2973,27 @@ void setupStr() {
add_operators(str_cls); add_operators(str_cls);
str_cls->freeze(); str_cls->freeze();
// string is special in that it is a c++ type which has tp_as_number and tp_as_sequence.
// This causes problems because when we fixup the slot dispatcher we will set the tp_as_number fields but not the
// tp_as_sequence because setting both can cause problems.
// Some extensions (e.g. numpy) require that we use the sq_* functions instead of nb_*.
// Therefore clear the tp_as_number fields (except nb_remainder which cpython has set too because it is not part of
// tp_as_sequence).
memset(&str_as_number, 0, sizeof(str_as_number));
str_cls->tp_as_number->nb_remainder = str_mod;
str_cls->tp_repr = str_repr; str_cls->tp_repr = str_repr;
str_cls->tp_str = str_str; str_cls->tp_str = str_str;
str_cls->tp_print = string_print; str_cls->tp_print = string_print;
str_cls->tp_iter = (decltype(str_cls->tp_iter))strIter; str_cls->tp_iter = (decltype(str_cls->tp_iter))strIter;
str_cls->tp_hash = (hashfunc)str_hash; str_cls->tp_hash = (hashfunc)str_hash;
str_cls->tp_as_sequence->sq_length = str_length; str_cls->tp_as_sequence->sq_concat = (binaryfunc)strAdd<CAPI>;
str_cls->tp_as_sequence->sq_contains = (objobjproc)string_contains;
str_cls->tp_as_sequence->sq_item = (ssizeargfunc)string_item; str_cls->tp_as_sequence->sq_item = (ssizeargfunc)string_item;
str_cls->tp_as_sequence->sq_length = str_length;
str_cls->tp_as_sequence->sq_repeat = (ssizeargfunc)string_repeat;
str_cls->tp_as_sequence->sq_slice = str_slice; str_cls->tp_as_sequence->sq_slice = str_slice;
str_cls->tp_as_sequence->sq_contains = (objobjproc)string_contains;
str_cls->tp_new = (newfunc)strNewPacked; str_cls->tp_new = (newfunc)strNewPacked;
basestring_cls->giveAttr("__doc__", basestring_cls->giveAttr("__doc__",
......
...@@ -89,7 +89,7 @@ except: ...@@ -89,7 +89,7 @@ except:
try: try:
test_helper.run_test(['sh', '-c', '. %s/bin/activate && python %s/numpy/tools/test-installed-numpy.py' % (ENV_DIR, ENV_DIR)], test_helper.run_test(['sh', '-c', '. %s/bin/activate && python %s/numpy/tools/test-installed-numpy.py' % (ENV_DIR, ENV_DIR)],
ENV_NAME, [dict(ran=5781, errors=2, failures=1)]) ENV_NAME, [dict(ran=5781, errors=1, failures=1)])
finally: finally:
if USE_CUSTOM_PATCHES: if USE_CUSTOM_PATCHES:
print_progress_header("Unpatching NumPy...") print_progress_header("Unpatching NumPy...")
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment