Commit 960f8ac3 authored by Sergei Golubchik's avatar Sergei Golubchik

mhnsw: fix memory management

move everything into a query-local memroot which is freed at the end
parent 57db6c20
...@@ -42,13 +42,13 @@ const LEX_CSTRING mhnsw_hlindex_table={STRING_WITH_LEN("\ ...@@ -42,13 +42,13 @@ const LEX_CSTRING mhnsw_hlindex_table={STRING_WITH_LEN("\
")}; ")};
class FVectorRef class FVectorRef: public Sql_alloc
{ {
public: public:
// Shallow ref copy. Used for other ref lookups in HashSet // Shallow ref copy. Used for other ref lookups in HashSet
FVectorRef(const void *ref, size_t ref_len): ref{(uchar*)ref}, ref_len{ref_len} {} FVectorRef(const void *ref, size_t ref_len): ref{(uchar*)ref}, ref_len{ref_len} {}
static const uchar *get_key(const FVectorRef *elem, size_t *key_len, my_bool) static uchar *get_key(const FVectorRef *elem, size_t *key_len, my_bool)
{ {
*key_len= elem->ref_len; *key_len= elem->ref_len;
return elem->ref; return elem->ref;
...@@ -60,7 +60,7 @@ class FVectorRef ...@@ -60,7 +60,7 @@ class FVectorRef
} }
size_t get_ref_len() const { return ref_len; } size_t get_ref_len() const { return ref_len; }
const uchar* get_ref() const { return ref; } uchar* get_ref() const { return ref; }
protected: protected:
FVectorRef() = default; FVectorRef() = default;
...@@ -68,12 +68,6 @@ class FVectorRef ...@@ -68,12 +68,6 @@ class FVectorRef
size_t ref_len; size_t ref_len;
}; };
Hash_set<FVectorRef> all_vector_set(PSI_INSTRUMENT_MEM, &my_charset_bin,
1000, 0, 0, (my_hash_get_key)FVectorRef::get_key, 0, HASH_UNIQUE);
Hash_set<FVectorRef> all_vector_ref_set(PSI_INSTRUMENT_MEM, &my_charset_bin,
1000, 0, 0, (my_hash_get_key)FVectorRef::get_key, NULL, HASH_UNIQUE);
class FVector: public FVectorRef class FVector: public FVectorRef
{ {
private: private:
...@@ -81,83 +75,90 @@ class FVector: public FVectorRef ...@@ -81,83 +75,90 @@ class FVector: public FVectorRef
size_t vec_len; size_t vec_len;
public: public:
FVector(): vec(nullptr), vec_len(0) {} FVector(): vec(nullptr), vec_len(0) {}
~FVector() { my_free(this->ref); }
bool init(const uchar *ref, size_t ref_len, const void *vec, size_t bytes) bool init(MEM_ROOT *root, const uchar *ref_, size_t ref_len_, const void *vec_, size_t bytes)
{ {
this->ref= (uchar*)my_malloc(PSI_NOT_INSTRUMENTED, ref_len + bytes, MYF(0)); ref= (uchar*)alloc_root(root, ref_len_ + bytes);
if (!this->ref) if (!ref)
return true; return true;
this->vec= reinterpret_cast<float *>(this->ref + ref_len); vec= reinterpret_cast<float *>(ref + ref_len_);
memcpy(this->ref, ref, ref_len); memcpy(ref, ref_, ref_len_);
memcpy(this->vec, vec, bytes); memcpy(vec, vec_, bytes);
this->ref_len= ref_len; ref_len= ref_len_;
this->vec_len= bytes / sizeof(float); vec_len= bytes / sizeof(float);
return false; return false;
} }
size_t size_of() const { return vec_len * sizeof(float); } size_t size_of() const { return vec_len * sizeof(float); }
size_t get_vec_len() const { return vec_len; }
const float* get_vec() const { return vec; }
float distance_to(const FVector &other) const float distance_to(const FVector &other) const
{ {
DBUG_ASSERT(other.vec_len == vec_len); DBUG_ASSERT(other.vec_len == vec_len);
return euclidean_vec_distance(vec, other.vec, vec_len); return euclidean_vec_distance(vec, other.vec, vec_len);
} }
};
class MHNSW_Context
{
public:
MEM_ROOT root;
TABLE *table;
Field *vec_field;
Hash_set<FVectorRef> vector_cache{PSI_INSTRUMENT_MEM, FVectorRef::get_key};
Hash_set<FVectorRef> vector_ref_cache{PSI_INSTRUMENT_MEM, FVectorRef::get_key};
static FVectorRef *get_fvector_ref(const uchar *ref, size_t ref_len) MHNSW_Context(TABLE *table, Field *vec_field)
: table(table), vec_field(vec_field)
{ {
FVectorRef tmp{ref, ref_len}; init_alloc_root(PSI_INSTRUMENT_MEM, &root, 8192, 0, MYF(MY_THREAD_SPECIFIC));
FVectorRef *v= all_vector_ref_set.find(&tmp); }
if (v)
return v;
// TODO(cvicentiu) memory management. ~MHNSW_Context()
uchar *buf= (uchar *)my_malloc(PSI_NOT_INSTRUMENTED, ref_len, MYF(0));
if (buf)
{ {
memcpy(buf, ref, ref_len); free_root(&root, MYF(0));
if ((v= new FVectorRef(buf, ref_len)))
all_vector_ref_set.insert(v);
} }
FVectorRef *get_fvector_ref(const uchar *ref, size_t ref_len)
{
FVectorRef tmp(ref, ref_len);
FVectorRef *v= vector_ref_cache.find(&tmp);
if (v)
return v;
uchar *buf= (uchar*)memdup_root(&root, ref, ref_len);
if ((v= new (&root) FVectorRef(buf, ref_len)))
vector_ref_cache.insert(v);
return v; return v;
} }
static FVector *get_fvector_from_source(TABLE *source, Field *vec_field, FVector *get_fvector_from_source(const FVectorRef &ref)
const FVectorRef &ref)
{ {
FVectorRef *v= all_vector_set.find(&ref); FVectorRef *v= vector_cache.find(&ref);
if (v) if (v)
return (FVector *)v; return (FVector *)v;
FVector *new_vector= new FVector; if (table->file->ha_rnd_pos(table->record[0], ref.get_ref()))
if (!new_vector) return nullptr; // XXX the error code is lost
return nullptr;
source->file->ha_rnd_pos(source->record[0], String buf, *vec= vec_field->val_str(&buf);
const_cast<uchar *>(ref.get_ref()));
String buf, *vec; FVector *new_vector= new (&root) FVector;
vec= vec_field->val_str(&buf); new_vector->init(&root, ref.get_ref(), ref.get_ref_len(), vec->ptr(), vec->length());
// TODO(cvicentiu) error checking vector_cache.insert(new_vector);
new_vector->init(ref.get_ref(), ref.get_ref_len(), vec->ptr(), vec->length());
all_vector_set.insert(new_vector);
return new_vector; return new_vector;
} }
}; };
static int cmp_vec(const FVector *reference, const FVector *a, const FVector *b) static int cmp_vec(const FVector *target, const FVector *a, const FVector *b)
{ {
float a_dist= reference->distance_to(*a); float a_dist= a->distance_to(*target);
float b_dist= reference->distance_to(*b); float b_dist= b->distance_to(*target);
if (a_dist < b_dist) if (a_dist < b_dist)
return -1; return -1;
...@@ -169,10 +170,11 @@ static int cmp_vec(const FVector *reference, const FVector *a, const FVector *b) ...@@ -169,10 +170,11 @@ static int cmp_vec(const FVector *reference, const FVector *a, const FVector *b)
const bool KEEP_PRUNED_CONNECTIONS=true; // XXX why? const bool KEEP_PRUNED_CONNECTIONS=true; // XXX why?
const bool EXTEND_CANDIDATES=true; // XXX or false? const bool EXTEND_CANDIDATES=true; // XXX or false?
static int get_neighbors(TABLE *graph, size_t layer_number, static int get_neighbors(MHNSW_Context *ctx, size_t layer_number,
const FVectorRef &source_node, const FVectorRef &source_node,
List<FVectorRef> *neighbors) List<FVectorRef> *neighbors)
{ {
TABLE *graph= ctx->table->hlindex;
uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length)); uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length));
graph->field[0]->store(layer_number, false); graph->field[0]->store(layer_number, false);
...@@ -195,10 +197,10 @@ static int get_neighbors(TABLE *graph, size_t layer_number, ...@@ -195,10 +197,10 @@ static int get_neighbors(TABLE *graph, size_t layer_number,
const uchar *pos= neigh_arr_bytes + HNSW_MAX_M_WIDTH; const uchar *pos= neigh_arr_bytes + HNSW_MAX_M_WIDTH;
for (uint i= 0; i < number_of_neighbors; i++) for (uint i= 0; i < number_of_neighbors; i++)
{ {
FVectorRef *v= FVector::get_fvector_ref(pos, ref_length); FVectorRef *v= ctx->get_fvector_ref(pos, ref_length);
if (!v) if (!v)
return HA_ERR_OUT_OF_MEM; return HA_ERR_OUT_OF_MEM;
neighbors->push_back(v); neighbors->push_back(v, &ctx->root);
pos+= ref_length; pos+= ref_length;
} }
...@@ -206,7 +208,7 @@ static int get_neighbors(TABLE *graph, size_t layer_number, ...@@ -206,7 +208,7 @@ static int get_neighbors(TABLE *graph, size_t layer_number,
} }
static int select_neighbors(TABLE *source, TABLE *graph, Field *vec_field, static int select_neighbors(MHNSW_Context *ctx,
size_t layer_number, const FVector &target, size_t layer_number, const FVector &target,
const List<FVectorRef> &candidates, const List<FVectorRef> &candidates,
size_t max_neighbor_connections, size_t max_neighbor_connections,
...@@ -217,9 +219,7 @@ static int select_neighbors(TABLE *source, TABLE *graph, Field *vec_field, ...@@ -217,9 +219,7 @@ static int select_neighbors(TABLE *source, TABLE *graph, Field *vec_field,
no need to do additional queue build steps here. no need to do additional queue build steps here.
*/ */
Hash_set<FVectorRef> visited(PSI_INSTRUMENT_MEM, &my_charset_bin, 1000, 0, Hash_set<FVectorRef> visited(PSI_INSTRUMENT_MEM, FVectorRef::get_key);
0, (my_hash_get_key)FVectorRef::get_key,
NULL, HASH_UNIQUE);
Queue<FVector, const FVector> pq; // working queue Queue<FVector, const FVector> pq; // working queue
Queue<FVector, const FVector> pq_discard; // queue for discarded candidates Queue<FVector, const FVector> pq_discard; // queue for discarded candidates
...@@ -234,7 +234,7 @@ static int select_neighbors(TABLE *source, TABLE *graph, Field *vec_field, ...@@ -234,7 +234,7 @@ static int select_neighbors(TABLE *source, TABLE *graph, Field *vec_field,
for (const FVectorRef &candidate : candidates) for (const FVectorRef &candidate : candidates)
{ {
FVector *v= FVector::get_fvector_from_source(source, vec_field, candidate); FVector *v= ctx->get_fvector_from_source(candidate);
if (!v) if (!v)
return HA_ERR_OUT_OF_MEM; return HA_ERR_OUT_OF_MEM;
visited.insert(&candidate); visited.insert(&candidate);
...@@ -246,7 +246,7 @@ static int select_neighbors(TABLE *source, TABLE *graph, Field *vec_field, ...@@ -246,7 +246,7 @@ static int select_neighbors(TABLE *source, TABLE *graph, Field *vec_field,
for (const FVectorRef &candidate : candidates) for (const FVectorRef &candidate : candidates)
{ {
List<FVectorRef> candidate_neighbors; List<FVectorRef> candidate_neighbors;
if (int err= get_neighbors(graph, layer_number, candidate, if (int err= get_neighbors(ctx, layer_number, candidate,
&candidate_neighbors)) &candidate_neighbors))
return err; return err;
for (const FVectorRef &extra_candidate : candidate_neighbors) for (const FVectorRef &extra_candidate : candidate_neighbors)
...@@ -254,8 +254,7 @@ static int select_neighbors(TABLE *source, TABLE *graph, Field *vec_field, ...@@ -254,8 +254,7 @@ static int select_neighbors(TABLE *source, TABLE *graph, Field *vec_field,
if (visited.find(&extra_candidate)) if (visited.find(&extra_candidate))
continue; continue;
visited.insert(&extra_candidate); visited.insert(&extra_candidate);
FVector *v= FVector::get_fvector_from_source(source, vec_field, FVector *v= ctx->get_fvector_from_source(extra_candidate);
extra_candidate);
if (!v) if (!v)
return HA_ERR_OUT_OF_MEM; return HA_ERR_OUT_OF_MEM;
pq.push(v); pq.push(v);
...@@ -292,7 +291,7 @@ static int select_neighbors(TABLE *source, TABLE *graph, Field *vec_field, ...@@ -292,7 +291,7 @@ static int select_neighbors(TABLE *source, TABLE *graph, Field *vec_field,
DBUG_ASSERT(best.elements() <= max_neighbor_connections); DBUG_ASSERT(best.elements() <= max_neighbor_connections);
while (best.elements()) // XXX why not to return best directly? while (best.elements()) // XXX why not to return best directly?
neighbors->push_front(best.pop()); neighbors->push_front(best.pop(), &ctx->root);
return 0; return 0;
} }
...@@ -337,10 +336,11 @@ static void dbug_print_hash_vec(Hash_set<FVectorRef> &h) ...@@ -337,10 +336,11 @@ static void dbug_print_hash_vec(Hash_set<FVectorRef> &h)
} }
static int write_neighbors(TABLE *graph, size_t layer_number, static int write_neighbors(MHNSW_Context *ctx, size_t layer_number,
const FVectorRef &source_node, const FVectorRef &source_node,
const List<FVectorRef> &new_neighbors) const List<FVectorRef> &new_neighbors)
{ {
TABLE *graph= ctx->table->hlindex;
DBUG_ASSERT(new_neighbors.elements <= HNSW_MAX_M); DBUG_ASSERT(new_neighbors.elements <= HNSW_MAX_M);
size_t total_size= HNSW_MAX_M_WIDTH + new_neighbors.elements * source_node.get_ref_len(); size_t total_size= HNSW_MAX_M_WIDTH + new_neighbors.elements * source_node.get_ref_len();
...@@ -365,6 +365,7 @@ static int write_neighbors(TABLE *graph, size_t layer_number, ...@@ -365,6 +365,7 @@ static int write_neighbors(TABLE *graph, size_t layer_number,
uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length)); uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length));
key_copy(key, graph->record[0], graph->key_info, graph->key_info->key_length); key_copy(key, graph->record[0], graph->key_info, graph->key_info->key_length);
// XXX try to write first?
int err= graph->file->ha_index_read_map(graph->record[1], key, HA_WHOLE_KEY, int err= graph->file->ha_index_read_map(graph->record[1], key, HA_WHOLE_KEY,
HA_READ_KEY_EXACT); HA_READ_KEY_EXACT);
...@@ -386,8 +387,8 @@ static int write_neighbors(TABLE *graph, size_t layer_number, ...@@ -386,8 +387,8 @@ static int write_neighbors(TABLE *graph, size_t layer_number,
} }
static int update_second_degree_neighbors(TABLE *source, Field *vec_field, static int update_second_degree_neighbors(MHNSW_Context *ctx,
TABLE *graph, size_t layer_number, size_t layer_number,
uint max_neighbors, uint max_neighbors,
const FVectorRef &source_node, const FVectorRef &source_node,
const List<FVectorRef> &neighbors) const List<FVectorRef> &neighbors)
...@@ -397,92 +398,84 @@ static int update_second_degree_neighbors(TABLE *source, Field *vec_field, ...@@ -397,92 +398,84 @@ static int update_second_degree_neighbors(TABLE *source, Field *vec_field,
for (const FVectorRef &neigh: neighbors) // XXX why this loop? for (const FVectorRef &neigh: neighbors) // XXX why this loop?
{ {
List<FVectorRef> new_neighbors; List<FVectorRef> new_neighbors;
if (int err= get_neighbors(graph, layer_number, neigh, &new_neighbors)) if (int err= get_neighbors(ctx, layer_number, neigh, &new_neighbors))
return err; return err;
new_neighbors.push_back(&source_node); new_neighbors.push_back(&source_node, &ctx->root);
if (int err= write_neighbors(graph, layer_number, neigh, new_neighbors)) if (int err= write_neighbors(ctx, layer_number, neigh, new_neighbors))
return err; return err;
} }
for (const FVectorRef &neigh: neighbors) for (const FVectorRef &neigh: neighbors)
{ {
List<FVectorRef> new_neighbors; List<FVectorRef> new_neighbors;
if (int err= get_neighbors(graph, layer_number, neigh, &new_neighbors)) if (int err= get_neighbors(ctx, layer_number, neigh, &new_neighbors))
return err; return err;
if (new_neighbors.elements > max_neighbors) if (new_neighbors.elements > max_neighbors)
{ {
// shrink the neighbors // shrink the neighbors
List<FVectorRef> selected; List<FVectorRef> selected;
FVector *v= FVector::get_fvector_from_source(source, vec_field, neigh); FVector *v= ctx->get_fvector_from_source(neigh);
if (!v) if (!v)
return HA_ERR_OUT_OF_MEM; return HA_ERR_OUT_OF_MEM;
if (int err= select_neighbors(source, graph, vec_field, layer_number, if (int err= select_neighbors(ctx, layer_number, *v,
*v, new_neighbors, max_neighbors, &selected)) new_neighbors, max_neighbors, &selected))
return err; return err;
if (int err= write_neighbors(graph, layer_number, neigh, selected)) if (int err= write_neighbors(ctx, layer_number, neigh, selected))
return err; return err;
} }
// release memory
new_neighbors.empty();
} }
return 0; return 0;
} }
static int update_neighbors(TABLE *source, TABLE *graph, Field *vec_field, static int update_neighbors(MHNSW_Context *ctx,
size_t layer_number, uint max_neighbors, size_t layer_number, uint max_neighbors,
const FVectorRef &source_node, const FVectorRef &source_node,
const List<FVectorRef> &neighbors) const List<FVectorRef> &neighbors)
{ {
// 1. update node's neighbors // 1. update node's neighbors
if (int err= write_neighbors(graph, layer_number, source_node, neighbors)) if (int err= write_neighbors(ctx, layer_number, source_node, neighbors))
return err; return err;
// 2. update node's neighbors' neighbors (shrink before update) // 2. update node's neighbors' neighbors (shrink before update)
return update_second_degree_neighbors(source, vec_field, graph, layer_number, return update_second_degree_neighbors(ctx, layer_number,
max_neighbors, source_node, neighbors); max_neighbors, source_node, neighbors);
} }
static int search_layer(TABLE *source, TABLE *graph, Field *vec_field, static int search_layer(MHNSW_Context *ctx, const FVector &target,
const FVector &target,
const List<FVectorRef> &start_nodes, const List<FVectorRef> &start_nodes,
uint max_candidates_return, size_t layer, uint max_candidates_return, size_t layer,
List<FVectorRef> *result) List<FVectorRef> *result)
{ {
DBUG_ASSERT(start_nodes.elements > 0); DBUG_ASSERT(start_nodes.elements > 0);
// Result list must be empty, otherwise there's a risk of memory leak
DBUG_ASSERT(result->elements == 0); DBUG_ASSERT(result->elements == 0);
Queue<FVector, const FVector> candidates; Queue<FVector, const FVector> candidates;
Queue<FVector, const FVector> best; Queue<FVector, const FVector> best;
//TODO(cvicentiu) Fix this hash method. Hash_set<FVectorRef> visited(PSI_INSTRUMENT_MEM, FVectorRef::get_key);
Hash_set<FVectorRef> visited(PSI_INSTRUMENT_MEM, &my_charset_bin, 1000, 0, 0,
(my_hash_get_key)FVectorRef::get_key, NULL,
HASH_UNIQUE);
candidates.init(10000, false, cmp_vec, &target); candidates.init(10000, false, cmp_vec, &target);
best.init(max_candidates_return, true, cmp_vec, &target); best.init(max_candidates_return, true, cmp_vec, &target);
for (const FVectorRef &node : start_nodes) for (const FVectorRef &node : start_nodes)
{ {
FVector *v= FVector::get_fvector_from_source(source, vec_field, node); FVector *v= ctx->get_fvector_from_source(node);
candidates.push(v); candidates.push(v);
if (best.elements() < max_candidates_return) if (best.elements() < max_candidates_return)
best.push(v); best.push(v);
else if (target.distance_to(*v) > target.distance_to(*best.top())) else if (v->distance_to(target) > best.top()->distance_to(target))
best.replace_top(v); best.replace_top(v);
visited.insert(v); visited.insert(v);
dbug_print_vec_ref("INSERTING node in visited: ", layer, node); dbug_print_vec_ref("INSERTING node in visited: ", layer, node);
} }
float furthest_best= target.distance_to(*best.top()); float furthest_best= best.top()->distance_to(target);
while (candidates.elements()) while (candidates.elements())
{ {
const FVector &cur_vec= *candidates.pop(); const FVector &cur_vec= *candidates.pop();
float cur_distance= target.distance_to(cur_vec); float cur_distance= cur_vec.distance_to(target);
if (cur_distance > furthest_best && best.elements() == max_candidates_return) if (cur_distance > furthest_best && best.elements() == max_candidates_return)
{ {
break; // All possible candidates are worse than what we have. break; // All possible candidates are worse than what we have.
...@@ -490,7 +483,7 @@ static int search_layer(TABLE *source, TABLE *graph, Field *vec_field, ...@@ -490,7 +483,7 @@ static int search_layer(TABLE *source, TABLE *graph, Field *vec_field,
} }
List<FVectorRef> neighbors; List<FVectorRef> neighbors;
get_neighbors(graph, layer, cur_vec, &neighbors); get_neighbors(ctx, layer, cur_vec, &neighbors);
for (const FVectorRef &neigh: neighbors) for (const FVectorRef &neigh: neighbors)
{ {
...@@ -498,20 +491,19 @@ static int search_layer(TABLE *source, TABLE *graph, Field *vec_field, ...@@ -498,20 +491,19 @@ static int search_layer(TABLE *source, TABLE *graph, Field *vec_field,
if (visited.find(&neigh)) if (visited.find(&neigh))
continue; continue;
FVector *clone= FVector::get_fvector_from_source(source, vec_field, neigh); FVector *clone= ctx->get_fvector_from_source(neigh);
// TODO(cvicentiu) mem ownership...
visited.insert(clone); visited.insert(clone);
if (best.elements() < max_candidates_return) if (best.elements() < max_candidates_return)
{ {
candidates.push(clone); candidates.push(clone);
best.push(clone); best.push(clone);
furthest_best= target.distance_to(*best.top()); furthest_best= best.top()->distance_to(target);
} }
else if (target.distance_to(*clone) < furthest_best) else if (clone->distance_to(target) < furthest_best)
{ {
best.replace_top(clone); best.replace_top(clone);
candidates.push(clone); candidates.push(clone);
furthest_best= target.distance_to(*best.top()); furthest_best= best.top()->distance_to(target);
} }
} }
neighbors.empty(); neighbors.empty();
...@@ -520,9 +512,8 @@ static int search_layer(TABLE *source, TABLE *graph, Field *vec_field, ...@@ -520,9 +512,8 @@ static int search_layer(TABLE *source, TABLE *graph, Field *vec_field,
while (best.elements()) while (best.elements())
{ {
// TODO(cvicentiu) FVector memory leak.
// TODO(cvicentiu) this is n*log(n), we need a queue iterator. // TODO(cvicentiu) this is n*log(n), we need a queue iterator.
result->push_front(best.pop()); result->push_front(best.pop(), &ctx->root);
} }
return 0; return 0;
...@@ -547,6 +538,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) ...@@ -547,6 +538,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
Field *vec_field= keyinfo->key_part->field; Field *vec_field= keyinfo->key_part->field;
String buf, *res= vec_field->val_str(&buf); String buf, *res= vec_field->val_str(&buf);
handler *h= table->file->lookup_handler; handler *h= table->file->lookup_handler;
MHNSW_Context ctx(table, vec_field);
/* metadata are checked on open */ /* metadata are checked on open */
DBUG_ASSERT(graph); DBUG_ASSERT(graph);
...@@ -583,7 +575,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) ...@@ -583,7 +575,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
// First insert! // First insert!
h->position(table->record[0]); h->position(table->record[0]);
return write_neighbors(graph, 0, {h->ref, h->ref_length}, {}); return write_neighbors(&ctx, 0, {h->ref, h->ref_length}, {});
} }
longlong max_layer= graph->field[0]->val_int(); longlong max_layer= graph->field[0]->val_int();
...@@ -599,10 +591,10 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) ...@@ -599,10 +591,10 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
// TODO(cvicentiu) use a random start node in last layer. // TODO(cvicentiu) use a random start node in last layer.
// XXX or may be *all* nodes in the last layer? there should be few // XXX or may be *all* nodes in the last layer? there should be few
if (start_nodes.push_back(&start_node_ref)) if (start_nodes.push_back(&start_node_ref, &ctx.root))
return HA_ERR_OUT_OF_MEM; return HA_ERR_OUT_OF_MEM;
FVector *v= FVector::get_fvector_from_source(table, vec_field, start_node_ref); FVector *v= ctx.get_fvector_from_source(start_node_ref);
if (!v) if (!v)
return HA_ERR_OUT_OF_MEM; return HA_ERR_OUT_OF_MEM;
...@@ -610,7 +602,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) ...@@ -610,7 +602,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
return bad_value_on_insert(vec_field); return bad_value_on_insert(vec_field);
FVector target; FVector target;
target.init(h->ref, h->ref_length, res->ptr(), res->length()); target.init(&ctx.root, h->ref, h->ref_length, res->ptr(), res->length());
double new_num= my_rnd(&thd->rand); double new_num= my_rnd(&thd->rand);
double log= -std::log(new_num) * NORMALIZATION_FACTOR; double log= -std::log(new_num) * NORMALIZATION_FACTOR;
...@@ -618,37 +610,33 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) ...@@ -618,37 +610,33 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
for (longlong cur_layer= max_layer; cur_layer > new_node_layer; cur_layer--) for (longlong cur_layer= max_layer; cur_layer > new_node_layer; cur_layer--)
{ {
if (int err= search_layer(table, graph, vec_field, target, start_nodes, if (int err= search_layer(&ctx, target, start_nodes,
thd->variables.hnsw_ef_constructor, cur_layer, thd->variables.hnsw_ef_constructor, cur_layer,
&candidates)) &candidates))
return err; return err;
start_nodes.empty(); start_nodes.empty();
start_nodes.push_back(candidates.head()); // XXX ef=1 start_nodes.push_back(candidates.head(), &ctx.root); // XXX ef=1
//candidates.delete_elements();
candidates.empty(); candidates.empty();
//TODO(cvicentiu) memory leak
} }
for (longlong cur_layer= std::min(max_layer, new_node_layer); for (longlong cur_layer= std::min(max_layer, new_node_layer);
cur_layer >= 0; cur_layer--) cur_layer >= 0; cur_layer--)
{ {
List<FVectorRef> neighbors; List<FVectorRef> neighbors;
if (int err= search_layer(table, graph, vec_field, target, start_nodes, if (int err= search_layer(&ctx, target, start_nodes,
thd->variables.hnsw_ef_constructor, cur_layer, thd->variables.hnsw_ef_constructor, cur_layer,
&candidates)) &candidates))
return err; return err;
// release vectors
start_nodes.empty();
uint max_neighbors= (cur_layer == 0) ? // heuristics from the paper uint max_neighbors= (cur_layer == 0) ? // heuristics from the paper
thd->variables.hnsw_max_connection_per_layer * 2 thd->variables.hnsw_max_connection_per_layer * 2
: thd->variables.hnsw_max_connection_per_layer; : thd->variables.hnsw_max_connection_per_layer;
if (int err= select_neighbors(table, graph, vec_field, cur_layer, target, if (int err= select_neighbors(&ctx, cur_layer, target, candidates,
candidates, max_neighbors, &neighbors)) max_neighbors, &neighbors))
return err; return err;
if (int err= update_neighbors(table, graph, vec_field, cur_layer, if (int err= update_neighbors(&ctx, cur_layer, max_neighbors, target,
max_neighbors, target, neighbors)) neighbors))
return err; return err;
start_nodes= candidates; start_nodes= candidates;
} }
...@@ -658,7 +646,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) ...@@ -658,7 +646,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
for (longlong cur_layer= max_layer + 1; cur_layer <= new_node_layer; for (longlong cur_layer= max_layer + 1; cur_layer <= new_node_layer;
cur_layer++) cur_layer++)
{ {
if (int err= write_neighbors(graph, cur_layer, target, {})) if (int err= write_neighbors(&ctx, cur_layer, target, {}))
return err; return err;
} }
...@@ -676,6 +664,7 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit) ...@@ -676,6 +664,7 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
Item_func_vec_distance *fun= (Item_func_vec_distance *)dist; Item_func_vec_distance *fun= (Item_func_vec_distance *)dist;
String buf, *res= fun->get_const_arg()->val_str(&buf); String buf, *res= fun->get_const_arg()->val_str(&buf);
handler *h= table->file; handler *h= table->file;
MHNSW_Context ctx(table, vec_field);
if (int err= h->ha_rnd_init(0)) if (int err= h->ha_rnd_init(0))
return err; return err;
...@@ -699,10 +688,10 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit) ...@@ -699,10 +688,10 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
// TODO(cvicentiu) use a random start node in last layer. // TODO(cvicentiu) use a random start node in last layer.
// XXX or may be *all* nodes in the last layer? there should be few // XXX or may be *all* nodes in the last layer? there should be few
if (start_nodes.push_back(&start_node_ref)) if (start_nodes.push_back(&start_node_ref, &ctx.root))
return HA_ERR_OUT_OF_MEM; return HA_ERR_OUT_OF_MEM;
FVector *v= FVector::get_fvector_from_source(table, vec_field, start_node_ref); FVector *v= ctx.get_fvector_from_source(start_node_ref);
if (!v) if (!v)
return HA_ERR_OUT_OF_MEM; return HA_ERR_OUT_OF_MEM;
...@@ -712,10 +701,10 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit) ...@@ -712,10 +701,10 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
in any order. For simplicity let's sort by the start_node. in any order. For simplicity let's sort by the start_node.
*/ */
if (!res || v->size_of() != res->length()) if (!res || v->size_of() != res->length())
(res= &buf)->set((const char*)(v->get_vec()), v->size_of(), &my_charset_bin); res= vec_field->val_str(&buf);
FVector target; FVector target;
if (target.init(h->ref, h->ref_length, res->ptr(), res->length())) if (target.init(&ctx.root, h->ref, h->ref_length, res->ptr(), res->length()))
return HA_ERR_OUT_OF_MEM; return HA_ERR_OUT_OF_MEM;
ulonglong ef_search= std::max<ulonglong>( //XXX why not always limit? ulonglong ef_search= std::max<ulonglong>( //XXX why not always limit?
...@@ -724,19 +713,16 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit) ...@@ -724,19 +713,16 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
for (size_t cur_layer= max_layer; cur_layer > 0; cur_layer--) for (size_t cur_layer= max_layer; cur_layer > 0; cur_layer--)
{ {
//XXX in the paper ef_search=1 here //XXX in the paper ef_search=1 here
if (int err= search_layer(table, graph, vec_field, target, start_nodes, if (int err= search_layer(&ctx, target, start_nodes, ef_search,
ef_search, cur_layer, &candidates)) cur_layer, &candidates))
return err; return err;
start_nodes.empty(); start_nodes.empty();
//start_nodes.delete_elements(); start_nodes.push_back(candidates.head(), &ctx.root); // XXX so ef_search=1 ???
start_nodes.push_back(candidates.head()); // XXX so ef_search=1 ???
//candidates.delete_elements();
candidates.empty(); candidates.empty();
//TODO(cvicentiu) memleak.
} }
if (int err= search_layer(table, graph, vec_field, target, start_nodes, if (int err= search_layer(&ctx, target, start_nodes, ef_search, 0,
ef_search, 0, &candidates)) &candidates))
return err; return err;
size_t context_size=limit * h->ref_length + sizeof(ulonglong); size_t context_size=limit * h->ref_length + sizeof(ulonglong);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment