Commit 19a32e2b authored by Rich Prohaska's avatar Rich Prohaska Committed by Yoni Fogel

refs #5727 changes to upsert code to address code review comments

git-svn-id: file:///svn/mysql/tokudb-engine/tokudb-engine@51004 c7de825b-a66e-492c-adef-691d508d4ae1
parent 970fb5fb
......@@ -568,7 +568,7 @@ static int save_in_field(Item *item, TABLE *table) {
}
// Generate an update message for an update operation and send it into the primary tree. Return 0 if successful.
int ha_tokudb::send_update_message(List<Item> &fields, List<Item> &values, Item *conds, DB_TXN *txn) {
int ha_tokudb::send_update_message(List<Item> &update_fields, List<Item> &update_values, Item *conds, DB_TXN *txn) {
int error;
// Save the primary key from the where conditions
......@@ -597,15 +597,16 @@ int ha_tokudb::send_update_message(List<Item> &fields, List<Item> &values, Item
uchar operation = UPDATE_OP_SIMPLE_UPDATE;
update_message.append(&operation, sizeof operation);
uint32_t update_mode = 0;
update_message.append(&update_mode, sizeof update_mode);
// append the descriptor
marshall_simple_descriptor(update_message, table, share->kc_info, primary_key);
// append the updates
List_iterator<Item> lhs_i(fields);
List_iterator<Item> rhs_i(values);
uint32_t num_updates = update_fields.elements;
update_message.append(&num_updates, sizeof num_updates);
List_iterator<Item> lhs_i(update_fields);
List_iterator<Item> rhs_i(update_values);
while (error == 0) {
Item *lhs_item = lhs_i++;
if (lhs_item == NULL)
......@@ -718,8 +719,6 @@ int ha_tokudb::send_upsert_message(THD *thd, List<Item> &update_fields, List<Ite
// append the operation
uchar operation = UPDATE_OP_SIMPLE_UPSERT;
update_message.append(&operation, sizeof operation);
uint32_t update_mode = 0;
update_message.append(&update_mode, sizeof update_mode);
// append the row
uint32_t row_length = row.size;
......@@ -730,6 +729,9 @@ int ha_tokudb::send_upsert_message(THD *thd, List<Item> &update_fields, List<Ite
marshall_simple_descriptor(update_message, table, share->kc_info, primary_key);
// append the update expressions
uint32_t num_updates = update_fields.elements;
update_message.append(&num_updates, sizeof num_updates);
List_iterator<Item> lhs_i(update_fields);
List_iterator<Item> rhs_i(update_values);
while (1) {
......
......@@ -94,17 +94,17 @@ enum {
// pad char 1
// Simple row descriptor:
// fixed field offset 4
// var field offset 4
// var_offset_bytes 1
// num_var_fields 4
// fixed field offset 4 offset of the beginning of the fixed fields
// var field offset 4 offset of the variable length offsets
// var_offset_bytes 1 size of each variable length offset
// num_var_fields 4 number of variable length offsets
// Field descriptor:
// field type 4
// field num 4
// field null num 4
// field offset 4
// field length 4
// field type 4 see field types above
// field num 4 unused for fixed length fields
// field null num 4 bit 31 is 1 if the field is nullible and the remaining bits contain the null bit number
// field offset 4 for fixed fields, this is the offset from begining of the row of the field
// field length 4 for fixed fields, this is the length of the field
// Simple update operation:
// update operation 4 == { '=', '+', '-' }
......@@ -118,18 +118,18 @@ enum {
// Simple update message:
// Operation 1 == UPDATE_OP_UPDATE_FIELD
// Update mode 4
// Simple row descriptor
// Simple update ops []
// Number of update ops 4 == N
// Uupdate ops [N]
// Simple upsert message:
// Operation 1 == UPDATE_OP_UPSERT
// Update mode 4
// Insert row:
// length 4 == N
// data N
// Simple row descriptor
// Simple update ops []
// Number of update ops 4 == N
// Update ops [N]
#include "tokudb_buffer.h"
#include "tokudb_math.h"
......@@ -952,12 +952,11 @@ static void set_fixed_field(uint32_t the_offset, uint32_t length, uint32_t field
}
// Update an int field: signed newval@offset = old_val@offset OP extra_val
static void int_op(uint32_t operation, uint32_t update_mode, uint32_t the_offset, uint32_t length, uint32_t field_num, uint32_t field_null_num,
static void int_op(uint32_t operation, uint32_t the_offset, uint32_t length, uint32_t field_num, uint32_t field_null_num,
tokudb::buffer &new_val, tokudb::buffer &old_val, void *extra_val) {
assert(the_offset + length <= new_val.size());
assert(the_offset + length <= old_val.size());
assert(length == 1 || length == 2 || length == 3 || length == 4 || length == 8);
assert(update_mode == 0);
uchar *old_val_ptr = (uchar *) old_val.data();
bool field_is_null = false;
......@@ -973,31 +972,27 @@ static void int_op(uint32_t operation, uint32_t update_mode, uint32_t the_offset
case '+':
if (!field_is_null) {
bool over;
v = tokudb::int_add(v, extra_v, 8*length, over);
v = tokudb::int_add(v, extra_v, 8*length, &over);
if (over) {
if (extra_v > 0)
v = tokudb::int_high_endpoint(8*length);
else
v = tokudb::int_low_endpoint(8*length);
over = false;
}
if (!over)
new_val.replace(the_offset, length, &v, length);
new_val.replace(the_offset, length, &v, length);
}
break;
case '-':
if (!field_is_null) {
bool over;
v = tokudb::int_sub(v, extra_v, 8*length, over);
v = tokudb::int_sub(v, extra_v, 8*length, &over);
if (over) {
if (extra_v > 0)
v = tokudb::int_low_endpoint(8*length);
else
v = tokudb::int_high_endpoint(8*length);
over = false;
}
if (!over)
new_val.replace(the_offset, length, &v, length);
new_val.replace(the_offset, length, &v, length);
}
break;
default:
......@@ -1006,12 +1001,11 @@ static void int_op(uint32_t operation, uint32_t update_mode, uint32_t the_offset
}
// Update an unsigned field: unsigned newval@offset = old_val@offset OP extra_val
static void uint_op(uint32_t operation, uint32_t update_mode, uint32_t the_offset, uint32_t length, uint32_t field_num, uint32_t field_null_num,
static void uint_op(uint32_t operation, uint32_t the_offset, uint32_t length, uint32_t field_num, uint32_t field_null_num,
tokudb::buffer &new_val, tokudb::buffer &old_val, void *extra_val) {
assert(the_offset + length <= new_val.size());
assert(the_offset + length <= old_val.size());
assert(length == 1 || length == 2 || length == 3 || length == 4 || length == 8);
assert(update_mode == 0);
uchar *old_val_ptr = (uchar *) old_val.data();
bool field_is_null = false;
......@@ -1025,25 +1019,21 @@ static void uint_op(uint32_t operation, uint32_t update_mode, uint32_t the_offse
case '+':
if (!field_is_null) {
bool over;
v = tokudb::uint_add(v, extra_v, 8*length, over);
v = tokudb::uint_add(v, extra_v, 8*length, &over);
if (over) {
v = tokudb::uint_high_endpoint(8*length);
over = false;
}
if (!over)
new_val.replace(the_offset, length, &v, length);
new_val.replace(the_offset, length, &v, length);
}
break;
case '-':
if (!field_is_null) {
bool over;
v = tokudb::uint_sub(v, extra_v, 8*length, over);
v = tokudb::uint_sub(v, extra_v, 8*length, &over);
if (over) {
v = tokudb::uint_low_endpoint(8*length);
over = false;
}
if (!over)
new_val.replace(the_offset, length, &v, length);
new_val.replace(the_offset, length, &v, length);
}
break;
default:
......@@ -1053,9 +1043,10 @@ static void uint_op(uint32_t operation, uint32_t update_mode, uint32_t the_offse
// Decode and apply a sequence of update operations defined in the extra to the old value and put the result
// in the new value.
static void apply_updates(tokudb::buffer &new_val, tokudb::buffer &old_val, tokudb::buffer &extra_val,
uint32_t update_mode, Simple_row_descriptor &sd) {
while (extra_val.size() < extra_val.limit()) {
static void apply_updates(tokudb::buffer &new_val, tokudb::buffer &old_val, tokudb::buffer &extra_val, const Simple_row_descriptor &sd) {
uint32_t num_updates;
extra_val.consume(&num_updates, sizeof num_updates);
for ( ; num_updates > 0; num_updates--) {
// get the update operation
uint32_t update_operation;
extra_val.consume(&update_operation, sizeof update_operation);
......@@ -1077,13 +1068,13 @@ static void apply_updates(tokudb::buffer &new_val, tokudb::buffer &old_val, toku
if (update_operation == '=')
set_fixed_field(the_offset, length, field_num, field_null_num, new_val, extra_val_ptr);
else
int_op(update_operation, update_mode, the_offset, length, field_num, field_null_num, new_val, old_val, extra_val_ptr);
int_op(update_operation, the_offset, length, field_num, field_null_num, new_val, old_val, extra_val_ptr);
break;
case UPDATE_TYPE_UINT:
if (update_operation == '=')
set_fixed_field(the_offset, length, field_num, field_null_num, new_val, extra_val_ptr);
else
uint_op(update_operation, update_mode, the_offset, length, field_num, field_null_num, new_val, old_val, extra_val_ptr);
uint_op(update_operation, the_offset, length, field_num, field_null_num, new_val, old_val, extra_val_ptr);
break;
case UPDATE_TYPE_CHAR:
case UPDATE_TYPE_BINARY:
......@@ -1097,6 +1088,7 @@ static void apply_updates(tokudb::buffer &new_val, tokudb::buffer &old_val, toku
break;
}
}
assert(extra_val.size() == extra_val.limit());
}
// Simple update handler. Decode the update message, apply the update operations to the old value, and set
......@@ -1116,9 +1108,6 @@ static int tokudb_simple_update_fun(
extra_val.consume(&operation, sizeof operation);
assert(operation == UPDATE_OP_SIMPLE_UPDATE);
uint32_t update_mode;
extra_val.consume(&update_mode, sizeof update_mode);
if (old_val_dbt != NULL) {
// get the simple descriptor
Simple_row_descriptor sd;
......@@ -1131,7 +1120,7 @@ static int tokudb_simple_update_fun(
new_val.append(old_val_dbt->data, old_val_dbt->size);
// apply updates to new val
apply_updates(new_val, old_val, extra_val, update_mode, sd);
apply_updates(new_val, old_val, extra_val, sd);
// set the new val
DBT new_val_dbt; memset(&new_val_dbt, 0, sizeof new_val_dbt);
......@@ -1160,9 +1149,6 @@ static int tokudb_simple_upsert_fun(
extra_val.consume(&operation, sizeof operation);
assert(operation == UPDATE_OP_SIMPLE_UPSERT);
uint32_t update_mode;
extra_val.consume(&update_mode, sizeof update_mode);
uint32_t insert_length;
extra_val.consume(&insert_length, sizeof insert_length);
void *insert_row = extra_val.consume_ptr(insert_length);
......@@ -1185,7 +1171,7 @@ static int tokudb_simple_upsert_fun(
new_val.append(old_val_dbt->data, old_val_dbt->size);
// apply updates to new val
apply_updates(new_val, old_val, extra_val, update_mode, sd);
apply_updates(new_val, old_val, extra_val, sd);
// set the new val
DBT new_val_dbt; memset(&new_val_dbt, 0, sizeof new_val_dbt);
......
......@@ -11,7 +11,7 @@ static void test(int length_bits) {
for (int64_t x = -max-1; x <= max; x++) {
for (int64_t y = -max-1; y <= max; y++) {
bool over;
int64_t n = int_add(x, y, length_bits, over);
int64_t n = int_add(x, y, length_bits, &over);
printf("%lld %lld %lld %u\n", x, y, n, over);
}
}
......
......@@ -22,13 +22,13 @@ static void test_uint8() {
uint64_t m;
for (uint64_t x = 0; x <= (1ULL<<8)-1; x++) {
for (uint64_t y = 0; y <= (1ULL<<8)-1; y++) {
n = uint_add(x, y, 8, over);
n = uint_add(x, y, 8, &over);
m = x + y;
if (m > (1ULL<<8)-1)
assert(over);
else
assert(!over && n == (m % 256));
n = uint_sub(x, y, 8, over);
n = uint_sub(x, y, 8, &over);
m = x - y;
if (m > x)
assert(over);
......@@ -46,13 +46,13 @@ static void test_uint16() {
uint64_t m;
for (uint64_t x = 0; x <= (1ULL<<16)-1; x++) {
for (uint64_t y = 0; y <= (1ULL<<16)-1; y++) {
n = uint_add(x, y, 16, over);
n = uint_add(x, y, 16, &over);
m = x + y;
if (m > (1ULL<<16)-1)
assert(over);
else
assert(!over && n == (m % (1ULL<<16)));
n = uint_sub(x, y, 16, over);
n = uint_sub(x, y, 16, &over);
m = x - y;
if (m > x)
assert(over);
......@@ -68,15 +68,15 @@ static void test_uint24() {
bool over;
uint64_t s;
s = uint_add((1ULL<<24)-1, (1ULL<<24)-1, 24, over); assert(over);
s = uint_add((1ULL<<24)-1, 1, 24, over); assert(over);
s = uint_add((1ULL<<24)-1, 0, 24, over); assert(!over && s == (1ULL<<24)-1);
s = uint_add(0, 1, 24, over); assert(!over && s == 1);
s = uint_add(0, 0, 24, over); assert(!over && s == 0);
s = uint_sub(0, 0, 24, over); assert(!over && s == 0);
s = uint_sub(0, 1, 24, over); assert(over);
s = uint_sub(0, (1ULL<<24)-1, 24, over); assert(over);
s = uint_sub((1ULL<<24)-1, (1ULL<<24)-1, 24, over); assert(!over && s == 0);
s = uint_add((1ULL<<24)-1, (1ULL<<24)-1, 24, &over); assert(over);
s = uint_add((1ULL<<24)-1, 1, 24, &over); assert(over);
s = uint_add((1ULL<<24)-1, 0, 24, &over); assert(!over && s == (1ULL<<24)-1);
s = uint_add(0, 1, 24, &over); assert(!over && s == 1);
s = uint_add(0, 0, 24, &over); assert(!over && s == 0);
s = uint_sub(0, 0, 24, &over); assert(!over && s == 0);
s = uint_sub(0, 1, 24, &over); assert(over);
s = uint_sub(0, (1ULL<<24)-1, 24, &over); assert(over);
s = uint_sub((1ULL<<24)-1, (1ULL<<24)-1, 24, &over); assert(!over && s == 0);
}
static void test_uint32() {
......@@ -85,15 +85,15 @@ static void test_uint32() {
bool over;
uint64_t s;
s = uint_add((1ULL<<32)-1, (1ULL<<32)-1, 32, over); assert(over);
s = uint_add((1ULL<<32)-1, 1, 32, over); assert(over);
s = uint_add((1ULL<<32)-1, 0, 32, over); assert(!over && s == (1ULL<<32)-1);
s = uint_add(0, 1, 32, over); assert(!over && s == 1);
s = uint_add(0, 0, 32, over); assert(!over && s == 0);
s = uint_sub(0, 0, 32, over); assert(!over && s == 0);
s = uint_sub(0, 1, 32, over); assert(over);
s = uint_sub(0, (1ULL<<32)-1, 32, over); assert(over);
s = uint_sub((1ULL<<32)-1, (1ULL<<32)-1, 32, over); assert(!over && s == 0);
s = uint_add((1ULL<<32)-1, (1ULL<<32)-1, 32, &over); assert(over);
s = uint_add((1ULL<<32)-1, 1, 32, &over); assert(over);
s = uint_add((1ULL<<32)-1, 0, 32, &over); assert(!over && s == (1ULL<<32)-1);
s = uint_add(0, 1, 32, &over); assert(!over && s == 1);
s = uint_add(0, 0, 32, &over); assert(!over && s == 0);
s = uint_sub(0, 0, 32, &over); assert(!over && s == 0);
s = uint_sub(0, 1, 32, &over); assert(over);
s = uint_sub(0, (1ULL<<32)-1, 32, &over); assert(over);
s = uint_sub((1ULL<<32)-1, (1ULL<<32)-1, 32, &over); assert(!over && s == 0);
}
static void test_uint64() {
......@@ -102,15 +102,15 @@ static void test_uint64() {
bool over;
uint64_t s;
s = uint_add(~0ULL, ~0ULL, 64, over); assert(over);
s = uint_add(~0ULL, 1, 64, over); assert(over);
s = uint_add(~0ULL, 0, 64, over); assert(!over && s == ~0ULL);
s = uint_add(0, 1, 64, over); assert(!over && s == 1);
s = uint_add(0, 0, 64, over); assert(!over && s == 0);
s = uint_sub(0, 0, 64, over); assert(!over && s == 0);
s = uint_sub(0, 1, 64, over); assert(over);
s = uint_sub(0, ~0ULL, 64, over); assert(over);
s = uint_sub(~0ULL, ~0ULL, 64, over); assert(!over && s == 0);
s = uint_add(~0ULL, ~0ULL, 64, &over); assert(over);
s = uint_add(~0ULL, 1, 64, &over); assert(over);
s = uint_add(~0ULL, 0, 64, &over); assert(!over && s == ~0ULL);
s = uint_add(0, 1, 64, &over); assert(!over && s == 1);
s = uint_add(0, 0, 64, &over); assert(!over && s == 0);
s = uint_sub(0, 0, 64, &over); assert(!over && s == 0);
s = uint_sub(0, 1, 64, &over); assert(over);
s = uint_sub(0, ~0ULL, 64, &over); assert(over);
s = uint_sub(~0ULL, ~0ULL, 64, &over); assert(!over && s == 0);
}
static int64_t sign_extend(uint length_bits, int64_t n) {
......@@ -130,7 +130,7 @@ static void test_int8() {
for (int64_t y = -max; y <= max-1; y++) {
bool over;
int64_t n, m;
n = int_add(x, y, 8, over);
n = int_add(x, y, 8, &over);
m = x + y;
if (m > max-1)
assert(over);
......@@ -138,7 +138,7 @@ static void test_int8() {
assert(over);
else
assert(!over && n == m);
n = int_sub(x, y, 8, over);
n = int_sub(x, y, 8, &over);
m = x - y;
if (m > max-1)
assert(over);
......@@ -158,7 +158,7 @@ static void test_int16() {
for (int64_t y = -max; y <= max-1; y++) {
bool over;
int64_t n, m;
n = int_add(x, y, 16, over);
n = int_add(x, y, 16, &over);
m = x + y;
if (m > max-1)
assert(over);
......@@ -166,7 +166,7 @@ static void test_int16() {
assert(over);
else
assert(!over && n == m);
n = int_sub(x, y, 16, over);
n = int_sub(x, y, 16, &over);
m = x - y;
if (m > max-1)
assert(over);
......@@ -184,22 +184,22 @@ static void test_int24() {
int64_t s;
bool over;
s = int_add(1, (1ULL<<23)-1, 24, over); assert(over);
s = int_add((1ULL<<23)-1, 1, 24, over); assert(over);
s = int_sub(-1, (1ULL<<23), 24, over); assert(!over && s == (1ULL<<23)-1);
s = int_sub((1ULL<<23), 1, 24, over); assert(over);
s = int_add(1, (1ULL<<23)-1, 24, &over); assert(over);
s = int_add((1ULL<<23)-1, 1, 24, &over); assert(over);
s = int_sub(-1, (1ULL<<23), 24, &over); assert(!over && s == (1ULL<<23)-1);
s = int_sub((1ULL<<23), 1, 24, &over); assert(over);
s = int_add(0, 0, 24, over); assert(!over && s == 0);
s = int_sub(0, 0, 24, over); assert(!over && s == 0);
s = int_add(0, -1, 24, over); assert(!over && s == -1);
s = int_sub(0, 1, 24, over); assert(!over && s == -1);
s = int_add(0, (1ULL<<23), 24, over); assert(!over && (s & (1ULL<<24)-1) == (1ULL<<23));
s = int_sub(0, (1ULL<<23)-1, 24, over); assert(!over && (s & (1ULL<<24)-1) == (1ULL<<23)+1);
s = int_add(0, 0, 24, &over); assert(!over && s == 0);
s = int_sub(0, 0, 24, &over); assert(!over && s == 0);
s = int_add(0, -1, 24, &over); assert(!over && s == -1);
s = int_sub(0, 1, 24, &over); assert(!over && s == -1);
s = int_add(0, (1ULL<<23), 24, &over); assert(!over && (s & (1ULL<<24)-1) == (1ULL<<23));
s = int_sub(0, (1ULL<<23)-1, 24, &over); assert(!over && (s & (1ULL<<24)-1) == (1ULL<<23)+1);
s = int_add(-1, 0, 24, over); assert(!over && s == -1);
s = int_add(-1, 1, 24, over); assert(!over && s == 0);
s = int_sub(-1, -1, 24, over); assert(!over && s == 0);
s = int_sub(-1, (1ULL<<23)-1, 24, over); assert(!over && (s & (1ULL<<24)-1) == (1ULL<<23));
s = int_add(-1, 0, 24, &over); assert(!over && s == -1);
s = int_add(-1, 1, 24, &over); assert(!over && s == 0);
s = int_sub(-1, -1, 24, &over); assert(!over && s == 0);
s = int_sub(-1, (1ULL<<23)-1, 24, &over); assert(!over && (s & (1ULL<<24)-1) == (1ULL<<23));
}
static void test_int32() {
......@@ -208,22 +208,22 @@ static void test_int32() {
int64_t s;
bool over;
s = int_add(1, (1ULL<<31)-1, 32, over); assert(over);
s = int_add((1ULL<<31)-1, 1, 32, over); assert(over);
s = int_sub(-1, (1ULL<<31), 32, over); assert(s == (1ULL<<31)-1 && !over);
s = int_sub((1ULL<<31), 1, 32, over); assert(over);
s = int_add(1, (1ULL<<31)-1, 32, &over); assert(over);
s = int_add((1ULL<<31)-1, 1, 32, &over); assert(over);
s = int_sub(-1, (1ULL<<31), 32, &over); assert(s == (1ULL<<31)-1 && !over);
s = int_sub((1ULL<<31), 1, 32, &over); assert(over);
s = int_add(0, 0, 32, over); assert(s == 0 && !over);
s = int_sub(0, 0, 32, over); assert(s == 0 && !over);
s = int_add(0, -1, 32, over); assert(s == -1 && !over);
s = int_sub(0, 1, 32, over); assert(s == -1 && !over);
s = int_add(0, (1ULL<<31), 32, over); assert((s & (1ULL<<32)-1) == (1ULL<<31) && !over);
s = int_sub(0, (1ULL<<31)-1, 32, over); assert((s & (1ULL<<32)-1) == (1ULL<<31)+1 && !over);
s = int_add(0, 0, 32, &over); assert(s == 0 && !over);
s = int_sub(0, 0, 32, &over); assert(s == 0 && !over);
s = int_add(0, -1, 32, &over); assert(s == -1 && !over);
s = int_sub(0, 1, 32, &over); assert(s == -1 && !over);
s = int_add(0, (1ULL<<31), 32, &over); assert((s & (1ULL<<32)-1) == (1ULL<<31) && !over);
s = int_sub(0, (1ULL<<31)-1, 32, &over); assert((s & (1ULL<<32)-1) == (1ULL<<31)+1 && !over);
s = int_add(-1, 0, 32, over); assert(s == -1 && !over);
s = int_add(-1, 1, 32, over); assert(s == 0 && !over);
s = int_sub(-1, -1, 32, over); assert(s == 0 && !over);
s = int_sub(-1, (1ULL<<31)-1, 32, over); assert((s & (1ULL<<32)-1) == (1ULL<<31) && !over);
s = int_add(-1, 0, 32, &over); assert(s == -1 && !over);
s = int_add(-1, 1, 32, &over); assert(s == 0 && !over);
s = int_sub(-1, -1, 32, &over); assert(s == 0 && !over);
s = int_sub(-1, (1ULL<<31)-1, 32, &over); assert((s & (1ULL<<32)-1) == (1ULL<<31) && !over);
}
static void test_int64() {
......@@ -232,22 +232,22 @@ static void test_int64() {
int64_t s;
bool over;
s = int_add(1, (1ULL<<63)-1, 64, over); assert(over);
s = int_add((1ULL<<63)-1, 1, 64, over); assert(over);
s = int_sub(-1, (1ULL<<63), 64, over); assert(s == (1ULL<<63)-1 && !over);
s = int_sub((1ULL<<63), 1, 64, over); assert(over);
s = int_add(1, (1ULL<<63)-1, 64, &over); assert(over);
s = int_add((1ULL<<63)-1, 1, 64, &over); assert(over);
s = int_sub(-1, (1ULL<<63), 64, &over); assert(s == (1ULL<<63)-1 && !over);
s = int_sub((1ULL<<63), 1, 64, &over); assert(over);
s = int_add(0, 0, 64, over); assert(s == 0 && !over);
s = int_sub(0, 0, 64, over); assert(s == 0 && !over);
s = int_add(0, -1, 64, over); assert(s == -1 && !over);
s = int_sub(0, 1, 64, over); assert(s == -1 && !over);
s = int_add(0, (1ULL<<63), 64, over); assert(s == (1ULL<<63) && !over);
s = int_sub(0, (1ULL<<63)-1, 64, over); assert(s == (1ULL<<63)+1 && !over);
s = int_add(0, 0, 64, &over); assert(s == 0 && !over);
s = int_sub(0, 0, 64, &over); assert(s == 0 && !over);
s = int_add(0, -1, 64, &over); assert(s == -1 && !over);
s = int_sub(0, 1, 64, &over); assert(s == -1 && !over);
s = int_add(0, (1ULL<<63), 64, &over); assert(s == (1ULL<<63) && !over);
s = int_sub(0, (1ULL<<63)-1, 64, &over); assert(s == (1ULL<<63)+1 && !over);
s = int_add(-1, 0, 64, over); assert(s == -1 && !over);
s = int_add(-1, 1, 64, over); assert(s == 0 && !over);
s = int_sub(-1, -1, 64, over); assert(s == 0 && !over);
s = int_sub(-1, (1ULL<<63)-1, 64, over); assert(s == (1ULL<<63) && !over);
s = int_add(-1, 0, 64, &over); assert(s == -1 && !over);
s = int_add(-1, 1, 64, &over); assert(s == 0 && !over);
s = int_sub(-1, -1, 64, &over); assert(s == 0 && !over);
s = int_sub(-1, (1ULL<<63)-1, 64, &over); assert(s == (1ULL<<63) && !over);
}
static void test_int_sign(uint length_bits) {
......
......@@ -12,7 +12,7 @@ static void test(int length_bits) {
for (uint64_t x = 0; x <= max; x++) {
for (uint64_t y = 0; y <= max; y++) {
bool over;
uint64_t n = uint_add(x, y, max, over);
uint64_t n = uint_add(x, y, max, &over);
printf("%llu %llu %llu\n", x, y, n);
}
}
......
......@@ -27,24 +27,24 @@ static uint64_t uint_low_endpoint(uint length_bits) {
// Add two unsigned integers with max maximum value.
// If there is an overflow then set the sum to the max.
// Return the sum and the overflow.
static uint64_t uint_add(uint64_t x, uint64_t y, uint length_bits, bool &over) __attribute__((unused));
static uint64_t uint_add(uint64_t x, uint64_t y, uint length_bits, bool &over) {
static uint64_t uint_add(uint64_t x, uint64_t y, uint length_bits, bool *over) __attribute__((unused));
static uint64_t uint_add(uint64_t x, uint64_t y, uint length_bits, bool *over) {
uint64_t mask = uint_mask(length_bits);
assert((x & ~mask) == 0 && (y & ~mask) == 0);
uint64_t s = (x + y) & mask;
over = s < x; // check for overflow
*over = s < x; // check for overflow
return s;
}
// Subtract two unsigned ints with max maximum value.
// If there is an over then set the difference to 0.
// Return the difference and the overflow.
static uint64_t uint_sub(uint64_t x, uint64_t y, uint length_bits, bool &over) __attribute__((unused));
static uint64_t uint_sub(uint64_t x, uint64_t y, uint length_bits, bool &over) {
static uint64_t uint_sub(uint64_t x, uint64_t y, uint length_bits, bool *over) __attribute__((unused));
static uint64_t uint_sub(uint64_t x, uint64_t y, uint length_bits, bool *over) {
uint64_t mask = uint_mask(length_bits);
assert((x & ~mask) == 0 && (y & ~mask) == 0);
uint64_t s = (x - y) & mask;
over = s > x; // check for overflow
*over = s > x; // check for overflow
return s;
}
......@@ -74,11 +74,11 @@ static int64_t int_sign_extend(int64_t n, uint length_bits) {
// depending on the sign bit.
// Sign extend to 64 bits.
// Return the sum and the overflow.
static int64_t int_add(int64_t x, int64_t y, uint length_bits, bool &over) __attribute__((unused));
static int64_t int_add(int64_t x, int64_t y, uint length_bits, bool &over) {
static int64_t int_add(int64_t x, int64_t y, uint length_bits, bool *over) __attribute__((unused));
static int64_t int_add(int64_t x, int64_t y, uint length_bits, bool *over) {
int64_t mask = uint_mask(length_bits);
int64_t n = (x + y) & mask;
over = (((n ^ x) & (n ^ y)) >> (length_bits-1)) & 1; // check for overflow
*over = (((n ^ x) & (n ^ y)) >> (length_bits-1)) & 1; // check for overflow
if (n & (1LL<<(length_bits-1)))
n |= ~mask; // sign extend
return n;
......@@ -89,11 +89,11 @@ static int64_t int_add(int64_t x, int64_t y, uint length_bits, bool &over) {
// depending on the sign bit.
// Sign extend to 64 bits.
// Return the sum and the overflow.
static int64_t int_sub(int64_t x, int64_t y, uint length_bits, bool &over) __attribute__((unused));
static int64_t int_sub(int64_t x, int64_t y, uint length_bits, bool &over) {
static int64_t int_sub(int64_t x, int64_t y, uint length_bits, bool *over) __attribute__((unused));
static int64_t int_sub(int64_t x, int64_t y, uint length_bits, bool *over) {
int64_t mask = uint_mask(length_bits);
int64_t n = (x - y) & mask;
over = (((x ^ y) & (n ^ x)) >> (length_bits-1)) & 1; // check for overflow
*over = (((x ^ y) & (n ^ x)) >> (length_bits-1)) & 1; // check for overflow
if (n & (1LL<<(length_bits-1)))
n |= ~mask; // sign extend
return n;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment