Commit 4db6bfd7 authored by Sergei Golubchik's avatar Sergei Golubchik

cleanup search_layer()

to return only as many elements as needed, the caller no longer needs to
overallocate result arrays for throwaway nodes
parent b3c30521
...@@ -886,16 +886,29 @@ static int update_second_degree_neighbors(MHNSW_Context *ctx, TABLE *graph, ...@@ -886,16 +886,29 @@ static int update_second_degree_neighbors(MHNSW_Context *ctx, TABLE *graph,
} }
static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target, static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target,
Neighborhood *start_nodes, uint ef, size_t layer, Neighborhood *start_nodes, uint result_size,
Neighborhood *result, bool skip_deleted) size_t layer, Neighborhood *result, bool construction)
{ {
DBUG_ASSERT(start_nodes->num > 0); DBUG_ASSERT(start_nodes->num > 0);
result->num= 0; result->num= 0;
MEM_ROOT * const root= graph->in_use->mem_root; MEM_ROOT * const root= graph->in_use->mem_root;
Queue<Visited> candidates, best;
bool skip_deleted;
uint ef= result_size;
Queue<Visited> candidates; if (construction)
Queue<Visited> best; {
skip_deleted= false;
if (ef > 1)
ef= std::max(ef_construction, ef);
}
else
{
skip_deleted= layer == 0;
if (ef > 1 || layer == 0)
ef= std::max(graph->in_use->variables.mhnsw_min_limit, ef);
}
// WARNING! heuristic here // WARNING! heuristic here
const double est_heuristic= 8 * std::sqrt(ctx->max_neighbors(layer)); const double est_heuristic= 8 * std::sqrt(ctx->max_neighbors(layer));
...@@ -905,23 +918,21 @@ static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target, ...@@ -905,23 +918,21 @@ static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target,
candidates.init(10000, false, Visited::cmp); candidates.init(10000, false, Visited::cmp);
best.init(ef, true, Visited::cmp); best.init(ef, true, Visited::cmp);
DBUG_ASSERT(start_nodes->num <= result_size);
for (size_t i=0; i < start_nodes->num; i++) for (size_t i=0; i < start_nodes->num; i++)
{ {
Visited *v= visited.create(start_nodes->links[i]); Visited *v= visited.create(start_nodes->links[i]);
candidates.push(v); candidates.push(v);
if (skip_deleted && v->node->deleted) if (skip_deleted && v->node->deleted)
continue; continue;
if (best.elements() < ef) best.push(v);
best.push(v);
else if (v->distance_to_target < best.top()->distance_to_target)
best.replace_top(v);
} }
float furthest_best= FLT_MAX; float furthest_best= FLT_MAX;
while (candidates.elements()) while (candidates.elements())
{ {
const Visited &cur= *candidates.pop(); const Visited &cur= *candidates.pop();
if (cur.distance_to_target > furthest_best && best.elements() == ef) if (cur.distance_to_target > furthest_best && best.is_full())
break; // All possible candidates are worse than what we have break; // All possible candidates are worse than what we have
visited.flush(); visited.flush();
...@@ -941,7 +952,7 @@ static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target, ...@@ -941,7 +952,7 @@ static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target,
if (int err= links[i]->load(graph)) if (int err= links[i]->load(graph))
return err; return err;
Visited *v= visited.create(links[i]); Visited *v= visited.create(links[i]);
if (best.elements() < ef) if (!best.is_full())
{ {
candidates.push(v); candidates.push(v);
if (skip_deleted && v->node->deleted) if (skip_deleted && v->node->deleted)
...@@ -966,6 +977,9 @@ static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target, ...@@ -966,6 +977,9 @@ static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target,
set_if_bigger(ctx->ef_power, ef_power); // not atomic, but it's ok set_if_bigger(ctx->ef_power, ef_power); // not atomic, but it's ok
} }
while (best.elements() > result_size)
best.pop();
result->num= best.elements(); result->num= best.elements();
for (FVectorNode **links= result->links + result->num; best.elements();) for (FVectorNode **links= result->links + result->num; best.elements();)
*--links= best.pop()->node; *--links= best.pop()->node;
...@@ -1033,9 +1047,10 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) ...@@ -1033,9 +1047,10 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
root_make_savepoint(thd->mem_root, &memroot_sv); root_make_savepoint(thd->mem_root, &memroot_sv);
SCOPE_EXIT([memroot_sv](){ root_free_to_savepoint(&memroot_sv); }); SCOPE_EXIT([memroot_sv](){ root_free_to_savepoint(&memroot_sv); });
const size_t max_found= ctx->max_neighbors(0);
Neighborhood candidates, start_nodes; Neighborhood candidates, start_nodes;
candidates.init(thd->alloc<FVectorNode*>(ef_construction + 7), ef_construction); candidates.init(thd->alloc<FVectorNode*>(max_found + 7), max_found);
start_nodes.init(thd->alloc<FVectorNode*>(ef_construction + 7), ef_construction); start_nodes.init(thd->alloc<FVectorNode*>(max_found + 7), max_found);
start_nodes.links[start_nodes.num++]= ctx->start; start_nodes.links[start_nodes.num++]= ctx->start;
const double NORMALIZATION_FACTOR= 1 / std::log(ctx->M); const double NORMALIZATION_FACTOR= 1 / std::log(ctx->M);
...@@ -1063,7 +1078,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) ...@@ -1063,7 +1078,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
{ {
uint max_neighbors= ctx->max_neighbors(cur_layer); uint max_neighbors= ctx->max_neighbors(cur_layer);
if (int err= search_layer(ctx, graph, target->vec, &start_nodes, if (int err= search_layer(ctx, graph, target->vec, &start_nodes,
ef_construction, cur_layer, &candidates, false)) max_neighbors, cur_layer, &candidates, true))
return err; return err;
if (int err= select_neighbors(ctx, graph, cur_layer, *target, candidates, if (int err= select_neighbors(ctx, graph, cur_layer, *target, candidates,
...@@ -1106,11 +1121,9 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit) ...@@ -1106,11 +1121,9 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
if (err) if (err)
return err; return err;
size_t ef= thd->variables.mhnsw_min_limit;
Neighborhood candidates, start_nodes; Neighborhood candidates, start_nodes;
candidates.init(thd->alloc<FVectorNode*>(ef + 7), ef); candidates.init(thd->alloc<FVectorNode*>(limit + 7), limit);
start_nodes.init(thd->alloc<FVectorNode*>(ef + 7), ef); start_nodes.init(thd->alloc<FVectorNode*>(limit + 7), limit);
// one could put all max_layer nodes in start_nodes // one could put all max_layer nodes in start_nodes
// but it has no effect on the recall or speed // but it has no effect on the recall or speed
...@@ -1145,8 +1158,8 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit) ...@@ -1145,8 +1158,8 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
std::swap(start_nodes, candidates); std::swap(start_nodes, candidates);
} }
if (int err= search_layer(ctx, graph, target, &start_nodes, ef, 0, if (int err= search_layer(ctx, graph, target, &start_nodes,
&candidates, true)) static_cast<uint>(limit), 0, &candidates, false))
return err; return err;
if (limit > candidates.num) if (limit > candidates.num)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment