Commit 0235ae9a authored by Sergei Golubchik's avatar Sergei Golubchik

mhnsw: modify target's neighbors directly

parent 8053db57
...@@ -226,10 +226,9 @@ const bool KEEP_PRUNED_CONNECTIONS=true; // XXX why? ...@@ -226,10 +226,9 @@ const bool KEEP_PRUNED_CONNECTIONS=true; // XXX why?
const bool EXTEND_CANDIDATES=true; // XXX or false? const bool EXTEND_CANDIDATES=true; // XXX or false?
static int select_neighbors(MHNSW_Context *ctx, static int select_neighbors(MHNSW_Context *ctx,
size_t layer, const FVector &target, size_t layer, const FVectorNode &target,
const List<FVectorNode> &candidates, const List<FVectorNode> &candidates,
size_t max_neighbor_connections, size_t max_neighbor_connections)
List<FVectorNode> *neighbors)
{ {
/* /*
TODO: If the input neighbors list is already sorted in search_layer, then TODO: If the input neighbors list is already sorted in search_layer, then
...@@ -297,8 +296,10 @@ static int select_neighbors(MHNSW_Context *ctx, ...@@ -297,8 +296,10 @@ static int select_neighbors(MHNSW_Context *ctx,
} }
DBUG_ASSERT(best.elements() <= max_neighbor_connections); DBUG_ASSERT(best.elements() <= max_neighbor_connections);
while (best.elements()) // XXX why not to return best directly? List<FVectorNode> &neighbors= target.get_neighbors(layer);
neighbors->push_front(best.pop(), &ctx->root); neighbors.empty();
while (best.elements())
neighbors.push_front(best.pop(), &ctx->root);
return 0; return 0;
} }
...@@ -344,11 +345,11 @@ static void dbug_print_hash_vec(Hash_set<FVectorNode> &h) ...@@ -344,11 +345,11 @@ static void dbug_print_hash_vec(Hash_set<FVectorNode> &h)
static int write_neighbors(MHNSW_Context *ctx, size_t layer, static int write_neighbors(MHNSW_Context *ctx, size_t layer,
const FVectorNode &source_node, const FVectorNode &source_node)
const List<FVectorNode> &new_neighbors)
{ {
int err; int err;
TABLE *graph= ctx->table->hlindex; TABLE *graph= ctx->table->hlindex;
const List<FVectorNode> &new_neighbors= source_node.get_neighbors(layer);
DBUG_ASSERT(new_neighbors.elements <= HNSW_MAX_M); DBUG_ASSERT(new_neighbors.elements <= HNSW_MAX_M);
size_t total_size= HNSW_MAX_M_WIDTH + new_neighbors.elements * source_node.get_ref_len(); size_t total_size= HNSW_MAX_M_WIDTH + new_neighbors.elements * source_node.get_ref_len();
...@@ -396,31 +397,25 @@ static int write_neighbors(MHNSW_Context *ctx, size_t layer, ...@@ -396,31 +397,25 @@ static int write_neighbors(MHNSW_Context *ctx, size_t layer,
static int update_second_degree_neighbors(MHNSW_Context *ctx, size_t layer, static int update_second_degree_neighbors(MHNSW_Context *ctx, size_t layer,
uint max_neighbors, uint max_neighbors,
const FVectorNode &source_node, const FVectorNode &node)
const List<FVectorNode> &neighbors)
{ {
//dbug_print_vec_ref("Updating second degree neighbors", layer, source_node); for (const FVectorNode &neigh: node.get_neighbors(layer)) // XXX why this loop?
//dbug_print_vec_neigh(layer, neighbors);
for (const FVectorNode &neigh: neighbors) // XXX why this loop?
{ {
neigh.get_neighbors(layer).push_back(&source_node, &ctx->root); neigh.get_neighbors(layer).push_back(&node, &ctx->root);
if (int err= write_neighbors(ctx, layer, neigh, neigh.get_neighbors(layer))) if (int err= write_neighbors(ctx, layer, neigh))
return err; return err;
} }
for (const FVectorNode &neigh: neighbors) for (const FVectorNode &neigh: node.get_neighbors(layer))
{ {
if (neigh.get_neighbors(layer).elements > max_neighbors) if (neigh.get_neighbors(layer).elements > max_neighbors)
{ {
// shrink the neighbors // shrink the neighbors
List<FVectorNode> selected;
if (int err= select_neighbors(ctx, layer, neigh, if (int err= select_neighbors(ctx, layer, neigh,
neigh.get_neighbors(layer), neigh.get_neighbors(layer), max_neighbors))
max_neighbors, &selected))
return err; return err;
if (int err= write_neighbors(ctx, layer, neigh, selected)) if (int err= write_neighbors(ctx, layer, neigh))
return err; return err;
// XXX neigh.get_neighbors(layer)= selected;
} }
} }
...@@ -428,17 +423,14 @@ static int update_second_degree_neighbors(MHNSW_Context *ctx, size_t layer, ...@@ -428,17 +423,14 @@ static int update_second_degree_neighbors(MHNSW_Context *ctx, size_t layer,
} }
static int update_neighbors(MHNSW_Context *ctx, static int update_neighbors(MHNSW_Context *ctx, size_t layer,
size_t layer, uint max_neighbors, uint max_neighbors, const FVectorNode &node)
const FVectorNode &source_node,
const List<FVectorNode> &neighbors)
{ {
// 1. update node's neighbors // 1. update node's neighbors
if (int err= write_neighbors(ctx, layer, source_node, neighbors)) if (int err= write_neighbors(ctx, layer, node))
return err; return err;
// 2. update node's neighbors' neighbors (shrink before update) // 2. update node's neighbors' neighbors (shrink before update)
return update_second_degree_neighbors(ctx, layer, return update_second_degree_neighbors(ctx, layer, max_neighbors, node);
max_neighbors, source_node, neighbors);
} }
...@@ -571,7 +563,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) ...@@ -571,7 +563,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
// First insert! // First insert!
FVectorNode target(&ctx, h->ref); FVectorNode target(&ctx, h->ref);
ctx.target= &target; ctx.target= &target;
return write_neighbors(&ctx, 0, target, {}); return write_neighbors(&ctx, 0, target);
} }
longlong max_layer= graph->field[0]->val_int(); longlong max_layer= graph->field[0]->val_int();
...@@ -601,6 +593,14 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) ...@@ -601,6 +593,14 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
double log= -std::log(new_num) * NORMALIZATION_FACTOR; double log= -std::log(new_num) * NORMALIZATION_FACTOR;
longlong new_node_layer= static_cast<longlong>(std::floor(log)); longlong new_node_layer= static_cast<longlong>(std::floor(log));
// XXX what is that?
for (longlong cur_layer= new_node_layer; cur_layer >= max_layer + 1;
cur_layer--)
{
if (int err= write_neighbors(&ctx, cur_layer, target))
return err;
}
for (longlong cur_layer= max_layer; cur_layer > new_node_layer; cur_layer--) for (longlong cur_layer= max_layer; cur_layer > new_node_layer; cur_layer--)
{ {
if (int err= search_layer(&ctx, start_nodes, if (int err= search_layer(&ctx, start_nodes,
...@@ -615,7 +615,6 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) ...@@ -615,7 +615,6 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
for (longlong cur_layer= std::min(max_layer, new_node_layer); for (longlong cur_layer= std::min(max_layer, new_node_layer);
cur_layer >= 0; cur_layer--) cur_layer >= 0; cur_layer--)
{ {
List<FVectorNode> neighbors;
if (int err= search_layer(&ctx, start_nodes, if (int err= search_layer(&ctx, start_nodes,
thd->variables.hnsw_ef_constructor, cur_layer, thd->variables.hnsw_ef_constructor, cur_layer,
&candidates)) &candidates))
...@@ -626,23 +625,14 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) ...@@ -626,23 +625,14 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
: thd->variables.hnsw_max_connection_per_layer; : thd->variables.hnsw_max_connection_per_layer;
if (int err= select_neighbors(&ctx, cur_layer, target, candidates, if (int err= select_neighbors(&ctx, cur_layer, target, candidates,
max_neighbors, &neighbors)) max_neighbors))
return err; return err;
if (int err= update_neighbors(&ctx, cur_layer, max_neighbors, target, if (int err= update_neighbors(&ctx, cur_layer, max_neighbors, target))
neighbors))
return err; return err;
start_nodes= candidates; start_nodes= candidates;
} }
start_nodes.empty(); start_nodes.empty();
// XXX what is that?
for (longlong cur_layer= max_layer + 1; cur_layer <= new_node_layer;
cur_layer++)
{
if (int err= write_neighbors(&ctx, cur_layer, target, {}))
return err;
}
dbug_tmp_restore_column_map(&table->read_set, old_map); dbug_tmp_restore_column_map(&table->read_set, old_map);
return 0; return 0;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment