diff --git a/sql/vector_mhnsw.cc b/sql/vector_mhnsw.cc index d705a9350c8819e59fb0a392ad6a21d6e33d4ceb..929cf6755ae86092ab0e822c68de88d3dcbd86bd 100644 --- a/sql/vector_mhnsw.cc +++ b/sql/vector_mhnsw.cc @@ -65,6 +65,7 @@ class FVectorNode: public FVector int instantiate_vector(); size_t get_ref_len() const; uchar *get_ref() const { return ref; } + bool is_new() const; static uchar *get_key(const FVectorNode *elem, size_t *key_len, my_bool); }; @@ -76,6 +77,7 @@ class MHNSW_Context TABLE *table; Field *vec_field; size_t vec_len= 0; + FVector *target= 0; Hash_set<FVectorNode> node_cache{PSI_INSTRUMENT_MEM, FVectorNode::get_key}; @@ -133,6 +135,11 @@ size_t FVectorNode::get_ref_len() const return ctx->table->file->ref_length; } +bool FVectorNode::is_new() const +{ + return this == ctx->target; +} + uchar *FVectorNode::get_key(const FVectorNode *elem, size_t *key_len, my_bool) { *key_len= elem->get_ref_len(); @@ -327,6 +334,7 @@ static int write_neighbors(MHNSW_Context *ctx, size_t layer_number, const FVectorNode &source_node, const List<FVectorNode> &new_neighbors) { + int err; TABLE *graph= ctx->table->hlindex; DBUG_ASSERT(new_neighbors.elements <= HNSW_MAX_M); @@ -349,25 +357,24 @@ static int write_neighbors(MHNSW_Context *ctx, size_t layer_number, graph->field[1]->store_binary(source_node.get_ref(), source_node.get_ref_len()); graph->field[2]->store_binary(neighbor_array_bytes, total_size); - uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length)); - key_copy(key, graph->record[0], graph->key_info, graph->key_info->key_length); - - // XXX try to write first? - int err= graph->file->ha_index_read_map(graph->record[1], key, HA_WHOLE_KEY, - HA_READ_KEY_EXACT); - - // no record - if (err == HA_ERR_KEY_NOT_FOUND) + if (source_node.is_new()) { dbug_print_vec_ref("INSERT ", layer_number, source_node); err= graph->file->ha_write_row(graph->record[0]); } - else if (!err) + else { dbug_print_vec_ref("UPDATE ", layer_number, source_node); dbug_print_vec_neigh(layer_number, new_neighbors); - err= graph->file->ha_update_row(graph->record[1], graph->record[0]); + uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length)); + key_copy(key, graph->record[0], graph->key_info, graph->key_info->key_length); + + err= graph->file->ha_index_read_map(graph->record[1], key, + HA_WHOLE_KEY, HA_READ_KEY_EXACT); + if (!err) + err= graph->file->ha_update_row(graph->record[1], graph->record[0]); + } my_safe_afree(neighbor_array_bytes, total_size); return err; @@ -428,7 +435,7 @@ static int update_neighbors(MHNSW_Context *ctx, } -static int search_layer(MHNSW_Context *ctx, const FVector &target, +static int search_layer(MHNSW_Context *ctx, const List<FVectorNode> &start_nodes, uint max_candidates_return, size_t layer, List<FVectorNode> *result) @@ -439,6 +446,7 @@ static int search_layer(MHNSW_Context *ctx, const FVector &target, Queue<FVectorNode, const FVector> candidates; Queue<FVectorNode, const FVector> best; Hash_set<FVectorNode> visited(PSI_INSTRUMENT_MEM, FVectorNode::get_key); + const FVector &target= *ctx->target; candidates.init(10000, false, cmp_vec, &target); best.init(max_candidates_return, true, cmp_vec, &target); @@ -550,20 +558,21 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) SCOPE_EXIT([graph](){ graph->file->ha_index_end(); }); + h->position(table->record[0]); + if (int err= graph->file->ha_index_last(graph->record[0])) { if (err != HA_ERR_END_OF_FILE) return err; // First insert! - h->position(table->record[0]); - return write_neighbors(&ctx, 0, {&ctx, h->ref}, {}); + FVectorNode target(&ctx, h->ref); + ctx.target= ⌖ + return write_neighbors(&ctx, 0, target, {}); } longlong max_layer= graph->field[0]->val_int(); - h->position(table->record[0]); - List<FVectorNode> candidates; List<FVectorNode> start_nodes; String ref_str, *ref_ptr; @@ -583,6 +592,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) return bad_value_on_insert(vec_field); FVectorNode target(&ctx, h->ref, res->ptr()); + ctx.target= ⌖ double new_num= my_rnd(&thd->rand); double log= -std::log(new_num) * NORMALIZATION_FACTOR; @@ -590,7 +600,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) for (longlong cur_layer= max_layer; cur_layer > new_node_layer; cur_layer--) { - if (int err= search_layer(&ctx, target, start_nodes, + if (int err= search_layer(&ctx, start_nodes, thd->variables.hnsw_ef_constructor, cur_layer, &candidates)) return err; @@ -603,7 +613,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) cur_layer >= 0; cur_layer--) { List<FVectorNode> neighbors; - if (int err= search_layer(&ctx, target, start_nodes, + if (int err= search_layer(&ctx, start_nodes, thd->variables.hnsw_ef_constructor, cur_layer, &candidates)) return err; @@ -682,6 +692,7 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit) res= vec_field->val_str(&buf); FVector target(&ctx, res->ptr()); + ctx.target= ⌖ ulonglong ef_search= std::max<ulonglong>( //XXX why not always limit? thd->variables.hnsw_ef_search, limit); @@ -689,16 +700,15 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit) for (size_t cur_layer= max_layer; cur_layer > 0; cur_layer--) { //XXX in the paper ef_search=1 here - if (int err= search_layer(&ctx, target, start_nodes, ef_search, - cur_layer, &candidates)) + if (int err= search_layer(&ctx, start_nodes, ef_search, cur_layer, + &candidates)) return err; start_nodes.empty(); start_nodes.push_back(candidates.head(), &ctx.root); // XXX so ef_search=1 ??? candidates.empty(); } - if (int err= search_layer(&ctx, target, start_nodes, ef_search, 0, - &candidates)) + if (int err= search_layer(&ctx, start_nodes, ef_search, 0, &candidates)) return err; size_t context_size=limit * h->ref_length + sizeof(ulonglong);