Commit aad082f4 authored by Sergei Golubchik's avatar Sergei Golubchik

mhnsw: refactor FVector* classes

Now there's an FVector class which is a pure vector, an array of floats.
It doesn't necessarily corresponds to a row in the table, and usually
there is only one FVector instance - the one we're searching for.

And there's an FVectorNode class, which is a node in the graph.
It has a ref (identifying a row in the source table), possibly an array
of floats (or not — in which case it will be read lazily from the
source table as needed). There are many FVectorNodes and they're
cached to avoid re-reading them from the disk.
parent 5301097f
...@@ -42,64 +42,31 @@ const LEX_CSTRING mhnsw_hlindex_table={STRING_WITH_LEN("\ ...@@ -42,64 +42,31 @@ const LEX_CSTRING mhnsw_hlindex_table={STRING_WITH_LEN("\
")}; ")};
class FVectorRef: public Sql_alloc class MHNSW_Context;
class FVector: public Sql_alloc
{ {
public: public:
// Shallow ref copy. Used for other ref lookups in HashSet MHNSW_Context *ctx;
FVectorRef(const void *ref, size_t ref_len): ref{(uchar*)ref}, ref_len{ref_len} {} FVector(MHNSW_Context *ctx_, const void *vec_);
float *vec;
static uchar *get_key(const FVectorRef *elem, size_t *key_len, my_bool)
{
*key_len= elem->ref_len;
return elem->ref;
}
static void free_vector(void *elem)
{
delete (FVectorRef *)elem;
}
size_t get_ref_len() const { return ref_len; }
uchar* get_ref() const { return ref; }
protected: protected:
FVectorRef() = default; FVector(MHNSW_Context *ctx_) : ctx(ctx_), vec(nullptr) {}
uchar *ref;
size_t ref_len;
}; };
class FVector: public FVectorRef class FVectorNode: public FVector
{ {
private: private:
float *vec; uchar *ref;
size_t vec_len;
public: public:
FVector(): vec(nullptr), vec_len(0) {} FVectorNode(MHNSW_Context *ctx_, const void *ref_);
FVectorNode(MHNSW_Context *ctx_, const void *ref_, const void *vec_);
bool init(MEM_ROOT *root, const uchar *ref_, size_t ref_len_, const void *vec_, size_t bytes) float distance_to(const FVector &other) const;
{ int instantiate_vector();
ref= (uchar*)alloc_root(root, ref_len_ + bytes); size_t get_ref_len() const;
if (!ref) uchar *get_ref() const { return ref; }
return true;
static uchar *get_key(const FVectorNode *elem, size_t *key_len, my_bool);
vec= reinterpret_cast<float *>(ref + ref_len_);
memcpy(ref, ref_, ref_len_);
memcpy(vec, vec_, bytes);
ref_len= ref_len_;
vec_len= bytes / sizeof(float);
return false;
}
size_t size_of() const { return vec_len * sizeof(float); }
float distance_to(const FVector &other) const
{
DBUG_ASSERT(other.vec_len == vec_len);
return euclidean_vec_distance(vec, other.vec, vec_len);
}
}; };
class MHNSW_Context class MHNSW_Context
...@@ -108,8 +75,9 @@ class MHNSW_Context ...@@ -108,8 +75,9 @@ class MHNSW_Context
MEM_ROOT root; MEM_ROOT root;
TABLE *table; TABLE *table;
Field *vec_field; Field *vec_field;
Hash_set<FVectorRef> vector_cache{PSI_INSTRUMENT_MEM, FVectorRef::get_key}; size_t vec_len= 0;
Hash_set<FVectorRef> vector_ref_cache{PSI_INSTRUMENT_MEM, FVectorRef::get_key};
Hash_set<FVectorNode> node_cache{PSI_INSTRUMENT_MEM, FVectorNode::get_key};
MHNSW_Context(TABLE *table, Field *vec_field) MHNSW_Context(TABLE *table, Field *vec_field)
: table(table), vec_field(vec_field) : table(table), vec_field(vec_field)
...@@ -122,40 +90,67 @@ class MHNSW_Context ...@@ -122,40 +90,67 @@ class MHNSW_Context
free_root(&root, MYF(0)); free_root(&root, MYF(0));
} }
FVectorRef *get_fvector_ref(const uchar *ref, size_t ref_len) FVectorNode *get_node(const void *ref_);
{ };
FVectorRef tmp(ref, ref_len);
FVectorRef *v= vector_ref_cache.find(&tmp);
if (v)
return v;
uchar *buf= (uchar*)memdup_root(&root, ref, ref_len);
if ((v= new (&root) FVectorRef(buf, ref_len)))
vector_ref_cache.insert(v);
return v;
}
FVector *get_fvector_from_source(const FVectorRef &ref) FVector::FVector(MHNSW_Context *ctx_, const void *vec_) : ctx(ctx_)
{ {
FVectorRef *v= vector_cache.find(&ref); vec= (float*)memdup_root(&ctx->root, vec_, ctx->vec_len * sizeof(float));
if (v) }
return (FVector *)v;
if (table->file->ha_rnd_pos(table->record[0], ref.get_ref())) FVectorNode::FVectorNode(MHNSW_Context *ctx_, const void *ref_)
return nullptr; // XXX the error code is lost : FVector(ctx_)
{
ref= (uchar*)memdup_root(&ctx->root, ref_, get_ref_len());
}
String buf, *vec= vec_field->val_str(&buf); FVectorNode::FVectorNode(MHNSW_Context *ctx_, const void *ref_, const void *vec_)
: FVector(ctx_, vec_)
{
ref= (uchar*)memdup_root(&ctx->root, ref_, get_ref_len());
}
FVector *new_vector= new (&root) FVector; float FVectorNode::distance_to(const FVector &other) const
new_vector->init(&root, ref.get_ref(), ref.get_ref_len(), vec->ptr(), vec->length()); {
if (!vec)
const_cast<FVectorNode*>(this)->instantiate_vector();
return euclidean_vec_distance(vec, other.vec, ctx->vec_len);
}
vector_cache.insert(new_vector); int FVectorNode::instantiate_vector()
{
DBUG_ASSERT(vec == nullptr);
if (int err= ctx->table->file->ha_rnd_pos(ctx->table->record[0], ref))
return err;
String buf, *v= ctx->vec_field->val_str(&buf);
ctx->vec_len= v->length() / sizeof(float);
vec= (float*)memdup_root(&ctx->root, v->ptr(), v->length());
return 0;
}
size_t FVectorNode::get_ref_len() const
{
return ctx->table->file->ref_length;
}
return new_vector; uchar *FVectorNode::get_key(const FVectorNode *elem, size_t *key_len, my_bool)
{
*key_len= elem->get_ref_len();
return elem->ref;
}
FVectorNode *MHNSW_Context::get_node(const void *ref)
{
FVectorNode *node= node_cache.find(ref, table->file->ref_length);
if (!node)
{
node= new (&root) FVectorNode(this, ref);
node_cache.insert(node);
} }
}; return node;
}
static int cmp_vec(const FVector *target, const FVector *a, const FVector *b) static int cmp_vec(const FVector *target, const FVectorNode *a, const FVectorNode *b)
{ {
float a_dist= a->distance_to(*target); float a_dist= a->distance_to(*target);
float b_dist= b->distance_to(*target); float b_dist= b->distance_to(*target);
...@@ -171,8 +166,8 @@ const bool KEEP_PRUNED_CONNECTIONS=true; // XXX why? ...@@ -171,8 +166,8 @@ const bool KEEP_PRUNED_CONNECTIONS=true; // XXX why?
const bool EXTEND_CANDIDATES=true; // XXX or false? const bool EXTEND_CANDIDATES=true; // XXX or false?
static int get_neighbors(MHNSW_Context *ctx, size_t layer_number, static int get_neighbors(MHNSW_Context *ctx, size_t layer_number,
const FVectorRef &source_node, const FVectorNode &source_node,
List<FVectorRef> *neighbors) List<FVectorNode> *neighbors)
{ {
TABLE *graph= ctx->table->hlindex; TABLE *graph= ctx->table->hlindex;
uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length)); uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length));
...@@ -189,18 +184,16 @@ static int get_neighbors(MHNSW_Context *ctx, size_t layer_number, ...@@ -189,18 +184,16 @@ static int get_neighbors(MHNSW_Context *ctx, size_t layer_number,
// mhnsw_insert() guarantees that all ref have the same length // mhnsw_insert() guarantees that all ref have the same length
uint ref_length= source_node.get_ref_len(); uint ref_length= source_node.get_ref_len();
const uchar *neigh_arr_bytes= reinterpret_cast<const uchar *>(str->ptr()); const char *neigh_arr_bytes= str->ptr();
uint number_of_neighbors= HNSW_MAX_M_read(neigh_arr_bytes); uint number_of_neighbors= HNSW_MAX_M_read(neigh_arr_bytes);
if (number_of_neighbors * ref_length + HNSW_MAX_M_WIDTH != str->length()) if (number_of_neighbors * ref_length + HNSW_MAX_M_WIDTH != str->length())
return HA_ERR_CRASHED; // should not happen, corrupted HNSW index return HA_ERR_CRASHED; // should not happen, corrupted HNSW index
const uchar *pos= neigh_arr_bytes + HNSW_MAX_M_WIDTH; const char *pos= neigh_arr_bytes + HNSW_MAX_M_WIDTH;
for (uint i= 0; i < number_of_neighbors; i++) for (uint i= 0; i < number_of_neighbors; i++)
{ {
FVectorRef *v= ctx->get_fvector_ref(pos, ref_length); FVectorNode *neigh= ctx->get_node(pos);
if (!v) neighbors->push_back(neigh, &ctx->root);
return HA_ERR_OUT_OF_MEM;
neighbors->push_back(v, &ctx->root);
pos+= ref_length; pos+= ref_length;
} }
...@@ -210,20 +203,20 @@ static int get_neighbors(MHNSW_Context *ctx, size_t layer_number, ...@@ -210,20 +203,20 @@ static int get_neighbors(MHNSW_Context *ctx, size_t layer_number,
static int select_neighbors(MHNSW_Context *ctx, static int select_neighbors(MHNSW_Context *ctx,
size_t layer_number, const FVector &target, size_t layer_number, const FVector &target,
const List<FVectorRef> &candidates, const List<FVectorNode> &candidates,
size_t max_neighbor_connections, size_t max_neighbor_connections,
List<FVectorRef> *neighbors) List<FVectorNode> *neighbors)
{ {
/* /*
TODO: If the input neighbors list is already sorted in search_layer, then TODO: If the input neighbors list is already sorted in search_layer, then
no need to do additional queue build steps here. no need to do additional queue build steps here.
*/ */
Hash_set<FVectorRef> visited(PSI_INSTRUMENT_MEM, FVectorRef::get_key); Hash_set<FVectorNode> visited(PSI_INSTRUMENT_MEM, FVectorNode::get_key);
Queue<FVector, const FVector> pq; // working queue Queue<FVectorNode, const FVector> pq; // working queue
Queue<FVector, const FVector> pq_discard; // queue for discarded candidates Queue<FVectorNode, const FVector> pq_discard; // queue for discarded candidates
Queue<FVector, const FVector> best; // neighbors to return Queue<FVectorNode, const FVector> best; // neighbors to return
// TODO(cvicentiu) this 1000 here is a hardcoded value for max queue size. // TODO(cvicentiu) this 1000 here is a hardcoded value for max queue size.
// This should not be fixed. // This should not be fixed.
...@@ -232,32 +225,26 @@ static int select_neighbors(MHNSW_Context *ctx, ...@@ -232,32 +225,26 @@ static int select_neighbors(MHNSW_Context *ctx,
best.init(max_neighbor_connections, true, cmp_vec, &target)) best.init(max_neighbor_connections, true, cmp_vec, &target))
return HA_ERR_OUT_OF_MEM; return HA_ERR_OUT_OF_MEM;
for (const FVectorRef &candidate : candidates) for (const FVectorNode &candidate : candidates)
{ {
FVector *v= ctx->get_fvector_from_source(candidate);
if (!v)
return HA_ERR_OUT_OF_MEM;
visited.insert(&candidate); visited.insert(&candidate);
pq.push(v); pq.push(&candidate);
} }
if (EXTEND_CANDIDATES) if (EXTEND_CANDIDATES)
{ {
for (const FVectorRef &candidate : candidates) for (const FVectorNode &candidate : candidates)
{ {
List<FVectorRef> candidate_neighbors; List<FVectorNode> candidate_neighbors;
if (int err= get_neighbors(ctx, layer_number, candidate, if (int err= get_neighbors(ctx, layer_number, candidate,
&candidate_neighbors)) &candidate_neighbors))
return err; return err;
for (const FVectorRef &extra_candidate : candidate_neighbors) for (const FVectorNode &extra_candidate : candidate_neighbors)
{ {
if (visited.find(&extra_candidate)) if (visited.find(&extra_candidate))
continue; continue;
visited.insert(&extra_candidate); visited.insert(&extra_candidate);
FVector *v= ctx->get_fvector_from_source(extra_candidate); pq.push(&extra_candidate);
if (!v)
return HA_ERR_OUT_OF_MEM;
pq.push(v);
} }
} }
} }
...@@ -268,7 +255,7 @@ static int select_neighbors(MHNSW_Context *ctx, ...@@ -268,7 +255,7 @@ static int select_neighbors(MHNSW_Context *ctx,
float best_top= best.top()->distance_to(target); float best_top= best.top()->distance_to(target);
while (pq.elements() && best.elements() < max_neighbor_connections) while (pq.elements() && best.elements() < max_neighbor_connections)
{ {
const FVector *vec= pq.pop(); const FVectorNode *vec= pq.pop();
const float cur_dist= vec->distance_to(target); const float cur_dist= vec->distance_to(target);
if (cur_dist < best_top) if (cur_dist < best_top)
{ {
...@@ -298,7 +285,7 @@ static int select_neighbors(MHNSW_Context *ctx, ...@@ -298,7 +285,7 @@ static int select_neighbors(MHNSW_Context *ctx,
static void dbug_print_vec_ref(const char *prefix, uint layer, static void dbug_print_vec_ref(const char *prefix, uint layer,
const FVectorRef &ref) const FVectorNode &ref)
{ {
#ifndef DBUG_OFF #ifndef DBUG_OFF
// TODO(cvicentiu) disable this in release build. // TODO(cvicentiu) disable this in release build.
...@@ -313,21 +300,21 @@ static void dbug_print_vec_ref(const char *prefix, uint layer, ...@@ -313,21 +300,21 @@ static void dbug_print_vec_ref(const char *prefix, uint layer,
#endif #endif
} }
static void dbug_print_vec_neigh(uint layer, const List<FVectorRef> &neighbors) static void dbug_print_vec_neigh(uint layer, const List<FVectorNode> &neighbors)
{ {
#ifndef DBUG_OFF #ifndef DBUG_OFF
DBUG_PRINT("VECTOR", ("NEIGH: NUM: %d", neighbors.elements)); DBUG_PRINT("VECTOR", ("NEIGH: NUM: %d", neighbors.elements));
for (const FVectorRef& ref : neighbors) for (const FVectorNode& ref : neighbors)
{ {
dbug_print_vec_ref("NEIGH: ", layer, ref); dbug_print_vec_ref("NEIGH: ", layer, ref);
} }
#endif #endif
} }
static void dbug_print_hash_vec(Hash_set<FVectorRef> &h) static void dbug_print_hash_vec(Hash_set<FVectorNode> &h)
{ {
#ifndef DBUG_OFF #ifndef DBUG_OFF
for (FVectorRef &ptr : h) for (FVectorNode &ptr : h)
{ {
DBUG_PRINT("VECTOR", ("HASH elem: %p", &ptr)); DBUG_PRINT("VECTOR", ("HASH elem: %p", &ptr));
dbug_print_vec_ref("VISITED: ", 0, ptr); dbug_print_vec_ref("VISITED: ", 0, ptr);
...@@ -337,8 +324,8 @@ static void dbug_print_hash_vec(Hash_set<FVectorRef> &h) ...@@ -337,8 +324,8 @@ static void dbug_print_hash_vec(Hash_set<FVectorRef> &h)
static int write_neighbors(MHNSW_Context *ctx, size_t layer_number, static int write_neighbors(MHNSW_Context *ctx, size_t layer_number,
const FVectorRef &source_node, const FVectorNode &source_node,
const List<FVectorRef> &new_neighbors) const List<FVectorNode> &new_neighbors)
{ {
TABLE *graph= ctx->table->hlindex; TABLE *graph= ctx->table->hlindex;
DBUG_ASSERT(new_neighbors.elements <= HNSW_MAX_M); DBUG_ASSERT(new_neighbors.elements <= HNSW_MAX_M);
...@@ -390,14 +377,14 @@ static int write_neighbors(MHNSW_Context *ctx, size_t layer_number, ...@@ -390,14 +377,14 @@ static int write_neighbors(MHNSW_Context *ctx, size_t layer_number,
static int update_second_degree_neighbors(MHNSW_Context *ctx, static int update_second_degree_neighbors(MHNSW_Context *ctx,
size_t layer_number, size_t layer_number,
uint max_neighbors, uint max_neighbors,
const FVectorRef &source_node, const FVectorNode &source_node,
const List<FVectorRef> &neighbors) const List<FVectorNode> &neighbors)
{ {
//dbug_print_vec_ref("Updating second degree neighbors", layer_number, source_node); //dbug_print_vec_ref("Updating second degree neighbors", layer_number, source_node);
//dbug_print_vec_neigh(layer_number, neighbors); //dbug_print_vec_neigh(layer_number, neighbors);
for (const FVectorRef &neigh: neighbors) // XXX why this loop? for (const FVectorNode &neigh: neighbors) // XXX why this loop?
{ {
List<FVectorRef> new_neighbors; List<FVectorNode> new_neighbors;
if (int err= get_neighbors(ctx, layer_number, neigh, &new_neighbors)) if (int err= get_neighbors(ctx, layer_number, neigh, &new_neighbors))
return err; return err;
new_neighbors.push_back(&source_node, &ctx->root); new_neighbors.push_back(&source_node, &ctx->root);
...@@ -405,20 +392,17 @@ static int update_second_degree_neighbors(MHNSW_Context *ctx, ...@@ -405,20 +392,17 @@ static int update_second_degree_neighbors(MHNSW_Context *ctx,
return err; return err;
} }
for (const FVectorRef &neigh: neighbors) for (const FVectorNode &neigh: neighbors)
{ {
List<FVectorRef> new_neighbors; List<FVectorNode> new_neighbors;
if (int err= get_neighbors(ctx, layer_number, neigh, &new_neighbors)) if (int err= get_neighbors(ctx, layer_number, neigh, &new_neighbors))
return err; return err;
if (new_neighbors.elements > max_neighbors) if (new_neighbors.elements > max_neighbors)
{ {
// shrink the neighbors // shrink the neighbors
List<FVectorRef> selected; List<FVectorNode> selected;
FVector *v= ctx->get_fvector_from_source(neigh); if (int err= select_neighbors(ctx, layer_number, neigh,
if (!v)
return HA_ERR_OUT_OF_MEM;
if (int err= select_neighbors(ctx, layer_number, *v,
new_neighbors, max_neighbors, &selected)) new_neighbors, max_neighbors, &selected))
return err; return err;
if (int err= write_neighbors(ctx, layer_number, neigh, selected)) if (int err= write_neighbors(ctx, layer_number, neigh, selected))
...@@ -432,8 +416,8 @@ static int update_second_degree_neighbors(MHNSW_Context *ctx, ...@@ -432,8 +416,8 @@ static int update_second_degree_neighbors(MHNSW_Context *ctx,
static int update_neighbors(MHNSW_Context *ctx, static int update_neighbors(MHNSW_Context *ctx,
size_t layer_number, uint max_neighbors, size_t layer_number, uint max_neighbors,
const FVectorRef &source_node, const FVectorNode &source_node,
const List<FVectorRef> &neighbors) const List<FVectorNode> &neighbors)
{ {
// 1. update node's neighbors // 1. update node's neighbors
if (int err= write_neighbors(ctx, layer_number, source_node, neighbors)) if (int err= write_neighbors(ctx, layer_number, source_node, neighbors))
...@@ -445,36 +429,35 @@ static int update_neighbors(MHNSW_Context *ctx, ...@@ -445,36 +429,35 @@ static int update_neighbors(MHNSW_Context *ctx,
static int search_layer(MHNSW_Context *ctx, const FVector &target, static int search_layer(MHNSW_Context *ctx, const FVector &target,
const List<FVectorRef> &start_nodes, const List<FVectorNode> &start_nodes,
uint max_candidates_return, size_t layer, uint max_candidates_return, size_t layer,
List<FVectorRef> *result) List<FVectorNode> *result)
{ {
DBUG_ASSERT(start_nodes.elements > 0); DBUG_ASSERT(start_nodes.elements > 0);
DBUG_ASSERT(result->elements == 0); DBUG_ASSERT(result->elements == 0);
Queue<FVector, const FVector> candidates; Queue<FVectorNode, const FVector> candidates;
Queue<FVector, const FVector> best; Queue<FVectorNode, const FVector> best;
Hash_set<FVectorRef> visited(PSI_INSTRUMENT_MEM, FVectorRef::get_key); Hash_set<FVectorNode> visited(PSI_INSTRUMENT_MEM, FVectorNode::get_key);
candidates.init(10000, false, cmp_vec, &target); candidates.init(10000, false, cmp_vec, &target);
best.init(max_candidates_return, true, cmp_vec, &target); best.init(max_candidates_return, true, cmp_vec, &target);
for (const FVectorRef &node : start_nodes) for (const FVectorNode &node : start_nodes)
{ {
FVector *v= ctx->get_fvector_from_source(node); candidates.push(&node);
candidates.push(v);
if (best.elements() < max_candidates_return) if (best.elements() < max_candidates_return)
best.push(v); best.push(&node);
else if (v->distance_to(target) > best.top()->distance_to(target)) else if (node.distance_to(target) > best.top()->distance_to(target))
best.replace_top(v); best.replace_top(&node);
visited.insert(v); visited.insert(&node);
dbug_print_vec_ref("INSERTING node in visited: ", layer, node); dbug_print_vec_ref("INSERTING node in visited: ", layer, node);
} }
float furthest_best= best.top()->distance_to(target); float furthest_best= best.top()->distance_to(target);
while (candidates.elements()) while (candidates.elements())
{ {
const FVector &cur_vec= *candidates.pop(); const FVectorNode &cur_vec= *candidates.pop();
float cur_distance= cur_vec.distance_to(target); float cur_distance= cur_vec.distance_to(target);
if (cur_distance > furthest_best && best.elements() == max_candidates_return) if (cur_distance > furthest_best && best.elements() == max_candidates_return)
{ {
...@@ -482,27 +465,26 @@ static int search_layer(MHNSW_Context *ctx, const FVector &target, ...@@ -482,27 +465,26 @@ static int search_layer(MHNSW_Context *ctx, const FVector &target,
// Can't get better. // Can't get better.
} }
List<FVectorRef> neighbors; List<FVectorNode> neighbors;
get_neighbors(ctx, layer, cur_vec, &neighbors); get_neighbors(ctx, layer, cur_vec, &neighbors);
for (const FVectorRef &neigh: neighbors) for (const FVectorNode &neigh: neighbors)
{ {
dbug_print_hash_vec(visited); dbug_print_hash_vec(visited);
if (visited.find(&neigh)) if (visited.find(&neigh))
continue; continue;
FVector *clone= ctx->get_fvector_from_source(neigh); visited.insert(&neigh);
visited.insert(clone);
if (best.elements() < max_candidates_return) if (best.elements() < max_candidates_return)
{ {
candidates.push(clone); candidates.push(&neigh);
best.push(clone); best.push(&neigh);
furthest_best= best.top()->distance_to(target); furthest_best= best.top()->distance_to(target);
} }
else if (clone->distance_to(target) < furthest_best) else if (neigh.distance_to(target) < furthest_best)
{ {
best.replace_top(clone); best.replace_top(&neigh);
candidates.push(clone); candidates.push(&neigh);
furthest_best= best.top()->distance_to(target); furthest_best= best.top()->distance_to(target);
} }
} }
...@@ -575,34 +557,32 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) ...@@ -575,34 +557,32 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
// First insert! // First insert!
h->position(table->record[0]); h->position(table->record[0]);
return write_neighbors(&ctx, 0, {h->ref, h->ref_length}, {}); return write_neighbors(&ctx, 0, {&ctx, h->ref}, {});
} }
longlong max_layer= graph->field[0]->val_int(); longlong max_layer= graph->field[0]->val_int();
h->position(table->record[0]); h->position(table->record[0]);
List<FVectorRef> candidates; List<FVectorNode> candidates;
List<FVectorRef> start_nodes; List<FVectorNode> start_nodes;
String ref_str, *ref_ptr; String ref_str, *ref_ptr;
ref_ptr= graph->field[1]->val_str(&ref_str); ref_ptr= graph->field[1]->val_str(&ref_str);
FVectorRef start_node_ref{ref_ptr->ptr(), ref_ptr->length()}; FVectorNode start_node(&ctx, ref_ptr->ptr());
// TODO(cvicentiu) use a random start node in last layer. // TODO(cvicentiu) use a random start node in last layer.
// XXX or may be *all* nodes in the last layer? there should be few // XXX or may be *all* nodes in the last layer? there should be few
if (start_nodes.push_back(&start_node_ref, &ctx.root)) if (start_nodes.push_back(&start_node, &ctx.root))
return HA_ERR_OUT_OF_MEM; return HA_ERR_OUT_OF_MEM;
FVector *v= ctx.get_fvector_from_source(start_node_ref); if (int err= start_node.instantiate_vector())
if (!v) return err;
return HA_ERR_OUT_OF_MEM;
if (v->size_of() != res->length()) if (ctx.vec_len * sizeof(float) != res->length())
return bad_value_on_insert(vec_field); return bad_value_on_insert(vec_field);
FVector target; FVectorNode target(&ctx, h->ref, res->ptr());
target.init(&ctx.root, h->ref, h->ref_length, res->ptr(), res->length());
double new_num= my_rnd(&thd->rand); double new_num= my_rnd(&thd->rand);
double log= -std::log(new_num) * NORMALIZATION_FACTOR; double log= -std::log(new_num) * NORMALIZATION_FACTOR;
...@@ -622,7 +602,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) ...@@ -622,7 +602,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
for (longlong cur_layer= std::min(max_layer, new_node_layer); for (longlong cur_layer= std::min(max_layer, new_node_layer);
cur_layer >= 0; cur_layer--) cur_layer >= 0; cur_layer--)
{ {
List<FVectorRef> neighbors; List<FVectorNode> neighbors;
if (int err= search_layer(&ctx, target, start_nodes, if (int err= search_layer(&ctx, target, start_nodes,
thd->variables.hnsw_ef_constructor, cur_layer, thd->variables.hnsw_ef_constructor, cur_layer,
&candidates)) &candidates))
...@@ -679,33 +659,29 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit) ...@@ -679,33 +659,29 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
longlong max_layer= graph->field[0]->val_int(); longlong max_layer= graph->field[0]->val_int();
List<FVectorRef> candidates; // XXX List? not Queue by distance? List<FVectorNode> candidates; // XXX List? not Queue by distance?
List<FVectorRef> start_nodes; List<FVectorNode> start_nodes;
String ref_str, *ref_ptr; String ref_str, *ref_ptr= graph->field[1]->val_str(&ref_str);
ref_ptr= graph->field[1]->val_str(&ref_str); FVectorNode start_node(&ctx, ref_ptr->ptr());
FVectorRef start_node_ref{ref_ptr->ptr(), ref_ptr->length()};
// TODO(cvicentiu) use a random start node in last layer. // TODO(cvicentiu) use a random start node in last layer.
// XXX or may be *all* nodes in the last layer? there should be few // XXX or may be *all* nodes in the last layer? there should be few
if (start_nodes.push_back(&start_node_ref, &ctx.root)) if (start_nodes.push_back(&start_node, &ctx.root))
return HA_ERR_OUT_OF_MEM; return HA_ERR_OUT_OF_MEM;
FVector *v= ctx.get_fvector_from_source(start_node_ref); if (int err= start_node.instantiate_vector())
if (!v) return err;
return HA_ERR_OUT_OF_MEM;
/* /*
if the query vector is NULL or invalid, VEC_DISTANCE will return if the query vector is NULL or invalid, VEC_DISTANCE will return
NULL, so the result is basically unsorted, we can return rows NULL, so the result is basically unsorted, we can return rows
in any order. For simplicity let's sort by the start_node. in any order. For simplicity let's sort by the start_node.
*/ */
if (!res || v->size_of() != res->length()) if (!res || ctx.vec_len * sizeof(float) != res->length())
res= vec_field->val_str(&buf); res= vec_field->val_str(&buf);
FVector target; FVector target(&ctx, res->ptr());
if (target.init(&ctx.root, h->ref, h->ref_length, res->ptr(), res->length()))
return HA_ERR_OUT_OF_MEM;
ulonglong ef_search= std::max<ulonglong>( //XXX why not always limit? ulonglong ef_search= std::max<ulonglong>( //XXX why not always limit?
thd->variables.hnsw_ef_search, limit); thd->variables.hnsw_ef_search, limit);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment