Commit 1dce1222 authored by Vicențiu Ciorbaru's avatar Vicențiu Ciorbaru Committed by Vicențiu Ciorbaru

MDEV-32886 Vec_FromText and Vec_ToText

This commit introduces two utility functions meant to make working with
vectors simpler.

Vec_ToText converts a binary vector into a json array of numbers
(floats).
Vec_FromText takes in a json array of numbers and converts it into a
little-endian IEEE float sequence of bytes (4 bytes per float).
parent e55f8288
......@@ -63,6 +63,7 @@ insert t1 (v) values (x'e360d63ebe554f3fcdbc523f4522193f5236083d'),
(x'6ca1d43e9df91b3fe580da3e1c247d3f147cf33e');
select id,vec_distance(v, x'b047263c9f87233Fcfd27e3eae493e3f0329f43e') d from t1 order by d limit 5;
--error ER_TRUNCATED_WRONG_VALUE_FOR_FIELD
insert t1 (v) values ('');
--error ER_TRUNCATED_WRONG_VALUE_FOR_FIELD
......
create table t1 (id int auto_increment primary key, v blob not null, vector index (v));
insert t1 (v) values (x'e360d63ebe554f3fcdbc523f4522193f5236083d'),
(x'f511303f72224a3fdd05fe3eb22a133ffae86a3f'),
(x'f09baa3ea172763f123def3e0c7fe53e288bf33e'),
(x'b97a523f2a193e3eb4f62e3f2d23583e9dd60d3f'),
(x'f7c5df3e984b2b3e65e59d3d7376db3eac63773e'),
(x'de01453ffa486d3f10aa4d3fdd66813c71cb163f'),
(x'76edfc3e4b57243f10f8423fb158713f020bda3e'),
(x'56926c3fdf098d3e2c8c5e3d1ad4953daa9d0b3e'),
(x'7b713f3e5258323f80d1113d673b2b3f66e3583f'),
(x'6ca1d43e9df91b3fe580da3e1c247d3f147cf33e');
# Error cases first.
select vec_totext(x'aabbcc');
select vec_totext(x'0000f07f');
select vec_totext(x'0000f0ff');
select vec_totext(x'0000807f');
select vec_totext(x'000080ff');
select hex(vec_fromtext('["a"]'));
select hex(vec_fromtext('[]'));
select hex(vec_fromtext('["a"]'));
select hex(vec_fromtext('[{"a": "b"}]'));
select hex(vec_fromtext('[null]'));
select hex(vec_fromtext('[1, null]'));
select hex(vec_fromtext('[1, ["a"]]'));
select hex(vec_fromtext('[1, [2]]'));
select hex(vec_fromtext('{"a":"b"}'));
select hex(vec_fromtext('[1, 2, "z", 3]'));
select hex(vec_fromtext('[1, 2, 3'));
select hex(vec_fromtext('1, 2, 3]'));
# Empty vectors ok.
select hex(vec_fromtext('[]'));
select vec_totext(x'');
select id, vec_totext(t1.v) as a, vec_totext(vec_fromtext(vec_totext(t1.v))) as b,
vec_distance(t1.v, vec_fromtext(vec_totext(t1.v))) < 0.000001
from t1;
drop table t1;
set collation_connection=utf16_general_ci;
set character_set_results=utf16;
select hex(vec_fromtext('[1,2,3]'));
select vec_totext(x'0000803F0000004000004040FFFFFFFF0000807F000080FF');
set character_set_results=default;
select vec_totext(x'0000803F0000004000004040FFFFFFFF0000807F000080FF');
......@@ -6238,7 +6238,8 @@ Create_func_year_week::create_native(THD *thd, const LEX_CSTRING *name,
class Create_func_vec_distance: public Create_func_arg2
{
public:
Item *create_2_arg(THD *thd, Item *arg1, Item *arg2);
Item *create_2_arg(THD *thd, Item *arg1, Item *arg2)
{ return new (thd->mem_root) Item_func_vec_distance(thd, arg1, arg2); }
static Create_func_vec_distance s_singleton;
......@@ -6250,11 +6251,37 @@ class Create_func_vec_distance: public Create_func_arg2
Create_func_vec_distance Create_func_vec_distance::s_singleton;
Item*
Create_func_vec_distance::create_2_arg(THD *thd, Item *arg1, Item *arg2)
class Create_func_vec_totext: public Create_func_arg1
{
return new (thd->mem_root) Item_func_vec_distance(thd, arg1, arg2);
}
public:
Item *create_1_arg(THD *thd, Item *arg1)
{ return new (thd->mem_root) Item_func_vec_totext(thd, arg1); }
static Create_func_vec_totext s_singleton;
protected:
Create_func_vec_totext() = default;
virtual ~Create_func_vec_totext() = default;
};
Create_func_vec_totext Create_func_vec_totext::s_singleton;
class Create_func_vec_fromtext: public Create_func_arg1
{
public:
Item *create_1_arg(THD *thd, Item *arg1)
{ return new (thd->mem_root) Item_func_vec_fromtext(thd, arg1); }
static Create_func_vec_fromtext s_singleton;
protected:
Create_func_vec_fromtext() = default;
virtual ~Create_func_vec_fromtext() = default;
};
Create_func_vec_fromtext Create_func_vec_fromtext::s_singleton;
#define BUILDER(F) & F::s_singleton
......@@ -6484,6 +6511,8 @@ const Native_func_registry func_array[] =
{ { STRING_WITH_LEN("UPPER") }, BUILDER(Create_func_ucase)},
{ { STRING_WITH_LEN("UUID_SHORT") }, BUILDER(Create_func_uuid_short)},
{ { STRING_WITH_LEN("VEC_DISTANCE") }, BUILDER(Create_func_vec_distance)},
{ { STRING_WITH_LEN("VEC_FROMTEXT") }, BUILDER(Create_func_vec_fromtext)},
{ { STRING_WITH_LEN("VEC_TOTEXT") }, BUILDER(Create_func_vec_totext)},
{ { STRING_WITH_LEN("VERSION") }, BUILDER(Create_func_version)},
{ { STRING_WITH_LEN("WEEK") }, BUILDER(Create_func_week)},
{ { STRING_WITH_LEN("WEEKDAY") }, BUILDER(Create_func_weekday)},
......
......@@ -13,7 +13,6 @@
along with this program; if not, write to the Free Software
Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1335 USA */
/**
@file
......@@ -21,9 +20,15 @@
This file defines all vector functions
*/
#include <cmath>
#include <my_global.h>
#include "item.h"
#include "item_vectorfunc.h"
#include "json_lib.h"
#include "m_ctype.h"
#include "sql_const.h"
#include "sql_error.h"
key_map Item_func_vec_distance::part_of_sortkey() const
{
......@@ -47,11 +52,69 @@ double Item_func_vec_distance::val_real()
r1->length() % sizeof(float);
if (null_value)
return 0;
float *v1= (float*)r1->ptr();
float *v2= (float*)r2->ptr();
float *v1= (float *) r1->ptr();
float *v2= (float *) r2->ptr();
return euclidean_vec_distance(v1, v2, (r1->length()) / sizeof(float));
}
bool Item_func_vec_totext::fix_length_and_dec(THD *thd)
{
collation.set(default_charset());
decimals= 0;
max_length= ((args[0]->max_length / 4) *
(MAX_FLOAT_STR_LENGTH + 1 /* comma */)) + 2 /* braces */ *
collation.collation->mbmaxlen;
set_maybe_null();
return false;
}
String *Item_func_vec_totext::val_str_ascii(String *str)
{
String *r1= args[0]->val_str();
if (args[0]->null_value)
{
null_value= true;
return nullptr;
}
// Wrong size returns null
if (r1->length() % 4)
{
THD *thd= current_thd;
push_warning_printf(thd, Sql_condition::WARN_LEVEL_WARN,
ER_VECTOR_BINARY_FORMAT_INVALID,
ER_THD(thd, ER_VECTOR_BINARY_FORMAT_INVALID));
null_value= true;
return nullptr;
}
str->length(0);
str->set_charset(&my_charset_numeric);
str->reserve(r1->length() / 4 * (MAX_FLOAT_STR_LENGTH + 1) + 2);
str->append('[');
const char *ptr= r1->ptr();
for (size_t i= 0; i < r1->length(); i+= 4)
{
float val= *reinterpret_cast<const float *>(ptr);
if (std::isinf(val))
if (val < 0)
str->append(STRING_WITH_LEN("-Inf"));
else
str->append(STRING_WITH_LEN("Inf"));
else if (std::isnan(val))
str->append(STRING_WITH_LEN("NaN"));
else
str->append_float(val, FLT_DIG);
ptr+= 4;
if (r1->length() - i > 4)
str->append(',');
}
str->append(']');
return str;
}
double euclidean_vec_distance(float *v1, float *v2, size_t v_len)
{
float *p1= v1;
......@@ -64,3 +127,109 @@ double euclidean_vec_distance(float *v1, float *v2, size_t v_len)
}
return sqrt(d);
}
Item_func_vec_totext::Item_func_vec_totext(THD *thd, Item *a)
: Item_str_ascii_checksum_func(thd, a)
{
}
Item_func_vec_fromtext::Item_func_vec_fromtext(THD *thd, Item *a)
: Item_str_func(thd, a)
{
}
bool Item_func_vec_fromtext::fix_length_and_dec(THD *thd)
{
collation.set(&my_charset_bin);
decimals= 0;
/* Worst case scenario, for a valid input we have a string of the form:
[1,2,3,4,5,...] single digit numbers.
This means we can have (max_length - 2) / 2 floats.
Each float takes 4 bytes, so we do max_length - 2 * 2. */
max_length= (args[0]->max_length - 2) * 2 * my_charset_bin.mbmaxlen;
set_maybe_null();
return false;
}
String *Item_func_vec_fromtext::val_str(String *buf)
{
json_engine_t je;
bool end_ok= false;
String *value = args[0]->val_json(&tmp_js);
CHARSET_INFO *cs= value->charset();
buf->length(0);
if (!value)
{
null_value= true;
return nullptr;
}
const uchar *start= reinterpret_cast<const uchar *>(value->ptr());
const uchar *end= start + value->length();
if (json_scan_start(&je, cs, start, end) ||
json_read_value(&je))
goto error;
if (je.value_type != JSON_VALUE_ARRAY)
goto error_format;
/* Accept only arrays of floats. */
do {
switch (je.state)
{
case JST_ARRAY_START:
continue;
case JST_ARRAY_END:
end_ok = true;
break;
case JST_VALUE:
{
if (json_read_value(&je))
goto error;
if (je.value_type != JSON_VALUE_NUMBER)
goto error_format;
int error;
char *start= (char *)je.value_begin, *end;
float f= cs->strntod(start, je.value_len, &end, &error);
if (unlikely(error))
goto error_format;
uchar f_bin[4];
float4store(f_bin, f);
buf->append((char *)f_bin, 4);
break;
}
default:
goto error_format;
}
} while (json_scan_next(&je) == 0);
if (!end_ok)
goto error_format;
return buf;
error_format:
{
int position= (int)((const char *) je.s.c_str - value->ptr());
null_value= true;
THD *thd= current_thd;
push_warning_printf(thd, Sql_condition::WARN_LEVEL_WARN,
ER_VECTOR_FORMAT_INVALID,
ER_THD(thd, ER_VECTOR_FORMAT_INVALID),
position,
value->ptr());
return nullptr;
}
error:
report_json_error_ex(value->ptr(), &je, func_name(),
0, Sql_condition::WARN_LEVEL_WARN);
null_value= true;
return nullptr;
}
......@@ -43,7 +43,7 @@ class Item_func_vec_distance: public Item_real_func
double val_real() override;
LEX_CSTRING func_name_cstring() const override
{
static LEX_CSTRING name= {STRING_WITH_LEN("vec_distance") };
static LEX_CSTRING name= { STRING_WITH_LEN("VEC_Distance") };
return name;
}
Item *get_const_arg() const
......@@ -60,5 +60,43 @@ class Item_func_vec_distance: public Item_real_func
};
class Item_func_vec_totext: public Item_str_ascii_checksum_func
{
bool check_arguments() const override
{
return check_argument_types_or_binary(NULL, 0, arg_count);
}
public:
bool fix_length_and_dec(THD *thd) override;
Item_func_vec_totext(THD *thd, Item *a);
String *val_str_ascii(String *buf) override;
LEX_CSTRING func_name_cstring() const override
{
static LEX_CSTRING name= { STRING_WITH_LEN("VEC_ToText") };
return name;
}
Item *get_copy(THD *thd) override
{ return get_item_copy<Item_func_vec_totext>(thd, this); }
};
class Item_func_vec_fromtext: public Item_str_func
{
String tmp_js;
public:
bool fix_length_and_dec(THD *thd) override;
Item_func_vec_fromtext(THD *thd, Item *a);
String *val_str(String *buf) override;
LEX_CSTRING func_name_cstring() const override
{
static LEX_CSTRING name= { STRING_WITH_LEN("VEC_FromText") };
return name;
}
Item *get_copy(THD *thd) override
{ return get_item_copy<Item_func_vec_fromtext>(thd, this); }
};
double euclidean_vec_distance(float *v1, float *v2, size_t v_len);
#endif
......@@ -12279,3 +12279,7 @@ ER_SEQUENCE_TABLE_ORDER_BY
eng "ORDER BY"
ER_VARIABLE_IGNORED
eng "The variable '%s' is ignored. It only exists for compatibility with old installations and will be removed in a future release"
ER_VECTOR_BINARY_FORMAT_INVALID
eng "Invalid binary vector format. Must use IEEE standard float representation in little-endian format. Use VEC_FromText() to generate it."
ER_VECTOR_FORMAT_INVALID
eng "Invalid vector format at offset: %d for '%-.100s'. Must be a valid JSON array of numbers."
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment