Commit 73d37653 authored by Sergei Golubchik's avatar Sergei Golubchik

mhnsw: cache neighbors too

parent 0951f3b1
...@@ -58,13 +58,17 @@ class FVectorNode: public FVector ...@@ -58,13 +58,17 @@ class FVectorNode: public FVector
{ {
private: private:
uchar *ref; uchar *ref;
List<FVectorNode> *neighbors= nullptr;
char *neighbors_read= 0;
public: public:
FVectorNode(MHNSW_Context *ctx_, const void *ref_); FVectorNode(MHNSW_Context *ctx_, const void *ref_);
FVectorNode(MHNSW_Context *ctx_, const void *ref_, const void *vec_); FVectorNode(MHNSW_Context *ctx_, const void *ref_, const void *vec_);
float distance_to(const FVector &other) const; float distance_to(const FVector &other) const;
int instantiate_vector(); int instantiate_vector();
int instantiate_neighbors(size_t layer);
size_t get_ref_len() const; size_t get_ref_len() const;
uchar *get_ref() const { return ref; } uchar *get_ref() const { return ref; }
List<FVectorNode> &get_neighbors(size_t layer) const;
bool is_new() const; bool is_new() const;
static uchar *get_key(const FVectorNode *elem, size_t *key_len, my_bool); static uchar *get_key(const FVectorNode *elem, size_t *key_len, my_bool);
...@@ -130,6 +134,55 @@ int FVectorNode::instantiate_vector() ...@@ -130,6 +134,55 @@ int FVectorNode::instantiate_vector()
return 0; return 0;
} }
int FVectorNode::instantiate_neighbors(size_t layer)
{
if (!neighbors)
{
neighbors= new (&ctx->root) List<FVectorNode>[layer+1];
neighbors_read= (char*)alloc_root(&ctx->root, layer+1);
bzero(neighbors_read, layer+1);
}
if (!neighbors_read[layer])
{
if (!is_new())
{
TABLE *graph= ctx->table->hlindex;
uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length));
const size_t ref_len= get_ref_len();
graph->field[0]->store(layer, false);
graph->field[1]->store_binary(ref, ref_len);
key_copy(key, graph->record[0], graph->key_info, graph->key_info->key_length);
if (int err= graph->file->ha_index_read_map(graph->record[0], key,
HA_WHOLE_KEY, HA_READ_KEY_EXACT))
return err;
String strbuf, *str= graph->field[2]->val_str(&strbuf);
const char *neigh_arr_bytes= str->ptr();
uint number_of_neighbors= HNSW_MAX_M_read(neigh_arr_bytes);
if (number_of_neighbors * ref_len + HNSW_MAX_M_WIDTH != str->length())
return HA_ERR_CRASHED; // should not happen, corrupted HNSW index
const char *pos= neigh_arr_bytes + HNSW_MAX_M_WIDTH;
for (uint i= 0; i < number_of_neighbors; i++)
{
FVectorNode *neigh= ctx->get_node(pos);
neighbors[layer].push_back(neigh, &ctx->root);
pos+= ref_len;
}
}
neighbors_read[layer]= 1;
}
return 0;
}
List<FVectorNode> &FVectorNode::get_neighbors(size_t layer) const
{
const_cast<FVectorNode*>(this)->instantiate_neighbors(layer);
return neighbors[layer];
}
size_t FVectorNode::get_ref_len() const size_t FVectorNode::get_ref_len() const
{ {
return ctx->table->file->ref_length; return ctx->table->file->ref_length;
...@@ -172,44 +225,8 @@ static int cmp_vec(const FVector *target, const FVectorNode *a, const FVectorNod ...@@ -172,44 +225,8 @@ static int cmp_vec(const FVector *target, const FVectorNode *a, const FVectorNod
const bool KEEP_PRUNED_CONNECTIONS=true; // XXX why? const bool KEEP_PRUNED_CONNECTIONS=true; // XXX why?
const bool EXTEND_CANDIDATES=true; // XXX or false? const bool EXTEND_CANDIDATES=true; // XXX or false?
static int get_neighbors(MHNSW_Context *ctx, size_t layer_number,
const FVectorNode &source_node,
List<FVectorNode> *neighbors)
{
TABLE *graph= ctx->table->hlindex;
uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length));
graph->field[0]->store(layer_number, false);
graph->field[1]->store_binary(source_node.get_ref(), source_node.get_ref_len());
key_copy(key, graph->record[0], graph->key_info, graph->key_info->key_length);
if (int err= graph->file->ha_index_read_map(graph->record[0], key,
HA_WHOLE_KEY, HA_READ_KEY_EXACT))
return err;
String strbuf, *str= graph->field[2]->val_str(&strbuf);
// mhnsw_insert() guarantees that all ref have the same length
uint ref_length= source_node.get_ref_len();
const char *neigh_arr_bytes= str->ptr();
uint number_of_neighbors= HNSW_MAX_M_read(neigh_arr_bytes);
if (number_of_neighbors * ref_length + HNSW_MAX_M_WIDTH != str->length())
return HA_ERR_CRASHED; // should not happen, corrupted HNSW index
const char *pos= neigh_arr_bytes + HNSW_MAX_M_WIDTH;
for (uint i= 0; i < number_of_neighbors; i++)
{
FVectorNode *neigh= ctx->get_node(pos);
neighbors->push_back(neigh, &ctx->root);
pos+= ref_length;
}
return 0;
}
static int select_neighbors(MHNSW_Context *ctx, static int select_neighbors(MHNSW_Context *ctx,
size_t layer_number, const FVector &target, size_t layer, const FVector &target,
const List<FVectorNode> &candidates, const List<FVectorNode> &candidates,
size_t max_neighbor_connections, size_t max_neighbor_connections,
List<FVectorNode> *neighbors) List<FVectorNode> *neighbors)
...@@ -242,11 +259,7 @@ static int select_neighbors(MHNSW_Context *ctx, ...@@ -242,11 +259,7 @@ static int select_neighbors(MHNSW_Context *ctx,
{ {
for (const FVectorNode &candidate : candidates) for (const FVectorNode &candidate : candidates)
{ {
List<FVectorNode> candidate_neighbors; for (const FVectorNode &extra_candidate : candidate.get_neighbors(layer))
if (int err= get_neighbors(ctx, layer_number, candidate,
&candidate_neighbors))
return err;
for (const FVectorNode &extra_candidate : candidate_neighbors)
{ {
if (visited.find(&extra_candidate)) if (visited.find(&extra_candidate))
continue; continue;
...@@ -330,7 +343,7 @@ static void dbug_print_hash_vec(Hash_set<FVectorNode> &h) ...@@ -330,7 +343,7 @@ static void dbug_print_hash_vec(Hash_set<FVectorNode> &h)
} }
static int write_neighbors(MHNSW_Context *ctx, size_t layer_number, static int write_neighbors(MHNSW_Context *ctx, size_t layer,
const FVectorNode &source_node, const FVectorNode &source_node,
const List<FVectorNode> &new_neighbors) const List<FVectorNode> &new_neighbors)
{ {
...@@ -353,19 +366,19 @@ static int write_neighbors(MHNSW_Context *ctx, size_t layer_number, ...@@ -353,19 +366,19 @@ static int write_neighbors(MHNSW_Context *ctx, size_t layer_number,
pos+= node.get_ref_len(); pos+= node.get_ref_len();
} }
graph->field[0]->store(layer_number, false); graph->field[0]->store(layer, false);
graph->field[1]->store_binary(source_node.get_ref(), source_node.get_ref_len()); graph->field[1]->store_binary(source_node.get_ref(), source_node.get_ref_len());
graph->field[2]->store_binary(neighbor_array_bytes, total_size); graph->field[2]->store_binary(neighbor_array_bytes, total_size);
if (source_node.is_new()) if (source_node.is_new())
{ {
dbug_print_vec_ref("INSERT ", layer_number, source_node); dbug_print_vec_ref("INSERT ", layer, source_node);
err= graph->file->ha_write_row(graph->record[0]); err= graph->file->ha_write_row(graph->record[0]);
} }
else else
{ {
dbug_print_vec_ref("UPDATE ", layer_number, source_node); dbug_print_vec_ref("UPDATE ", layer, source_node);
dbug_print_vec_neigh(layer_number, new_neighbors); dbug_print_vec_neigh(layer, new_neighbors);
uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length)); uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length));
key_copy(key, graph->record[0], graph->key_info, graph->key_info->key_length); key_copy(key, graph->record[0], graph->key_info, graph->key_info->key_length);
...@@ -381,39 +394,33 @@ static int write_neighbors(MHNSW_Context *ctx, size_t layer_number, ...@@ -381,39 +394,33 @@ static int write_neighbors(MHNSW_Context *ctx, size_t layer_number,
} }
static int update_second_degree_neighbors(MHNSW_Context *ctx, static int update_second_degree_neighbors(MHNSW_Context *ctx, size_t layer,
size_t layer_number,
uint max_neighbors, uint max_neighbors,
const FVectorNode &source_node, const FVectorNode &source_node,
const List<FVectorNode> &neighbors) const List<FVectorNode> &neighbors)
{ {
//dbug_print_vec_ref("Updating second degree neighbors", layer_number, source_node); //dbug_print_vec_ref("Updating second degree neighbors", layer, source_node);
//dbug_print_vec_neigh(layer_number, neighbors); //dbug_print_vec_neigh(layer, neighbors);
for (const FVectorNode &neigh: neighbors) // XXX why this loop? for (const FVectorNode &neigh: neighbors) // XXX why this loop?
{ {
List<FVectorNode> new_neighbors; neigh.get_neighbors(layer).push_back(&source_node, &ctx->root);
if (int err= get_neighbors(ctx, layer_number, neigh, &new_neighbors)) if (int err= write_neighbors(ctx, layer, neigh, neigh.get_neighbors(layer)))
return err;
new_neighbors.push_back(&source_node, &ctx->root);
if (int err= write_neighbors(ctx, layer_number, neigh, new_neighbors))
return err; return err;
} }
for (const FVectorNode &neigh: neighbors) for (const FVectorNode &neigh: neighbors)
{ {
List<FVectorNode> new_neighbors; if (neigh.get_neighbors(layer).elements > max_neighbors)
if (int err= get_neighbors(ctx, layer_number, neigh, &new_neighbors))
return err;
if (new_neighbors.elements > max_neighbors)
{ {
// shrink the neighbors // shrink the neighbors
List<FVectorNode> selected; List<FVectorNode> selected;
if (int err= select_neighbors(ctx, layer_number, neigh, if (int err= select_neighbors(ctx, layer, neigh,
new_neighbors, max_neighbors, &selected)) neigh.get_neighbors(layer),
max_neighbors, &selected))
return err; return err;
if (int err= write_neighbors(ctx, layer_number, neigh, selected)) if (int err= write_neighbors(ctx, layer, neigh, selected))
return err; return err;
// XXX neigh.get_neighbors(layer)= selected;
} }
} }
...@@ -422,15 +429,15 @@ static int update_second_degree_neighbors(MHNSW_Context *ctx, ...@@ -422,15 +429,15 @@ static int update_second_degree_neighbors(MHNSW_Context *ctx,
static int update_neighbors(MHNSW_Context *ctx, static int update_neighbors(MHNSW_Context *ctx,
size_t layer_number, uint max_neighbors, size_t layer, uint max_neighbors,
const FVectorNode &source_node, const FVectorNode &source_node,
const List<FVectorNode> &neighbors) const List<FVectorNode> &neighbors)
{ {
// 1. update node's neighbors // 1. update node's neighbors
if (int err= write_neighbors(ctx, layer_number, source_node, neighbors)) if (int err= write_neighbors(ctx, layer, source_node, neighbors))
return err; return err;
// 2. update node's neighbors' neighbors (shrink before update) // 2. update node's neighbors' neighbors (shrink before update)
return update_second_degree_neighbors(ctx, layer_number, return update_second_degree_neighbors(ctx, layer,
max_neighbors, source_node, neighbors); max_neighbors, source_node, neighbors);
} }
...@@ -473,10 +480,7 @@ static int search_layer(MHNSW_Context *ctx, ...@@ -473,10 +480,7 @@ static int search_layer(MHNSW_Context *ctx,
// Can't get better. // Can't get better.
} }
List<FVectorNode> neighbors; for (const FVectorNode &neigh: cur_vec.get_neighbors(layer))
get_neighbors(ctx, layer, cur_vec, &neighbors);
for (const FVectorNode &neigh: neighbors)
{ {
dbug_print_hash_vec(visited); dbug_print_hash_vec(visited);
if (visited.find(&neigh)) if (visited.find(&neigh))
...@@ -496,7 +500,6 @@ static int search_layer(MHNSW_Context *ctx, ...@@ -496,7 +500,6 @@ static int search_layer(MHNSW_Context *ctx,
furthest_best= best.top()->distance_to(target); furthest_best= best.top()->distance_to(target);
} }
} }
neighbors.empty();
} }
DBUG_PRINT("VECTOR", ("SEARCH_LAYER_END %d best", best.elements())); DBUG_PRINT("VECTOR", ("SEARCH_LAYER_END %d best", best.elements()));
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment