Commit fe0f7d20 authored by Sergei Golubchik's avatar Sergei Golubchik

mhnsw: SIMD for euclidean distance

parent 365afe70
...@@ -25,6 +25,13 @@ ...@@ -25,6 +25,13 @@
static constexpr float alpha = 1.1f; static constexpr float alpha = 1.1f;
static constexpr uint ef_construction= 10; static constexpr uint ef_construction= 10;
// SIMD definitions
#define SIMD_word (256/8)
#define SIMD_floats (SIMD_word/sizeof(float))
// how many extra bytes we need to alloc to be able to convert
// sizeof(double) aligned memory to SIMD_word aligned
#define SIMD_margin (SIMD_word - sizeof(double))
class MHNSW_Context; class MHNSW_Context;
class FVector: public Sql_alloc class FVector: public Sql_alloc
...@@ -35,6 +42,7 @@ class FVector: public Sql_alloc ...@@ -35,6 +42,7 @@ class FVector: public Sql_alloc
float *vec; float *vec;
protected: protected:
FVector(MHNSW_Context *ctx_) : ctx(ctx_), vec(nullptr) {} FVector(MHNSW_Context *ctx_) : ctx(ctx_), vec(nullptr) {}
void make_vec(const void *vec_);
}; };
class FVectorNode: public FVector class FVectorNode: public FVector
...@@ -64,6 +72,7 @@ class MHNSW_Context ...@@ -64,6 +72,7 @@ class MHNSW_Context
TABLE *table; TABLE *table;
Field *vec_field; Field *vec_field;
size_t vec_len= 0; size_t vec_len= 0;
size_t byte_len= 0;
FVector *target= 0; FVector *target= 0;
Hash_set<FVectorNode> node_cache{PSI_INSTRUMENT_MEM, FVectorNode::get_key}; Hash_set<FVectorNode> node_cache{PSI_INSTRUMENT_MEM, FVectorNode::get_key};
...@@ -84,7 +93,18 @@ class MHNSW_Context ...@@ -84,7 +93,18 @@ class MHNSW_Context
FVector::FVector(MHNSW_Context *ctx_, const void *vec_) : ctx(ctx_) FVector::FVector(MHNSW_Context *ctx_, const void *vec_) : ctx(ctx_)
{ {
vec= (float*)memdup_root(&ctx->root, vec_, ctx->vec_len * sizeof(float)); make_vec(vec_);
}
void FVector::make_vec(const void *vec_)
{
vec= (float*)alloc_root(&ctx->root,
ctx->vec_len * sizeof(float) + SIMD_margin);
if (int off= ((intptr)vec) % SIMD_word)
vec += (SIMD_word - off) / sizeof(float);
memcpy(vec, vec_, ctx->byte_len);
for (size_t i=ctx->byte_len/sizeof(float); i < ctx->vec_len; i++)
vec[i]=0;
} }
FVectorNode::FVectorNode(MHNSW_Context *ctx_, const void *ref_) FVectorNode::FVectorNode(MHNSW_Context *ctx_, const void *ref_)
...@@ -103,7 +123,20 @@ float FVectorNode::distance_to(const FVector &other) const ...@@ -103,7 +123,20 @@ float FVectorNode::distance_to(const FVector &other) const
{ {
if (!vec) if (!vec)
const_cast<FVectorNode*>(this)->instantiate_vector(); const_cast<FVectorNode*>(this)->instantiate_vector();
#if __GNUC__ > 7
typedef float v8f __attribute__((vector_size(SIMD_word)));
v8f *p1= (v8f*)vec;
v8f *p2= (v8f*)other.vec;
v8f d= {0,0,0,0,0,0,0,0};
for (size_t i= 0; i < ctx->vec_len/SIMD_floats; p1++, p2++, i++)
{
v8f dist= *p1 - *p2;
d+= dist * dist;
}
return d[0] + d[1] + d[2] + d[3] + d[4] + d[5] + d[6] + d[7];
#else
return euclidean_vec_distance(vec, other.vec, ctx->vec_len); return euclidean_vec_distance(vec, other.vec, ctx->vec_len);
#endif
} }
int FVectorNode::instantiate_vector() int FVectorNode::instantiate_vector()
...@@ -112,8 +145,12 @@ int FVectorNode::instantiate_vector() ...@@ -112,8 +145,12 @@ int FVectorNode::instantiate_vector()
if (int err= ctx->table->file->ha_rnd_pos(ctx->table->record[0], ref)) if (int err= ctx->table->file->ha_rnd_pos(ctx->table->record[0], ref))
return err; return err;
String buf, *v= ctx->vec_field->val_str(&buf); String buf, *v= ctx->vec_field->val_str(&buf);
ctx->vec_len= v->length() / sizeof(float); if (unlikely(ctx->byte_len == 0))
vec= (float*)memdup_root(&ctx->root, v->ptr(), v->length()); {
ctx->byte_len= v->length();
ctx->vec_len= MY_ALIGN(ctx->byte_len/sizeof(float), SIMD_floats);
}
make_vec(v->ptr());
return 0; return 0;
} }
...@@ -469,7 +506,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) ...@@ -469,7 +506,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
if (int err= start_node->instantiate_vector()) if (int err= start_node->instantiate_vector())
return err; return err;
if (ctx.vec_len * sizeof(float) != res->length()) if (ctx.byte_len != res->length())
return bad_value_on_insert(vec_field); return bad_value_on_insert(vec_field);
FVectorNode target(&ctx, table->file->ref, res->ptr()); FVectorNode target(&ctx, table->file->ref, res->ptr());
...@@ -563,7 +600,7 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit) ...@@ -563,7 +600,7 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
NULL, so the result is basically unsorted, we can return rows NULL, so the result is basically unsorted, we can return rows
in any order. For simplicity let's sort by the start_node. in any order. For simplicity let's sort by the start_node.
*/ */
if (!res || ctx.vec_len * sizeof(float) != res->length()) if (!res || ctx.byte_len != res->length())
res= vec_field->val_str(&buf); res= vec_field->val_str(&buf);
FVector target(&ctx, res->ptr()); FVector target(&ctx, res->ptr());
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment