diff --git a/sql/vector_mhnsw.cc b/sql/vector_mhnsw.cc
index d705a9350c8819e59fb0a392ad6a21d6e33d4ceb..929cf6755ae86092ab0e822c68de88d3dcbd86bd 100644
--- a/sql/vector_mhnsw.cc
+++ b/sql/vector_mhnsw.cc
@@ -65,6 +65,7 @@ class FVectorNode: public FVector
   int instantiate_vector();
   size_t get_ref_len() const;
   uchar *get_ref() const { return ref; }
+  bool is_new() const;
 
   static uchar *get_key(const FVectorNode *elem, size_t *key_len, my_bool);
 };
@@ -76,6 +77,7 @@ class MHNSW_Context
   TABLE *table;
   Field *vec_field;
   size_t vec_len= 0;
+  FVector *target= 0;
 
   Hash_set<FVectorNode> node_cache{PSI_INSTRUMENT_MEM, FVectorNode::get_key};
 
@@ -133,6 +135,11 @@ size_t FVectorNode::get_ref_len() const
   return ctx->table->file->ref_length;
 }
 
+bool FVectorNode::is_new() const
+{
+  return this == ctx->target;
+}
+
 uchar *FVectorNode::get_key(const FVectorNode *elem, size_t *key_len, my_bool)
 {
   *key_len= elem->get_ref_len();
@@ -327,6 +334,7 @@ static int write_neighbors(MHNSW_Context *ctx, size_t layer_number,
                             const FVectorNode &source_node,
                             const List<FVectorNode> &new_neighbors)
 {
+  int err;
   TABLE *graph= ctx->table->hlindex;
   DBUG_ASSERT(new_neighbors.elements <= HNSW_MAX_M);
 
@@ -349,25 +357,24 @@ static int write_neighbors(MHNSW_Context *ctx, size_t layer_number,
   graph->field[1]->store_binary(source_node.get_ref(), source_node.get_ref_len());
   graph->field[2]->store_binary(neighbor_array_bytes, total_size);
 
-  uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length));
-  key_copy(key, graph->record[0], graph->key_info, graph->key_info->key_length);
-
-  // XXX try to write first?
-  int err= graph->file->ha_index_read_map(graph->record[1], key, HA_WHOLE_KEY,
-                                          HA_READ_KEY_EXACT);
-
-  // no record
-  if (err == HA_ERR_KEY_NOT_FOUND)
+  if (source_node.is_new())
   {
     dbug_print_vec_ref("INSERT ", layer_number, source_node);
     err= graph->file->ha_write_row(graph->record[0]);
   }
-  else if (!err)
+  else
   {
     dbug_print_vec_ref("UPDATE ", layer_number, source_node);
     dbug_print_vec_neigh(layer_number, new_neighbors);
 
-    err= graph->file->ha_update_row(graph->record[1], graph->record[0]);
+    uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length));
+    key_copy(key, graph->record[0], graph->key_info, graph->key_info->key_length);
+
+    err= graph->file->ha_index_read_map(graph->record[1], key,
+                                            HA_WHOLE_KEY, HA_READ_KEY_EXACT);
+    if (!err)
+      err= graph->file->ha_update_row(graph->record[1], graph->record[0]);
+
   }
   my_safe_afree(neighbor_array_bytes, total_size);
   return err;
@@ -428,7 +435,7 @@ static int update_neighbors(MHNSW_Context *ctx,
 }
 
 
-static int search_layer(MHNSW_Context *ctx, const FVector &target,
+static int search_layer(MHNSW_Context *ctx,
                         const List<FVectorNode> &start_nodes,
                         uint max_candidates_return, size_t layer,
                         List<FVectorNode> *result)
@@ -439,6 +446,7 @@ static int search_layer(MHNSW_Context *ctx, const FVector &target,
   Queue<FVectorNode, const FVector> candidates;
   Queue<FVectorNode, const FVector> best;
   Hash_set<FVectorNode> visited(PSI_INSTRUMENT_MEM, FVectorNode::get_key);
+  const FVector &target= *ctx->target;
 
   candidates.init(10000, false, cmp_vec, &target);
   best.init(max_candidates_return, true, cmp_vec, &target);
@@ -550,20 +558,21 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
 
   SCOPE_EXIT([graph](){ graph->file->ha_index_end(); });
 
+  h->position(table->record[0]);
+
   if (int err= graph->file->ha_index_last(graph->record[0]))
   {
     if (err != HA_ERR_END_OF_FILE)
       return err;
 
     // First insert!
-    h->position(table->record[0]);
-    return write_neighbors(&ctx, 0, {&ctx, h->ref}, {});
+    FVectorNode target(&ctx, h->ref);
+    ctx.target= &target;
+    return write_neighbors(&ctx, 0, target, {});
   }
 
   longlong max_layer= graph->field[0]->val_int();
 
-  h->position(table->record[0]);
-
   List<FVectorNode> candidates;
   List<FVectorNode> start_nodes;
   String ref_str, *ref_ptr;
@@ -583,6 +592,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
     return bad_value_on_insert(vec_field);
 
   FVectorNode target(&ctx, h->ref, res->ptr());
+  ctx.target= &target;
 
   double new_num= my_rnd(&thd->rand);
   double log= -std::log(new_num) * NORMALIZATION_FACTOR;
@@ -590,7 +600,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
 
   for (longlong cur_layer= max_layer; cur_layer > new_node_layer; cur_layer--)
   {
-    if (int err= search_layer(&ctx, target, start_nodes,
+    if (int err= search_layer(&ctx, start_nodes,
                               thd->variables.hnsw_ef_constructor, cur_layer,
                               &candidates))
       return err;
@@ -603,7 +613,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
        cur_layer >= 0; cur_layer--)
   {
     List<FVectorNode> neighbors;
-    if (int err= search_layer(&ctx, target, start_nodes,
+    if (int err= search_layer(&ctx, start_nodes,
                               thd->variables.hnsw_ef_constructor, cur_layer,
                               &candidates))
       return err;
@@ -682,6 +692,7 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
     res= vec_field->val_str(&buf);
 
   FVector target(&ctx, res->ptr());
+  ctx.target= &target;
 
   ulonglong ef_search= std::max<ulonglong>( //XXX why not always limit?
     thd->variables.hnsw_ef_search, limit);
@@ -689,16 +700,15 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
   for (size_t cur_layer= max_layer; cur_layer > 0; cur_layer--)
   {
     //XXX in the paper ef_search=1 here
-    if (int err= search_layer(&ctx, target, start_nodes, ef_search,
-                              cur_layer, &candidates))
+    if (int err= search_layer(&ctx, start_nodes, ef_search, cur_layer,
+                              &candidates))
       return err;
     start_nodes.empty();
     start_nodes.push_back(candidates.head(), &ctx.root); // XXX so ef_search=1 ???
     candidates.empty();
   }
 
-  if (int err= search_layer(&ctx, target, start_nodes, ef_search, 0,
-                            &candidates))
+  if (int err= search_layer(&ctx, start_nodes, ef_search, 0, &candidates))
     return err;
 
   size_t context_size=limit * h->ref_length + sizeof(ulonglong);