Commit 3407e355 authored by Sergei Golubchik's avatar Sergei Golubchik

misc changes

* sysvars should be REQUIRED_ARG
* fix a mix of US and UK spelling (use US)
* use consistent naming
* work if VEC_DISTANCE arguments are in the swapped order (const, col)
* work if VEC_DISTANCE argument is NULL/invalid or wrong length
* abort INSERT if the value is invalid or wrong length
* store the "number of neighbors" in a blob in endianness-independent way
* use field->store(longlong, bool) not field->store(double)
* a lot more error checking everywhere
* cleanup after errors
* simplify calling conventions, remove reinterpret_cast's
* todo/XXX comments
* whitespaces
* use float consistently

memory management is still totally PoC quality

Initial HNSW implementation
parent 16b55b46
...@@ -402,10 +402,10 @@ The following specify which files/extra groups are read (specified before remain ...@@ -402,10 +402,10 @@ The following specify which files/extra groups are read (specified before remain
height-balanced, DOUBLE_PREC_HB - double precision height-balanced, DOUBLE_PREC_HB - double precision
height-balanced, JSON_HB - height-balanced, stored as height-balanced, JSON_HB - height-balanced, stored as
JSON JSON
--hnsw-ef-constructor --hnsw-ef-constructor=#
hnsw_ef_constructor hnsw_ef_constructor
--hnsw-ef-search hnsw_ef_search --hnsw-ef-search=# hnsw_ef_search
--hnsw-max-connection-per-layer --hnsw-max-connection-per-layer=#
hnsw_max_connection_per_layer hnsw_max_connection_per_layer
--host-cache-size=# How many host names should be cached to avoid resolving --host-cache-size=# How many host names should be cached to avoid resolving
(Automatically configured unless set explicitly) (Automatically configured unless set explicitly)
......
...@@ -80,6 +80,21 @@ id d ...@@ -80,6 +80,21 @@ id d
9 0.4719976290006591 9 0.4719976290006591
10 0.5069011044450041 10 0.5069011044450041
3 0.5865673124650332 3 0.5865673124650332
select id,vec_distance(x'b047263c9f87233fcfd27e3eae493e3f0329f43e', v) d from t1 order by d limit 3;
id d
9 0.4719976290006591
10 0.5069011044450041
3 0.5865673124650332
select id>0,vec_distance(v, NULL) d from t1 order by d limit 3;
id>0 d
1 NULL
1 NULL
1 NULL
select id>0,vec_distance(v, x'123456') d from t1 order by d limit 3;
id>0 d
1 NULL
1 NULL
1 NULL
select t1.id as id1, t2.id as id2, vec_distance(t1.v, t2.v) from t1, t1 as t2 order by 3,1,2; select t1.id as id1, t2.id as id2, vec_distance(t1.v, t2.v) from t1, t1 as t2 order by 3,1,2;
id1 id2 vec_distance(t1.v, t2.v) id1 id2 vec_distance(t1.v, t2.v)
1 1 0 1 1 0
...@@ -182,5 +197,11 @@ id1 id2 vec_distance(t1.v, t2.v) ...@@ -182,5 +197,11 @@ id1 id2 vec_distance(t1.v, t2.v)
9 8 1.2575258643523053 9 8 1.2575258643523053
7 8 1.288239696195716 7 8 1.288239696195716
8 7 1.288239696195716 8 7 1.288239696195716
insert t1 (v) values ('');
ERROR 22007: Incorrect vector value: '...' for column `test`.`t1`.`v` at row 1
insert t1 (v) values (x'1234');
ERROR 22007: Incorrect vector value: '...' for column `test`.`t1`.`v` at row 1
insert t1 (v) values (x'12345678');
ERROR 22007: Incorrect vector value: '...' for column `test`.`t1`.`v` at row 1
drop table t1; drop table t1;
db.opt db.opt
...@@ -10,7 +10,7 @@ create table t1 (id int auto_increment primary key, v blob not null, vector inde ...@@ -10,7 +10,7 @@ create table t1 (id int auto_increment primary key, v blob not null, vector inde
show create table t1; show create table t1;
show keys from t1; show keys from t1;
query_vertical select * from information_schema.statistics where table_name='t1'; query_vertical select * from information_schema.statistics where table_name='t1';
# print unpack(H40,pack(f5,map{rand}1..5)) # print unpack("H*",pack("f*",map{rand}1..5))
insert t1 (v) values (x'e360d63ebe554f3fcdbc523f4522193f5236083d'), insert t1 (v) values (x'e360d63ebe554f3fcdbc523f4522193f5236083d'),
(x'f511303f72224a3fdd05fe3eb22a133ffae86a3f'), (x'f511303f72224a3fdd05fe3eb22a133ffae86a3f'),
(x'f09baa3ea172763f123def3e0c7fe53e288bf33e'), (x'f09baa3ea172763f123def3e0c7fe53e288bf33e'),
...@@ -24,8 +24,23 @@ insert t1 (v) values (x'e360d63ebe554f3fcdbc523f4522193f5236083d'), ...@@ -24,8 +24,23 @@ insert t1 (v) values (x'e360d63ebe554f3fcdbc523f4522193f5236083d'),
select id, hex(v) from t1; select id, hex(v) from t1;
flush tables; flush tables;
# test with a valid query vector
select id,vec_distance(v, x'b047263c9f87233fcfd27e3eae493e3f0329f43e') d from t1 order by d limit 3; select id,vec_distance(v, x'b047263c9f87233fcfd27e3eae493e3f0329f43e') d from t1 order by d limit 3;
# swapped arguments
select id,vec_distance(x'b047263c9f87233fcfd27e3eae493e3f0329f43e', v) d from t1 order by d limit 3;
# test with NULL (id is unpredictable)
select id>0,vec_distance(v, NULL) d from t1 order by d limit 3;
# test with invalid query vector (id is unpredictable)
select id>0,vec_distance(v, x'123456') d from t1 order by d limit 3;
select t1.id as id1, t2.id as id2, vec_distance(t1.v, t2.v) from t1, t1 as t2 order by 3,1,2; select t1.id as id1, t2.id as id2, vec_distance(t1.v, t2.v) from t1, t1 as t2 order by 3,1,2;
--error ER_TRUNCATED_WRONG_VALUE_FOR_FIELD
insert t1 (v) values ('');
--error ER_TRUNCATED_WRONG_VALUE_FOR_FIELD
insert t1 (v) values (x'1234');
--error ER_TRUNCATED_WRONG_VALUE_FOR_FIELD
insert t1 (v) values (x'12345678');
drop table t1; drop table t1;
let $datadir=`select @@datadir`; let $datadir=`select @@datadir`;
list_files $datadir/test; list_files $datadir/test;
...@@ -1431,7 +1431,7 @@ NUMERIC_MAX_VALUE 4294967295 ...@@ -1431,7 +1431,7 @@ NUMERIC_MAX_VALUE 4294967295
NUMERIC_BLOCK_SIZE 1 NUMERIC_BLOCK_SIZE 1
ENUM_VALUE_LIST NULL ENUM_VALUE_LIST NULL
READ_ONLY NO READ_ONLY NO
COMMAND_LINE_ARGUMENT NONE COMMAND_LINE_ARGUMENT REQUIRED
VARIABLE_NAME HNSW_EF_SEARCH VARIABLE_NAME HNSW_EF_SEARCH
VARIABLE_SCOPE SESSION VARIABLE_SCOPE SESSION
VARIABLE_TYPE INT UNSIGNED VARIABLE_TYPE INT UNSIGNED
...@@ -1441,7 +1441,7 @@ NUMERIC_MAX_VALUE 4294967295 ...@@ -1441,7 +1441,7 @@ NUMERIC_MAX_VALUE 4294967295
NUMERIC_BLOCK_SIZE 1 NUMERIC_BLOCK_SIZE 1
ENUM_VALUE_LIST NULL ENUM_VALUE_LIST NULL
READ_ONLY NO READ_ONLY NO
COMMAND_LINE_ARGUMENT NONE COMMAND_LINE_ARGUMENT REQUIRED
VARIABLE_NAME HNSW_MAX_CONNECTION_PER_LAYER VARIABLE_NAME HNSW_MAX_CONNECTION_PER_LAYER
VARIABLE_SCOPE SESSION VARIABLE_SCOPE SESSION
VARIABLE_TYPE INT UNSIGNED VARIABLE_TYPE INT UNSIGNED
...@@ -1451,7 +1451,7 @@ NUMERIC_MAX_VALUE 4294967295 ...@@ -1451,7 +1451,7 @@ NUMERIC_MAX_VALUE 4294967295
NUMERIC_BLOCK_SIZE 1 NUMERIC_BLOCK_SIZE 1
ENUM_VALUE_LIST NULL ENUM_VALUE_LIST NULL
READ_ONLY NO READ_ONLY NO
COMMAND_LINE_ARGUMENT NONE COMMAND_LINE_ARGUMENT REQUIRED
VARIABLE_NAME HOSTNAME VARIABLE_NAME HOSTNAME
VARIABLE_SCOPE GLOBAL VARIABLE_SCOPE GLOBAL
VARIABLE_TYPE VARCHAR VARIABLE_TYPE VARCHAR
......
...@@ -965,6 +965,10 @@ class Field: public Value_source ...@@ -965,6 +965,10 @@ class Field: public Value_source
{ {
return store(to, length, &my_charset_bin); return store(to, length, &my_charset_bin);
} }
int store_binary(const uchar *to, size_t length)
{
return store_binary((const char*)(to), length);
}
virtual int store_hex_hybrid(const char *str, size_t length); virtual int store_hex_hybrid(const char *str, size_t length);
virtual int store(double nr)=0; virtual int store(double nr)=0;
virtual int store(longlong nr, bool unsigned_val)=0; virtual int store(longlong nr, bool unsigned_val)=0;
......
...@@ -51,10 +51,17 @@ class Item_func_vec_distance: public Item_real_func ...@@ -51,10 +51,17 @@ class Item_func_vec_distance: public Item_real_func
static LEX_CSTRING name= {STRING_WITH_LEN("vec_distance") }; static LEX_CSTRING name= {STRING_WITH_LEN("vec_distance") };
return name; return name;
} }
Item *get_const_arg() const
{
if (args[0]->type() == Item::FIELD_ITEM && args[1]->const_item())
return args[1];
if (args[1]->type() == Item::FIELD_ITEM && args[0]->const_item())
return args[0];
return NULL;
}
key_map part_of_sortkey() const override; key_map part_of_sortkey() const override;
Item *do_get_copy(THD *thd) const override Item *do_get_copy(THD *thd) const override
{ return get_item_copy<Item_func_vec_distance>(thd, this); } { return get_item_copy<Item_func_vec_distance>(thd, this); }
virtual ~Item_func_vec_distance() {};
}; };
......
...@@ -7388,18 +7388,18 @@ static Sys_var_enum Sys_block_encryption_mode( ...@@ -7388,18 +7388,18 @@ static Sys_var_enum Sys_block_encryption_mode(
static Sys_var_uint Sys_hnsw_ef_search( static Sys_var_uint Sys_hnsw_ef_search(
"hnsw_ef_search", "hnsw_ef_search",
"hnsw_ef_search", "hnsw_ef_search",
SESSION_VAR(hnsw_ef_search), CMD_LINE(NO_ARG), SESSION_VAR(hnsw_ef_search), CMD_LINE(REQUIRED_ARG),
VALID_RANGE(0, UINT_MAX), DEFAULT(10), VALID_RANGE(0, UINT_MAX), DEFAULT(10),
BLOCK_SIZE(1)); BLOCK_SIZE(1));
static Sys_var_uint Sys_hnsw_ef_constructor( static Sys_var_uint Sys_hnsw_ef_constructor(
"hnsw_ef_constructor", "hnsw_ef_constructor",
"hnsw_ef_constructor", "hnsw_ef_constructor",
SESSION_VAR(hnsw_ef_constructor), CMD_LINE(NO_ARG), SESSION_VAR(hnsw_ef_constructor), CMD_LINE(REQUIRED_ARG),
VALID_RANGE(0, UINT_MAX), DEFAULT(10), VALID_RANGE(0, UINT_MAX), DEFAULT(10),
BLOCK_SIZE(1)); BLOCK_SIZE(1));
static Sys_var_uint Sys_hnsw_max_connection_per_layer( static Sys_var_uint Sys_hnsw_max_connection_per_layer(
"hnsw_max_connection_per_layer", "hnsw_max_connection_per_layer",
"hnsw_max_connection_per_layer", "hnsw_max_connection_per_layer",
SESSION_VAR(hnsw_max_connection_per_layer), CMD_LINE(NO_ARG), SESSION_VAR(hnsw_max_connection_per_layer), CMD_LINE(REQUIRED_ARG),
VALID_RANGE(0, UINT_MAX), DEFAULT(50), VALID_RANGE(0, UINT_MAX), DEFAULT(50),
BLOCK_SIZE(1)); BLOCK_SIZE(1));
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
along with this program; if not, write to the Free Software along with this program; if not, write to the Free Software
Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1335 USA Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1335 USA
*/ */
#include <random>
#include <my_global.h> #include <my_global.h>
#include "vector_mhnsw.h" #include "vector_mhnsw.h"
...@@ -27,14 +26,18 @@ ...@@ -27,14 +26,18 @@
#include "my_base.h" #include "my_base.h"
#include "mysql/psi/psi_base.h" #include "mysql/psi/psi_base.h"
#include "sql_queue.h" #include "sql_queue.h"
#include <scope.h>
#define HNSW_MAX_M 10000 #define HNSW_MAX_M 10000 // practically the number of neighbors should be ~100
#define HNSW_MAX_M_WIDTH 2
#define HNSW_MAX_M_store int2store
#define HNSW_MAX_M_read uint2korr
const LEX_CSTRING mhnsw_hlindex_table={STRING_WITH_LEN("\ const LEX_CSTRING mhnsw_hlindex_table={STRING_WITH_LEN("\
CREATE TABLE i ( \ CREATE TABLE i ( \
layer int not null, \ layer int not null, \
src varbinary(255) not null, \ src varbinary(255) not null, \
neighbors varbinary(10000) not null, \ neighbors blob not null, \
index (layer, src)) \ index (layer, src)) \
")}; ")};
...@@ -43,8 +46,7 @@ class FVectorRef ...@@ -43,8 +46,7 @@ class FVectorRef
{ {
public: public:
// Shallow ref copy. Used for other ref lookups in HashSet // Shallow ref copy. Used for other ref lookups in HashSet
FVectorRef(uchar *ref, size_t ref_len): ref{ref}, ref_len{ref_len} {} FVectorRef(const void *ref, size_t ref_len): ref{(uchar*)ref}, ref_len{ref_len} {}
virtual ~FVectorRef() {}
static const uchar *get_key(const FVectorRef *elem, size_t *key_len, my_bool) static const uchar *get_key(const FVectorRef *elem, size_t *key_len, my_bool)
{ {
...@@ -66,22 +68,11 @@ class FVectorRef ...@@ -66,22 +68,11 @@ class FVectorRef
size_t ref_len; size_t ref_len;
}; };
class FVector; Hash_set<FVectorRef> all_vector_set(PSI_INSTRUMENT_MEM, &my_charset_bin,
1000, 0, 0, (my_hash_get_key)FVectorRef::get_key, 0, HASH_UNIQUE);
Hash_set<FVectorRef> all_vector_set(
PSI_INSTRUMENT_MEM, &my_charset_bin,
1000, 0, 0,
(my_hash_get_key)FVectorRef::get_key,
NULL,
HASH_UNIQUE);
Hash_set<FVectorRef> all_vector_ref_set(
PSI_INSTRUMENT_MEM, &my_charset_bin,
1000, 0, 0,
(my_hash_get_key)FVectorRef::get_key,
NULL,
HASH_UNIQUE);
Hash_set<FVectorRef> all_vector_ref_set(PSI_INSTRUMENT_MEM, &my_charset_bin,
1000, 0, 0, (my_hash_get_key)FVectorRef::get_key, NULL, HASH_UNIQUE);
class FVector: public FVectorRef class FVector: public FVectorRef
{ {
...@@ -92,29 +83,28 @@ class FVector: public FVectorRef ...@@ -92,29 +83,28 @@ class FVector: public FVectorRef
FVector(): vec(nullptr), vec_len(0) {} FVector(): vec(nullptr), vec_len(0) {}
~FVector() { my_free(this->ref); } ~FVector() { my_free(this->ref); }
bool init(const uchar *ref, size_t ref_len, bool init(const uchar *ref, size_t ref_len, const void *vec, size_t bytes)
const float *vec, size_t vec_len)
{ {
this->ref= (uchar *)my_malloc(PSI_NOT_INSTRUMENTED, this->ref= (uchar*)my_malloc(PSI_NOT_INSTRUMENTED, ref_len + bytes, MYF(0));
ref_len + vec_len * sizeof(float),
MYF(0));
if (!this->ref) if (!this->ref)
return true; return true;
this->vec= reinterpret_cast<float *>(this->ref + ref_len); this->vec= reinterpret_cast<float *>(this->ref + ref_len);
memcpy(this->ref, ref, ref_len); memcpy(this->ref, ref, ref_len);
memcpy(this->vec, vec, vec_len * sizeof(float)); memcpy(this->vec, vec, bytes);
this->ref_len= ref_len; this->ref_len= ref_len;
this->vec_len= vec_len; this->vec_len= bytes / sizeof(float);
return false; return false;
} }
size_t size_of() const { return vec_len * sizeof(float); }
size_t get_vec_len() const { return vec_len; } size_t get_vec_len() const { return vec_len; }
const float* get_vec() const { return vec; } const float* get_vec() const { return vec; }
double distance_to(const FVector &other) const float distance_to(const FVector &other) const
{ {
DBUG_ASSERT(other.vec_len == vec_len); DBUG_ASSERT(other.vec_len == vec_len);
return euclidean_vec_distance(vec, other.vec, vec_len); return euclidean_vec_distance(vec, other.vec, vec_len);
...@@ -122,24 +112,25 @@ class FVector: public FVectorRef ...@@ -122,24 +112,25 @@ class FVector: public FVectorRef
static FVectorRef *get_fvector_ref(const uchar *ref, size_t ref_len) static FVectorRef *get_fvector_ref(const uchar *ref, size_t ref_len)
{ {
FVectorRef tmp{(uchar*)ref, ref_len}; FVectorRef tmp{ref, ref_len};
FVectorRef *v= all_vector_ref_set.find(&tmp); FVectorRef *v= all_vector_ref_set.find(&tmp);
if (v) if (v)
return v; return v;
// TODO(cvicentiu) memory management. // TODO(cvicentiu) memory management.
uchar *buf= (uchar *)my_malloc(PSI_NOT_INSTRUMENTED, ref_len, MYF(0)); uchar *buf= (uchar *)my_malloc(PSI_NOT_INSTRUMENTED, ref_len, MYF(0));
if (buf)
{
memcpy(buf, ref, ref_len); memcpy(buf, ref, ref_len);
v= new FVectorRef{buf, ref_len}; if ((v= new FVectorRef(buf, ref_len)))
all_vector_ref_set.insert(v); all_vector_ref_set.insert(v);
}
return v; return v;
} }
static FVector *get_fvector_from_source(TABLE *source, static FVector *get_fvector_from_source(TABLE *source, Field *vec_field,
Field *vect_field,
const FVectorRef &ref) const FVectorRef &ref)
{ {
FVectorRef *v= all_vector_set.find(&ref); FVectorRef *v= all_vector_set.find(&ref);
if (v) if (v)
return (FVector *)v; return (FVector *)v;
...@@ -152,12 +143,10 @@ class FVector: public FVectorRef ...@@ -152,12 +143,10 @@ class FVector: public FVectorRef
const_cast<uchar *>(ref.get_ref())); const_cast<uchar *>(ref.get_ref()));
String buf, *vec; String buf, *vec;
vec= vect_field->val_str(&buf); vec= vec_field->val_str(&buf);
// TODO(cvicentiu) error checking // TODO(cvicentiu) error checking
new_vector->init(ref.get_ref(), ref.get_ref_len(), new_vector->init(ref.get_ref(), ref.get_ref_len(), vec->ptr(), vec->length());
reinterpret_cast<const float *>(vec->ptr()),
vec->length() / sizeof(float));
all_vector_set.insert(new_vector); all_vector_set.insert(new_vector);
...@@ -165,13 +154,10 @@ class FVector: public FVectorRef ...@@ -165,13 +154,10 @@ class FVector: public FVectorRef
} }
}; };
static int cmp_vec(const FVector *reference, const FVector *a, const FVector *b) static int cmp_vec(const FVector *reference, const FVector *a, const FVector *b)
{ {
double a_dist= reference->distance_to(*a); float a_dist= reference->distance_to(*a);
double b_dist= reference->distance_to(*b); float b_dist= reference->distance_to(*b);
if (a_dist < b_dist) if (a_dist < b_dist)
return -1; return -1;
...@@ -180,65 +166,99 @@ static int cmp_vec(const FVector *reference, const FVector *a, const FVector *b) ...@@ -180,65 +166,99 @@ static int cmp_vec(const FVector *reference, const FVector *a, const FVector *b)
return 0; return 0;
} }
const bool KEEP_PRUNED_CONNECTIONS=true; const bool KEEP_PRUNED_CONNECTIONS=true; // XXX why?
const bool EXTEND_CANDIDATES=true; const bool EXTEND_CANDIDATES=true; // XXX or false?
static bool get_neighbours(TABLE *graph, static int get_neighbors(TABLE *graph, size_t layer_number,
size_t layer_number,
const FVectorRef &source_node, const FVectorRef &source_node,
List<FVectorRef> *neighbours); List<FVectorRef> *neighbors)
{
uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length));
static bool select_neighbours(TABLE *source, TABLE *graph, graph->field[0]->store(layer_number, false);
Field *vect_field, graph->field[1]->store_binary(source_node.get_ref(), source_node.get_ref_len());
size_t layer_number, key_copy(key, graph->record[0], graph->key_info, graph->key_info->key_length);
const FVector &target, if (int err= graph->file->ha_index_read_map(graph->record[0], key,
HA_WHOLE_KEY, HA_READ_KEY_EXACT))
return err;
String strbuf, *str= graph->field[2]->val_str(&strbuf);
// mhnsw_insert() guarantees that all ref have the same length
uint ref_length= source_node.get_ref_len();
const uchar *neigh_arr_bytes= reinterpret_cast<const uchar *>(str->ptr());
uint number_of_neighbors= HNSW_MAX_M_read(neigh_arr_bytes);
if (number_of_neighbors * ref_length + HNSW_MAX_M_WIDTH != str->length())
return HA_ERR_CRASHED; // should not happen, corrupted HNSW index
const uchar *pos= neigh_arr_bytes + HNSW_MAX_M_WIDTH;
for (uint i= 0; i < number_of_neighbors; i++)
{
FVectorRef *v= FVector::get_fvector_ref(pos, ref_length);
if (!v)
return HA_ERR_OUT_OF_MEM;
neighbors->push_back(v);
pos+= ref_length;
}
return 0;
}
static int select_neighbors(TABLE *source, TABLE *graph, Field *vec_field,
size_t layer_number, const FVector &target,
const List<FVectorRef> &candidates, const List<FVectorRef> &candidates,
size_t max_neighbour_connections, size_t max_neighbor_connections,
List<FVectorRef> *neighbours) List<FVectorRef> *neighbors)
{ {
/* /*
TODO: If the input neighbours list is already sorted in search_layer, then TODO: If the input neighbors list is already sorted in search_layer, then
no need to do additional queue build steps here. no need to do additional queue build steps here.
*/ */
Hash_set<FVectorRef> visited(PSI_INSTRUMENT_MEM, &my_charset_bin, Hash_set<FVectorRef> visited(PSI_INSTRUMENT_MEM, &my_charset_bin, 1000, 0,
1000, 0, 0, 0, (my_hash_get_key)FVectorRef::get_key,
(my_hash_get_key)FVectorRef::get_key, NULL, HASH_UNIQUE);
NULL,
HASH_UNIQUE);
Queue<FVector, const FVector> pq; Queue<FVector, const FVector> pq; // working queue
Queue<FVector, const FVector> pq_discard; Queue<FVector, const FVector> pq_discard; // queue for discarded candidates
Queue<FVector, const FVector> best; Queue<FVector, const FVector> best; // neighbors to return
// TODO(cvicentiu) this 1000 here is a hardcoded value for max queue size. // TODO(cvicentiu) this 1000 here is a hardcoded value for max queue size.
// This should not be fixed. // This should not be fixed.
pq.init(10000, 0, cmp_vec, &target); if (pq.init(10000, 0, cmp_vec, &target) ||
pq_discard.init(10000, 0, cmp_vec, &target); pq_discard.init(10000, 0, cmp_vec, &target) ||
best.init(max_neighbour_connections, true, cmp_vec, &target); best.init(max_neighbor_connections, true, cmp_vec, &target))
return HA_ERR_OUT_OF_MEM;
// TODO(cvicentiu) error checking.
for (const FVectorRef &candidate : candidates) for (const FVectorRef &candidate : candidates)
{ {
pq.push(FVector::get_fvector_from_source(source, vect_field, candidate)); FVector *v= FVector::get_fvector_from_source(source, vec_field, candidate);
if (!v)
return HA_ERR_OUT_OF_MEM;
visited.insert(&candidate); visited.insert(&candidate);
pq.push(v);
} }
if (EXTEND_CANDIDATES) if (EXTEND_CANDIDATES)
{ {
for (const FVectorRef &candidate : candidates) for (const FVectorRef &candidate : candidates)
{ {
List<FVectorRef> candidate_neighbours; List<FVectorRef> candidate_neighbors;
get_neighbours(graph, layer_number, candidate, &candidate_neighbours); if (int err= get_neighbors(graph, layer_number, candidate,
for (const FVectorRef &extra_candidate : candidate_neighbours) &candidate_neighbors))
return err;
for (const FVectorRef &extra_candidate : candidate_neighbors)
{ {
if (visited.find(&extra_candidate)) if (visited.find(&extra_candidate))
continue; continue;
visited.insert(&extra_candidate); visited.insert(&extra_candidate);
pq.push(FVector::get_fvector_from_source(source, FVector *v= FVector::get_fvector_from_source(source, vec_field,
vect_field, extra_candidate);
extra_candidate)); if (!v)
return HA_ERR_OUT_OF_MEM;
pq.push(v);
} }
} }
} }
...@@ -246,16 +266,16 @@ static bool select_neighbours(TABLE *source, TABLE *graph, ...@@ -246,16 +266,16 @@ static bool select_neighbours(TABLE *source, TABLE *graph,
DBUG_ASSERT(pq.elements()); DBUG_ASSERT(pq.elements());
best.push(pq.pop()); best.push(pq.pop());
double best_top = best.top()->distance_to(target); float best_top= best.top()->distance_to(target);
while (pq.elements() && best.elements() < max_neighbour_connections) while (pq.elements() && best.elements() < max_neighbor_connections)
{ {
const FVector *vec= pq.pop(); const FVector *vec= pq.pop();
double cur_dist = vec->distance_to(target); const float cur_dist= vec->distance_to(target);
// TODO(cvicentiu) best distance can be cached. if (cur_dist < best_top)
if (cur_dist < best_top) { {
DBUG_ASSERT(0); // impossible. XXX redo the loop
best.push(vec); best.push(vec);
best_top = cur_dist; best_top= cur_dist;
} }
else else
pq_discard.push(vec); pq_discard.push(vec);
...@@ -264,61 +284,29 @@ static bool select_neighbours(TABLE *source, TABLE *graph, ...@@ -264,61 +284,29 @@ static bool select_neighbours(TABLE *source, TABLE *graph,
if (KEEP_PRUNED_CONNECTIONS) if (KEEP_PRUNED_CONNECTIONS)
{ {
while (pq_discard.elements() && while (pq_discard.elements() &&
best.elements() < max_neighbour_connections) best.elements() < max_neighbor_connections)
{ {
best.push(pq_discard.pop()); best.push(pq_discard.pop());
} }
} }
DBUG_ASSERT(best.elements() <= max_neighbour_connections); DBUG_ASSERT(best.elements() <= max_neighbor_connections);
while (best.elements()) { while (best.elements()) // XXX why not to return best directly?
neighbours->push_front(best.pop()); neighbors->push_front(best.pop());
}
return false; return 0;
} }
//static bool select_neighbours(TABLE *source, TABLE *graph,
// Field *vect_field, static void dbug_print_vec_ref(const char *prefix, uint layer,
// size_t layer_number,
// const FVector &target,
// const List<FVectorRef> &candidates,
// size_t max_neighbour_connections,
// List<FVectorRef> *neighbours)
//{
// /*
// TODO: If the input neighbours list is already sorted in search_layer, then
// no need to do additional queue build steps here.
// */
//
// Queue<FVector, const FVector> pq;
// pq.init(candidates.elements, 0, 0, cmp_vec, &target);
//
// // TODO(cvicentiu) error checking.
// for (const FVectorRef &candidate : candidates)
// pq.push(FVector::get_fvector_from_source(source, vect_field, candidate));
//
// for (size_t i = 0; i < max_neighbour_connections; i++)
// {
// if (!pq.elements())
// break;
// neighbours->push_back(pq.pop());
// }
//
// return false;
//}
static void dbug_print_vec_ref(const char *prefix,
uint layer,
const FVectorRef &ref) const FVectorRef &ref)
{ {
#ifndef DBUG_OFF #ifndef DBUG_OFF
// TODO(cvicentiu) disable this in release build. // TODO(cvicentiu) disable this in release build.
char *ref_str= (char *)alloca(ref.get_ref_len() * 2 + 1); char *ref_str= static_cast<char *>(alloca(ref.get_ref_len() * 2 + 1));
DBUG_ASSERT(ref_str); DBUG_ASSERT(ref_str);
char *ptr= ref_str; char *ptr= ref_str;
for (size_t i = 0; i < ref.get_ref_len(); ptr += 2, i++) for (size_t i= 0; i < ref.get_ref_len(); ptr += 2, i++)
{ {
snprintf(ptr, 3, "%02x", ref.get_ref()[i]); snprintf(ptr, 3, "%02x", ref.get_ref()[i]);
} }
...@@ -326,8 +314,7 @@ static void dbug_print_vec_ref(const char *prefix, ...@@ -326,8 +314,7 @@ static void dbug_print_vec_ref(const char *prefix,
#endif #endif
} }
static void dbug_print_vec_neigh(uint layer, static void dbug_print_vec_neigh(uint layer, const List<FVectorRef> &neighbors)
const List<FVectorRef> &neighbors)
{ {
#ifndef DBUG_OFF #ifndef DBUG_OFF
DBUG_PRINT("VECTOR", ("NEIGH: NUM: %d", neighbors.elements)); DBUG_PRINT("VECTOR", ("NEIGH: NUM: %d", neighbors.elements));
...@@ -341,190 +328,128 @@ static void dbug_print_vec_neigh(uint layer, ...@@ -341,190 +328,128 @@ static void dbug_print_vec_neigh(uint layer,
static void dbug_print_hash_vec(Hash_set<FVectorRef> &h) static void dbug_print_hash_vec(Hash_set<FVectorRef> &h)
{ {
#ifndef DBUG_OFF #ifndef DBUG_OFF
Hash_set<FVectorRef>::Iterator it(h); for (FVectorRef &ptr : h)
FVectorRef *ptr;
while ((ptr = it++))
{ {
DBUG_PRINT("VECTOR", ("HASH elem: %p", ptr)); DBUG_PRINT("VECTOR", ("HASH elem: %p", &ptr));
dbug_print_vec_ref("VISITED: ", 0, *ptr); dbug_print_vec_ref("VISITED: ", 0, ptr);
} }
#endif #endif
} }
static bool write_neighbours(TABLE *graph, static int write_neighbors(TABLE *graph, size_t layer_number,
size_t layer_number,
const FVectorRef &source_node, const FVectorRef &source_node,
const List<FVectorRef> &new_neighbours) const List<FVectorRef> &new_neighbors)
{ {
DBUG_ASSERT(new_neighbours.elements <= HNSW_MAX_M); DBUG_ASSERT(new_neighbors.elements <= HNSW_MAX_M);
size_t total_size= HNSW_MAX_M_WIDTH + new_neighbors.elements * source_node.get_ref_len();
size_t total_size= sizeof(uint16_t) +
new_neighbours.elements * source_node.get_ref_len();
// Allocate memory for the struct and the flexible array member // Allocate memory for the struct and the flexible array member
char *neighbor_array_bytes= static_cast<char *>(alloca(total_size)); char *neighbor_array_bytes= static_cast<char *>(my_safe_alloca(total_size));
DBUG_ASSERT(new_neighbours.elements <= INT16_MAX); // XXX why bother storing it?
*(uint16_t *) neighbor_array_bytes= new_neighbours.elements; HNSW_MAX_M_store(neighbor_array_bytes, new_neighbors.elements);
char *pos= neighbor_array_bytes + sizeof(uint16_t); char *pos= neighbor_array_bytes + HNSW_MAX_M_WIDTH;
for (const auto &node: new_neighbours) for (const auto &node: new_neighbors)
{ {
DBUG_ASSERT(node.get_ref_len() == source_node.get_ref_len()); DBUG_ASSERT(node.get_ref_len() == source_node.get_ref_len());
memcpy(pos, node.get_ref(), node.get_ref_len()); memcpy(pos, node.get_ref(), node.get_ref_len());
pos+= node.get_ref_len(); pos+= node.get_ref_len();
} }
graph->field[0]->store(layer_number); graph->field[0]->store(layer_number, false);
graph->field[1]->store_binary( graph->field[1]->store_binary(source_node.get_ref(), source_node.get_ref_len());
reinterpret_cast<const char *>(source_node.get_ref()), graph->field[2]->store_binary(neighbor_array_bytes, total_size);
source_node.get_ref_len());
graph->field[2]->set_null();
uchar *key= (uchar*)alloca(graph->key_info->key_length); uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length));
key_copy(key, graph->record[0], graph->key_info, graph->key_info->key_length); key_copy(key, graph->record[0], graph->key_info, graph->key_info->key_length);
int err= graph->file->ha_index_read_map(graph->record[1], key, int err= graph->file->ha_index_read_map(graph->record[1], key, HA_WHOLE_KEY,
HA_WHOLE_KEY,
HA_READ_KEY_EXACT); HA_READ_KEY_EXACT);
// no record // no record
if (err == HA_ERR_KEY_NOT_FOUND) if (err == HA_ERR_KEY_NOT_FOUND)
{ {
dbug_print_vec_ref("INSERT ", layer_number, source_node); dbug_print_vec_ref("INSERT ", layer_number, source_node);
graph->field[2]->store_binary(neighbor_array_bytes, total_size); err= graph->file->ha_write_row(graph->record[0]);
graph->file->ha_write_row(graph->record[0]);
return false;
} }
dbug_print_vec_ref("UPDATE ", layer_number, source_node); else if (!err)
dbug_print_vec_neigh(layer_number, new_neighbours);
graph->field[2]->store_binary(neighbor_array_bytes, total_size);
graph->file->ha_update_row(graph->record[1], graph->record[0]);
return false;
}
static bool get_neighbours(TABLE *graph,
size_t layer_number,
const FVectorRef &source_node,
List<FVectorRef> *neighbours)
{
// TODO(cvicentiu) This allocation need not happen in this function.
uchar *key= (uchar*)alloca(graph->key_info->key_length);
graph->field[0]->store(layer_number);
graph->field[1]->store_binary(
reinterpret_cast<const char *>(source_node.get_ref()),
source_node.get_ref_len());
graph->field[2]->set_null();
key_copy(key, graph->record[0],
graph->key_info, graph->key_info->key_length);
if ((graph->file->ha_index_read_map(graph->record[0], key,
HA_WHOLE_KEY,
HA_READ_KEY_EXACT)))
return true;
//TODO This does two memcpys, one should use str's buffer.
String strbuf;
String *str= graph->field[2]->val_str(&strbuf);
// All ref should have same length
uint ref_length= source_node.get_ref_len();
const uchar *neigh_arr_bytes= reinterpret_cast<const uchar *>(str->ptr());
uint16_t number_of_neighbours=
*reinterpret_cast<const uint16_t*>(neigh_arr_bytes);
if (number_of_neighbours != (str->length() - sizeof(uint16_t)) / ref_length)
{ {
/* dbug_print_vec_ref("UPDATE ", layer_number, source_node);
neighbours number does not match the data length, dbug_print_vec_neigh(layer_number, new_neighbors);
should not happen, possible corrupted HNSW index
*/
DBUG_ASSERT(0); // TODO(cvicentiu) remove this after testing.
return true;
}
const uchar *pos = neigh_arr_bytes + sizeof(uint16_t); err= graph->file->ha_update_row(graph->record[1], graph->record[0]);
for (uint16_t i= 0; i < number_of_neighbours; i++)
{
neighbours->push_back(FVector::get_fvector_ref(pos, ref_length));
pos+= ref_length;
} }
my_safe_afree(neighbor_array_bytes, total_size);
return false; return err;
} }
static bool update_second_degree_neighbors(TABLE *source, static int update_second_degree_neighbors(TABLE *source, Field *vec_field,
Field *vec_field, TABLE *graph, size_t layer_number,
TABLE *graph, uint max_neighbors,
size_t layer_number,
uint max_neighbours,
const FVectorRef &source_node, const FVectorRef &source_node,
const List<FVectorRef> &neighbours) const List<FVectorRef> &neighbors)
{ {
//dbug_print_vec_ref("Updating second degree neighbours", layer_number, source_node); //dbug_print_vec_ref("Updating second degree neighbors", layer_number, source_node);
//dbug_print_vec_neigh(layer_number, neighbours); //dbug_print_vec_neigh(layer_number, neighbors);
for (const FVectorRef &neigh: neighbours) for (const FVectorRef &neigh: neighbors) // XXX why this loop?
{ {
List<FVectorRef> new_neighbours; List<FVectorRef> new_neighbors;
get_neighbours(graph, layer_number, neigh, &new_neighbours); if (int err= get_neighbors(graph, layer_number, neigh, &new_neighbors))
new_neighbours.push_back(&source_node); return err;
write_neighbours(graph, layer_number, neigh, new_neighbours); new_neighbors.push_back(&source_node);
if (int err= write_neighbors(graph, layer_number, neigh, new_neighbors))
return err;
} }
for (const FVectorRef &neigh: neighbours) for (const FVectorRef &neigh: neighbors)
{ {
List<FVectorRef> new_neighbours; List<FVectorRef> new_neighbors;
get_neighbours(graph, layer_number, neigh, &new_neighbours); if (int err= get_neighbors(graph, layer_number, neigh, &new_neighbors))
// TODO(cvicentiu) get_fvector_from_source results must not need to be freed. return err;
FVector *neigh_vec = FVector::get_fvector_from_source(source, vec_field, neigh);
if (new_neighbours.elements > max_neighbours) if (new_neighbors.elements > max_neighbors)
{ {
// shrink the neighbours // shrink the neighbors
List<FVectorRef> selected; List<FVectorRef> selected;
select_neighbours(source, graph, vec_field, layer_number, FVector *v= FVector::get_fvector_from_source(source, vec_field, neigh);
*neigh_vec, new_neighbours, if (!v)
max_neighbours, &selected); return HA_ERR_OUT_OF_MEM;
write_neighbours(graph, layer_number, neigh, selected); if (int err= select_neighbors(source, graph, vec_field, layer_number,
*v, new_neighbors, max_neighbors, &selected))
return err;
if (int err= write_neighbors(graph, layer_number, neigh, selected))
return err;
} }
// release memory // release memory
new_neighbours.empty(); new_neighbors.empty();
} }
return false; return 0;
} }
static bool update_neighbours(TABLE *source, static int update_neighbors(TABLE *source, TABLE *graph, Field *vec_field,
TABLE *graph, size_t layer_number, uint max_neighbors,
Field *vec_field,
size_t layer_number,
uint max_neighbours,
const FVectorRef &source_node, const FVectorRef &source_node,
const List<FVectorRef> &neighbours) const List<FVectorRef> &neighbors)
{ {
// 1. update node's neighbours // 1. update node's neighbors
write_neighbours(graph, layer_number, source_node, neighbours); if (int err= write_neighbors(graph, layer_number, source_node, neighbors))
// 2. update node's neighbours' neighbours (shrink before update) return err;
update_second_degree_neighbors(source, vec_field, graph, layer_number, // 2. update node's neighbors' neighbors (shrink before update)
max_neighbours, source_node, neighbours); return update_second_degree_neighbors(source, vec_field, graph, layer_number,
return false; max_neighbors, source_node, neighbors);
} }
static bool search_layer(TABLE *source, static int search_layer(TABLE *source, TABLE *graph, Field *vec_field,
TABLE *graph,
Field *vec_field,
const FVector &target, const FVector &target,
const List<FVectorRef> &start_nodes, const List<FVectorRef> &start_nodes,
uint max_candidates_return, uint max_candidates_return, size_t layer,
size_t layer,
List<FVectorRef> *result) List<FVectorRef> *result)
{ {
DBUG_ASSERT(start_nodes.elements > 0); DBUG_ASSERT(start_nodes.elements > 0);
...@@ -534,10 +459,8 @@ static bool search_layer(TABLE *source, ...@@ -534,10 +459,8 @@ static bool search_layer(TABLE *source,
Queue<FVector, const FVector> candidates; Queue<FVector, const FVector> candidates;
Queue<FVector, const FVector> best; Queue<FVector, const FVector> best;
//TODO(cvicentiu) Fix this hash method. //TODO(cvicentiu) Fix this hash method.
Hash_set<FVectorRef> visited(PSI_INSTRUMENT_MEM, &my_charset_bin, Hash_set<FVectorRef> visited(PSI_INSTRUMENT_MEM, &my_charset_bin, 1000, 0, 0,
1000, 0, 0, (my_hash_get_key)FVectorRef::get_key, NULL,
(my_hash_get_key)FVectorRef::get_key,
NULL,
HASH_UNIQUE); HASH_UNIQUE);
candidates.init(10000, false, cmp_vec, &target); candidates.init(10000, false, cmp_vec, &target);
...@@ -549,50 +472,49 @@ static bool search_layer(TABLE *source, ...@@ -549,50 +472,49 @@ static bool search_layer(TABLE *source,
candidates.push(v); candidates.push(v);
if (best.elements() < max_candidates_return) if (best.elements() < max_candidates_return)
best.push(v); best.push(v);
else if (target.distance_to(*v) > target.distance_to(*best.top())) { else if (target.distance_to(*v) > target.distance_to(*best.top()))
best.replace_top(v); best.replace_top(v);
}
visited.insert(v); visited.insert(v);
dbug_print_vec_ref("INSERTING node in visited: ", layer, node); dbug_print_vec_ref("INSERTING node in visited: ", layer, node);
} }
double furthest_best = target.distance_to(*best.top()); float furthest_best= target.distance_to(*best.top());
while (candidates.elements()) while (candidates.elements())
{ {
const FVector &cur_vec= *candidates.pop(); const FVector &cur_vec= *candidates.pop();
double cur_distance = target.distance_to(cur_vec); float cur_distance= target.distance_to(cur_vec);
if (cur_distance > furthest_best && best.elements() == max_candidates_return) if (cur_distance > furthest_best && best.elements() == max_candidates_return)
{ {
break; // All possible candidates are worse than what we have. break; // All possible candidates are worse than what we have.
// Can't get better. // Can't get better.
} }
List<FVectorRef> neighbours; List<FVectorRef> neighbors;
get_neighbours(graph, layer, cur_vec, &neighbours); get_neighbors(graph, layer, cur_vec, &neighbors);
for (const FVectorRef &neigh: neighbours) for (const FVectorRef &neigh: neighbors)
{ {
dbug_print_hash_vec(visited); dbug_print_hash_vec(visited);
if (visited.find(&neigh)) if (visited.find(&neigh))
continue; continue;
FVector *clone = FVector::get_fvector_from_source(source, vec_field, neigh); FVector *clone= FVector::get_fvector_from_source(source, vec_field, neigh);
// TODO(cvicentiu) mem ownershipw... // TODO(cvicentiu) mem ownership...
visited.insert(clone); visited.insert(clone);
if (best.elements() < max_candidates_return) if (best.elements() < max_candidates_return)
{ {
candidates.push(clone); candidates.push(clone);
best.push(clone); best.push(clone);
furthest_best = target.distance_to(*best.top()); furthest_best= target.distance_to(*best.top());
} }
else if (target.distance_to(*clone) < furthest_best) else if (target.distance_to(*clone) < furthest_best)
{ {
best.replace_top(clone); best.replace_top(clone);
candidates.push(clone); candidates.push(clone);
furthest_best = target.distance_to(*best.top()); furthest_best= target.distance_to(*best.top());
} }
} }
neighbours.empty(); neighbors.empty();
} }
DBUG_PRINT("VECTOR", ("SEARCH_LAYER_END %d best", best.elements())); DBUG_PRINT("VECTOR", ("SEARCH_LAYER_END %d best", best.elements()));
...@@ -603,21 +525,28 @@ static bool search_layer(TABLE *source, ...@@ -603,21 +525,28 @@ static bool search_layer(TABLE *source,
result->push_front(best.pop()); result->push_front(best.pop());
} }
return false; return 0;
} }
std::mt19937 gen(42); // TODO(cvicentiu) seeded with 42 for now, this should static int bad_value_on_insert(Field *f)
// use a rnd service {
my_error(ER_TRUNCATED_WRONG_VALUE_FOR_FIELD, MYF(0), "vector", "...",
f->table->s->db.str, f->table->s->table_name.str, f->field_name.str,
f->table->in_use->get_stmt_da()->current_row_for_warning());
return HA_ERR_GENERIC;
}
int mhnsw_insert(TABLE *table, KEY *keyinfo) int mhnsw_insert(TABLE *table, KEY *keyinfo)
{ {
THD *thd= table->in_use;
TABLE *graph= table->hlindex; TABLE *graph= table->hlindex;
MY_BITMAP *old_map= dbug_tmp_use_all_columns(table, &table->read_set); MY_BITMAP *old_map= dbug_tmp_use_all_columns(table, &table->read_set);
Field *vec_field= keyinfo->key_part->field; Field *vec_field= keyinfo->key_part->field;
String buf, *res= vec_field->val_str(&buf); String buf, *res= vec_field->val_str(&buf);
handler *h= table->file; handler *h= table->file->lookup_handler;
int err= 0;
/* metadata are checked on open */ /* metadata are checked on open */
DBUG_ASSERT(graph); DBUG_ASSERT(graph);
...@@ -627,187 +556,191 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) ...@@ -627,187 +556,191 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
DBUG_ASSERT(vec_field->cmp_type() == STRING_RESULT); DBUG_ASSERT(vec_field->cmp_type() == STRING_RESULT);
DBUG_ASSERT(res); // ER_INDEX_CANNOT_HAVE_NULL DBUG_ASSERT(res); // ER_INDEX_CANNOT_HAVE_NULL
DBUG_ASSERT(h->ref_length <= graph->field[1]->field_length); DBUG_ASSERT(h->ref_length <= graph->field[1]->field_length);
DBUG_ASSERT(h->ref_length <= graph->field[2]->field_length);
// XXX returning an error here will rollback the insert in InnoDB
// but in MyISAM the row will stay inserted, making the index out of sync:
// invalid vector values are present in the table but cannot be found
// via an index. The easiest way to fix it is with a VECTOR(N) type
if (res->length() == 0 || res->length() % 4) if (res->length() == 0 || res->length() % 4)
return 1; return bad_value_on_insert(vec_field);
const double NORMALIZATION_FACTOR = 1 / std::log(1.0 * const double NORMALIZATION_FACTOR= 1 / std::log(thd->variables.hnsw_max_connection_per_layer);
table->in_use->variables.hnsw_max_connection_per_layer);
if ((err= h->ha_rnd_init(1))) if (int err= h->ha_rnd_init(1))
return err; return err;
SCOPE_EXIT([h](){ h->ha_rnd_end(); });
if ((err= graph->file->ha_index_init(0, 1))) if (int err= graph->file->ha_index_init(0, 1))
return err; return err;
longlong max_layer; SCOPE_EXIT([graph](){ graph->file->ha_index_end(); });
if ((err= graph->file->ha_index_last(graph->record[0])))
if (int err= graph->file->ha_index_last(graph->record[0]))
{ {
if (err != HA_ERR_END_OF_FILE) if (err != HA_ERR_END_OF_FILE)
{
graph->file->ha_index_end();
return err; return err;
}
// First insert! // First insert!
h->position(table->record[0]); h->position(table->record[0]);
write_neighbours(graph, 0, {h->ref, h->ref_length}, {}); return write_neighbors(graph, 0, {h->ref, h->ref_length}, {});
h->ha_rnd_end();
graph->file->ha_index_end();
return 0; // TODO (error during store_link)
} }
else
max_layer= graph->field[0]->val_int();
FVector target; longlong max_layer= graph->field[0]->val_int();
h->position(table->record[0]);
// TODO (cvicentiu) Error checking.
target.init(h->ref, h->ref_length,
reinterpret_cast<const float *>(res->ptr()),
res->length() / sizeof(float));
h->position(table->record[0]);
std::uniform_real_distribution<> dis(0.0, 1.0); List<FVectorRef> candidates;
double new_num= dis(gen);
double log= -std::log(new_num) * NORMALIZATION_FACTOR;
longlong new_node_layer= std::floor(log);
List<FVectorRef> start_nodes; List<FVectorRef> start_nodes;
String ref_str, *ref_ptr; String ref_str, *ref_ptr;
ref_ptr= graph->field[1]->val_str(&ref_str); ref_ptr= graph->field[1]->val_str(&ref_str);
FVectorRef start_node_ref{ref_ptr->ptr(), ref_ptr->length()};
// TODO(cvicentiu) use a random start node in last layer.
// XXX or may be *all* nodes in the last layer? there should be few
if (start_nodes.push_back(&start_node_ref))
return HA_ERR_OUT_OF_MEM;
FVectorRef start_node_ref{(uchar *)ref_ptr->ptr(), ref_ptr->length()}; FVector *v= FVector::get_fvector_from_source(table, vec_field, start_node_ref);
//FVector *start_node= start_node_ref.get_fvector_from_source(table, vec_field); if (!v)
return HA_ERR_OUT_OF_MEM;
if (v->size_of() != res->length())
return bad_value_on_insert(vec_field);
FVector target;
target.init(h->ref, h->ref_length, res->ptr(), res->length());
double new_num= my_rnd(&thd->rand);
double log= -std::log(new_num) * NORMALIZATION_FACTOR;
longlong new_node_layer= static_cast<longlong>(std::floor(log));
// TODO(cvicentiu) error checking. Also make sure we use a random start node
// in last layer.
start_nodes.push_back(&start_node_ref);
// TODO start_nodes needs to have one element in it.
for (longlong cur_layer= max_layer; cur_layer > new_node_layer; cur_layer--) for (longlong cur_layer= max_layer; cur_layer > new_node_layer; cur_layer--)
{ {
List<FVectorRef> candidates; if (int err= search_layer(table, graph, vec_field, target, start_nodes,
search_layer(table, graph, vec_field, target, start_nodes, thd->variables.hnsw_ef_constructor, cur_layer,
table->in_use->variables.hnsw_ef_constructor, cur_layer, &candidates))
&candidates); return err;
start_nodes.empty(); start_nodes.empty();
start_nodes.push_back(candidates.head()); start_nodes.push_back(candidates.head()); // XXX ef=1
//candidates.delete_elements(); //candidates.delete_elements();
candidates.empty();
//TODO(cvicentiu) memory leak //TODO(cvicentiu) memory leak
} }
for (longlong cur_layer= std::min(max_layer, new_node_layer); for (longlong cur_layer= std::min(max_layer, new_node_layer);
cur_layer >= 0; cur_layer--) cur_layer >= 0; cur_layer--)
{ {
List<FVectorRef> candidates; List<FVectorRef> neighbors;
List<FVectorRef> neighbours; if (int err= search_layer(table, graph, vec_field, target, start_nodes,
search_layer(table, graph, vec_field, target, start_nodes, thd->variables.hnsw_ef_constructor, cur_layer,
table->in_use->variables.hnsw_ef_constructor, &candidates))
cur_layer, &candidates); return err;
// release vectors // release vectors
start_nodes.empty(); start_nodes.empty();
uint max_neighbours= (cur_layer == 0) ? uint max_neighbors= (cur_layer == 0) ? // heuristics from the paper
table->in_use->variables.hnsw_max_connection_per_layer * 2 thd->variables.hnsw_max_connection_per_layer * 2
: table->in_use->variables.hnsw_max_connection_per_layer; : thd->variables.hnsw_max_connection_per_layer;
select_neighbours(table, graph, vec_field, cur_layer, if (int err= select_neighbors(table, graph, vec_field, cur_layer, target,
target, candidates, candidates, max_neighbors, &neighbors))
max_neighbours, &neighbours); return err;
update_neighbours(table, graph, vec_field, cur_layer, max_neighbours, if (int err= update_neighbors(table, graph, vec_field, cur_layer,
target, neighbours); max_neighbors, target, neighbors))
return err;
start_nodes= candidates; start_nodes= candidates;
} }
start_nodes.empty(); start_nodes.empty();
// XXX what is that?
for (longlong cur_layer= max_layer + 1; cur_layer <= new_node_layer; for (longlong cur_layer= max_layer + 1; cur_layer <= new_node_layer;
cur_layer++) cur_layer++)
{ {
write_neighbours(graph, cur_layer, target, {}); if (int err= write_neighbors(graph, cur_layer, target, {}))
return err;
} }
h->ha_rnd_end();
graph->file->ha_index_end();
dbug_tmp_restore_column_map(&table->read_set, old_map); dbug_tmp_restore_column_map(&table->read_set, old_map);
return err == HA_ERR_END_OF_FILE ? 0 : err; return 0;
} }
int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit) int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
{ {
THD *thd= table->in_use;
TABLE *graph= table->hlindex; TABLE *graph= table->hlindex;
MY_BITMAP *old_map= dbug_tmp_use_all_columns(table, &table->read_set);
// TODO(cvicentiu) onlye one hlindex now.
Field *vec_field= keyinfo->key_part->field; Field *vec_field= keyinfo->key_part->field;
Item_func_vec_distance *fun= (Item_func_vec_distance *)dist; Item_func_vec_distance *fun= (Item_func_vec_distance *)dist;
String buf, *res= fun->arguments()[1]->val_str(&buf); String buf, *res= fun->get_const_arg()->val_str(&buf);
handler *h= table->file; handler *h= table->file;
//TODO(scope_exit) if (int err= h->ha_rnd_init(0))
int err;
if ((err= h->ha_rnd_init(0)))
return err; return err;
if ((err= graph->file->ha_index_init(0, 1))) if (int err= graph->file->ha_index_init(0, 1))
return err; return err;
h->position(table->record[0]); SCOPE_EXIT([graph](){ graph->file->ha_index_end(); });
FVector target;
target.init(h->ref,
h->ref_length,
reinterpret_cast<const float *>(res->ptr()),
res->length() / sizeof(float));
List<FVectorRef> candidates;
List<FVectorRef> start_nodes;
longlong max_layer; if (int err= graph->file->ha_index_last(graph->record[0]))
if ((err= graph->file->ha_index_last(graph->record[0])))
{
if (err != HA_ERR_END_OF_FILE)
{
graph->file->ha_index_end();
return err; return err;
}
h->ha_rnd_end();
graph->file->ha_index_end();
return 0; // TODO (error during store_link)
}
else
max_layer= graph->field[0]->val_int();
longlong max_layer= graph->field[0]->val_int();
List<FVectorRef> candidates; // XXX List? not Queue by distance?
List<FVectorRef> start_nodes;
String ref_str, *ref_ptr; String ref_str, *ref_ptr;
ref_ptr= graph->field[1]->val_str(&ref_str); ref_ptr= graph->field[1]->val_str(&ref_str);
FVectorRef start_node_ref{(uchar *)ref_ptr->ptr(), ref_ptr->length()}; FVectorRef start_node_ref{ref_ptr->ptr(), ref_ptr->length()};
// TODO(cvicentiu) error checking. Also make sure we use a random start node
// in last layer. // TODO(cvicentiu) use a random start node in last layer.
start_nodes.push_back(&start_node_ref); // XXX or may be *all* nodes in the last layer? there should be few
if (start_nodes.push_back(&start_node_ref))
return HA_ERR_OUT_OF_MEM;
FVector *v= FVector::get_fvector_from_source(table, vec_field, start_node_ref);
if (!v)
return HA_ERR_OUT_OF_MEM;
/*
if the query vector is NULL or invalid, VEC_DISTANCE will return
NULL, so the result is basically unsorted, we can return rows
in any order. For simplicity let's sort by the start_node.
*/
if (!res || v->size_of() != res->length())
(res= &buf)->set((const char*)(v->get_vec()), v->size_of(), &my_charset_bin);
FVector target;
if (target.init(h->ref, h->ref_length, res->ptr(), res->length()))
return HA_ERR_OUT_OF_MEM;
ulonglong ef_search= MY_MAX( ulonglong ef_search= std::max<ulonglong>( //XXX why not always limit?
table->in_use->variables.hnsw_ef_search, limit); thd->variables.hnsw_ef_search, limit);
for (size_t cur_layer= max_layer; cur_layer > 0; cur_layer--) for (size_t cur_layer= max_layer; cur_layer > 0; cur_layer--)
{ {
search_layer(table, graph, vec_field, target, start_nodes, ef_search, //XXX in the paper ef_search=1 here
cur_layer, &candidates); if (int err= search_layer(table, graph, vec_field, target, start_nodes,
ef_search, cur_layer, &candidates))
return err;
start_nodes.empty(); start_nodes.empty();
//start_nodes.delete_elements(); //start_nodes.delete_elements();
start_nodes.push_back(candidates.head()); start_nodes.push_back(candidates.head()); // XXX so ef_search=1 ???
//candidates.delete_elements(); //candidates.delete_elements();
candidates.empty(); candidates.empty();
//TODO(cvicentiu) memleak. //TODO(cvicentiu) memleak.
} }
search_layer(table, graph, vec_field, target, start_nodes, if (int err= search_layer(table, graph, vec_field, target, start_nodes,
ef_search, 0, &candidates); ef_search, 0, &candidates))
return err;
// 8. return results // 8. return results
FVectorRef **context= (FVectorRef**)table->in_use->alloc( FVectorRef **context= thd->alloc<FVectorRef*>(limit + 1);
sizeof(FVectorRef*) * (limit + 1));
graph->context= context; graph->context= context;
FVectorRef **ptr= context; FVectorRef **ptr= context;
...@@ -815,13 +748,7 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit) ...@@ -815,13 +748,7 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
*ptr++= candidates.pop(); *ptr++= candidates.pop();
*ptr= nullptr; *ptr= nullptr;
err= mhnsw_next(table); return mhnsw_next(table);
graph->file->ha_index_end();
// TODO release vectors after query
dbug_tmp_restore_column_map(&table->read_set, old_map);
return err;
} }
int mhnsw_next(TABLE *table) int mhnsw_next(TABLE *table)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment