Commit 3358c653 authored by Sergei Golubchik's avatar Sergei Golubchik

VEC_Distance_Cosine()

parent e999e6b2
......@@ -692,6 +692,9 @@ The following specify which files/extra groups are read (specified before remain
Unused. Deprecated, will be removed in a future release.
--mhnsw-cache-size=#
Size of the cache for the MHNSW vector index
--mhnsw-distance-function=name
Distance function to build the vector index for. One of:
euclidean, cosine
--mhnsw-max-edges-per-node=#
Larger values means slower INSERT, larger index size and
higher memory consumption, but better search results
......@@ -1800,6 +1803,7 @@ memlock FALSE
metadata-locks-cache-size 1024
metadata-locks-hash-instances 8
mhnsw-cache-size 16777216
mhnsw-distance-function euclidean
mhnsw-max-edges-per-node 6
mhnsw-min-limit 20
min-examined-row-limit 0
......
......@@ -235,6 +235,32 @@ from t1 where id < 10
id d
9 0.4719976290006591
3 0.5865673124650332
flush session status;
select id,vec_distance_euclidean(v, x'B047263c9f87233fcfd27e3eae493e3f0329f43e') d from t1 order by d limit 3;
id d
9 0.4719976290006591
10 0.5069011044450041
3 0.5865673124650332
show status like 'handler_read_rnd_next';
Variable_name Value
Handler_read_rnd_next 0
select id,vec_distance_euclidean(v, x'B047263c9f87233fcfd27e3eae493e3f0329f43e') d from t1 use index () order by d limit 3;
id d
9 0.4719976290006591
10 0.5069011044450041
3 0.5865673124650332
show status like 'handler_read_rnd_next';
Variable_name Value
Handler_read_rnd_next 11
flush session status;
select id,vec_distance_cosine(v, x'B047263c9f87233fcfd27e3eae493e3f0329f43e') d from t1 order by d limit 3;
id d
10 0.05905546376032378
9 0.06546887818344715
3 0.10750282439505232
show status like 'handler_read_rnd_next';
Variable_name Value
Handler_read_rnd_next 11
delete from t1 where v = x'7b713f3e5258323f80d1113d673b2b3f66e3583f';
select id,vec_distance_euclidean(v, x'B047263C9f87233fcfd27e3eae493e3f0329f43e') d from t1 order by d limit 3;
id d
......@@ -382,3 +408,58 @@ t1.frm
t1.ibd
drop database test1;
db.opt
#
# Cosine distance
#
create table t1 (id int auto_increment primary key, v blob not null,
vector index (v) distance_function=cosine);
show create table t1;
Table Create Table
t1 CREATE TABLE `t1` (
`id` int(11) NOT NULL AUTO_INCREMENT,
`v` blob NOT NULL,
PRIMARY KEY (`id`),
VECTOR KEY `v` (`v`) `distance_function`=cosine
) ENGINE=MyISAM DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_uca1400_ai_ci
insert t1 (v) values (x'e360d63ebe554f3fcdbc523f4522193f5236083d'),
(x'f511303f72224a3fdd05fe3eb22a133ffae86a3f'),
(x'f09baa3ea172763f123def3e0c7fe53e288bf33e'),
(x'b97a523f2a193e3eb4f62e3f2d23583e9dd60d3f'),
(x'f7c5df3e984b2b3e65e59d3d7376db3eac63773e'),
(x'de01453ffa486d3f10aa4d3fdd66813c71cb163f'),
(x'76edfc3e4b57243f10f8423fb158713f020bda3e'),
(x'56926c3fdf098d3e2c8c5e3d1ad4953daa9d0b3e'),
(x'7b713f3e5258323f80d1113d673b2b3f66e3583f'),
(x'6ca1d43e9df91b3fe580da3e1c247d3f147cf33e');
select id,vec_distance_cosine(v, x'B047263c9f87233fcfd27e3eae493e3f0329f43e') d from t1 order by d limit 3;
id d
10 0.05905546376032378
9 0.06546887818344715
3 0.10750282439505232
flush session status;
select id,vec_distance_cosine(v, x'B047263c9f87233fcfd27e3eae493e3f0329f43e') d from t1 order by d limit 3;
id d
10 0.05905546376032378
9 0.06546887818344715
3 0.10750282439505232
show status like 'handler_read_rnd_next';
Variable_name Value
Handler_read_rnd_next 0
select id,vec_distance_cosine(v, x'B047263c9f87233fcfd27e3eae493e3f0329f43e') d from t1 use index () order by d limit 3;
id d
10 0.05905546376032378
9 0.06546887818344715
3 0.10750282439505232
show status like 'handler_read_rnd_next';
Variable_name Value
Handler_read_rnd_next 11
flush session status;
select id,vec_distance_euclidean(v, x'B047263c9f87233fcfd27e3eae493e3f0329f43e') d from t1 order by d limit 3;
id d
9 0.4719976290006591
10 0.5069011044450041
3 0.5865673124650332
show status like 'handler_read_rnd_next';
Variable_name Value
Handler_read_rnd_next 11
drop table t1;
......@@ -56,6 +56,20 @@ select * from (
from t1 where id < 10
) u order by d limit 3;
# see if order by uses index:
--disable_view_protocol
--disable_ps2_protocol
flush session status;
select id,vec_distance_euclidean(v, x'B047263c9f87233fcfd27e3eae493e3f0329f43e') d from t1 order by d limit 3;
show status like 'handler_read_rnd_next'; # used
select id,vec_distance_euclidean(v, x'B047263c9f87233fcfd27e3eae493e3f0329f43e') d from t1 use index () order by d limit 3;
show status like 'handler_read_rnd_next'; # not used
flush session status;
select id,vec_distance_cosine(v, x'B047263c9f87233fcfd27e3eae493e3f0329f43e') d from t1 order by d limit 3;
show status like 'handler_read_rnd_next'; # not used, wrong distance metric
--enable_ps2_protocol
--enable_view_protocol
# test delete
delete from t1 where v = x'7b713f3e5258323f80d1113d673b2b3f66e3583f';
select id,vec_distance_euclidean(v, x'B047263C9f87233fcfd27e3eae493e3f0329f43e') d from t1 order by d limit 3;
......@@ -162,3 +176,38 @@ rename table test1.t1 to test1.t2;
list_files $datadir/test1;
drop database test1;
list_files $datadir/test;
--echo #
--echo # Cosine distance
--echo #
create table t1 (id int auto_increment primary key, v blob not null,
vector index (v) distance_function=cosine);
replace_result InnoDB MyISAM;
show create table t1;
insert t1 (v) values (x'e360d63ebe554f3fcdbc523f4522193f5236083d'),
(x'f511303f72224a3fdd05fe3eb22a133ffae86a3f'),
(x'f09baa3ea172763f123def3e0c7fe53e288bf33e'),
(x'b97a523f2a193e3eb4f62e3f2d23583e9dd60d3f'),
(x'f7c5df3e984b2b3e65e59d3d7376db3eac63773e'),
(x'de01453ffa486d3f10aa4d3fdd66813c71cb163f'),
(x'76edfc3e4b57243f10f8423fb158713f020bda3e'),
(x'56926c3fdf098d3e2c8c5e3d1ad4953daa9d0b3e'),
(x'7b713f3e5258323f80d1113d673b2b3f66e3583f'),
(x'6ca1d43e9df91b3fe580da3e1c247d3f147cf33e');
# make sure the graph is loaded
select id,vec_distance_cosine(v, x'B047263c9f87233fcfd27e3eae493e3f0329f43e') d from t1 order by d limit 3;
--disable_view_protocol
--disable_ps2_protocol
flush session status;
select id,vec_distance_cosine(v, x'B047263c9f87233fcfd27e3eae493e3f0329f43e') d from t1 order by d limit 3;
show status like 'handler_read_rnd_next';
select id,vec_distance_cosine(v, x'B047263c9f87233fcfd27e3eae493e3f0329f43e') d from t1 use index () order by d limit 3;
show status like 'handler_read_rnd_next';
flush session status;
select id,vec_distance_euclidean(v, x'B047263c9f87233fcfd27e3eae493e3f0329f43e') d from t1 order by d limit 3;
show status like 'handler_read_rnd_next';
--enable_ps2_protocol
--enable_view_protocol
drop table t1;
......@@ -2172,6 +2172,16 @@ NUMERIC_BLOCK_SIZE 1
ENUM_VALUE_LIST NULL
READ_ONLY NO
COMMAND_LINE_ARGUMENT REQUIRED
VARIABLE_NAME MHNSW_DISTANCE_FUNCTION
VARIABLE_SCOPE SESSION
VARIABLE_TYPE ENUM
VARIABLE_COMMENT Distance function to build the vector index for
NUMERIC_MIN_VALUE NULL
NUMERIC_MAX_VALUE NULL
NUMERIC_BLOCK_SIZE NULL
ENUM_VALUE_LIST euclidean,cosine
READ_ONLY NO
COMMAND_LINE_ARGUMENT REQUIRED
VARIABLE_NAME MHNSW_MAX_EDGES_PER_NODE
VARIABLE_SCOPE SESSION
VARIABLE_TYPE INT UNSIGNED
......
......@@ -2382,6 +2382,16 @@ NUMERIC_BLOCK_SIZE 1
ENUM_VALUE_LIST NULL
READ_ONLY NO
COMMAND_LINE_ARGUMENT REQUIRED
VARIABLE_NAME MHNSW_DISTANCE_FUNCTION
VARIABLE_SCOPE SESSION
VARIABLE_TYPE ENUM
VARIABLE_COMMENT Distance function to build the vector index for
NUMERIC_MIN_VALUE NULL
NUMERIC_MAX_VALUE NULL
NUMERIC_BLOCK_SIZE NULL
ENUM_VALUE_LIST euclidean,cosine
READ_ONLY NO
COMMAND_LINE_ARGUMENT REQUIRED
VARIABLE_NAME MHNSW_MAX_EDGES_PER_NODE
VARIABLE_SCOPE SESSION
VARIABLE_TYPE INT UNSIGNED
......
......@@ -6251,6 +6251,21 @@ class Create_func_vec_distance_euclidean: public Create_func_arg2
Create_func_vec_distance_euclidean Create_func_vec_distance_euclidean::s_singleton;
class Create_func_vec_distance_cosine: public Create_func_arg2
{
public:
Item *create_2_arg(THD *thd, Item *arg1, Item *arg2) override
{ return new (thd->mem_root) Item_func_vec_distance_cosine(thd, arg1, arg2); }
static Create_func_vec_distance_cosine s_singleton;
protected:
Create_func_vec_distance_cosine() = default;
virtual ~Create_func_vec_distance_cosine() = default;
};
Create_func_vec_distance_cosine Create_func_vec_distance_cosine::s_singleton;
class Create_func_vec_totext: public Create_func_arg1
{
public:
......@@ -6511,6 +6526,7 @@ const Native_func_registry func_array[] =
{ { STRING_WITH_LEN("UPPER") }, BUILDER(Create_func_ucase)},
{ { STRING_WITH_LEN("UUID_SHORT") }, BUILDER(Create_func_uuid_short)},
{ { STRING_WITH_LEN("VEC_DISTANCE_EUCLIDEAN") }, BUILDER(Create_func_vec_distance_euclidean)},
{ { STRING_WITH_LEN("VEC_DISTANCE_COSINE") }, BUILDER(Create_func_vec_distance_cosine)},
{ { STRING_WITH_LEN("VEC_FROMTEXT") }, BUILDER(Create_func_vec_fromtext)},
{ { STRING_WITH_LEN("VEC_TOTEXT") }, BUILDER(Create_func_vec_totext)},
{ { STRING_WITH_LEN("VERSION") }, BUILDER(Create_func_version)},
......
......@@ -21,6 +21,7 @@
*/
#include "item_vectorfunc.h"
#include "vector_mhnsw.h"
key_map Item_func_vec_distance_common::part_of_sortkey() const
{
......@@ -28,9 +29,10 @@ key_map Item_func_vec_distance_common::part_of_sortkey() const
if (Item_field *item= get_field_arg())
{
Field *f= item->field;
KEY *keyinfo= f->table->s->key_info;
for (uint i= f->table->s->keys; i < f->table->s->total_keys; i++)
if (f->table->s->key_info[i].algorithm == HA_KEY_ALG_VECTOR &&
f->key_start.is_set(i))
if (keyinfo[i].algorithm == HA_KEY_ALG_VECTOR && f->key_start.is_set(i)
&& mhnsw_uses_distance(f->table, keyinfo + i, this))
map.set_bit(i);
}
return map;
......
......@@ -85,6 +85,33 @@ class Item_func_vec_distance_euclidean: public Item_func_vec_distance_common
};
class Item_func_vec_distance_cosine: public Item_func_vec_distance_common
{
double calc_distance(float *v1, float *v2, size_t v_len) override
{
double dotp=0, abs1=0, abs2=0;
for (size_t i= 0; i < v_len; i++, v1++, v2++)
{
abs1+= *v1 * *v1;
abs2+= *v2 * *v2;
dotp+= *v1 * *v2;
}
return 1 - dotp/sqrt(abs1*abs2);
}
public:
Item_func_vec_distance_cosine(THD *thd, Item *a, Item *b)
:Item_func_vec_distance_common(thd, a, b) {}
LEX_CSTRING func_name_cstring() const override
{
static LEX_CSTRING name= { STRING_WITH_LEN("VEC_DISTANCE_COSINE") };
return name;
}
Item *do_get_copy(THD *thd) const override
{ return get_item_copy<Item_func_vec_distance_cosine>(thd, this); }
};
class Item_func_vec_totext: public Item_str_ascii_checksum_func
{
bool check_arguments() const override
......
......@@ -44,9 +44,17 @@ static MYSQL_THDVAR_UINT(max_edges_per_node, PLUGIN_VAR_RQCMDARG,
"memory consumption, but better search results",
nullptr, nullptr, 6, 3, 200, 1);
enum metric_type : uint { EUCLIDEAN, COSINE };
static const char *distance_function_names[]= { "euclidean", "cosine", nullptr };
static TYPELIB distance_functions= CREATE_TYPELIB_FOR(distance_function_names);
static MYSQL_THDVAR_ENUM(distance_function, PLUGIN_VAR_RQCMDARG,
"Distance function to build the vector index for",
nullptr, nullptr, EUCLIDEAN, &distance_functions);
struct ha_index_option_struct
{
uint M;
metric_type metric;
};
enum Graph_table_fields {
......@@ -79,7 +87,7 @@ struct FVector
static size_t data_to_value_size(size_t data_size)
{ return (data_size - data_header)*2; }
static const FVector *create(void *mem, const void *src, size_t src_len)
static const FVector *create(metric_type metric, void *mem, const void *src, size_t src_len)
{
float scale=0, *v= (float *)src;
size_t vec_len= src_len / sizeof(float);
......@@ -92,6 +100,12 @@ struct FVector
for (size_t i= 0; i < vec_len; i++)
vec->dims[i] = static_cast<int16_t>(std::round(v[i] / vec->scale));
vec->postprocess(vec_len);
if (metric == COSINE)
{
if (vec->abs2 > 0.0f)
vec->scale/= std::sqrt(vec->abs2);
vec->abs2= 1.0f;
}
return vec;
}
......@@ -292,11 +306,13 @@ class MHNSW_Context : public Sql_alloc
const uint tref_len;
const uint gref_len;
const uint M;
metric_type metric;
MHNSW_Context(TABLE *t)
: tref_len(t->file->ref_length),
gref_len(t->hlindex->file->ref_length),
M(t->s->key_info[t->s->keys].option_struct->M)
M(t->s->key_info[t->s->keys].option_struct->M),
metric(t->s->key_info[t->s->keys].option_struct->metric)
{
mysql_rwlock_init(PSI_INSTRUMENT_ME, &commit_lock);
mysql_mutex_init(PSI_INSTRUMENT_ME, &cache_lock, MY_MUTEX_INIT_FAST);
......@@ -601,7 +617,7 @@ int MHNSW_Context::acquire(MHNSW_Context **ctx, TABLE *table, bool for_update)
/* copy the vector, preprocessed as needed */
const FVector *FVectorNode::make_vec(const void *v)
{
return FVector::create(tref() + tref_len(), v, ctx->byte_len);
return FVector::create(ctx->metric, tref() + tref_len(), v, ctx->byte_len);
}
FVectorNode::FVectorNode(MHNSW_Context *ctx_, const void *gref_)
......@@ -1132,7 +1148,9 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
{
THD *thd= table->in_use;
TABLE *graph= table->hlindex;
auto *fun= (Item_func_vec_distance_euclidean *)(dist->real_item());
auto *fun= static_cast<Item_func_vec_distance_common*>(dist->real_item());
DBUG_ASSERT(fun);
String buf, *res= fun->get_const_arg()->val_str(&buf);
MHNSW_Context *ctx;
......@@ -1165,7 +1183,7 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
}
const longlong max_layer= start_nodes.links[0]->max_layer;
auto target= FVector::create(thd->alloc(FVector::alloc_size(ctx->vec_len)),
auto target= FVector::create(ctx->metric, thd->alloc(FVector::alloc_size(ctx->vec_len)),
res->ptr(), res->length());
if (int err= graph->file->ha_rnd_init(0))
......@@ -1303,6 +1321,13 @@ const LEX_CSTRING mhnsw_hlindex_table_def(THD *thd, uint ref_length)
return {s, len};
}
bool mhnsw_uses_distance(const TABLE *table, KEY *keyinfo, const Item *dist)
{
if (keyinfo->option_struct->metric == EUCLIDEAN)
return dynamic_cast<const Item_func_vec_distance_euclidean*>(dist) != NULL;
return dynamic_cast<const Item_func_vec_distance_cosine*>(dist) != NULL;
}
/*
Declare the plugin and index options
*/
......@@ -1310,6 +1335,7 @@ const LEX_CSTRING mhnsw_hlindex_table_def(THD *thd, uint ref_length)
ha_create_table_option mhnsw_index_options[]=
{
HA_IOPTION_SYSVAR("max_edges_per_node", M, max_edges_per_node),
HA_IOPTION_SYSVAR("distance_function", metric, distance_function),
HA_IOPTION_END
};
......@@ -1338,6 +1364,7 @@ static struct st_mysql_sys_var *mhnsw_sys_vars[]=
{
MYSQL_SYSVAR(cache_size),
MYSQL_SYSVAR(max_edges_per_node),
MYSQL_SYSVAR(distance_function),
MYSQL_SYSVAR(min_limit),
NULL
};
......
......@@ -21,6 +21,10 @@
#include "structs.h"
#include "table.h"
/*
This will become a vector index plugin API, or, perhaps,
a hlindex plugin API. When we'll have more than one implementation.
*/
const LEX_CSTRING mhnsw_hlindex_table_def(THD *thd, uint ref_length);
int mhnsw_insert(TABLE *table, KEY *keyinfo);
int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit);
......@@ -28,6 +32,7 @@ int mhnsw_invalidate(TABLE *table, const uchar *rec, KEY *keyinfo);
int mhnsw_delete_all(TABLE *table, KEY *keyinfo);
int mhnsw_next(TABLE *table);
void mhnsw_free(TABLE_SHARE *share);
bool mhnsw_uses_distance(const TABLE *table, KEY *keyinfo, const Item *dist);
extern ha_create_table_option mhnsw_index_options[];
extern st_plugin_int *mhnsw_plugin;
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment