Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
M
MariaDB
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
nexedi
MariaDB
Commits
fe0f7d20
Commit
fe0f7d20
authored
Jun 06, 2024
by
Sergei Golubchik
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
mhnsw: SIMD for euclidean distance
parent
365afe70
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
42 additions
and
5 deletions
+42
-5
sql/vector_mhnsw.cc
sql/vector_mhnsw.cc
+42
-5
No files found.
sql/vector_mhnsw.cc
View file @
fe0f7d20
...
...
@@ -25,6 +25,13 @@
static
constexpr
float
alpha
=
1.1
f
;
static
constexpr
uint
ef_construction
=
10
;
// SIMD definitions
#define SIMD_word (256/8)
#define SIMD_floats (SIMD_word/sizeof(float))
// how many extra bytes we need to alloc to be able to convert
// sizeof(double) aligned memory to SIMD_word aligned
#define SIMD_margin (SIMD_word - sizeof(double))
class
MHNSW_Context
;
class
FVector
:
public
Sql_alloc
...
...
@@ -35,6 +42,7 @@ class FVector: public Sql_alloc
float
*
vec
;
protected:
FVector
(
MHNSW_Context
*
ctx_
)
:
ctx
(
ctx_
),
vec
(
nullptr
)
{}
void
make_vec
(
const
void
*
vec_
);
};
class
FVectorNode
:
public
FVector
...
...
@@ -64,6 +72,7 @@ class MHNSW_Context
TABLE
*
table
;
Field
*
vec_field
;
size_t
vec_len
=
0
;
size_t
byte_len
=
0
;
FVector
*
target
=
0
;
Hash_set
<
FVectorNode
>
node_cache
{
PSI_INSTRUMENT_MEM
,
FVectorNode
::
get_key
};
...
...
@@ -84,7 +93,18 @@ class MHNSW_Context
FVector
::
FVector
(
MHNSW_Context
*
ctx_
,
const
void
*
vec_
)
:
ctx
(
ctx_
)
{
vec
=
(
float
*
)
memdup_root
(
&
ctx
->
root
,
vec_
,
ctx
->
vec_len
*
sizeof
(
float
));
make_vec
(
vec_
);
}
void
FVector
::
make_vec
(
const
void
*
vec_
)
{
vec
=
(
float
*
)
alloc_root
(
&
ctx
->
root
,
ctx
->
vec_len
*
sizeof
(
float
)
+
SIMD_margin
);
if
(
int
off
=
((
intptr
)
vec
)
%
SIMD_word
)
vec
+=
(
SIMD_word
-
off
)
/
sizeof
(
float
);
memcpy
(
vec
,
vec_
,
ctx
->
byte_len
);
for
(
size_t
i
=
ctx
->
byte_len
/
sizeof
(
float
);
i
<
ctx
->
vec_len
;
i
++
)
vec
[
i
]
=
0
;
}
FVectorNode
::
FVectorNode
(
MHNSW_Context
*
ctx_
,
const
void
*
ref_
)
...
...
@@ -103,7 +123,20 @@ float FVectorNode::distance_to(const FVector &other) const
{
if
(
!
vec
)
const_cast
<
FVectorNode
*>
(
this
)
->
instantiate_vector
();
#if __GNUC__ > 7
typedef
float
v8f
__attribute__
((
vector_size
(
SIMD_word
)));
v8f
*
p1
=
(
v8f
*
)
vec
;
v8f
*
p2
=
(
v8f
*
)
other
.
vec
;
v8f
d
=
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
};
for
(
size_t
i
=
0
;
i
<
ctx
->
vec_len
/
SIMD_floats
;
p1
++
,
p2
++
,
i
++
)
{
v8f
dist
=
*
p1
-
*
p2
;
d
+=
dist
*
dist
;
}
return
d
[
0
]
+
d
[
1
]
+
d
[
2
]
+
d
[
3
]
+
d
[
4
]
+
d
[
5
]
+
d
[
6
]
+
d
[
7
];
#else
return
euclidean_vec_distance
(
vec
,
other
.
vec
,
ctx
->
vec_len
);
#endif
}
int
FVectorNode
::
instantiate_vector
()
...
...
@@ -112,8 +145,12 @@ int FVectorNode::instantiate_vector()
if
(
int
err
=
ctx
->
table
->
file
->
ha_rnd_pos
(
ctx
->
table
->
record
[
0
],
ref
))
return
err
;
String
buf
,
*
v
=
ctx
->
vec_field
->
val_str
(
&
buf
);
ctx
->
vec_len
=
v
->
length
()
/
sizeof
(
float
);
vec
=
(
float
*
)
memdup_root
(
&
ctx
->
root
,
v
->
ptr
(),
v
->
length
());
if
(
unlikely
(
ctx
->
byte_len
==
0
))
{
ctx
->
byte_len
=
v
->
length
();
ctx
->
vec_len
=
MY_ALIGN
(
ctx
->
byte_len
/
sizeof
(
float
),
SIMD_floats
);
}
make_vec
(
v
->
ptr
());
return
0
;
}
...
...
@@ -469,7 +506,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
if
(
int
err
=
start_node
->
instantiate_vector
())
return
err
;
if
(
ctx
.
vec_len
*
sizeof
(
float
)
!=
res
->
length
())
if
(
ctx
.
byte_len
!=
res
->
length
())
return
bad_value_on_insert
(
vec_field
);
FVectorNode
target
(
&
ctx
,
table
->
file
->
ref
,
res
->
ptr
());
...
...
@@ -563,7 +600,7 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
NULL, so the result is basically unsorted, we can return rows
in any order. For simplicity let's sort by the start_node.
*/
if
(
!
res
||
ctx
.
vec_len
*
sizeof
(
float
)
!=
res
->
length
())
if
(
!
res
||
ctx
.
byte_len
!=
res
->
length
())
res
=
vec_field
->
val_str
(
&
buf
);
FVector
target
(
&
ctx
,
res
->
ptr
());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment