Commit 0951f3b1 authored by Sergei Golubchik's avatar Sergei Golubchik

mhnsw: don't guess whether it's insert or update

we know it every time
parent aad082f4
...@@ -65,6 +65,7 @@ class FVectorNode: public FVector ...@@ -65,6 +65,7 @@ class FVectorNode: public FVector
int instantiate_vector(); int instantiate_vector();
size_t get_ref_len() const; size_t get_ref_len() const;
uchar *get_ref() const { return ref; } uchar *get_ref() const { return ref; }
bool is_new() const;
static uchar *get_key(const FVectorNode *elem, size_t *key_len, my_bool); static uchar *get_key(const FVectorNode *elem, size_t *key_len, my_bool);
}; };
...@@ -76,6 +77,7 @@ class MHNSW_Context ...@@ -76,6 +77,7 @@ class MHNSW_Context
TABLE *table; TABLE *table;
Field *vec_field; Field *vec_field;
size_t vec_len= 0; size_t vec_len= 0;
FVector *target= 0;
Hash_set<FVectorNode> node_cache{PSI_INSTRUMENT_MEM, FVectorNode::get_key}; Hash_set<FVectorNode> node_cache{PSI_INSTRUMENT_MEM, FVectorNode::get_key};
...@@ -133,6 +135,11 @@ size_t FVectorNode::get_ref_len() const ...@@ -133,6 +135,11 @@ size_t FVectorNode::get_ref_len() const
return ctx->table->file->ref_length; return ctx->table->file->ref_length;
} }
bool FVectorNode::is_new() const
{
return this == ctx->target;
}
uchar *FVectorNode::get_key(const FVectorNode *elem, size_t *key_len, my_bool) uchar *FVectorNode::get_key(const FVectorNode *elem, size_t *key_len, my_bool)
{ {
*key_len= elem->get_ref_len(); *key_len= elem->get_ref_len();
...@@ -327,6 +334,7 @@ static int write_neighbors(MHNSW_Context *ctx, size_t layer_number, ...@@ -327,6 +334,7 @@ static int write_neighbors(MHNSW_Context *ctx, size_t layer_number,
const FVectorNode &source_node, const FVectorNode &source_node,
const List<FVectorNode> &new_neighbors) const List<FVectorNode> &new_neighbors)
{ {
int err;
TABLE *graph= ctx->table->hlindex; TABLE *graph= ctx->table->hlindex;
DBUG_ASSERT(new_neighbors.elements <= HNSW_MAX_M); DBUG_ASSERT(new_neighbors.elements <= HNSW_MAX_M);
...@@ -349,25 +357,24 @@ static int write_neighbors(MHNSW_Context *ctx, size_t layer_number, ...@@ -349,25 +357,24 @@ static int write_neighbors(MHNSW_Context *ctx, size_t layer_number,
graph->field[1]->store_binary(source_node.get_ref(), source_node.get_ref_len()); graph->field[1]->store_binary(source_node.get_ref(), source_node.get_ref_len());
graph->field[2]->store_binary(neighbor_array_bytes, total_size); graph->field[2]->store_binary(neighbor_array_bytes, total_size);
uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length)); if (source_node.is_new())
key_copy(key, graph->record[0], graph->key_info, graph->key_info->key_length);
// XXX try to write first?
int err= graph->file->ha_index_read_map(graph->record[1], key, HA_WHOLE_KEY,
HA_READ_KEY_EXACT);
// no record
if (err == HA_ERR_KEY_NOT_FOUND)
{ {
dbug_print_vec_ref("INSERT ", layer_number, source_node); dbug_print_vec_ref("INSERT ", layer_number, source_node);
err= graph->file->ha_write_row(graph->record[0]); err= graph->file->ha_write_row(graph->record[0]);
} }
else if (!err) else
{ {
dbug_print_vec_ref("UPDATE ", layer_number, source_node); dbug_print_vec_ref("UPDATE ", layer_number, source_node);
dbug_print_vec_neigh(layer_number, new_neighbors); dbug_print_vec_neigh(layer_number, new_neighbors);
err= graph->file->ha_update_row(graph->record[1], graph->record[0]); uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length));
key_copy(key, graph->record[0], graph->key_info, graph->key_info->key_length);
err= graph->file->ha_index_read_map(graph->record[1], key,
HA_WHOLE_KEY, HA_READ_KEY_EXACT);
if (!err)
err= graph->file->ha_update_row(graph->record[1], graph->record[0]);
} }
my_safe_afree(neighbor_array_bytes, total_size); my_safe_afree(neighbor_array_bytes, total_size);
return err; return err;
...@@ -428,7 +435,7 @@ static int update_neighbors(MHNSW_Context *ctx, ...@@ -428,7 +435,7 @@ static int update_neighbors(MHNSW_Context *ctx,
} }
static int search_layer(MHNSW_Context *ctx, const FVector &target, static int search_layer(MHNSW_Context *ctx,
const List<FVectorNode> &start_nodes, const List<FVectorNode> &start_nodes,
uint max_candidates_return, size_t layer, uint max_candidates_return, size_t layer,
List<FVectorNode> *result) List<FVectorNode> *result)
...@@ -439,6 +446,7 @@ static int search_layer(MHNSW_Context *ctx, const FVector &target, ...@@ -439,6 +446,7 @@ static int search_layer(MHNSW_Context *ctx, const FVector &target,
Queue<FVectorNode, const FVector> candidates; Queue<FVectorNode, const FVector> candidates;
Queue<FVectorNode, const FVector> best; Queue<FVectorNode, const FVector> best;
Hash_set<FVectorNode> visited(PSI_INSTRUMENT_MEM, FVectorNode::get_key); Hash_set<FVectorNode> visited(PSI_INSTRUMENT_MEM, FVectorNode::get_key);
const FVector &target= *ctx->target;
candidates.init(10000, false, cmp_vec, &target); candidates.init(10000, false, cmp_vec, &target);
best.init(max_candidates_return, true, cmp_vec, &target); best.init(max_candidates_return, true, cmp_vec, &target);
...@@ -550,20 +558,21 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) ...@@ -550,20 +558,21 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
SCOPE_EXIT([graph](){ graph->file->ha_index_end(); }); SCOPE_EXIT([graph](){ graph->file->ha_index_end(); });
h->position(table->record[0]);
if (int err= graph->file->ha_index_last(graph->record[0])) if (int err= graph->file->ha_index_last(graph->record[0]))
{ {
if (err != HA_ERR_END_OF_FILE) if (err != HA_ERR_END_OF_FILE)
return err; return err;
// First insert! // First insert!
h->position(table->record[0]); FVectorNode target(&ctx, h->ref);
return write_neighbors(&ctx, 0, {&ctx, h->ref}, {}); ctx.target= &target;
return write_neighbors(&ctx, 0, target, {});
} }
longlong max_layer= graph->field[0]->val_int(); longlong max_layer= graph->field[0]->val_int();
h->position(table->record[0]);
List<FVectorNode> candidates; List<FVectorNode> candidates;
List<FVectorNode> start_nodes; List<FVectorNode> start_nodes;
String ref_str, *ref_ptr; String ref_str, *ref_ptr;
...@@ -583,6 +592,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) ...@@ -583,6 +592,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
return bad_value_on_insert(vec_field); return bad_value_on_insert(vec_field);
FVectorNode target(&ctx, h->ref, res->ptr()); FVectorNode target(&ctx, h->ref, res->ptr());
ctx.target= &target;
double new_num= my_rnd(&thd->rand); double new_num= my_rnd(&thd->rand);
double log= -std::log(new_num) * NORMALIZATION_FACTOR; double log= -std::log(new_num) * NORMALIZATION_FACTOR;
...@@ -590,7 +600,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) ...@@ -590,7 +600,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
for (longlong cur_layer= max_layer; cur_layer > new_node_layer; cur_layer--) for (longlong cur_layer= max_layer; cur_layer > new_node_layer; cur_layer--)
{ {
if (int err= search_layer(&ctx, target, start_nodes, if (int err= search_layer(&ctx, start_nodes,
thd->variables.hnsw_ef_constructor, cur_layer, thd->variables.hnsw_ef_constructor, cur_layer,
&candidates)) &candidates))
return err; return err;
...@@ -603,7 +613,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) ...@@ -603,7 +613,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
cur_layer >= 0; cur_layer--) cur_layer >= 0; cur_layer--)
{ {
List<FVectorNode> neighbors; List<FVectorNode> neighbors;
if (int err= search_layer(&ctx, target, start_nodes, if (int err= search_layer(&ctx, start_nodes,
thd->variables.hnsw_ef_constructor, cur_layer, thd->variables.hnsw_ef_constructor, cur_layer,
&candidates)) &candidates))
return err; return err;
...@@ -682,6 +692,7 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit) ...@@ -682,6 +692,7 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
res= vec_field->val_str(&buf); res= vec_field->val_str(&buf);
FVector target(&ctx, res->ptr()); FVector target(&ctx, res->ptr());
ctx.target= &target;
ulonglong ef_search= std::max<ulonglong>( //XXX why not always limit? ulonglong ef_search= std::max<ulonglong>( //XXX why not always limit?
thd->variables.hnsw_ef_search, limit); thd->variables.hnsw_ef_search, limit);
...@@ -689,16 +700,15 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit) ...@@ -689,16 +700,15 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
for (size_t cur_layer= max_layer; cur_layer > 0; cur_layer--) for (size_t cur_layer= max_layer; cur_layer > 0; cur_layer--)
{ {
//XXX in the paper ef_search=1 here //XXX in the paper ef_search=1 here
if (int err= search_layer(&ctx, target, start_nodes, ef_search, if (int err= search_layer(&ctx, start_nodes, ef_search, cur_layer,
cur_layer, &candidates)) &candidates))
return err; return err;
start_nodes.empty(); start_nodes.empty();
start_nodes.push_back(candidates.head(), &ctx.root); // XXX so ef_search=1 ??? start_nodes.push_back(candidates.head(), &ctx.root); // XXX so ef_search=1 ???
candidates.empty(); candidates.empty();
} }
if (int err= search_layer(&ctx, target, start_nodes, ef_search, 0, if (int err= search_layer(&ctx, start_nodes, ef_search, 0, &candidates))
&candidates))
return err; return err;
size_t context_size=limit * h->ref_length + sizeof(ulonglong); size_t context_size=limit * h->ref_length + sizeof(ulonglong);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment