From 3804afad1a7cc7827c7ad6267b79f54564fa0166 Mon Sep 17 00:00:00 2001
From: Sergei Golubchik <serg@mariadb.org>
Date: Fri, 7 Jun 2024 19:12:08 +0200
Subject: [PATCH] mhnsw: cache start node too, don't push too much in
 pg_discard

---
 sql/vector_mhnsw.cc | 18 +++++++++---------
 1 file changed, 9 insertions(+), 9 deletions(-)

diff --git a/sql/vector_mhnsw.cc b/sql/vector_mhnsw.cc
index 4b5c1c1d344..04410f6be58 100644
--- a/sql/vector_mhnsw.cc
+++ b/sql/vector_mhnsw.cc
@@ -233,10 +233,10 @@ static int select_neighbors(MHNSW_Context *ctx, size_t layer,
       if ((discard= vec->distance_to(neigh) < target_dist))
         break;
     }
-    if (discard)
-      pq_discard.push(vec);
-    else
+    if (!discard)
       neighbors.push_back(vec, &ctx->root);
+    else if (pq_discard.elements() + neighbors.elements < max_neighbor_connections)
+      pq_discard.push(vec);
   }
 
   while (pq_discard.elements() &&
@@ -454,12 +454,12 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
   String ref_str, *ref_ptr;
 
   ref_ptr= graph->field[1]->val_str(&ref_str);
-  FVectorNode start_node(&ctx, ref_ptr->ptr());
+  FVectorNode *start_node= ctx.get_node(ref_ptr->ptr());
 
-  if (start_nodes.push_back(&start_node, &ctx.root))
+  if (start_nodes.push_back(start_node, &ctx.root))
     return HA_ERR_OUT_OF_MEM;
 
-  if (int err= start_node.instantiate_vector())
+  if (int err= start_node->instantiate_vector())
     return err;
 
   if (ctx.vec_len * sizeof(float) != res->length())
@@ -543,14 +543,14 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
   List<FVectorNode> start_nodes;
   String ref_str, *ref_ptr= graph->field[1]->val_str(&ref_str);
 
-  FVectorNode start_node(&ctx, ref_ptr->ptr());
+  FVectorNode *start_node= ctx.get_node(ref_ptr->ptr());
 
   // one could put all max_layer nodes in start_nodes
   // but it has no effect of the recall or speed
-  if (start_nodes.push_back(&start_node, &ctx.root))
+  if (start_nodes.push_back(start_node, &ctx.root))
     return HA_ERR_OUT_OF_MEM;
 
-  if (int err= start_node.instantiate_vector())
+  if (int err= start_node->instantiate_vector())
     return err;
 
   /*
-- 
2.30.9