Commit 51a0a261 authored by Boxiang Sun's avatar Boxiang Sun

rewrite int pow, add error check and calculate pow with mod in single

parent ea949d3d
...@@ -334,33 +334,17 @@ extern "C" i64 mod_i64_i64(i64 lhs, i64 rhs) { ...@@ -334,33 +334,17 @@ extern "C" i64 mod_i64_i64(i64 lhs, i64 rhs) {
return lhs % rhs; return lhs % rhs;
} }
extern "C" Box* pow_i64_i64(i64 lhs, i64 rhs) { extern "C" Box* pow_i64_i64(i64 lhs, i64 rhs, Box* mod) {
i64 orig_rhs = rhs; i64 orig_rhs = rhs;
i64 rtn = 1, curpow = lhs; i64 rtn = 1, curpow = lhs;
if (rhs < 0) if (rhs < 0)
// already checked, rhs is a integer,
// and mod will be None in this case.
return boxFloat(pow_float_float(lhs, rhs)); return boxFloat(pow_float_float(lhs, rhs));
if (rhs == 0) // let longPow do the checks.
return boxInt(1); return longPow(boxLong(lhs), boxLong(rhs), mod);
assert(rhs > 0);
while (true) {
if (rhs & 1) {
// TODO: could potentially avoid restarting the entire computation on overflow?
if (__builtin_smull_overflow(rtn, curpow, &rtn))
return longPow(boxLong(lhs), boxLong(orig_rhs));
}
rhs >>= 1;
if (!rhs)
break;
if (__builtin_smull_overflow(curpow, curpow, &curpow))
return longPow(boxLong(lhs), boxLong(orig_rhs));
}
return boxInt(rtn);
} }
extern "C" Box* mul_i64_i64(i64 lhs, i64 rhs) { extern "C" Box* mul_i64_i64(i64 lhs, i64 rhs) {
...@@ -663,36 +647,62 @@ extern "C" Box* intMul(BoxedInt* lhs, Box* rhs) { ...@@ -663,36 +647,62 @@ extern "C" Box* intMul(BoxedInt* lhs, Box* rhs) {
} }
} }
extern "C" Box* intPowInt(BoxedInt* lhs, BoxedInt* rhs) { static void _addFuncPow(const char* name, ConcreteCompilerType* rtn_type, void* float_func, void* int_func) {
std::vector<ConcreteCompilerType*> v_ifu{ BOXED_INT, BOXED_FLOAT, UNKNOWN };
std::vector<ConcreteCompilerType*> v_uuu{ UNKNOWN, UNKNOWN, UNKNOWN };
CLFunction* cl = createRTFunction(3, 1, false, false);
addRTFunction(cl, float_func, UNKNOWN, v_ifu);
addRTFunction(cl, int_func, UNKNOWN, v_uuu);
int_cls->giveAttr(name, new BoxedFunction(cl, { None }));
}
extern "C" Box* intPowLong(BoxedInt* lhs, BoxedLong* rhs, Box* mod) {
assert(isSubclass(lhs->cls, int_cls)); assert(isSubclass(lhs->cls, int_cls));
assert(isSubclass(rhs->cls, int_cls)); assert(isSubclass(rhs->cls, long_cls));
BoxedInt* rhs_int = static_cast<BoxedInt*>(rhs); BoxedLong* lhs_long = boxLong(lhs->n);
return pow_i64_i64(lhs->n, rhs_int->n); return longPow(lhs_long, rhs, mod);
} }
extern "C" Box* intPowFloat(BoxedInt* lhs, BoxedFloat* rhs) { extern "C" Box* intPowFloat(BoxedInt* lhs, BoxedFloat* rhs, Box* mod) {
assert(isSubclass(lhs->cls, int_cls)); assert(isSubclass(lhs->cls, int_cls));
assert(rhs->cls == float_cls); assert(rhs->cls == float_cls);
return boxFloat(pow(lhs->n, rhs->d));
if (mod != None) {
raiseExcHelper(TypeError, "pow() 3rd argument not allowed unless all arguments are integers");
}
return boxFloat(pow_float_float(lhs->n, rhs->d));
} }
extern "C" Box* intPow(BoxedInt* lhs, Box* rhs, Box* mod) { extern "C" Box* intPow(BoxedInt* lhs, Box* rhs, Box* mod) {
if (!isSubclass(lhs->cls, int_cls)) if (!isSubclass(lhs->cls, int_cls))
raiseExcHelper(TypeError, "descriptor '__pow__' requires a 'int' object but received a '%s'", getTypeName(lhs)); raiseExcHelper(TypeError, "descriptor '__pow__' requires a 'int' object but received a '%s'", getTypeName(lhs));
if (isSubclass(rhs->cls, int_cls)) { if (isSubclass(rhs->cls, long_cls))
return intPowLong(lhs, static_cast<BoxedLong*>(rhs), mod);
else if (isSubclass(rhs->cls, float_cls))
return intPowFloat(lhs, static_cast<BoxedFloat*>(rhs), mod);
else if (!isSubclass(rhs->cls, int_cls))
return NotImplemented;
BoxedInt* rhs_int = static_cast<BoxedInt*>(rhs); BoxedInt* rhs_int = static_cast<BoxedInt*>(rhs);
Box* rtn = intPowInt(lhs, rhs_int); BoxedInt* mod_int = static_cast<BoxedInt*>(mod);
if (mod == None)
return rtn; if (mod != None) {
return binop(rtn, mod, AST_TYPE::Mod); if (rhs_int->n < 0)
} else if (rhs->cls == float_cls) { raiseExcHelper(TypeError, "pow() 2nd argument "
RELEASE_ASSERT(mod == None, ""); "cannot be negative when 3rd argument specified");
BoxedFloat* rhs_float = static_cast<BoxedFloat*>(rhs); if (!isSubclass(mod->cls, int_cls)) {
return intPowFloat(lhs, rhs_float);
} else {
return NotImplemented; return NotImplemented;
} else if (mod_int->n == 0) {
raiseExcHelper(ValueError, "pow() 3rd argument cannot be 0");
}
} }
Box* rtn = pow_i64_i64(lhs->n, rhs_int->n, mod);
if (isSubclass(rtn->cls, long_cls))
return longInt(rtn);
return rtn;
} }
extern "C" Box* intRShiftInt(BoxedInt* lhs, BoxedInt* rhs) { extern "C" Box* intRShiftInt(BoxedInt* lhs, BoxedInt* rhs) {
...@@ -1096,9 +1106,7 @@ void setupInt() { ...@@ -1096,9 +1106,7 @@ void setupInt() {
_addFuncIntFloatUnknown("__truediv__", (void*)intTruedivInt, (void*)intTruedivFloat, (void*)intTruediv); _addFuncIntFloatUnknown("__truediv__", (void*)intTruedivInt, (void*)intTruedivFloat, (void*)intTruediv);
_addFuncIntFloatUnknown("__mul__", (void*)intMulInt, (void*)intMulFloat, (void*)intMul); _addFuncIntFloatUnknown("__mul__", (void*)intMulInt, (void*)intMulFloat, (void*)intMul);
_addFuncIntUnknown("__mod__", BOXED_INT, (void*)intModInt, (void*)intMod); _addFuncIntUnknown("__mod__", BOXED_INT, (void*)intModInt, (void*)intMod);
int_cls->giveAttr("__pow__", _addFuncPow("__pow__", BOXED_INT, (void*)intPowFloat, (void*)intPow);
new BoxedFunction(boxRTFunction((void*)intPow, UNKNOWN, 3, 1, false, false), { None }));
// Note: CPython implements int comparisons using tp_compare // Note: CPython implements int comparisons using tp_compare
int_cls->tp_richcompare = int_richcompare; int_cls->tp_richcompare = int_richcompare;
......
...@@ -32,7 +32,7 @@ extern "C" i64 mod_i64_i64(i64 lhs, i64 rhs); ...@@ -32,7 +32,7 @@ extern "C" i64 mod_i64_i64(i64 lhs, i64 rhs);
extern "C" Box* add_i64_i64(i64 lhs, i64 rhs); extern "C" Box* add_i64_i64(i64 lhs, i64 rhs);
extern "C" Box* sub_i64_i64(i64 lhs, i64 rhs); extern "C" Box* sub_i64_i64(i64 lhs, i64 rhs);
extern "C" Box* pow_i64_i64(i64 lhs, i64 rhs); extern "C" Box* pow_i64_i64(i64 lhs, i64 rhs, Box* mod = None);
extern "C" Box* mul_i64_i64(i64 lhs, i64 rhs); extern "C" Box* mul_i64_i64(i64 lhs, i64 rhs);
extern "C" i1 eq_i64_i64(i64 lhs, i64 rhs); extern "C" i1 eq_i64_i64(i64 lhs, i64 rhs);
extern "C" i1 ne_i64_i64(i64 lhs, i64 rhs); extern "C" i1 ne_i64_i64(i64 lhs, i64 rhs);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment