Commit 62ee7639 authored by Marius Wachtler's avatar Marius Wachtler

Merge pull request #1203 from undingen/use_pynumber

binop: use PyNumber_* for user defined classes
parents 9ffdeadd b22043a4
......@@ -1846,77 +1846,73 @@ extern "C" int PyNumber_Check(PyObject* obj) noexcept {
return obj->cls->tp_as_number && (obj->cls->tp_as_number->nb_int || obj->cls->tp_as_number->nb_float);
}
extern "C" PyObject* PyNumber_Add(PyObject* lhs, PyObject* rhs) noexcept {
try {
return binop(lhs, rhs, AST_TYPE::Add);
} catch (ExcInfo e) {
setCAPIException(e);
return nullptr;
extern "C" PyObject* PyNumber_Add(PyObject* v, PyObject* w) noexcept {
PyObject* result = binary_op1(v, w, NB_SLOT(nb_add));
if (result == Py_NotImplemented) {
PySequenceMethods* m = v->cls->tp_as_sequence;
Py_DECREF(result);
if (m && m->sq_concat) {
return (*m->sq_concat)(v, w);
}
result = binop_type_error(v, w, "+");
}
return result;
}
extern "C" PyObject* PyNumber_Subtract(PyObject* lhs, PyObject* rhs) noexcept {
try {
return binop(lhs, rhs, AST_TYPE::Sub);
} catch (ExcInfo e) {
setCAPIException(e);
return nullptr;
}
}
#define BINARY_FUNC(func, op, op_name) \
extern "C" PyObject* func(PyObject* v, PyObject* w) noexcept { return binary_op(v, w, NB_SLOT(op), op_name); }
extern "C" PyObject* PyNumber_Multiply(PyObject* lhs, PyObject* rhs) noexcept {
try {
return binop(lhs, rhs, AST_TYPE::Mult);
} catch (ExcInfo e) {
setCAPIException(e);
return nullptr;
}
}
BINARY_FUNC(PyNumber_Or, nb_or, "|")
BINARY_FUNC(PyNumber_Xor, nb_xor, "^")
BINARY_FUNC(PyNumber_And, nb_and, "&")
BINARY_FUNC(PyNumber_Lshift, nb_lshift, "<<")
BINARY_FUNC(PyNumber_Rshift, nb_rshift, ">>")
BINARY_FUNC(PyNumber_Subtract, nb_subtract, "-")
BINARY_FUNC(PyNumber_Divide, nb_divide, "/")
BINARY_FUNC(PyNumber_Divmod, nb_divmod, "divmod()")
extern "C" PyObject* PyNumber_Divide(PyObject* lhs, PyObject* rhs) noexcept {
try {
return binop(lhs, rhs, AST_TYPE::Div);
} catch (ExcInfo e) {
setCAPIException(e);
return nullptr;
static PyObject* sequence_repeat(ssizeargfunc repeatfunc, PyObject* seq, PyObject* n) noexcept {
Py_ssize_t count;
if (PyIndex_Check(n)) {
count = PyNumber_AsSsize_t(n, PyExc_OverflowError);
if (count == -1 && PyErr_Occurred())
return NULL;
} else {
return type_error("can't multiply sequence by "
"non-int of type '%.200s'",
n);
}
return (*repeatfunc)(seq, count);
}
extern "C" PyObject* PyNumber_FloorDivide(PyObject* lhs, PyObject* rhs) noexcept {
try {
return binop(lhs, rhs, AST_TYPE::FloorDiv);
} catch (ExcInfo e) {
setCAPIException(e);
return nullptr;
extern "C" PyObject* PyNumber_Multiply(PyObject* v, PyObject* w) noexcept {
PyObject* result = binary_op1(v, w, NB_SLOT(nb_multiply));
if (result == Py_NotImplemented) {
PySequenceMethods* mv = v->cls->tp_as_sequence;
PySequenceMethods* mw = w->cls->tp_as_sequence;
Py_DECREF(result);
if (mv && mv->sq_repeat) {
return sequence_repeat(mv->sq_repeat, v, w);
} else if (mw && mw->sq_repeat) {
return sequence_repeat(mw->sq_repeat, w, v);
}
result = binop_type_error(v, w, "*");
}
return result;
}
extern "C" PyObject* PyNumber_TrueDivide(PyObject* lhs, PyObject* rhs) noexcept {
try {
return binop(lhs, rhs, AST_TYPE::TrueDiv);
} catch (ExcInfo e) {
setCAPIException(e);
return nullptr;
}
extern "C" PyObject* PyNumber_FloorDivide(PyObject* v, PyObject* w) noexcept {
/* XXX tp_flags test */
return binary_op(v, w, NB_SLOT(nb_floor_divide), "//");
}
extern "C" PyObject* PyNumber_Remainder(PyObject* lhs, PyObject* rhs) noexcept {
try {
return binop(lhs, rhs, AST_TYPE::Mod);
} catch (ExcInfo e) {
setCAPIException(e);
return nullptr;
}
extern "C" PyObject* PyNumber_TrueDivide(PyObject* v, PyObject* w) noexcept {
/* XXX tp_flags test */
return binary_op(v, w, NB_SLOT(nb_true_divide), "/");
}
extern "C" PyObject* PyNumber_Divmod(PyObject* lhs, PyObject* rhs) noexcept {
try {
return binop(lhs, rhs, AST_TYPE::DivMod);
} catch (ExcInfo e) {
e.clear();
fatalOrError(PyExc_NotImplementedError, "unimplemented");
return nullptr;
}
extern "C" PyObject* PyNumber_Remainder(PyObject* v, PyObject* w) noexcept {
return binary_op(v, w, NB_SLOT(nb_remainder), "%");
}
extern "C" PyObject* PyNumber_Power(PyObject* v, PyObject* w, PyObject* z) noexcept {
......@@ -1960,57 +1956,15 @@ extern "C" PyObject* PyNumber_Absolute(PyObject* o) noexcept {
}
extern "C" PyObject* PyNumber_Invert(PyObject* o) noexcept {
try {
return unaryop(o, AST_TYPE::Invert);
} catch (ExcInfo e) {
setCAPIException(e);
return nullptr;
}
}
extern "C" PyObject* PyNumber_Lshift(PyObject* lhs, PyObject* rhs) noexcept {
try {
return binop(lhs, rhs, AST_TYPE::LShift);
} catch (ExcInfo e) {
setCAPIException(e);
return nullptr;
}
}
extern "C" PyObject* PyNumber_Rshift(PyObject* lhs, PyObject* rhs) noexcept {
try {
return binop(lhs, rhs, AST_TYPE::RShift);
} catch (ExcInfo e) {
setCAPIException(e);
return nullptr;
}
}
extern "C" PyObject* PyNumber_And(PyObject* lhs, PyObject* rhs) noexcept {
try {
return binop(lhs, rhs, AST_TYPE::BitAnd);
} catch (ExcInfo e) {
setCAPIException(e);
return nullptr;
}
}
PyNumberMethods* m;
extern "C" PyObject* PyNumber_Xor(PyObject* lhs, PyObject* rhs) noexcept {
try {
return binop(lhs, rhs, AST_TYPE::BitXor);
} catch (ExcInfo e) {
setCAPIException(e);
return nullptr;
}
}
if (o == NULL)
return null_error();
m = o->cls->tp_as_number;
if (m && m->nb_invert)
return (*m->nb_invert)(o);
extern "C" PyObject* PyNumber_Or(PyObject* lhs, PyObject* rhs) noexcept {
try {
return binop(lhs, rhs, AST_TYPE::BitOr);
} catch (ExcInfo e) {
setCAPIException(e);
return nullptr;
}
return type_error("bad operand type for unary ~: '%.200s'", o);
}
extern "C" PyObject* PyNumber_InPlaceAdd(PyObject* v, PyObject* w) noexcept {
......@@ -2036,20 +1990,6 @@ extern "C" PyObject* PyNumber_InPlaceSubtract(PyObject* v, PyObject* w) noexcept
return binary_iop(v, w, NB_SLOT(nb_inplace_subtract), NB_SLOT(nb_subtract), "-=");
}
static PyObject* sequence_repeat(ssizeargfunc repeatfunc, PyObject* seq, PyObject* n) {
Py_ssize_t count;
if (PyIndex_Check(n)) {
count = PyNumber_AsSsize_t(n, PyExc_OverflowError);
if (count == -1 && PyErr_Occurred())
return NULL;
} else {
return type_error("can't multiply sequence by "
"non-int of type '%.200s'",
n);
}
return (*repeatfunc)(seq, count);
}
extern "C" PyObject* PyNumber_InPlaceMultiply(PyObject* v, PyObject* w) noexcept {
PyObject* result = binary_iop1(v, w, NB_SLOT(nb_inplace_multiply), NB_SLOT(nb_multiply));
if (result == Py_NotImplemented) {
......
......@@ -1404,7 +1404,7 @@ Box* ellipsisRepr(Box* self) {
return boxString("Ellipsis");
}
Box* divmod(Box* lhs, Box* rhs) {
return binopInternal<NOT_REWRITABLE>(lhs, rhs, AST_TYPE::DivMod, false, NULL);
return binopInternal<NOT_REWRITABLE, false>(lhs, rhs, AST_TYPE::DivMod, NULL);
}
Box* powFunc(Box* x, Box* y, Box* z) {
......
......@@ -5464,13 +5464,84 @@ static Box* binopInternalHelper(BinopRewriteArgs*& rewrite_args, BoxedString* op
return rtn;
}
template <Rewritable rewritable>
Box* binopInternal(Box* lhs, Box* rhs, int op_type, bool inplace, BinopRewriteArgs* rewrite_args) {
template <Rewritable rewritable, bool inplace>
Box* binopInternal(Box* lhs, Box* rhs, int op_type, BinopRewriteArgs* rewrite_args) {
if (rewritable == NOT_REWRITABLE) {
assert(!rewrite_args);
rewrite_args = NULL;
}
// Currently can't patchpoint user-defined binops since we can't assume that just because
// resolving it one way right now (ex, using the value from lhs.__add__) means that later
// we'll resolve it the same way, even for the same argument types.
// TODO implement full resolving semantics inside the rewrite?
bool can_patchpoint = !lhs->cls->is_user_defined && !rhs->cls->is_user_defined;
if (!can_patchpoint) {
PyObject* (*func)(PyObject*, PyObject*) = NULL;
switch (op_type) {
case AST_TYPE::Add:
func = inplace ? PyNumber_InPlaceAdd : PyNumber_Add;
break;
case AST_TYPE::BitOr:
func = inplace ? PyNumber_InPlaceOr : PyNumber_Or;
break;
case AST_TYPE::BitXor:
func = inplace ? PyNumber_InPlaceXor : PyNumber_Xor;
break;
case AST_TYPE::BitAnd:
func = inplace ? PyNumber_InPlaceAnd : PyNumber_And;
break;
case AST_TYPE::LShift:
func = inplace ? PyNumber_InPlaceLshift : PyNumber_Lshift;
break;
case AST_TYPE::RShift:
func = inplace ? PyNumber_InPlaceRshift : PyNumber_Rshift;
break;
case AST_TYPE::Sub:
func = inplace ? PyNumber_InPlaceSubtract : PyNumber_Subtract;
break;
case AST_TYPE::Div:
func = inplace ? PyNumber_InPlaceDivide : PyNumber_Divide;
break;
case AST_TYPE::Mod:
func = inplace ? PyNumber_InPlaceRemainder : PyNumber_Remainder;
break;
case AST_TYPE::Mult:
func = inplace ? PyNumber_InPlaceMultiply : PyNumber_Multiply;
break;
case AST_TYPE::FloorDiv:
func = inplace ? PyNumber_InPlaceFloorDivide : PyNumber_FloorDivide;
break;
case AST_TYPE::TrueDiv:
func = inplace ? PyNumber_InPlaceTrueDivide : PyNumber_TrueDivide;
break;
case AST_TYPE::DivMod:
func = inplace ? NULL : PyNumber_Divmod;
break;
};
if (func) {
if (rewrite_args) {
rewrite_args->lhs->addAttrGuard(offsetof(Box, cls), (intptr_t)lhs->cls);
rewrite_args->rhs->addAttrGuard(offsetof(Box, cls), (intptr_t)rhs->cls);
RewriterVar* r_ret = rewrite_args->rewriter->call(true, (void*)func, rewrite_args->lhs,
rewrite_args->rhs)->setType(RefType::OWNED);
rewrite_args->rewriter->checkAndThrowCAPIException(r_ret);
rewrite_args->out_rtn = r_ret;
rewrite_args->out_success = true;
}
Box* rtn = func(lhs, rhs);
if (!rtn)
throwCAPIException();
return rtn;
}
}
if (!can_patchpoint)
rewrite_args = NULL;
RewriterVar* r_lhs = NULL;
RewriterVar* r_rhs = NULL;
if (rewrite_args) {
......@@ -5536,12 +5607,13 @@ Box* binopInternal(Box* lhs, Box* rhs, int op_type, bool inplace, BinopRewriteAr
raiseExcHelper(TypeError, "unsupported operand type(s) for %s%s: '%s' and '%s'", op_sym.data(), op_sym_suffix,
getTypeName(lhs), getTypeName(rhs));
}
template Box* binopInternal<REWRITABLE>(Box*, Box*, int, bool, BinopRewriteArgs*);
template Box* binopInternal<NOT_REWRITABLE>(Box*, Box*, int, bool, BinopRewriteArgs*);
template Box* binopInternal<REWRITABLE, true>(Box*, Box*, int, BinopRewriteArgs*);
template Box* binopInternal<REWRITABLE, false>(Box*, Box*, int, BinopRewriteArgs*);
template Box* binopInternal<NOT_REWRITABLE, true>(Box*, Box*, int, BinopRewriteArgs*);
template Box* binopInternal<NOT_REWRITABLE, false>(Box*, Box*, int, BinopRewriteArgs*);
extern "C" Box* binop(Box* lhs, Box* rhs, int op_type) {
STAT_TIMER(t0, "us_timer_slowpath_binop", 10);
bool can_patchpoint = !lhs->cls->is_user_defined && !rhs->cls->is_user_defined;
#if 0
static uint64_t* st_id = Stats::getStatCounter("us_timer_slowpath_binop_patchable");
static uint64_t* st_id_nopatch = Stats::getStatCounter("us_timer_slowpath_binop_nopatch");
......@@ -5556,14 +5628,8 @@ extern "C" Box* binop(Box* lhs, Box* rhs, int op_type) {
// int id = Stats::getStatId("slowpath_binop_" + *getTypeName(lhs) + op_name + *getTypeName(rhs));
// Stats::log(id);
std::unique_ptr<Rewriter> rewriter((Rewriter*)NULL);
// Currently can't patchpoint user-defined binops since we can't assume that just because
// resolving it one way right now (ex, using the value from lhs.__add__) means that later
// we'll resolve it the same way, even for the same argument types.
// TODO implement full resolving semantics inside the rewrite?
if (can_patchpoint)
rewriter.reset(
Rewriter::createRewriter(__builtin_extract_return_addr(__builtin_return_address(0)), 3, "binop"));
std::unique_ptr<Rewriter> rewriter(
Rewriter::createRewriter(__builtin_extract_return_addr(__builtin_return_address(0)), 3, "binop"));
Box* rtn;
if (rewriter.get()) {
......@@ -5571,7 +5637,7 @@ extern "C" Box* binop(Box* lhs, Box* rhs, int op_type) {
BinopRewriteArgs rewrite_args(rewriter.get(), rewriter->getArg(0)->setType(RefType::BORROWED),
rewriter->getArg(1)->setType(RefType::BORROWED),
rewriter->getReturnDestination());
rtn = binopInternal<REWRITABLE>(lhs, rhs, op_type, false, &rewrite_args);
rtn = binopInternal<REWRITABLE, false /*not inplace*/>(lhs, rhs, op_type, &rewrite_args);
assert(rtn);
if (!rewrite_args.out_success) {
rewriter.reset(NULL);
......@@ -5579,7 +5645,7 @@ extern "C" Box* binop(Box* lhs, Box* rhs, int op_type) {
rewriter->commitReturning(rewrite_args.out_rtn);
}
} else {
rtn = binopInternal<NOT_REWRITABLE>(lhs, rhs, op_type, false, NULL);
rtn = binopInternal<NOT_REWRITABLE, false /*not inplace*/>(lhs, rhs, op_type, NULL);
}
return rtn;
......@@ -5595,28 +5661,21 @@ extern "C" Box* augbinop(Box* lhs, Box* rhs, int op_type) {
// int id = Stats::getStatId("slowpath_augbinop_" + *getTypeName(lhs) + op_name + *getTypeName(rhs));
// Stats::log(id);
std::unique_ptr<Rewriter> rewriter((Rewriter*)NULL);
// Currently can't patchpoint user-defined binops since we can't assume that just because
// resolving it one way right now (ex, using the value from lhs.__add__) means that later
// we'll resolve it the same way, even for the same argument types.
// TODO implement full resolving semantics inside the rewrite?
bool can_patchpoint = !lhs->cls->is_user_defined && !rhs->cls->is_user_defined;
if (can_patchpoint)
rewriter.reset(
Rewriter::createRewriter(__builtin_extract_return_addr(__builtin_return_address(0)), 3, "binop"));
std::unique_ptr<Rewriter> rewriter(
Rewriter::createRewriter(__builtin_extract_return_addr(__builtin_return_address(0)), 3, "binop"));
Box* rtn;
if (rewriter.get()) {
BinopRewriteArgs rewrite_args(rewriter.get(), rewriter->getArg(0), rewriter->getArg(1),
rewriter->getReturnDestination());
rtn = binopInternal<REWRITABLE>(lhs, rhs, op_type, true, &rewrite_args);
rtn = binopInternal<REWRITABLE, true /*inplace*/>(lhs, rhs, op_type, &rewrite_args);
if (!rewrite_args.out_success) {
rewriter.reset(NULL);
} else {
rewriter->commitReturning(rewrite_args.out_rtn);
}
} else {
rtn = binopInternal<NOT_REWRITABLE>(lhs, rhs, op_type, true, NULL);
rtn = binopInternal<NOT_REWRITABLE, true /*inplace*/>(lhs, rhs, op_type, NULL);
}
return rtn;
......
......@@ -115,8 +115,8 @@ template <Rewritable rewritable>
void setattrGeneric(Box* obj, BoxedString* attr, STOLEN(Box*) val, SetattrRewriteArgs* rewrite_args);
struct BinopRewriteArgs;
template <Rewritable rewritable>
Box* binopInternal(Box* lhs, Box* rhs, int op_type, bool inplace, BinopRewriteArgs* rewrite_args);
template <Rewritable rewritable, bool inplace>
Box* binopInternal(Box* lhs, Box* rhs, int op_type, BinopRewriteArgs* rewrite_args);
struct CallRewriteArgs;
template <ExceptionStyle S, Rewritable rewritable>
......
......@@ -33,7 +33,7 @@ def install_and_test_lxml():
subprocess.check_call([PYTHON_EXE, "setup.py", "build_ext", "-i", "--with-cython"], cwd=LXML_DIR)
expected = [{'ran': 1381, 'failures': 3, 'errors': 1}]
expected = [{'ran': 1381, 'failures': 3}]
run_test([PYTHON_EXE, "test.py"], cwd=LXML_DIR, expected=expected)
create_virtenv(ENV_NAME, None, force_create = True)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment