diff --git a/sql/vector_mhnsw.cc b/sql/vector_mhnsw.cc index a4665576161ba8f22a66d069afa6177e7366f7fd..d48e69015b8408ee286c8447f9e62a4c5803b9bf 100644 --- a/sql/vector_mhnsw.cc +++ b/sql/vector_mhnsw.cc @@ -213,9 +213,9 @@ static int cmp_vec(const FVector *target, const FVectorNode *a, const FVectorNod const bool KEEP_PRUNED_CONNECTIONS=true; // XXX why? const bool EXTEND_CANDIDATES=true; // XXX or false? -static int select_neighbors(MHNSW_Context *ctx, - size_t layer, const FVectorNode &target, - const List<FVectorNode> &candidates, +static int select_neighbors(MHNSW_Context *ctx, size_t layer, + const FVectorNode &target, + const List<FVectorNode> &candidates_unsafe, size_t max_neighbor_connections) { /* @@ -224,16 +224,20 @@ static int select_neighbors(MHNSW_Context *ctx, */ Hash_set<FVectorNode> visited(PSI_INSTRUMENT_MEM, FVectorNode::get_key); - Queue<FVectorNode, const FVector> pq; // working queue Queue<FVectorNode, const FVector> pq_discard; // queue for discarded candidates - Queue<FVectorNode, const FVector> best; // neighbors to return + /* + make a copy of candidates in case it's target.get_neighbors(layer). + because we're going to modify the latter below + */ + List<FVectorNode> candidates= candidates_unsafe; + List<FVectorNode> &neighbors= target.get_neighbors(layer); + neighbors.empty(); // TODO(cvicentiu) this 1000 here is a hardcoded value for max queue size. // This should not be fixed. if (pq.init(10000, 0, cmp_vec, &target) || - pq_discard.init(10000, 0, cmp_vec, &target) || - best.init(max_neighbor_connections, true, cmp_vec, &target)) + pq_discard.init(10000, 0, cmp_vec, &target)) return HA_ERR_OUT_OF_MEM; for (const FVectorNode &candidate : candidates) @@ -257,38 +261,33 @@ static int select_neighbors(MHNSW_Context *ctx, } DBUG_ASSERT(pq.elements()); - best.push(pq.pop()); + neighbors.push_back(pq.pop(), &ctx->root); - float best_top= best.top()->distance_to(target); - while (pq.elements() && best.elements() < max_neighbor_connections) + while (pq.elements() && neighbors.elements < max_neighbor_connections) { const FVectorNode *vec= pq.pop(); - const float cur_dist= vec->distance_to(target); - if (cur_dist < best_top) + const float target_dist= vec->distance_to(target); + bool discard= false; + for (const FVectorNode &neigh : neighbors) { - DBUG_ASSERT(0); // impossible. XXX redo the loop - best.push(vec); - best_top= cur_dist; + if ((discard= vec->distance_to(neigh) < target_dist)) + break; } - else + if (discard) pq_discard.push(vec); + else + neighbors.push_back(vec, &ctx->root); } if (KEEP_PRUNED_CONNECTIONS) { while (pq_discard.elements() && - best.elements() < max_neighbor_connections) + neighbors.elements < max_neighbor_connections) { - best.push(pq_discard.pop()); + neighbors.push_back(pq_discard.pop(), &ctx->root); } } - DBUG_ASSERT(best.elements() <= max_neighbor_connections); - List<FVectorNode> &neighbors= target.get_neighbors(layer); - neighbors.empty(); - while (best.elements()) - neighbors.push_front(best.pop(), &ctx->root); - return 0; }