Commit 19a32e2b authored by Rich Prohaska's avatar Rich Prohaska Committed by Yoni Fogel

refs #5727 changes to upsert code to address code review comments

git-svn-id: file:///svn/mysql/tokudb-engine/tokudb-engine@51004 c7de825b-a66e-492c-adef-691d508d4ae1
parent 970fb5fb
...@@ -568,7 +568,7 @@ static int save_in_field(Item *item, TABLE *table) { ...@@ -568,7 +568,7 @@ static int save_in_field(Item *item, TABLE *table) {
} }
// Generate an update message for an update operation and send it into the primary tree. Return 0 if successful. // Generate an update message for an update operation and send it into the primary tree. Return 0 if successful.
int ha_tokudb::send_update_message(List<Item> &fields, List<Item> &values, Item *conds, DB_TXN *txn) { int ha_tokudb::send_update_message(List<Item> &update_fields, List<Item> &update_values, Item *conds, DB_TXN *txn) {
int error; int error;
// Save the primary key from the where conditions // Save the primary key from the where conditions
...@@ -597,15 +597,16 @@ int ha_tokudb::send_update_message(List<Item> &fields, List<Item> &values, Item ...@@ -597,15 +597,16 @@ int ha_tokudb::send_update_message(List<Item> &fields, List<Item> &values, Item
uchar operation = UPDATE_OP_SIMPLE_UPDATE; uchar operation = UPDATE_OP_SIMPLE_UPDATE;
update_message.append(&operation, sizeof operation); update_message.append(&operation, sizeof operation);
uint32_t update_mode = 0;
update_message.append(&update_mode, sizeof update_mode);
// append the descriptor // append the descriptor
marshall_simple_descriptor(update_message, table, share->kc_info, primary_key); marshall_simple_descriptor(update_message, table, share->kc_info, primary_key);
// append the updates // append the updates
List_iterator<Item> lhs_i(fields); uint32_t num_updates = update_fields.elements;
List_iterator<Item> rhs_i(values); update_message.append(&num_updates, sizeof num_updates);
List_iterator<Item> lhs_i(update_fields);
List_iterator<Item> rhs_i(update_values);
while (error == 0) { while (error == 0) {
Item *lhs_item = lhs_i++; Item *lhs_item = lhs_i++;
if (lhs_item == NULL) if (lhs_item == NULL)
...@@ -718,8 +719,6 @@ int ha_tokudb::send_upsert_message(THD *thd, List<Item> &update_fields, List<Ite ...@@ -718,8 +719,6 @@ int ha_tokudb::send_upsert_message(THD *thd, List<Item> &update_fields, List<Ite
// append the operation // append the operation
uchar operation = UPDATE_OP_SIMPLE_UPSERT; uchar operation = UPDATE_OP_SIMPLE_UPSERT;
update_message.append(&operation, sizeof operation); update_message.append(&operation, sizeof operation);
uint32_t update_mode = 0;
update_message.append(&update_mode, sizeof update_mode);
// append the row // append the row
uint32_t row_length = row.size; uint32_t row_length = row.size;
...@@ -730,6 +729,9 @@ int ha_tokudb::send_upsert_message(THD *thd, List<Item> &update_fields, List<Ite ...@@ -730,6 +729,9 @@ int ha_tokudb::send_upsert_message(THD *thd, List<Item> &update_fields, List<Ite
marshall_simple_descriptor(update_message, table, share->kc_info, primary_key); marshall_simple_descriptor(update_message, table, share->kc_info, primary_key);
// append the update expressions // append the update expressions
uint32_t num_updates = update_fields.elements;
update_message.append(&num_updates, sizeof num_updates);
List_iterator<Item> lhs_i(update_fields); List_iterator<Item> lhs_i(update_fields);
List_iterator<Item> rhs_i(update_values); List_iterator<Item> rhs_i(update_values);
while (1) { while (1) {
......
...@@ -94,17 +94,17 @@ enum { ...@@ -94,17 +94,17 @@ enum {
// pad char 1 // pad char 1
// Simple row descriptor: // Simple row descriptor:
// fixed field offset 4 // fixed field offset 4 offset of the beginning of the fixed fields
// var field offset 4 // var field offset 4 offset of the variable length offsets
// var_offset_bytes 1 // var_offset_bytes 1 size of each variable length offset
// num_var_fields 4 // num_var_fields 4 number of variable length offsets
// Field descriptor: // Field descriptor:
// field type 4 // field type 4 see field types above
// field num 4 // field num 4 unused for fixed length fields
// field null num 4 // field null num 4 bit 31 is 1 if the field is nullible and the remaining bits contain the null bit number
// field offset 4 // field offset 4 for fixed fields, this is the offset from begining of the row of the field
// field length 4 // field length 4 for fixed fields, this is the length of the field
// Simple update operation: // Simple update operation:
// update operation 4 == { '=', '+', '-' } // update operation 4 == { '=', '+', '-' }
...@@ -118,18 +118,18 @@ enum { ...@@ -118,18 +118,18 @@ enum {
// Simple update message: // Simple update message:
// Operation 1 == UPDATE_OP_UPDATE_FIELD // Operation 1 == UPDATE_OP_UPDATE_FIELD
// Update mode 4
// Simple row descriptor // Simple row descriptor
// Simple update ops [] // Number of update ops 4 == N
// Uupdate ops [N]
// Simple upsert message: // Simple upsert message:
// Operation 1 == UPDATE_OP_UPSERT // Operation 1 == UPDATE_OP_UPSERT
// Update mode 4
// Insert row: // Insert row:
// length 4 == N // length 4 == N
// data N // data N
// Simple row descriptor // Simple row descriptor
// Simple update ops [] // Number of update ops 4 == N
// Update ops [N]
#include "tokudb_buffer.h" #include "tokudb_buffer.h"
#include "tokudb_math.h" #include "tokudb_math.h"
...@@ -952,12 +952,11 @@ static void set_fixed_field(uint32_t the_offset, uint32_t length, uint32_t field ...@@ -952,12 +952,11 @@ static void set_fixed_field(uint32_t the_offset, uint32_t length, uint32_t field
} }
// Update an int field: signed newval@offset = old_val@offset OP extra_val // Update an int field: signed newval@offset = old_val@offset OP extra_val
static void int_op(uint32_t operation, uint32_t update_mode, uint32_t the_offset, uint32_t length, uint32_t field_num, uint32_t field_null_num, static void int_op(uint32_t operation, uint32_t the_offset, uint32_t length, uint32_t field_num, uint32_t field_null_num,
tokudb::buffer &new_val, tokudb::buffer &old_val, void *extra_val) { tokudb::buffer &new_val, tokudb::buffer &old_val, void *extra_val) {
assert(the_offset + length <= new_val.size()); assert(the_offset + length <= new_val.size());
assert(the_offset + length <= old_val.size()); assert(the_offset + length <= old_val.size());
assert(length == 1 || length == 2 || length == 3 || length == 4 || length == 8); assert(length == 1 || length == 2 || length == 3 || length == 4 || length == 8);
assert(update_mode == 0);
uchar *old_val_ptr = (uchar *) old_val.data(); uchar *old_val_ptr = (uchar *) old_val.data();
bool field_is_null = false; bool field_is_null = false;
...@@ -973,31 +972,27 @@ static void int_op(uint32_t operation, uint32_t update_mode, uint32_t the_offset ...@@ -973,31 +972,27 @@ static void int_op(uint32_t operation, uint32_t update_mode, uint32_t the_offset
case '+': case '+':
if (!field_is_null) { if (!field_is_null) {
bool over; bool over;
v = tokudb::int_add(v, extra_v, 8*length, over); v = tokudb::int_add(v, extra_v, 8*length, &over);
if (over) { if (over) {
if (extra_v > 0) if (extra_v > 0)
v = tokudb::int_high_endpoint(8*length); v = tokudb::int_high_endpoint(8*length);
else else
v = tokudb::int_low_endpoint(8*length); v = tokudb::int_low_endpoint(8*length);
over = false;
} }
if (!over) new_val.replace(the_offset, length, &v, length);
new_val.replace(the_offset, length, &v, length);
} }
break; break;
case '-': case '-':
if (!field_is_null) { if (!field_is_null) {
bool over; bool over;
v = tokudb::int_sub(v, extra_v, 8*length, over); v = tokudb::int_sub(v, extra_v, 8*length, &over);
if (over) { if (over) {
if (extra_v > 0) if (extra_v > 0)
v = tokudb::int_low_endpoint(8*length); v = tokudb::int_low_endpoint(8*length);
else else
v = tokudb::int_high_endpoint(8*length); v = tokudb::int_high_endpoint(8*length);
over = false;
} }
if (!over) new_val.replace(the_offset, length, &v, length);
new_val.replace(the_offset, length, &v, length);
} }
break; break;
default: default:
...@@ -1006,12 +1001,11 @@ static void int_op(uint32_t operation, uint32_t update_mode, uint32_t the_offset ...@@ -1006,12 +1001,11 @@ static void int_op(uint32_t operation, uint32_t update_mode, uint32_t the_offset
} }
// Update an unsigned field: unsigned newval@offset = old_val@offset OP extra_val // Update an unsigned field: unsigned newval@offset = old_val@offset OP extra_val
static void uint_op(uint32_t operation, uint32_t update_mode, uint32_t the_offset, uint32_t length, uint32_t field_num, uint32_t field_null_num, static void uint_op(uint32_t operation, uint32_t the_offset, uint32_t length, uint32_t field_num, uint32_t field_null_num,
tokudb::buffer &new_val, tokudb::buffer &old_val, void *extra_val) { tokudb::buffer &new_val, tokudb::buffer &old_val, void *extra_val) {
assert(the_offset + length <= new_val.size()); assert(the_offset + length <= new_val.size());
assert(the_offset + length <= old_val.size()); assert(the_offset + length <= old_val.size());
assert(length == 1 || length == 2 || length == 3 || length == 4 || length == 8); assert(length == 1 || length == 2 || length == 3 || length == 4 || length == 8);
assert(update_mode == 0);
uchar *old_val_ptr = (uchar *) old_val.data(); uchar *old_val_ptr = (uchar *) old_val.data();
bool field_is_null = false; bool field_is_null = false;
...@@ -1025,25 +1019,21 @@ static void uint_op(uint32_t operation, uint32_t update_mode, uint32_t the_offse ...@@ -1025,25 +1019,21 @@ static void uint_op(uint32_t operation, uint32_t update_mode, uint32_t the_offse
case '+': case '+':
if (!field_is_null) { if (!field_is_null) {
bool over; bool over;
v = tokudb::uint_add(v, extra_v, 8*length, over); v = tokudb::uint_add(v, extra_v, 8*length, &over);
if (over) { if (over) {
v = tokudb::uint_high_endpoint(8*length); v = tokudb::uint_high_endpoint(8*length);
over = false;
} }
if (!over) new_val.replace(the_offset, length, &v, length);
new_val.replace(the_offset, length, &v, length);
} }
break; break;
case '-': case '-':
if (!field_is_null) { if (!field_is_null) {
bool over; bool over;
v = tokudb::uint_sub(v, extra_v, 8*length, over); v = tokudb::uint_sub(v, extra_v, 8*length, &over);
if (over) { if (over) {
v = tokudb::uint_low_endpoint(8*length); v = tokudb::uint_low_endpoint(8*length);
over = false;
} }
if (!over) new_val.replace(the_offset, length, &v, length);
new_val.replace(the_offset, length, &v, length);
} }
break; break;
default: default:
...@@ -1053,9 +1043,10 @@ static void uint_op(uint32_t operation, uint32_t update_mode, uint32_t the_offse ...@@ -1053,9 +1043,10 @@ static void uint_op(uint32_t operation, uint32_t update_mode, uint32_t the_offse
// Decode and apply a sequence of update operations defined in the extra to the old value and put the result // Decode and apply a sequence of update operations defined in the extra to the old value and put the result
// in the new value. // in the new value.
static void apply_updates(tokudb::buffer &new_val, tokudb::buffer &old_val, tokudb::buffer &extra_val, static void apply_updates(tokudb::buffer &new_val, tokudb::buffer &old_val, tokudb::buffer &extra_val, const Simple_row_descriptor &sd) {
uint32_t update_mode, Simple_row_descriptor &sd) { uint32_t num_updates;
while (extra_val.size() < extra_val.limit()) { extra_val.consume(&num_updates, sizeof num_updates);
for ( ; num_updates > 0; num_updates--) {
// get the update operation // get the update operation
uint32_t update_operation; uint32_t update_operation;
extra_val.consume(&update_operation, sizeof update_operation); extra_val.consume(&update_operation, sizeof update_operation);
...@@ -1077,13 +1068,13 @@ static void apply_updates(tokudb::buffer &new_val, tokudb::buffer &old_val, toku ...@@ -1077,13 +1068,13 @@ static void apply_updates(tokudb::buffer &new_val, tokudb::buffer &old_val, toku
if (update_operation == '=') if (update_operation == '=')
set_fixed_field(the_offset, length, field_num, field_null_num, new_val, extra_val_ptr); set_fixed_field(the_offset, length, field_num, field_null_num, new_val, extra_val_ptr);
else else
int_op(update_operation, update_mode, the_offset, length, field_num, field_null_num, new_val, old_val, extra_val_ptr); int_op(update_operation, the_offset, length, field_num, field_null_num, new_val, old_val, extra_val_ptr);
break; break;
case UPDATE_TYPE_UINT: case UPDATE_TYPE_UINT:
if (update_operation == '=') if (update_operation == '=')
set_fixed_field(the_offset, length, field_num, field_null_num, new_val, extra_val_ptr); set_fixed_field(the_offset, length, field_num, field_null_num, new_val, extra_val_ptr);
else else
uint_op(update_operation, update_mode, the_offset, length, field_num, field_null_num, new_val, old_val, extra_val_ptr); uint_op(update_operation, the_offset, length, field_num, field_null_num, new_val, old_val, extra_val_ptr);
break; break;
case UPDATE_TYPE_CHAR: case UPDATE_TYPE_CHAR:
case UPDATE_TYPE_BINARY: case UPDATE_TYPE_BINARY:
...@@ -1097,6 +1088,7 @@ static void apply_updates(tokudb::buffer &new_val, tokudb::buffer &old_val, toku ...@@ -1097,6 +1088,7 @@ static void apply_updates(tokudb::buffer &new_val, tokudb::buffer &old_val, toku
break; break;
} }
} }
assert(extra_val.size() == extra_val.limit());
} }
// Simple update handler. Decode the update message, apply the update operations to the old value, and set // Simple update handler. Decode the update message, apply the update operations to the old value, and set
...@@ -1116,9 +1108,6 @@ static int tokudb_simple_update_fun( ...@@ -1116,9 +1108,6 @@ static int tokudb_simple_update_fun(
extra_val.consume(&operation, sizeof operation); extra_val.consume(&operation, sizeof operation);
assert(operation == UPDATE_OP_SIMPLE_UPDATE); assert(operation == UPDATE_OP_SIMPLE_UPDATE);
uint32_t update_mode;
extra_val.consume(&update_mode, sizeof update_mode);
if (old_val_dbt != NULL) { if (old_val_dbt != NULL) {
// get the simple descriptor // get the simple descriptor
Simple_row_descriptor sd; Simple_row_descriptor sd;
...@@ -1131,7 +1120,7 @@ static int tokudb_simple_update_fun( ...@@ -1131,7 +1120,7 @@ static int tokudb_simple_update_fun(
new_val.append(old_val_dbt->data, old_val_dbt->size); new_val.append(old_val_dbt->data, old_val_dbt->size);
// apply updates to new val // apply updates to new val
apply_updates(new_val, old_val, extra_val, update_mode, sd); apply_updates(new_val, old_val, extra_val, sd);
// set the new val // set the new val
DBT new_val_dbt; memset(&new_val_dbt, 0, sizeof new_val_dbt); DBT new_val_dbt; memset(&new_val_dbt, 0, sizeof new_val_dbt);
...@@ -1160,9 +1149,6 @@ static int tokudb_simple_upsert_fun( ...@@ -1160,9 +1149,6 @@ static int tokudb_simple_upsert_fun(
extra_val.consume(&operation, sizeof operation); extra_val.consume(&operation, sizeof operation);
assert(operation == UPDATE_OP_SIMPLE_UPSERT); assert(operation == UPDATE_OP_SIMPLE_UPSERT);
uint32_t update_mode;
extra_val.consume(&update_mode, sizeof update_mode);
uint32_t insert_length; uint32_t insert_length;
extra_val.consume(&insert_length, sizeof insert_length); extra_val.consume(&insert_length, sizeof insert_length);
void *insert_row = extra_val.consume_ptr(insert_length); void *insert_row = extra_val.consume_ptr(insert_length);
...@@ -1185,7 +1171,7 @@ static int tokudb_simple_upsert_fun( ...@@ -1185,7 +1171,7 @@ static int tokudb_simple_upsert_fun(
new_val.append(old_val_dbt->data, old_val_dbt->size); new_val.append(old_val_dbt->data, old_val_dbt->size);
// apply updates to new val // apply updates to new val
apply_updates(new_val, old_val, extra_val, update_mode, sd); apply_updates(new_val, old_val, extra_val, sd);
// set the new val // set the new val
DBT new_val_dbt; memset(&new_val_dbt, 0, sizeof new_val_dbt); DBT new_val_dbt; memset(&new_val_dbt, 0, sizeof new_val_dbt);
......
...@@ -11,7 +11,7 @@ static void test(int length_bits) { ...@@ -11,7 +11,7 @@ static void test(int length_bits) {
for (int64_t x = -max-1; x <= max; x++) { for (int64_t x = -max-1; x <= max; x++) {
for (int64_t y = -max-1; y <= max; y++) { for (int64_t y = -max-1; y <= max; y++) {
bool over; bool over;
int64_t n = int_add(x, y, length_bits, over); int64_t n = int_add(x, y, length_bits, &over);
printf("%lld %lld %lld %u\n", x, y, n, over); printf("%lld %lld %lld %u\n", x, y, n, over);
} }
} }
......
...@@ -22,13 +22,13 @@ static void test_uint8() { ...@@ -22,13 +22,13 @@ static void test_uint8() {
uint64_t m; uint64_t m;
for (uint64_t x = 0; x <= (1ULL<<8)-1; x++) { for (uint64_t x = 0; x <= (1ULL<<8)-1; x++) {
for (uint64_t y = 0; y <= (1ULL<<8)-1; y++) { for (uint64_t y = 0; y <= (1ULL<<8)-1; y++) {
n = uint_add(x, y, 8, over); n = uint_add(x, y, 8, &over);
m = x + y; m = x + y;
if (m > (1ULL<<8)-1) if (m > (1ULL<<8)-1)
assert(over); assert(over);
else else
assert(!over && n == (m % 256)); assert(!over && n == (m % 256));
n = uint_sub(x, y, 8, over); n = uint_sub(x, y, 8, &over);
m = x - y; m = x - y;
if (m > x) if (m > x)
assert(over); assert(over);
...@@ -46,13 +46,13 @@ static void test_uint16() { ...@@ -46,13 +46,13 @@ static void test_uint16() {
uint64_t m; uint64_t m;
for (uint64_t x = 0; x <= (1ULL<<16)-1; x++) { for (uint64_t x = 0; x <= (1ULL<<16)-1; x++) {
for (uint64_t y = 0; y <= (1ULL<<16)-1; y++) { for (uint64_t y = 0; y <= (1ULL<<16)-1; y++) {
n = uint_add(x, y, 16, over); n = uint_add(x, y, 16, &over);
m = x + y; m = x + y;
if (m > (1ULL<<16)-1) if (m > (1ULL<<16)-1)
assert(over); assert(over);
else else
assert(!over && n == (m % (1ULL<<16))); assert(!over && n == (m % (1ULL<<16)));
n = uint_sub(x, y, 16, over); n = uint_sub(x, y, 16, &over);
m = x - y; m = x - y;
if (m > x) if (m > x)
assert(over); assert(over);
...@@ -68,15 +68,15 @@ static void test_uint24() { ...@@ -68,15 +68,15 @@ static void test_uint24() {
bool over; bool over;
uint64_t s; uint64_t s;
s = uint_add((1ULL<<24)-1, (1ULL<<24)-1, 24, over); assert(over); s = uint_add((1ULL<<24)-1, (1ULL<<24)-1, 24, &over); assert(over);
s = uint_add((1ULL<<24)-1, 1, 24, over); assert(over); s = uint_add((1ULL<<24)-1, 1, 24, &over); assert(over);
s = uint_add((1ULL<<24)-1, 0, 24, over); assert(!over && s == (1ULL<<24)-1); s = uint_add((1ULL<<24)-1, 0, 24, &over); assert(!over && s == (1ULL<<24)-1);
s = uint_add(0, 1, 24, over); assert(!over && s == 1); s = uint_add(0, 1, 24, &over); assert(!over && s == 1);
s = uint_add(0, 0, 24, over); assert(!over && s == 0); s = uint_add(0, 0, 24, &over); assert(!over && s == 0);
s = uint_sub(0, 0, 24, over); assert(!over && s == 0); s = uint_sub(0, 0, 24, &over); assert(!over && s == 0);
s = uint_sub(0, 1, 24, over); assert(over); s = uint_sub(0, 1, 24, &over); assert(over);
s = uint_sub(0, (1ULL<<24)-1, 24, over); assert(over); s = uint_sub(0, (1ULL<<24)-1, 24, &over); assert(over);
s = uint_sub((1ULL<<24)-1, (1ULL<<24)-1, 24, over); assert(!over && s == 0); s = uint_sub((1ULL<<24)-1, (1ULL<<24)-1, 24, &over); assert(!over && s == 0);
} }
static void test_uint32() { static void test_uint32() {
...@@ -85,15 +85,15 @@ static void test_uint32() { ...@@ -85,15 +85,15 @@ static void test_uint32() {
bool over; bool over;
uint64_t s; uint64_t s;
s = uint_add((1ULL<<32)-1, (1ULL<<32)-1, 32, over); assert(over); s = uint_add((1ULL<<32)-1, (1ULL<<32)-1, 32, &over); assert(over);
s = uint_add((1ULL<<32)-1, 1, 32, over); assert(over); s = uint_add((1ULL<<32)-1, 1, 32, &over); assert(over);
s = uint_add((1ULL<<32)-1, 0, 32, over); assert(!over && s == (1ULL<<32)-1); s = uint_add((1ULL<<32)-1, 0, 32, &over); assert(!over && s == (1ULL<<32)-1);
s = uint_add(0, 1, 32, over); assert(!over && s == 1); s = uint_add(0, 1, 32, &over); assert(!over && s == 1);
s = uint_add(0, 0, 32, over); assert(!over && s == 0); s = uint_add(0, 0, 32, &over); assert(!over && s == 0);
s = uint_sub(0, 0, 32, over); assert(!over && s == 0); s = uint_sub(0, 0, 32, &over); assert(!over && s == 0);
s = uint_sub(0, 1, 32, over); assert(over); s = uint_sub(0, 1, 32, &over); assert(over);
s = uint_sub(0, (1ULL<<32)-1, 32, over); assert(over); s = uint_sub(0, (1ULL<<32)-1, 32, &over); assert(over);
s = uint_sub((1ULL<<32)-1, (1ULL<<32)-1, 32, over); assert(!over && s == 0); s = uint_sub((1ULL<<32)-1, (1ULL<<32)-1, 32, &over); assert(!over && s == 0);
} }
static void test_uint64() { static void test_uint64() {
...@@ -102,15 +102,15 @@ static void test_uint64() { ...@@ -102,15 +102,15 @@ static void test_uint64() {
bool over; bool over;
uint64_t s; uint64_t s;
s = uint_add(~0ULL, ~0ULL, 64, over); assert(over); s = uint_add(~0ULL, ~0ULL, 64, &over); assert(over);
s = uint_add(~0ULL, 1, 64, over); assert(over); s = uint_add(~0ULL, 1, 64, &over); assert(over);
s = uint_add(~0ULL, 0, 64, over); assert(!over && s == ~0ULL); s = uint_add(~0ULL, 0, 64, &over); assert(!over && s == ~0ULL);
s = uint_add(0, 1, 64, over); assert(!over && s == 1); s = uint_add(0, 1, 64, &over); assert(!over && s == 1);
s = uint_add(0, 0, 64, over); assert(!over && s == 0); s = uint_add(0, 0, 64, &over); assert(!over && s == 0);
s = uint_sub(0, 0, 64, over); assert(!over && s == 0); s = uint_sub(0, 0, 64, &over); assert(!over && s == 0);
s = uint_sub(0, 1, 64, over); assert(over); s = uint_sub(0, 1, 64, &over); assert(over);
s = uint_sub(0, ~0ULL, 64, over); assert(over); s = uint_sub(0, ~0ULL, 64, &over); assert(over);
s = uint_sub(~0ULL, ~0ULL, 64, over); assert(!over && s == 0); s = uint_sub(~0ULL, ~0ULL, 64, &over); assert(!over && s == 0);
} }
static int64_t sign_extend(uint length_bits, int64_t n) { static int64_t sign_extend(uint length_bits, int64_t n) {
...@@ -130,7 +130,7 @@ static void test_int8() { ...@@ -130,7 +130,7 @@ static void test_int8() {
for (int64_t y = -max; y <= max-1; y++) { for (int64_t y = -max; y <= max-1; y++) {
bool over; bool over;
int64_t n, m; int64_t n, m;
n = int_add(x, y, 8, over); n = int_add(x, y, 8, &over);
m = x + y; m = x + y;
if (m > max-1) if (m > max-1)
assert(over); assert(over);
...@@ -138,7 +138,7 @@ static void test_int8() { ...@@ -138,7 +138,7 @@ static void test_int8() {
assert(over); assert(over);
else else
assert(!over && n == m); assert(!over && n == m);
n = int_sub(x, y, 8, over); n = int_sub(x, y, 8, &over);
m = x - y; m = x - y;
if (m > max-1) if (m > max-1)
assert(over); assert(over);
...@@ -158,7 +158,7 @@ static void test_int16() { ...@@ -158,7 +158,7 @@ static void test_int16() {
for (int64_t y = -max; y <= max-1; y++) { for (int64_t y = -max; y <= max-1; y++) {
bool over; bool over;
int64_t n, m; int64_t n, m;
n = int_add(x, y, 16, over); n = int_add(x, y, 16, &over);
m = x + y; m = x + y;
if (m > max-1) if (m > max-1)
assert(over); assert(over);
...@@ -166,7 +166,7 @@ static void test_int16() { ...@@ -166,7 +166,7 @@ static void test_int16() {
assert(over); assert(over);
else else
assert(!over && n == m); assert(!over && n == m);
n = int_sub(x, y, 16, over); n = int_sub(x, y, 16, &over);
m = x - y; m = x - y;
if (m > max-1) if (m > max-1)
assert(over); assert(over);
...@@ -184,22 +184,22 @@ static void test_int24() { ...@@ -184,22 +184,22 @@ static void test_int24() {
int64_t s; int64_t s;
bool over; bool over;
s = int_add(1, (1ULL<<23)-1, 24, over); assert(over); s = int_add(1, (1ULL<<23)-1, 24, &over); assert(over);
s = int_add((1ULL<<23)-1, 1, 24, over); assert(over); s = int_add((1ULL<<23)-1, 1, 24, &over); assert(over);
s = int_sub(-1, (1ULL<<23), 24, over); assert(!over && s == (1ULL<<23)-1); s = int_sub(-1, (1ULL<<23), 24, &over); assert(!over && s == (1ULL<<23)-1);
s = int_sub((1ULL<<23), 1, 24, over); assert(over); s = int_sub((1ULL<<23), 1, 24, &over); assert(over);
s = int_add(0, 0, 24, over); assert(!over && s == 0); s = int_add(0, 0, 24, &over); assert(!over && s == 0);
s = int_sub(0, 0, 24, over); assert(!over && s == 0); s = int_sub(0, 0, 24, &over); assert(!over && s == 0);
s = int_add(0, -1, 24, over); assert(!over && s == -1); s = int_add(0, -1, 24, &over); assert(!over && s == -1);
s = int_sub(0, 1, 24, over); assert(!over && s == -1); s = int_sub(0, 1, 24, &over); assert(!over && s == -1);
s = int_add(0, (1ULL<<23), 24, over); assert(!over && (s & (1ULL<<24)-1) == (1ULL<<23)); s = int_add(0, (1ULL<<23), 24, &over); assert(!over && (s & (1ULL<<24)-1) == (1ULL<<23));
s = int_sub(0, (1ULL<<23)-1, 24, over); assert(!over && (s & (1ULL<<24)-1) == (1ULL<<23)+1); s = int_sub(0, (1ULL<<23)-1, 24, &over); assert(!over && (s & (1ULL<<24)-1) == (1ULL<<23)+1);
s = int_add(-1, 0, 24, over); assert(!over && s == -1); s = int_add(-1, 0, 24, &over); assert(!over && s == -1);
s = int_add(-1, 1, 24, over); assert(!over && s == 0); s = int_add(-1, 1, 24, &over); assert(!over && s == 0);
s = int_sub(-1, -1, 24, over); assert(!over && s == 0); s = int_sub(-1, -1, 24, &over); assert(!over && s == 0);
s = int_sub(-1, (1ULL<<23)-1, 24, over); assert(!over && (s & (1ULL<<24)-1) == (1ULL<<23)); s = int_sub(-1, (1ULL<<23)-1, 24, &over); assert(!over && (s & (1ULL<<24)-1) == (1ULL<<23));
} }
static void test_int32() { static void test_int32() {
...@@ -208,22 +208,22 @@ static void test_int32() { ...@@ -208,22 +208,22 @@ static void test_int32() {
int64_t s; int64_t s;
bool over; bool over;
s = int_add(1, (1ULL<<31)-1, 32, over); assert(over); s = int_add(1, (1ULL<<31)-1, 32, &over); assert(over);
s = int_add((1ULL<<31)-1, 1, 32, over); assert(over); s = int_add((1ULL<<31)-1, 1, 32, &over); assert(over);
s = int_sub(-1, (1ULL<<31), 32, over); assert(s == (1ULL<<31)-1 && !over); s = int_sub(-1, (1ULL<<31), 32, &over); assert(s == (1ULL<<31)-1 && !over);
s = int_sub((1ULL<<31), 1, 32, over); assert(over); s = int_sub((1ULL<<31), 1, 32, &over); assert(over);
s = int_add(0, 0, 32, over); assert(s == 0 && !over); s = int_add(0, 0, 32, &over); assert(s == 0 && !over);
s = int_sub(0, 0, 32, over); assert(s == 0 && !over); s = int_sub(0, 0, 32, &over); assert(s == 0 && !over);
s = int_add(0, -1, 32, over); assert(s == -1 && !over); s = int_add(0, -1, 32, &over); assert(s == -1 && !over);
s = int_sub(0, 1, 32, over); assert(s == -1 && !over); s = int_sub(0, 1, 32, &over); assert(s == -1 && !over);
s = int_add(0, (1ULL<<31), 32, over); assert((s & (1ULL<<32)-1) == (1ULL<<31) && !over); s = int_add(0, (1ULL<<31), 32, &over); assert((s & (1ULL<<32)-1) == (1ULL<<31) && !over);
s = int_sub(0, (1ULL<<31)-1, 32, over); assert((s & (1ULL<<32)-1) == (1ULL<<31)+1 && !over); s = int_sub(0, (1ULL<<31)-1, 32, &over); assert((s & (1ULL<<32)-1) == (1ULL<<31)+1 && !over);
s = int_add(-1, 0, 32, over); assert(s == -1 && !over); s = int_add(-1, 0, 32, &over); assert(s == -1 && !over);
s = int_add(-1, 1, 32, over); assert(s == 0 && !over); s = int_add(-1, 1, 32, &over); assert(s == 0 && !over);
s = int_sub(-1, -1, 32, over); assert(s == 0 && !over); s = int_sub(-1, -1, 32, &over); assert(s == 0 && !over);
s = int_sub(-1, (1ULL<<31)-1, 32, over); assert((s & (1ULL<<32)-1) == (1ULL<<31) && !over); s = int_sub(-1, (1ULL<<31)-1, 32, &over); assert((s & (1ULL<<32)-1) == (1ULL<<31) && !over);
} }
static void test_int64() { static void test_int64() {
...@@ -232,22 +232,22 @@ static void test_int64() { ...@@ -232,22 +232,22 @@ static void test_int64() {
int64_t s; int64_t s;
bool over; bool over;
s = int_add(1, (1ULL<<63)-1, 64, over); assert(over); s = int_add(1, (1ULL<<63)-1, 64, &over); assert(over);
s = int_add((1ULL<<63)-1, 1, 64, over); assert(over); s = int_add((1ULL<<63)-1, 1, 64, &over); assert(over);
s = int_sub(-1, (1ULL<<63), 64, over); assert(s == (1ULL<<63)-1 && !over); s = int_sub(-1, (1ULL<<63), 64, &over); assert(s == (1ULL<<63)-1 && !over);
s = int_sub((1ULL<<63), 1, 64, over); assert(over); s = int_sub((1ULL<<63), 1, 64, &over); assert(over);
s = int_add(0, 0, 64, over); assert(s == 0 && !over); s = int_add(0, 0, 64, &over); assert(s == 0 && !over);
s = int_sub(0, 0, 64, over); assert(s == 0 && !over); s = int_sub(0, 0, 64, &over); assert(s == 0 && !over);
s = int_add(0, -1, 64, over); assert(s == -1 && !over); s = int_add(0, -1, 64, &over); assert(s == -1 && !over);
s = int_sub(0, 1, 64, over); assert(s == -1 && !over); s = int_sub(0, 1, 64, &over); assert(s == -1 && !over);
s = int_add(0, (1ULL<<63), 64, over); assert(s == (1ULL<<63) && !over); s = int_add(0, (1ULL<<63), 64, &over); assert(s == (1ULL<<63) && !over);
s = int_sub(0, (1ULL<<63)-1, 64, over); assert(s == (1ULL<<63)+1 && !over); s = int_sub(0, (1ULL<<63)-1, 64, &over); assert(s == (1ULL<<63)+1 && !over);
s = int_add(-1, 0, 64, over); assert(s == -1 && !over); s = int_add(-1, 0, 64, &over); assert(s == -1 && !over);
s = int_add(-1, 1, 64, over); assert(s == 0 && !over); s = int_add(-1, 1, 64, &over); assert(s == 0 && !over);
s = int_sub(-1, -1, 64, over); assert(s == 0 && !over); s = int_sub(-1, -1, 64, &over); assert(s == 0 && !over);
s = int_sub(-1, (1ULL<<63)-1, 64, over); assert(s == (1ULL<<63) && !over); s = int_sub(-1, (1ULL<<63)-1, 64, &over); assert(s == (1ULL<<63) && !over);
} }
static void test_int_sign(uint length_bits) { static void test_int_sign(uint length_bits) {
......
...@@ -12,7 +12,7 @@ static void test(int length_bits) { ...@@ -12,7 +12,7 @@ static void test(int length_bits) {
for (uint64_t x = 0; x <= max; x++) { for (uint64_t x = 0; x <= max; x++) {
for (uint64_t y = 0; y <= max; y++) { for (uint64_t y = 0; y <= max; y++) {
bool over; bool over;
uint64_t n = uint_add(x, y, max, over); uint64_t n = uint_add(x, y, max, &over);
printf("%llu %llu %llu\n", x, y, n); printf("%llu %llu %llu\n", x, y, n);
} }
} }
......
...@@ -27,24 +27,24 @@ static uint64_t uint_low_endpoint(uint length_bits) { ...@@ -27,24 +27,24 @@ static uint64_t uint_low_endpoint(uint length_bits) {
// Add two unsigned integers with max maximum value. // Add two unsigned integers with max maximum value.
// If there is an overflow then set the sum to the max. // If there is an overflow then set the sum to the max.
// Return the sum and the overflow. // Return the sum and the overflow.
static uint64_t uint_add(uint64_t x, uint64_t y, uint length_bits, bool &over) __attribute__((unused)); static uint64_t uint_add(uint64_t x, uint64_t y, uint length_bits, bool *over) __attribute__((unused));
static uint64_t uint_add(uint64_t x, uint64_t y, uint length_bits, bool &over) { static uint64_t uint_add(uint64_t x, uint64_t y, uint length_bits, bool *over) {
uint64_t mask = uint_mask(length_bits); uint64_t mask = uint_mask(length_bits);
assert((x & ~mask) == 0 && (y & ~mask) == 0); assert((x & ~mask) == 0 && (y & ~mask) == 0);
uint64_t s = (x + y) & mask; uint64_t s = (x + y) & mask;
over = s < x; // check for overflow *over = s < x; // check for overflow
return s; return s;
} }
// Subtract two unsigned ints with max maximum value. // Subtract two unsigned ints with max maximum value.
// If there is an over then set the difference to 0. // If there is an over then set the difference to 0.
// Return the difference and the overflow. // Return the difference and the overflow.
static uint64_t uint_sub(uint64_t x, uint64_t y, uint length_bits, bool &over) __attribute__((unused)); static uint64_t uint_sub(uint64_t x, uint64_t y, uint length_bits, bool *over) __attribute__((unused));
static uint64_t uint_sub(uint64_t x, uint64_t y, uint length_bits, bool &over) { static uint64_t uint_sub(uint64_t x, uint64_t y, uint length_bits, bool *over) {
uint64_t mask = uint_mask(length_bits); uint64_t mask = uint_mask(length_bits);
assert((x & ~mask) == 0 && (y & ~mask) == 0); assert((x & ~mask) == 0 && (y & ~mask) == 0);
uint64_t s = (x - y) & mask; uint64_t s = (x - y) & mask;
over = s > x; // check for overflow *over = s > x; // check for overflow
return s; return s;
} }
...@@ -74,11 +74,11 @@ static int64_t int_sign_extend(int64_t n, uint length_bits) { ...@@ -74,11 +74,11 @@ static int64_t int_sign_extend(int64_t n, uint length_bits) {
// depending on the sign bit. // depending on the sign bit.
// Sign extend to 64 bits. // Sign extend to 64 bits.
// Return the sum and the overflow. // Return the sum and the overflow.
static int64_t int_add(int64_t x, int64_t y, uint length_bits, bool &over) __attribute__((unused)); static int64_t int_add(int64_t x, int64_t y, uint length_bits, bool *over) __attribute__((unused));
static int64_t int_add(int64_t x, int64_t y, uint length_bits, bool &over) { static int64_t int_add(int64_t x, int64_t y, uint length_bits, bool *over) {
int64_t mask = uint_mask(length_bits); int64_t mask = uint_mask(length_bits);
int64_t n = (x + y) & mask; int64_t n = (x + y) & mask;
over = (((n ^ x) & (n ^ y)) >> (length_bits-1)) & 1; // check for overflow *over = (((n ^ x) & (n ^ y)) >> (length_bits-1)) & 1; // check for overflow
if (n & (1LL<<(length_bits-1))) if (n & (1LL<<(length_bits-1)))
n |= ~mask; // sign extend n |= ~mask; // sign extend
return n; return n;
...@@ -89,11 +89,11 @@ static int64_t int_add(int64_t x, int64_t y, uint length_bits, bool &over) { ...@@ -89,11 +89,11 @@ static int64_t int_add(int64_t x, int64_t y, uint length_bits, bool &over) {
// depending on the sign bit. // depending on the sign bit.
// Sign extend to 64 bits. // Sign extend to 64 bits.
// Return the sum and the overflow. // Return the sum and the overflow.
static int64_t int_sub(int64_t x, int64_t y, uint length_bits, bool &over) __attribute__((unused)); static int64_t int_sub(int64_t x, int64_t y, uint length_bits, bool *over) __attribute__((unused));
static int64_t int_sub(int64_t x, int64_t y, uint length_bits, bool &over) { static int64_t int_sub(int64_t x, int64_t y, uint length_bits, bool *over) {
int64_t mask = uint_mask(length_bits); int64_t mask = uint_mask(length_bits);
int64_t n = (x - y) & mask; int64_t n = (x - y) & mask;
over = (((x ^ y) & (n ^ x)) >> (length_bits-1)) & 1; // check for overflow *over = (((x ^ y) & (n ^ x)) >> (length_bits-1)) & 1; // check for overflow
if (n & (1LL<<(length_bits-1))) if (n & (1LL<<(length_bits-1)))
n |= ~mask; // sign extend n |= ~mask; // sign extend
return n; return n;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment