Commit 5c831b4d authored by Vicențiu Ciorbaru's avatar Vicențiu Ciorbaru Committed by Sergei Golubchik

MDEV-32886 Vec_FromText and Vec_ToText

This commit introduces two utility functions meant to make working with
vectors simpler.

Vec_ToText converts a binary vector into a json array of numbers
(floats).
Vec_FromText takes in a json array of numbers and converts it into a
little-endian IEEE float sequence of bytes (4 bytes per float).
parent 58f7f257
......@@ -62,18 +62,18 @@ insert t1 (v) values (x'e360d63ebe554f3fcdbc523f4522193f5236083d'),
(x'56926c3fdf098d3e2c8c5e3d1ad4953daa9d0b3e'),
(x'7b713f3e5258323f80d1113d673b2b3f66e3583f'),
(x'6ca1d43e9df91b3fe580da3e1c247d3f147cf33e');
select id, hex(v) from t1;
id hex(v)
1 E360D63EBE554F3FCDBC523F4522193F5236083D
2 F511303F72224A3FDD05FE3EB22A133FFAE86A3F
3 F09BAA3EA172763F123DEF3E0C7FE53E288BF33E
4 B97A523F2A193E3EB4F62E3F2D23583E9DD60D3F
5 F7C5DF3E984B2B3E65E59D3D7376DB3EAC63773E
6 DE01453FFA486D3F10AA4D3FDD66813C71CB163F
7 76EDFC3E4B57243F10F8423FB158713F020BDA3E
8 56926C3FDF098D3E2C8C5E3D1AD4953DAA9D0B3E
9 7B713F3E5258323F80D1113D673B2B3F66E3583F
10 6CA1D43E9DF91B3FE580DA3E1C247D3F147CF33E
select id, hex(v), vec_totext(v) from t1;
id hex(v) vec_totext(v)
1 E360D63EBE554F3FCDBC523F4522193F5236083D [0.418708,0.809902,0.823193,0.598179,0.033255]
2 F511303F72224A3FDD05FE3EB22A133FFAE86A3F [0.687774,0.789588,0.496138,0.574870,0.917617]
3 F09BAA3EA172763F123DEF3E0C7FE53E288BF33E [0.333221,0.962687,0.467263,0.448235,0.475671]
4 B97A523F2A193E3EB4F62E3F2D23583E9DD60D3F [0.822185,0.185643,0.683452,0.211072,0.554056]
5 F7C5DF3E984B2B3E65E59D3D7376DB3EAC63773E [0.437057,0.167281,0.077098,0.428638,0.241591]
6 DE01453FFA486D3F10AA4D3FDD66813C71CB163F [0.769560,0.926895,0.803376,0.015796,0.589042]
7 76EDFC3E4B57243F10F8423FB158713F020BDA3E [0.493999,0.641957,0.761598,0.942760,0.425865]
8 56926C3FDF098D3E2C8C5E3D1AD4953DAA9D0B3E [0.924108,0.275466,0.054333,0.073158,0.136344]
9 7B713F3E5258323F80D1113D673B2B3F66E3583F [0.186956,0.696660,0.035600,0.668875,0.847220]
10 6CA1D43E9DF91B3FE580DA3E1C247D3F147CF33E [0.415294,0.609278,0.426765,0.988832,0.475556]
flush tables;
select id,vec_distance(v, x'B047263c9f87233fcfd27e3eae493e3f0329f43e') d from t1 order by d limit 3;
id d
......
......@@ -24,7 +24,7 @@ insert t1 (v) values (x'e360d63ebe554f3fcdbc523f4522193f5236083d'),
(x'7b713f3e5258323f80d1113d673b2b3f66e3583f'),
(x'6ca1d43e9df91b3fe580da3e1c247d3f147cf33e');
select id, hex(v) from t1;
select id, hex(v), vec_totext(v) from t1;
flush tables;
# test with a valid query vector
select id,vec_distance(v, x'B047263c9f87233fcfd27e3eae493e3f0329f43e') d from t1 order by d limit 3;
......@@ -63,6 +63,7 @@ insert t1 (v) values (x'e360d63ebe554f3fcdbc523f4522193f5236083d'),
(x'6ca1d43e9df91b3fe580da3e1c247d3f147cf33e');
select id,vec_distance(v, x'b047263c9f87233Fcfd27e3eae493e3f0329f43e') d from t1 order by d limit 5;
--error ER_TRUNCATED_WRONG_VALUE_FOR_FIELD
insert t1 (v) values ('');
--error ER_TRUNCATED_WRONG_VALUE_FOR_FIELD
......
create table t1 (id int auto_increment primary key, v blob not null, vector index (v));
insert t1 (v) values (x'e360d63ebe554f3fcdbc523f4522193f5236083d'),
(x'f511303f72224a3fdd05fe3eb22a133ffae86a3f'),
(x'f09baa3ea172763f123def3e0c7fe53e288bf33e'),
(x'b97a523f2a193e3eb4f62e3f2d23583e9dd60d3f'),
(x'f7c5df3e984b2b3e65e59d3d7376db3eac63773e'),
(x'de01453ffa486d3f10aa4d3fdd66813c71cb163f'),
(x'76edfc3e4b57243f10f8423fb158713f020bda3e'),
(x'56926c3fdf098d3e2c8c5e3d1ad4953daa9d0b3e'),
(x'7b713f3e5258323f80d1113d673b2b3f66e3583f'),
(x'6ca1d43e9df91b3fe580da3e1c247d3f147cf33e');
# Error cases first.
select vec_totext(x'aabbcc');
select vec_totext(x'0000f07f');
select vec_totext(x'0000f0ff');
select vec_totext(x'0000807f');
select vec_totext(x'000080ff');
select hex(vec_fromtext('["a"]'));
select hex(vec_fromtext('[]'));
select hex(vec_fromtext('["a"]'));
select hex(vec_fromtext('[{"a": "b"}]'));
select hex(vec_fromtext('[null]'));
select hex(vec_fromtext('[1, null]'));
select hex(vec_fromtext('[1, ["a"]]'));
select hex(vec_fromtext('[1, [2]]'));
select hex(vec_fromtext('{"a":"b"}'));
select hex(vec_fromtext('[1, 2, "z", 3]'));
select hex(vec_fromtext('[1, 2, 3'));
select hex(vec_fromtext('1, 2, 3]'));
# Empty vectors ok.
select hex(vec_fromtext('[]'));
select vec_totext(x'');
select id, vec_totext(t1.v) as a, vec_totext(vec_fromtext(vec_totext(t1.v))) as b,
vec_distance(t1.v, vec_fromtext(vec_totext(t1.v))) < 0.000001
from t1;
drop table t1;
set collation_connection=utf16_general_ci;
set character_set_results=utf16;
select hex(vec_fromtext('[1,2,3]'));
select vec_totext(x'0000803F0000004000004040FFFFFFFF0000807F000080FF');
set character_set_results=default;
select vec_totext(x'0000803F0000004000004040FFFFFFFF0000807F000080FF');
......@@ -6238,7 +6238,8 @@ Create_func_year_week::create_native(THD *thd, const LEX_CSTRING *name,
class Create_func_vec_distance: public Create_func_arg2
{
public:
Item *create_2_arg(THD *thd, Item *arg1, Item *arg2) override;
Item *create_2_arg(THD *thd, Item *arg1, Item *arg2) override
{ return new (thd->mem_root) Item_func_vec_distance(thd, arg1, arg2); }
static Create_func_vec_distance s_singleton;
......@@ -6250,11 +6251,37 @@ class Create_func_vec_distance: public Create_func_arg2
Create_func_vec_distance Create_func_vec_distance::s_singleton;
Item*
Create_func_vec_distance::create_2_arg(THD *thd, Item *arg1, Item *arg2)
class Create_func_vec_totext: public Create_func_arg1
{
return new (thd->mem_root) Item_func_vec_distance(thd, arg1, arg2);
}
public:
Item *create_1_arg(THD *thd, Item *arg1) override
{ return new (thd->mem_root) Item_func_vec_totext(thd, arg1); }
static Create_func_vec_totext s_singleton;
protected:
Create_func_vec_totext() = default;
virtual ~Create_func_vec_totext() = default;
};
Create_func_vec_totext Create_func_vec_totext::s_singleton;
class Create_func_vec_fromtext: public Create_func_arg1
{
public:
Item *create_1_arg(THD *thd, Item *arg1) override
{ return new (thd->mem_root) Item_func_vec_fromtext(thd, arg1); }
static Create_func_vec_fromtext s_singleton;
protected:
Create_func_vec_fromtext() = default;
virtual ~Create_func_vec_fromtext() = default;
};
Create_func_vec_fromtext Create_func_vec_fromtext::s_singleton;
#define BUILDER(F) & F::s_singleton
......@@ -6484,6 +6511,8 @@ const Native_func_registry func_array[] =
{ { STRING_WITH_LEN("UPPER") }, BUILDER(Create_func_ucase)},
{ { STRING_WITH_LEN("UUID_SHORT") }, BUILDER(Create_func_uuid_short)},
{ { STRING_WITH_LEN("VEC_DISTANCE") }, BUILDER(Create_func_vec_distance)},
{ { STRING_WITH_LEN("VEC_FROMTEXT") }, BUILDER(Create_func_vec_fromtext)},
{ { STRING_WITH_LEN("VEC_TOTEXT") }, BUILDER(Create_func_vec_totext)},
{ { STRING_WITH_LEN("VERSION") }, BUILDER(Create_func_version)},
{ { STRING_WITH_LEN("WEEK") }, BUILDER(Create_func_week)},
{ { STRING_WITH_LEN("WEEKDAY") }, BUILDER(Create_func_weekday)},
......
......@@ -13,7 +13,6 @@
along with this program; if not, write to the Free Software
Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1335 USA */
/**
@file
......@@ -21,9 +20,15 @@
This file defines all vector functions
*/
#include <cmath>
#include <my_global.h>
#include "item.h"
#include "item_vectorfunc.h"
#include "json_lib.h"
#include "m_ctype.h"
#include "sql_const.h"
#include "sql_error.h"
key_map Item_func_vec_distance::part_of_sortkey() const
{
......@@ -47,11 +52,68 @@ double Item_func_vec_distance::val_real()
r1->length() % sizeof(float);
if (null_value)
return 0;
float *v1= (float*)r1->ptr();
float *v2= (float*)r2->ptr();
float *v1= (float *) r1->ptr();
float *v2= (float *) r2->ptr();
return euclidean_vec_distance(v1, v2, (r1->length()) / sizeof(float));
}
bool Item_func_vec_totext::fix_length_and_dec(THD *thd)
{
decimals= 0;
max_length= ((args[0]->max_length / 4) *
(MAX_FLOAT_STR_LENGTH + 1 /* comma */)) + 2 /* braces */;
fix_length_and_charset(max_length, default_charset());
set_maybe_null();
return false;
}
String *Item_func_vec_totext::val_str_ascii(String *str)
{
String *r1= args[0]->val_str();
if (args[0]->null_value)
{
null_value= true;
return nullptr;
}
// Wrong size returns null
if (r1->length() % 4)
{
THD *thd= current_thd;
push_warning_printf(thd, Sql_condition::WARN_LEVEL_WARN,
ER_VECTOR_BINARY_FORMAT_INVALID,
ER_THD(thd, ER_VECTOR_BINARY_FORMAT_INVALID));
null_value= true;
return nullptr;
}
str->length(0);
str->set_charset(&my_charset_numeric);
str->reserve(r1->length() / 4 * (MAX_FLOAT_STR_LENGTH + 1) + 2);
str->append('[');
const char *ptr= r1->ptr();
for (size_t i= 0; i < r1->length(); i+= 4)
{
float val= get_float(ptr);
if (std::isinf(val))
if (val < 0)
str->append(STRING_WITH_LEN("-Inf"));
else
str->append(STRING_WITH_LEN("Inf"));
else if (std::isnan(val))
str->append(STRING_WITH_LEN("NaN"));
else
str->append_float(val, FLT_DIG);
ptr+= 4;
if (r1->length() - i > 4)
str->append(',');
}
str->append(']');
return str;
}
double euclidean_vec_distance(float *v1, float *v2, size_t v_len)
{
float *p1= v1;
......@@ -64,3 +126,111 @@ double euclidean_vec_distance(float *v1, float *v2, size_t v_len)
}
return sqrt(d);
}
Item_func_vec_totext::Item_func_vec_totext(THD *thd, Item *a)
: Item_str_ascii_checksum_func(thd, a)
{
}
Item_func_vec_fromtext::Item_func_vec_fromtext(THD *thd, Item *a)
: Item_str_func(thd, a)
{
}
bool Item_func_vec_fromtext::fix_length_and_dec(THD *thd)
{
decimals= 0;
/* Worst case scenario, for a valid input we have a string of the form:
[1,2,3,4,5,...] single digit numbers.
This means we can have (max_length - 1) / 2 floats.
Each float takes 4 bytes, so we do (max_length - 1) * 2. */
fix_length_and_charset((args[0]->max_length - 1) * 2, &my_charset_bin);
set_maybe_null();
return false;
}
String *Item_func_vec_fromtext::val_str(String *buf)
{
json_engine_t je;
bool end_ok= false;
String *value = args[0]->val_json(&tmp_js);
CHARSET_INFO *cs= value->charset();
buf->length(0);
if (!value)
{
null_value= true;
return nullptr;
}
const uchar *start= reinterpret_cast<const uchar *>(value->ptr());
const uchar *end= start + value->length();
if (json_scan_start(&je, cs, start, end) ||
json_read_value(&je))
goto error;
if (je.value_type != JSON_VALUE_ARRAY)
goto error_format;
/* Accept only arrays of floats. */
do {
switch (je.state)
{
case JST_ARRAY_START:
continue;
case JST_ARRAY_END:
end_ok = true;
break;
case JST_VALUE:
{
if (json_read_value(&je))
goto error;
if (je.value_type != JSON_VALUE_NUMBER)
goto error_format;
int error;
char *start= (char *)je.value_begin, *end;
float f= (float)cs->strntod(start, je.value_len, &end, &error);
if (unlikely(error))
goto error_format;
char f_bin[4];
float4store(&f_bin, f);
buf->append(f_bin[0]);
buf->append(f_bin[1]);
buf->append(f_bin[2]);
buf->append(f_bin[3]);
break;
}
default:
goto error_format;
}
} while (json_scan_next(&je) == 0);
if (!end_ok)
goto error_format;
return buf;
error_format:
{
int position= (int)((const char *) je.s.c_str - value->ptr());
null_value= true;
THD *thd= current_thd;
push_warning_printf(thd, Sql_condition::WARN_LEVEL_WARN,
ER_VECTOR_FORMAT_INVALID,
ER_THD(thd, ER_VECTOR_FORMAT_INVALID),
position,
value->ptr());
return nullptr;
}
error:
report_json_error_ex(value->ptr(), &je, func_name(),
0, Sql_condition::WARN_LEVEL_WARN);
null_value= true;
return nullptr;
}
......@@ -48,7 +48,7 @@ class Item_func_vec_distance: public Item_real_func
double val_real() override;
LEX_CSTRING func_name_cstring() const override
{
static LEX_CSTRING name= {STRING_WITH_LEN("vec_distance") };
static LEX_CSTRING name= { STRING_WITH_LEN("VEC_Distance") };
return name;
}
Item *get_const_arg() const
......@@ -65,5 +65,43 @@ class Item_func_vec_distance: public Item_real_func
};
class Item_func_vec_totext: public Item_str_ascii_checksum_func
{
bool check_arguments() const override
{
return check_argument_types_or_binary(NULL, 0, arg_count);
}
public:
bool fix_length_and_dec(THD *thd) override;
Item_func_vec_totext(THD *thd, Item *a);
String *val_str_ascii(String *buf) override;
LEX_CSTRING func_name_cstring() const override
{
static LEX_CSTRING name= { STRING_WITH_LEN("VEC_ToText") };
return name;
}
Item *do_get_copy(THD *thd) const override
{ return get_item_copy<Item_func_vec_totext>(thd, this); }
};
class Item_func_vec_fromtext: public Item_str_func
{
String tmp_js;
public:
bool fix_length_and_dec(THD *thd) override;
Item_func_vec_fromtext(THD *thd, Item *a);
String *val_str(String *buf) override;
LEX_CSTRING func_name_cstring() const override
{
static LEX_CSTRING name= { STRING_WITH_LEN("VEC_FromText") };
return name;
}
Item *do_get_copy(THD *thd) const override
{ return get_item_copy<Item_func_vec_fromtext>(thd, this); }
};
double euclidean_vec_distance(float *v1, float *v2, size_t v_len);
#endif
......@@ -12279,3 +12279,7 @@ ER_SEQUENCE_TABLE_ORDER_BY
eng "ORDER BY"
ER_VARIABLE_IGNORED
eng "The variable '%s' is ignored. It only exists for compatibility with old installations and will be removed in a future release"
ER_VECTOR_BINARY_FORMAT_INVALID
eng "Invalid binary vector format. Must use IEEE standard float representation in little-endian format. Use VEC_FromText() to generate it."
ER_VECTOR_FORMAT_INVALID
eng "Invalid vector format at offset: %d for '%-.100s'. Must be a valid JSON array of numbers."
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment