Commit 41334e56 authored by Sergei Golubchik's avatar Sergei Golubchik

mhnsw: inter-statement shared cache

* preserve the graph in memory between statements
* keep it in a TABLE_SHARE, available for concurrent searches
* nodes are generally read-only, walking the graph doesn't change them
* distance to target is cached, calculated only once
* SIMD-optimized bloom filter detects visited nodes
* nodes are stored in an array, not List, to better utilize bloom filter
* auto-adjusting heuristic to estimate the number of visited nodes
  (to configure the bloom filter)
* many threads can concurrently walk the graph. MEM_ROOT and Hash_set
  are protected with a mutex, but walking doesn't need them
* up to 8 threads can concurrently load nodes into the cache,
  nodes are partitioned into 8 mutexes (8 is chosen arbitrarily, might
  need tuning)
* concurrent editing is not supported though
* this is fine for MyISAM, TL_WRITE protects the TABLE_SHARE and the
  graph (note that TL_WRITE_CONCURRENT_INSERT is not allowed, because an
  INSERT into the main table means multiple UPDATEs in the graph)
* InnoDB uses secondary transaction-level caches linked in a list in
  in thd->ha_data via a fake handlerton
* on rollback the secondary cache is discarded, on commit nodes
  from the secondary cache are invalidated in the shared cache
  while it is exclusively locked
* on savepoint rollback both caches are flushed. this can be improved
  in the future with a row visibility callback
* graph size is controlled by @@mhnsw_cache_size, the cache is flushed
  when it reaches the threshold
parent 656845ef
......@@ -947,6 +947,12 @@ extern LEX_STRING lex_string_casedn_root(MEM_ROOT *root,
CHARSET_INFO *cs,
const char *str, size_t length);
static inline size_t root_size(MEM_ROOT *root)
{
size_t k = root->block_num >> 2;
return k * (k + 1) * 2 * root->block_size;
}
extern my_bool my_compress(uchar *, size_t *, size_t *);
extern my_bool my_uncompress(uchar *, size_t , size_t *);
extern uchar *my_compress_alloc(const uchar *packet, size_t *len,
......
......@@ -688,6 +688,8 @@ The following specify which files/extra groups are read (specified before remain
Unused. Deprecated, will be removed in a future release.
--metadata-locks-hash-instances=#
Unused. Deprecated, will be removed in a future release.
--mhnsw-cache-size=#
Size of the cache for the MHNSW vector index
--mhnsw-limit-multiplier=#
Defines the number of result candidates to look for in
the vector index for ORDER BY ... LIMIT N queries.
......@@ -1791,6 +1793,7 @@ max-write-lock-count 18446744073709551615
memlock FALSE
metadata-locks-cache-size 1024
metadata-locks-hash-instances 8
mhnsw-cache-size 16777216
mhnsw-limit-multiplier 2
mhnsw-max-edges-per-node 15
min-examined-row-limit 0
......
create table t1 (id int auto_increment primary key, v blob not null, vector index (v)) engine=innodb;
show create table t1;
Table Create Table
t1 CREATE TABLE `t1` (
`id` int(11) NOT NULL AUTO_INCREMENT,
`v` blob NOT NULL,
PRIMARY KEY (`id`),
VECTOR KEY `v` (`v`)
) ENGINE=InnoDB DEFAULT CHARSET=latin1 COLLATE=latin1_swedish_ci
insert t1 (v) values
(x'106d263fdf68ba3eb08d533f97d46e3fd1e1ec3edc4c123f984c563f621a233f'),
(x'd55bee3c56eb9e3e84e3093f838dce3eb7cd653fe32d7d3f12de133c5715d23e'),
(x'fcd5553f3822443f5dae413f2593493f7777363f5f7f113ebf12373d4d145a3f'),
(x'7493093fd9a27d3e9b13783f8c66653f0bd7d23e50db983d251b013f1dba133f'),
(x'2e30373fae331a3eba94153ee32bce3e3311b33d5bc75d3f6c25653eb769113f'),
(x'381d5f3f2781de3e4f011f3f9353483f9bb37e3edd622d3eabecb63ec246953e'),
(x'4ee5dc3e214b103f0e7e583f5f36473e79d7823ea872ec3e3ab2913d1b84433f'),
(x'8826243f7d20f03e5135593f83ba653e44572d3fa87e8e3e943e0e3f649a293f'),
(x'3859ac3e7d21823ed3f5753fc79c143e61d39c3cee39ba3eb0b0133e815c173f'),
(x'cff0d93c32941e3f64b22a3f1e4f083f4ea2563fbff4a63e12a4703f6c824b3f');
start transaction;
insert t1 values
(30, x'f8e2413ed4ff773fef8b893eba487b3febee3f3f9e6f693f5961fd3ee479303d');
savepoint foo;
insert t1 values
(31, x'6129683f90fe1f3e1437bc3ed8c8f63dd141033f21e3a93e54346c3f8c4e043f'),
(32, x'1ec8b83d398c4d3f2efb463f23947a3fa1a5093fdde6303e5580413f51569b3e');
rollback to savepoint foo;
insert t1 values
(33, x'86d1003d4262033f8086713ffc4a633e317e933c4dce013d9c4d573fca83b93e');
commit;
start transaction;
insert t1 values
(40, x'71046a3e85329b3e05240e3f45c9283f1847363f98d47d3f4224b73d487b613f'),
(41, x'71046a3e85329b3e05240e3f45c9283f1847363f98d47d3f4224b73d487b613f');
rollback;
select id,vec_distance(v, x'c923e33dc0da313fe7c7983e526b3d3fde63963e6eaf3a3f27fa133fe27a583f') d from t1 order by d limit 5;
id d
10 0.8856208347761952
1 0.9381363209273885
30 1.0162643974895857
7 1.026397313888122
5 1.0308161006949719
select id,vec_distance(v, x'754b5f3ea2312b3fc169f43e4604883e1d20173e8dd7443f421b703fb11e0d3e') d from t1 order by d limit 5;
id d
33 0.9477554826856
30 1.111405427702547
1 1.1154613877616022
10 1.118630286292343
8 1.1405733350751739
create table t2 (id int auto_increment primary key, v blob not null, vector index (v)) engine=innodb;
insert t2 (v) values
(x'45cf153f830a313f7a0a113fb1ff533f47a1533fcf9e6e3f'),
(x'4b311d3fdd82423f35ba7d3fa041223dfd7db03e72d5833e'),
(x'f0d4123f6fc1833ea30a483fd9649d3cb94d733f4574a63d'),
(x'7ff8a53bf68e4a3e66e3563f214dea3e63372f3ec24d513f'),
(x'4709683f0d44473f8a045f3f40f3693df7f1303fdb98b73e'),
(x'09de2b3f5db80d3fb4405f3f64aadc3ecfa6183f823c733f'),
(x'a93a143f7f71e33d0cde5c3ff106373fd6f6233fc1f4fc3e'),
(x'11236e3de44a0d3f8241023d44d8383f2f70733f44d65c3f'),
(x'b5e47c3f35d3413fad8a533d5945133f66dbf33d92c6103f');
start transaction;
insert t1 values
(50, x'acae183f56ddc43e5093983d280df53e6fa2093f79c01a3eb1591f3f423a0e3d'),
(51, x'6285303f42ef6e3f355e313f3e96a53e70959b3edd720b3ec07f733e5bc8603f');
insert t2 values
(20, x'58dc7d3fc9feaa3e19e26b3f31820c3f93070b3fc4e36e3f'),
(21, x'35e05d3f18e8513fb81a3d3f8acf7d3e794a1d3c72f9613f');
commit;
select id,vec_distance(v, x'1f4d053f7056493f937da03dd8c97a3f220cbb3c926c1c3facca213ec0618a3e') d from t1 order by d limit 5;
id d
6 0.9309383181777582
5 0.9706304662574956
30 0.98144492002831
50 1.079862635421575
51 1.2403734530917931
select id,vec_distance(v, x'f618663f256be73e62cd453f8bcdbf3e16ae503c3858313f') d from t2 order by d limit 5;
id d
21 0.43559180321379337
20 0.6435053022072372
6 0.6942000623336242
2 0.7971622099055623
9 0.8298589136476077
drop table t1, t2;
source include/have_innodb.inc;
create table t1 (id int auto_increment primary key, v blob not null, vector index (v)) engine=innodb;
show create table t1;
# print unpack("H*",pack("f*",map{rand}1..8))
insert t1 (v) values
(x'106d263fdf68ba3eb08d533f97d46e3fd1e1ec3edc4c123f984c563f621a233f'),
(x'd55bee3c56eb9e3e84e3093f838dce3eb7cd653fe32d7d3f12de133c5715d23e'),
(x'fcd5553f3822443f5dae413f2593493f7777363f5f7f113ebf12373d4d145a3f'),
(x'7493093fd9a27d3e9b13783f8c66653f0bd7d23e50db983d251b013f1dba133f'),
(x'2e30373fae331a3eba94153ee32bce3e3311b33d5bc75d3f6c25653eb769113f'),
(x'381d5f3f2781de3e4f011f3f9353483f9bb37e3edd622d3eabecb63ec246953e'),
(x'4ee5dc3e214b103f0e7e583f5f36473e79d7823ea872ec3e3ab2913d1b84433f'),
(x'8826243f7d20f03e5135593f83ba653e44572d3fa87e8e3e943e0e3f649a293f'),
(x'3859ac3e7d21823ed3f5753fc79c143e61d39c3cee39ba3eb0b0133e815c173f'),
(x'cff0d93c32941e3f64b22a3f1e4f083f4ea2563fbff4a63e12a4703f6c824b3f');
### savepoints and rollbacks:
start transaction;
insert t1 values
(30, x'f8e2413ed4ff773fef8b893eba487b3febee3f3f9e6f693f5961fd3ee479303d');
savepoint foo;
insert t1 values
(31, x'6129683f90fe1f3e1437bc3ed8c8f63dd141033f21e3a93e54346c3f8c4e043f'),
(32, x'1ec8b83d398c4d3f2efb463f23947a3fa1a5093fdde6303e5580413f51569b3e');
rollback to savepoint foo;
insert t1 values
(33, x'86d1003d4262033f8086713ffc4a633e317e933c4dce013d9c4d573fca83b93e');
commit;
start transaction;
insert t1 values
(40, x'71046a3e85329b3e05240e3f45c9283f1847363f98d47d3f4224b73d487b613f'),
(41, x'71046a3e85329b3e05240e3f45c9283f1847363f98d47d3f4224b73d487b613f');
rollback;
select id,vec_distance(v, x'c923e33dc0da313fe7c7983e526b3d3fde63963e6eaf3a3f27fa133fe27a583f') d from t1 order by d limit 5;
select id,vec_distance(v, x'754b5f3ea2312b3fc169f43e4604883e1d20173e8dd7443f421b703fb11e0d3e') d from t1 order by d limit 5;
### two indexes in one transaction:
create table t2 (id int auto_increment primary key, v blob not null, vector index (v)) engine=innodb;
insert t2 (v) values
(x'45cf153f830a313f7a0a113fb1ff533f47a1533fcf9e6e3f'),
(x'4b311d3fdd82423f35ba7d3fa041223dfd7db03e72d5833e'),
(x'f0d4123f6fc1833ea30a483fd9649d3cb94d733f4574a63d'),
(x'7ff8a53bf68e4a3e66e3563f214dea3e63372f3ec24d513f'),
(x'4709683f0d44473f8a045f3f40f3693df7f1303fdb98b73e'),
(x'09de2b3f5db80d3fb4405f3f64aadc3ecfa6183f823c733f'),
(x'a93a143f7f71e33d0cde5c3ff106373fd6f6233fc1f4fc3e'),
(x'11236e3de44a0d3f8241023d44d8383f2f70733f44d65c3f'),
(x'b5e47c3f35d3413fad8a533d5945133f66dbf33d92c6103f');
start transaction;
insert t1 values
(50, x'acae183f56ddc43e5093983d280df53e6fa2093f79c01a3eb1591f3f423a0e3d'),
(51, x'6285303f42ef6e3f355e313f3e96a53e70959b3edd720b3ec07f733e5bc8603f');
insert t2 values
(20, x'58dc7d3fc9feaa3e19e26b3f31820c3f93070b3fc4e36e3f'),
(21, x'35e05d3f18e8513fb81a3d3f8acf7d3e794a1d3c72f9613f');
commit;
select id,vec_distance(v, x'1f4d053f7056493f937da03dd8c97a3f220cbb3c926c1c3facca213ec0618a3e') d from t1 order by d limit 5;
select id,vec_distance(v, x'f618663f256be73e62cd453f8bcdbf3e16ae503c3858313f') d from t2 order by d limit 5;
drop table t1, t2;
exit;
......@@ -2372,6 +2372,16 @@ NUMERIC_BLOCK_SIZE 1
ENUM_VALUE_LIST NULL
READ_ONLY YES
COMMAND_LINE_ARGUMENT REQUIRED
VARIABLE_NAME MHNSW_CACHE_SIZE
VARIABLE_SCOPE GLOBAL
VARIABLE_TYPE BIGINT UNSIGNED
VARIABLE_COMMENT Size of the cache for the MHNSW vector index
NUMERIC_MIN_VALUE 1048576
NUMERIC_MAX_VALUE 18446744073709551615
NUMERIC_BLOCK_SIZE 1
ENUM_VALUE_LIST NULL
READ_ONLY NO
COMMAND_LINE_ARGUMENT REQUIRED
VARIABLE_NAME MHNSW_LIMIT_MULTIPLIER
VARIABLE_SCOPE SESSION
VARIABLE_TYPE DOUBLE
......
......@@ -324,6 +324,7 @@ void *alloc_root(MEM_ROOT *mem_root, size_t length)
size_t alloced_length;
/* Increase block size over time if there is a lot of mallocs */
/* when changing this logic, update root_size() to match */
block_size= (MY_ALIGN(mem_root->block_size, ROOT_MIN_BLOCK_SIZE) *
(mem_root->block_num >> 2)- MALLOC_OVERHEAD);
get_size= length + ALIGN_SIZE(sizeof(USED_MEM));
......
/*
MIT License
Copyright (c) 2023 Sasha Krassovsky
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
*/
// https://save-buffer.github.io/bloom_filter.html
#pragma once
#include <cmath>
#include <vector>
#include <algorithm>
#include <immintrin.h>
template <typename T>
struct PatternedSimdBloomFilter
{
PatternedSimdBloomFilter(int n, float eps) : n(n), epsilon(eps)
{
m = ComputeNumBits();
int log_num_blocks = 32 - __builtin_clz(m) - rotate_bits;
num_blocks = (1ULL << log_num_blocks);
bv.resize(num_blocks);
}
uint64_t ComputeNumBits()
{
double bits_per_val = -1.44 * std::log2(epsilon);
return std::max<uint64_t>(512, bits_per_val * n + 0.5);
}
__attribute__ ((target ("avx2,avx,fma")))
__m256i CalcHash(__m256i vecData)
{
// (almost) xxHash parallel version, 64bit input, 64bit output, seed=0
static constexpr __m256i rotl48={
0x0504030201000706ULL, 0x0D0C0B0A09080F0EULL,
0x1514131211101716ULL, 0x1D1C1B1A19181F1EULL
};
static constexpr __m256i rotl24={
0x0201000706050403ULL, 0x0A09080F0E0D0C0BULL,
0x1211101716151413ULL, 0x1A19181F1E1D1C1BULL,
};
static constexpr uint64_t prime_mx2= 0x9FB21C651E98DF25ULL;
static constexpr uint64_t bitflip= 0xC73AB174C5ECD5A2ULL;
__m256i step1= _mm256_xor_si256(vecData, _mm256_set1_epi64x(bitflip));
__m256i step2= _mm256_shuffle_epi8(step1, rotl48);
__m256i step3= _mm256_shuffle_epi8(step1, rotl24);
__m256i step4= _mm256_xor_si256(step1, _mm256_xor_si256(step2, step3));
__m256i step5= _mm256_mul_epi32(step4, _mm256_set1_epi64x(prime_mx2));
__m256i step6= _mm256_srli_epi64(step5, 35);
__m256i step7= _mm256_add_epi64(step6, _mm256_set1_epi64x(8));
__m256i step8= _mm256_xor_si256(step5, step7);
__m256i step9= _mm256_mul_epi32(step8, _mm256_set1_epi64x(prime_mx2));
return _mm256_xor_si256(step9, _mm256_srli_epi64(step9, 28));
}
__attribute__ ((target ("avx2,avx,fma")))
__m256i GetBlockIdx(__m256i vecHash)
{
__m256i vecNumBlocksMask = _mm256_set1_epi64x(num_blocks - 1);
__m256i vecBlockIdx = _mm256_srli_epi64(vecHash, mask_idx_bits + rotate_bits);
return _mm256_and_si256(vecBlockIdx, vecNumBlocksMask);
}
__attribute__ ((target ("avx2,avx,fma")))
__m256i ConstructMask(__m256i vecHash)
{
__m256i vecMaskIdxMask = _mm256_set1_epi64x((1 << mask_idx_bits) - 1);
__m256i vecMaskMask = _mm256_set1_epi64x((1ull << bits_per_mask) - 1);
__m256i vec64 = _mm256_set1_epi64x(64);
__m256i vecMaskIdx = _mm256_and_si256(vecHash, vecMaskIdxMask);
__m256i vecMaskByteIdx = _mm256_srli_epi64(vecMaskIdx, 3);
__m256i vecMaskBitIdx = _mm256_and_si256(vecMaskIdx, _mm256_set1_epi64x(0x7));
__m256i vecRawMasks = _mm256_i64gather_epi64((const longlong *)masks, vecMaskByteIdx, 1);
__m256i vecUnrotated = _mm256_and_si256(_mm256_srlv_epi64(vecRawMasks, vecMaskBitIdx), vecMaskMask);
__m256i vecRotation = _mm256_and_si256(_mm256_srli_epi64(vecHash, mask_idx_bits), _mm256_set1_epi64x((1 << rotate_bits) - 1));
__m256i vecShiftUp = _mm256_sllv_epi64(vecUnrotated, vecRotation);
__m256i vecShiftDown = _mm256_srlv_epi64(vecUnrotated, _mm256_sub_epi64(vec64, vecRotation));
return _mm256_or_si256(vecShiftDown, vecShiftUp);
}
__attribute__ ((target ("avx2,avx,fma")))
void Insert(const T **data)
{
__m256i vecDataA = _mm256_loadu_si256(reinterpret_cast<__m256i *>(data + 0));
__m256i vecDataB = _mm256_loadu_si256(reinterpret_cast<__m256i *>(data + 4));
__m256i vecHashA= CalcHash(vecDataA);
__m256i vecHashB= CalcHash(vecDataB);
__m256i vecMaskA = ConstructMask(vecHashA);
__m256i vecMaskB = ConstructMask(vecHashB);
__m256i vecBlockIdxA = GetBlockIdx(vecHashA);
__m256i vecBlockIdxB = GetBlockIdx(vecHashB);
uint64_t block0 = _mm256_extract_epi64(vecBlockIdxA, 0);
uint64_t block1 = _mm256_extract_epi64(vecBlockIdxA, 1);
uint64_t block2 = _mm256_extract_epi64(vecBlockIdxA, 2);
uint64_t block3 = _mm256_extract_epi64(vecBlockIdxA, 3);
uint64_t block4 = _mm256_extract_epi64(vecBlockIdxB, 0);
uint64_t block5 = _mm256_extract_epi64(vecBlockIdxB, 1);
uint64_t block6 = _mm256_extract_epi64(vecBlockIdxB, 2);
uint64_t block7 = _mm256_extract_epi64(vecBlockIdxB, 3);
bv[block0] |= _mm256_extract_epi64(vecMaskA, 0);
bv[block1] |= _mm256_extract_epi64(vecMaskA, 1);
bv[block2] |= _mm256_extract_epi64(vecMaskA, 2);
bv[block3] |= _mm256_extract_epi64(vecMaskA, 3);
bv[block4] |= _mm256_extract_epi64(vecMaskB, 0);
bv[block5] |= _mm256_extract_epi64(vecMaskB, 1);
bv[block6] |= _mm256_extract_epi64(vecMaskB, 2);
bv[block7] |= _mm256_extract_epi64(vecMaskB, 3);
}
__attribute__ ((target ("avx2,avx,fma")))
uint8_t Query(T **data)
{
__m256i vecDataA = _mm256_loadu_si256(reinterpret_cast<__m256i *>(data + 0));
__m256i vecDataB = _mm256_loadu_si256(reinterpret_cast<__m256i *>(data + 4));
__m256i vecHashA= CalcHash(vecDataA);
__m256i vecHashB= CalcHash(vecDataB);
__m256i vecMaskA = ConstructMask(vecHashA);
__m256i vecMaskB = ConstructMask(vecHashB);
__m256i vecBlockIdxA = GetBlockIdx(vecHashA);
__m256i vecBlockIdxB = GetBlockIdx(vecHashB);
__m256i vecBloomA = _mm256_i64gather_epi64(bv.data(), vecBlockIdxA, sizeof(longlong));
__m256i vecBloomB = _mm256_i64gather_epi64(bv.data(), vecBlockIdxB, sizeof(longlong));
__m256i vecCmpA = _mm256_cmpeq_epi64(_mm256_and_si256(vecMaskA, vecBloomA), vecMaskA);
__m256i vecCmpB = _mm256_cmpeq_epi64(_mm256_and_si256(vecMaskB, vecBloomB), vecMaskB);
uint32_t res_a = static_cast<uint32_t>(_mm256_movemask_epi8(vecCmpA));
uint32_t res_b = static_cast<uint32_t>(_mm256_movemask_epi8(vecCmpB));
uint64_t res_bytes = res_a | (static_cast<uint64_t>(res_b) << 32);
uint8_t res_bits = static_cast<uint8_t>(_mm256_movemask_epi8(_mm256_set1_epi64x(res_bytes)) & 0xff);
return res_bits;
}
int n;
float epsilon;
uint64_t num_blocks;
uint64_t m;
// calculated from the upstream MaskTable and hard-coded
static constexpr int log_num_masks = 10;
static constexpr int bits_per_mask = 57;
const uint8_t masks[136]= {0x00, 0x04, 0x01, 0x04, 0x00, 0x20, 0x01, 0x00,
0x00, 0x02, 0x08, 0x00, 0x02, 0x42, 0x00, 0x00, 0x04, 0x00, 0x00, 0x84,
0x80, 0x00, 0x04, 0x00, 0x02, 0x00, 0x00, 0x21, 0x00, 0x08, 0x00, 0x14,
0x00, 0x00, 0x40, 0x00, 0x10, 0x00, 0xa8, 0x00, 0x00, 0x00, 0x00, 0x10,
0x04, 0x40, 0x01, 0x00, 0x40, 0x00, 0x00, 0x08, 0x01, 0x02, 0x80, 0x00,
0x00, 0x01, 0x00, 0x06, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, 0x0c, 0x10,
0x00, 0x10, 0x00, 0x00, 0x10, 0x08, 0x01, 0x10, 0x00, 0x00, 0x10, 0x20,
0x00, 0x01, 0x20, 0x00, 0x02, 0x40, 0x00, 0x00, 0x02, 0x40, 0x01, 0x00,
0x40, 0x00, 0x00, 0x0a, 0x00, 0x02, 0x01, 0x80, 0x00, 0x00, 0x10, 0x08,
0x00, 0x06, 0x00, 0x04, 0x00, 0x00, 0x50, 0x00, 0x08, 0x10, 0x20, 0x00,
0x00, 0x80, 0x00, 0x10, 0x10, 0x04, 0x04, 0x00, 0x00, 0x00, 0x20, 0x20,
0x08, 0x08, 0x02, 0x00, 0x00, 0x00, 0x40, 0x00};
std::vector<longlong> bv;
static constexpr int mask_idx_bits = log_num_masks;
static constexpr int rotate_bits = 6;
};
......@@ -559,6 +559,7 @@ enum legacy_db_type
{
/* note these numerical values are fixed and can *not* be changed */
DB_TYPE_UNKNOWN=0,
DB_TYPE_HLINDEX_HELPER=6,
DB_TYPE_HEAP=6,
DB_TYPE_MYISAM=9,
DB_TYPE_MRG_MYISAM=10,
......
......@@ -2401,6 +2401,10 @@ bool open_table(THD *thd, TABLE_LIST *table_list, Open_table_context *ot_ctx)
my_error(ER_NOT_SEQUENCE, MYF(0), table_list->db.str, table_list->alias.str);
DBUG_RETURN(true);
}
/* hlindexes don't support concurrent insert */
if (table->s->total_keys > table->s->keys &&
table_list->lock_type == TL_WRITE_CONCURRENT_INSERT)
table_list->lock_type= TL_WRITE_DEFAULT;
DBUG_ASSERT(thd->locked_tables_mode || table->file->row_logging == 0);
DBUG_RETURN(false);
......
......@@ -55,6 +55,7 @@
#include "opt_trace_context.h"
#include "log_event.h"
#include "optimizer_defaults.h"
#include "vector_mhnsw.h"
#ifdef WITH_PERFSCHEMA_STORAGE_ENGINE
#include "../storage/perfschema/pfs_server.h"
......@@ -7388,3 +7389,7 @@ static Sys_var_uint Sys_mhnsw_max_edges_per_node(
"memory consumption, but better search results. Not used for SELECT",
SESSION_VAR(mhnsw_max_edges_per_node), CMD_LINE(REQUIRED_ARG),
VALID_RANGE(2, 200), DEFAULT(15), BLOCK_SIZE(1));
static Sys_var_ulonglong Sys_mhnsw_cache_size(
"mhnsw_cache_size", "Size of the cache for the MHNSW vector index",
GLOBAL_VAR(mhnsw_cache_size), CMD_LINE(REQUIRED_ARG),
VALID_RANGE(1024*1024, SIZE_T_MAX), DEFAULT(16*1024*1024), BLOCK_SIZE(1));
......@@ -50,6 +50,7 @@
#include "sql_delete.h" // class Sql_cmd_delete
#include "rpl_rli.h" // class rpl_group_info
#include "rpl_mi.h" // class Master_info
#include "vector_mhnsw.h"
#ifdef WITH_WSREP
#include "wsrep_schema.h"
......@@ -502,7 +503,10 @@ void TABLE_SHARE::destroy()
delete sequence;
if (hlindex)
{
mhnsw_free(this);
hlindex->destroy();
}
/* The mutexes are initialized only for shares that are part of the TDC */
if (tmp_table == NO_TMP_TABLE)
......@@ -4786,6 +4790,7 @@ int closefrm(TABLE *table)
if (table->hlindex)
closefrm(table->hlindex);
if (table->db_stat)
error=table->file->ha_close();
table->alias.free();
......
......@@ -742,7 +742,11 @@ struct TABLE_SHARE
Virtual_column_info **check_constraints;
uint *blob_field; /* Index to blobs in Field arrray*/
LEX_CUSTRING vcol_defs; /* definitions of generated columns */
TABLE_SHARE *hlindex;
union {
void *hlindex_data; /* for hlindex tables */
TABLE_SHARE *hlindex; /* for normal tables */
};
/*
EITS statistics data from the last time the table was opened or ANALYZE
......
......@@ -19,6 +19,9 @@
#include "vector_mhnsw.h"
#include "item_vectorfunc.h"
#include <scope.h>
#include "bloom_filters.h"
ulonglong mhnsw_cache_size;
#define clo_nei_size 4
#define clo_nei_store float4store
......@@ -36,9 +39,6 @@ static const uint clo_nei_threshold= 10000;
// SIMD definitions
#define SIMD_word (256/8)
#define SIMD_floats (SIMD_word/sizeof(float))
// how many extra bytes we need to alloc to be able to convert
// sizeof(double) aligned memory to SIMD_word aligned
#define SIMD_margin (SIMD_word - sizeof(double))
enum Graph_table_fields {
FIELD_LAYER, FIELD_TREF, FIELD_VEC, FIELD_NEIGHBORS
......@@ -48,120 +48,501 @@ enum Graph_table_indices {
};
class MHNSW_Context;
class FVectorNode;
/*
One vector, an array of ctx->vec_len floats
class FVector: public Sql_alloc
Aligned on 32-byte (SIMD_word) boundary for SIMD, vector lenght
is zero-padded to multiples of 8, for the same reason.
*/
class FVector
{
public:
MHNSW_Context *ctx;
FVector(MHNSW_Context *ctx_, const void *vec_);
FVector(MHNSW_Context *ctx_, MEM_ROOT *root, const void *vec_);
float *vec;
protected:
FVector(MHNSW_Context *ctx_) : ctx(ctx_), vec(nullptr) {}
void make_vec(const void *vec_);
FVector() : vec(nullptr) {}
};
/*
An array of pointers to graph nodes
It's mainly used to store all neighbors of a given node on a given layer.
Additionally it stores the distance to the closest neighbor.
An array is fixed size, 2*M for the zero layer, M for other layers
see MHNSW_Context::max_neighbors().
Number of neighbors is zero-padded to multiples of 8 (for SIMD Bloom filter).
Also used as a simply array of nodes in search_layer, the array size
then is defined by ef or efConstruction.
*/
struct Neighborhood: public Sql_alloc
{
FVectorNode **links;
size_t num;
float closest;
void empty() { closest= FLT_MAX; num=0; }
FVectorNode **init(FVectorNode **ptr, size_t n)
{
empty();
links= ptr;
n= MY_ALIGN(n, 8);
bzero(ptr, n*sizeof(*ptr));
return ptr + n;
}
};
/*
One node in a graph = one row in the graph table
stores a vector itself, ref (= position) in the graph (= hlindex)
table, a ref in the main table, and an array of Neighborhood's, one
per layer.
It's lazily initialized, may know only gref, everything else is
loaded on demand.
On the other hand, on INSERT the new node knows everything except
gref - which only becomes known after ha_write_row.
Allocated on memroot in two chunks. One is the same size for all nodes
and stores FVectorNode object, gref, tref, and vector. The second
stores neighbors, all Neighborhood's together, its size depends
on the number of layers this node is on.
There can be millions of nodes in the cache and the cache size
is constrained by mhnsw_cache_size, so every byte matters here
*/
#pragma pack(push, 1)
class FVectorNode: public FVector
{
private:
uchar *tref, *gref;
size_t max_layer;
static uchar *gref_max;
MHNSW_Context *ctx;
float *make_vec(const void *v);
int alloc_neighborhood(size_t layer);
public:
List<FVectorNode> *neighbors= nullptr;
float *closest_neighbor= 0;
Neighborhood *neighbors= nullptr;
uint8_t max_layer;
bool stored;
FVectorNode(MHNSW_Context *ctx_, const void *gref_);
FVectorNode(MHNSW_Context *ctx_, const void *tref_, size_t layer,
const void *vec_);
float distance_to(const FVector &other) const;
int load();
int load_from_record();
int save();
size_t get_tref_len() const;
uchar *get_tref() const { return tref; }
size_t get_gref_len() const;
uchar *get_gref() const { return gref; }
void update_closest_neighbor(size_t layer, float dist, const FVectorNode &v);
int load(TABLE *graph);
int load_from_record(TABLE *graph);
int save(TABLE *graph);
size_t tref_len() const;
size_t gref_len() const;
uchar *gref() const;
uchar *tref() const;
void push_neighbor(size_t layer, float dist, FVectorNode *v);
static uchar *get_key(const FVectorNode *elem, size_t *key_len, my_bool);
};
#pragma pack(pop)
// this assumes that 1) rows from graph table are never deleted,
// 2) and thus a ref for a new row is larger than refs of existing rows,
// thus we can treat the not-yet-inserted row as having max possible ref.
// oh, yes, and 3) 8 bytes ought to be enough for everyone
uchar *FVectorNode::gref_max=(uchar*)"\xff\xff\xff\xff\xff\xff\xff\xff";
class MHNSW_Context
/*
Shared algorithm context. The graph.
Stored in TABLE_SHARE and on TABLE_SHARE::mem_root.
Stores the complete graph in MHNSW_Context::root,
The mapping gref->FVectorNode is in the node_cache.
Both root and node_cache are protected by a cache_lock, but it's
needed when loading nodes and is not used when the whole graph is in memory.
Graph can be traversed concurrently by different threads, as traversal
changes neither nodes nor the ctx.
Nodes can be loaded concurrently by different threads, this is protected
by a partitioned node_lock.
reference counter allows flushing the graph without interrupting
concurrent searches.
MyISAM automatically gets exclusive write access because of the TL_WRITE,
but InnoDB has to use a dedicated ctx->commit_lock for that
*/
class MHNSW_Context : public Sql_alloc
{
public:
std::atomic<uint> refcnt;
std::atomic<double> ef_power; // for the bloom filter size heuristic
mysql_mutex_t cache_lock;
mysql_mutex_t node_lock[8];
void cache_internal(FVectorNode *node)
{
DBUG_ASSERT(node->stored);
node_cache.insert(node);
}
void *alloc_node_internal()
{
return alloc_root(&root, sizeof(FVectorNode) + gref_len + tref_len
+ vec_len * sizeof(float) + SIMD_word - 1);
}
protected:
MEM_ROOT root;
TABLE *table;
Field *vec_field;
Hash_set<FVectorNode> node_cache{PSI_INSTRUMENT_MEM, FVectorNode::get_key};
public:
mysql_rwlock_t commit_lock;
size_t vec_len= 0;
size_t byte_len= 0;
uint err= 0;
FVectorNode *start= 0;
const uint tref_len;
const uint gref_len;
const uint M;
MHNSW_Context(TABLE *t)
: tref_len(t->file->ref_length),
gref_len(t->hlindex->file->ref_length),
M(t->in_use->variables.mhnsw_max_edges_per_node)
{
mysql_rwlock_init(PSI_INSTRUMENT_ME, &commit_lock);
mysql_mutex_init(PSI_INSTRUMENT_ME, &cache_lock, MY_MUTEX_INIT_FAST);
for (uint i=0; i < array_elements(node_lock); i++)
mysql_mutex_init(PSI_INSTRUMENT_ME, node_lock + i, MY_MUTEX_INIT_SLOW);
init_alloc_root(PSI_INSTRUMENT_MEM, &root, 1024*1024, 0, MYF(0));
set_ef_power(0.6);
refcnt.store(0, std::memory_order_relaxed);
}
Hash_set<FVectorNode> node_cache{PSI_INSTRUMENT_MEM, FVectorNode::get_key};
virtual ~MHNSW_Context()
{
free_root(&root, MYF(0));
mysql_rwlock_destroy(&commit_lock);
mysql_mutex_destroy(&cache_lock);
for (size_t i=0; i < array_elements(node_lock); i++)
mysql_mutex_destroy(node_lock + i);
}
MHNSW_Context(TABLE *table, Field *vec_field)
: table(table), vec_field(vec_field)
uint lock_node(FVectorNode *ptr)
{
init_alloc_root(PSI_INSTRUMENT_MEM, &root, 8192, 0, MYF(MY_THREAD_SPECIFIC));
ulong nr1= 1, nr2= 4;
my_hash_sort_bin(0, (uchar*)&ptr, sizeof(ptr), &nr1, &nr2);
uint ticket= nr1 % array_elements(node_lock);
mysql_mutex_lock(node_lock + ticket);
return ticket;
}
~MHNSW_Context()
void unlock_node(uint ticket)
{
free_root(&root, MYF(0));
mysql_mutex_unlock(node_lock + ticket);
}
double get_ef_power()
{
return ef_power.load(std::memory_order_relaxed);
}
void set_ef_power(double x)
{
if (x > get_ef_power()) // not atomic, but it doesn't matter
ef_power.store(x, std::memory_order_relaxed);
}
uint max_neighbors(size_t layer) const
{
return (layer ? 1 : 2) * M; // heuristic from the paper
}
FVectorNode *get_node(const void *gref);
void set_lengths(size_t len)
{
byte_len= len;
vec_len= MY_ALIGN(byte_len/sizeof(float), SIMD_floats);
}
static int acquire(MHNSW_Context **ctx, TABLE *table, bool for_update);
static MHNSW_Context *get_from_share(TABLE_SHARE *share, TABLE *table);
void reset_ctx(TABLE_SHARE *share)
{
mysql_mutex_lock(&share->LOCK_share);
if (static_cast<MHNSW_Context*>(share->hlindex->hlindex_data) == this)
{
share->hlindex->hlindex_data= nullptr;
--refcnt;
}
mysql_mutex_unlock(&share->LOCK_share);
}
void release(TABLE *table)
{
return release(table->file->has_transactions(), table->s);
}
virtual void release(bool can_commit, TABLE_SHARE *share)
{
if (can_commit)
mysql_rwlock_unlock(&commit_lock);
if (root_size(&root) > mhnsw_cache_size)
reset_ctx(share);
if (--refcnt == 0)
this->~MHNSW_Context(); // XXX reuse
}
FVectorNode *get_node(const void *gref)
{
mysql_mutex_lock(&cache_lock);
FVectorNode *node= node_cache.find(gref, gref_len);
if (!node)
{
node= new (alloc_node_internal()) FVectorNode(this, gref);
cache_internal(node);
}
mysql_mutex_unlock(&cache_lock);
return node;
}
/* used on INSERT, gref isn't known, so cannot cache the node yet */
void *alloc_node()
{
mysql_mutex_lock(&cache_lock);
auto p= alloc_node_internal();
mysql_mutex_unlock(&cache_lock);
return p;
}
/* explicitly cache the node after alloc_node() */
void cache_node(FVectorNode *node)
{
mysql_mutex_lock(&cache_lock);
cache_internal(node);
mysql_mutex_unlock(&cache_lock);
}
/* find the node without creating, only used on merging trx->ctx */
FVectorNode *find_node(const void *gref)
{
mysql_mutex_lock(&cache_lock);
FVectorNode *node= node_cache.find(gref, gref_len);
mysql_mutex_unlock(&cache_lock);
return node;
}
void *alloc_neighborhood(size_t max_layer)
{
mysql_mutex_lock(&cache_lock);
auto p= alloc_root(&root, sizeof(Neighborhood)*(max_layer+1) +
sizeof(FVectorNode*)*(MY_ALIGN(M, 4)*2 + MY_ALIGN(M,8)*max_layer));
mysql_mutex_unlock(&cache_lock);
return p;
}
};
/*
This is a non-shared context that exists within one transaction.
At the end of the transaction it's either discarded (on rollback)
or merged into the shared ctx (on commit).
trx's are stored in thd->ha_data[] in a single-linked list,
one instance of trx per TABLE_SHARE and allocated on the
thd->transaction->mem_root
*/
class MHNSW_Trx : public MHNSW_Context
{
public:
TABLE_SHARE *table_share;
bool list_of_nodes_is_lost= false;
MHNSW_Trx *next= nullptr;
MHNSW_Trx(TABLE *table) : MHNSW_Context(table), table_share(table->s) {}
void reset_trx()
{
node_cache.clear();
free_root(&root, MYF(0));
start= 0;
list_of_nodes_is_lost= true;
}
void release(bool, TABLE_SHARE *) override
{
if (root_size(&root) > mhnsw_cache_size)
reset_trx();
}
static MHNSW_Trx *get_from_thd(THD *thd, TABLE *table);
// it's okay in a transaction-local cache, there's no concurrent access
Hash_set<FVectorNode> &get_cache() { return node_cache; }
/* fake handlerton to use thd->ha_data and to get notified of commits */
static struct MHNSW_hton : public handlerton
{
MHNSW_hton()
{
db_type= DB_TYPE_HLINDEX_HELPER;
flags = HTON_NOT_USER_SELECTABLE | HTON_HIDDEN;
savepoint_offset= 0;
savepoint_set= [](handlerton *, THD *, void *){ return 0; };
savepoint_rollback_can_release_mdl= [](handlerton *, THD *){ return true; };
savepoint_rollback= do_savepoint_rollback;
commit= do_commit;
rollback= do_rollback;
}
static int do_commit(handlerton *, THD *thd, bool);
static int do_rollback(handlerton *, THD *thd, bool);
static int do_savepoint_rollback(handlerton *, THD *thd, void *);
} hton;
};
FVector::FVector(MHNSW_Context *ctx_, const void *vec_) : ctx(ctx_)
MHNSW_Trx::MHNSW_hton MHNSW_Trx::hton;
int MHNSW_Trx::MHNSW_hton::do_savepoint_rollback(handlerton *, THD *thd, void *)
{
make_vec(vec_);
for (auto trx= static_cast<MHNSW_Trx*>(thd_get_ha_data(thd, &hton));
trx; trx= trx->next)
trx->reset_trx();
return 0;
}
int MHNSW_Trx::MHNSW_hton::do_rollback(handlerton *, THD *thd, bool)
{
for (auto trx= static_cast<MHNSW_Trx*>(thd_get_ha_data(thd, &hton));
trx; trx= trx->next)
trx->~MHNSW_Trx();
thd_set_ha_data(current_thd, &hton, nullptr);
return 0;
}
void FVector::make_vec(const void *vec_)
int MHNSW_Trx::MHNSW_hton::do_commit(handlerton *, THD *thd, bool)
{
DBUG_ASSERT(ctx->vec_len);
vec= (float*)alloc_root(&ctx->root,
ctx->vec_len * sizeof(float) + SIMD_margin);
if (int off= ((intptr)vec) % SIMD_word)
vec += (SIMD_word - off) / sizeof(float);
memcpy(vec, vec_, ctx->byte_len);
for (size_t i=ctx->byte_len/sizeof(float); i < ctx->vec_len; i++)
vec[i]=0;
for (auto trx= static_cast<MHNSW_Trx*>(thd_get_ha_data(thd, &hton));
trx; trx= trx->next)
{
auto ctx= MHNSW_Context::get_from_share(trx->table_share, nullptr);
if (ctx)
{
mysql_rwlock_wrlock(&ctx->commit_lock);
if (trx->list_of_nodes_is_lost)
ctx->reset_ctx(trx->table_share);
else
{
// consider copying nodes from trx to shared cache when it makes sense
// for ann_benchmarks it does not
// also, consider flushing only changed nodes (a flag in the node)
for (FVectorNode &from : trx->get_cache())
if (FVectorNode *node= ctx->find_node(from.gref()))
node->vec= nullptr;
ctx->start= nullptr;
}
ctx->release(true, trx->table_share);
}
trx->~MHNSW_Trx();
}
thd_set_ha_data(current_thd, &hton, nullptr);
return 0;
}
MHNSW_Trx *MHNSW_Trx::get_from_thd(THD *thd, TABLE *table)
{
auto trx= static_cast<MHNSW_Trx*>(thd_get_ha_data(thd, &hton));
while (trx && trx->table_share != table->s) trx= trx->next;
if (!trx)
{
trx= new (&thd->transaction->mem_root) MHNSW_Trx(table);
trx->next= static_cast<MHNSW_Trx*>(thd_get_ha_data(thd, &hton));
thd_set_ha_data(thd, &hton, trx);
if (!trx->next)
{
bool all= thd_test_options(thd, OPTION_NOT_AUTOCOMMIT | OPTION_BEGIN);
trans_register_ha(thd, all, &hton, 0);
}
}
return trx;
}
MHNSW_Context *MHNSW_Context::get_from_share(TABLE_SHARE *share, TABLE *table)
{
mysql_mutex_lock(&share->LOCK_share);
auto ctx= static_cast<MHNSW_Context*>(share->hlindex->hlindex_data);
if (!ctx && table)
{
ctx= new (&share->hlindex->mem_root) MHNSW_Context(table);
share->hlindex->hlindex_data= ctx;
ctx->refcnt++;
}
if (ctx)
ctx->refcnt++;
mysql_mutex_unlock(&share->LOCK_share);
return ctx;
}
int MHNSW_Context::acquire(MHNSW_Context **ctx, TABLE *table, bool for_update)
{
TABLE *graph= table->hlindex;
THD *thd= table->in_use;
if (table->file->has_transactions() &&
(for_update || thd_get_ha_data(thd, &MHNSW_Trx::hton)))
*ctx= MHNSW_Trx::get_from_thd(thd, table);
else
{
*ctx= MHNSW_Context::get_from_share(table->s, table);
if (table->file->has_transactions())
mysql_rwlock_rdlock(&(*ctx)->commit_lock);
}
if ((*ctx)->start)
return 0;
if (int err= graph->file->ha_index_init(IDX_LAYER, 1))
return err;
int err= graph->file->ha_index_last(graph->record[0]);
graph->file->ha_index_end();
if (err)
return err;
graph->file->position(graph->record[0]);
(*ctx)->set_lengths(graph->field[FIELD_VEC]->value_length());
(*ctx)->start= (*ctx)->get_node(graph->file->ref);
return (*ctx)->start->load_from_record(graph);
}
/* copy the vector, aligned and padded for SIMD */
static float *make_vec(void *mem, const void *src, size_t src_len)
{
auto dst= (float*)MY_ALIGN((intptr)mem, SIMD_word);
memcpy(dst, src, src_len);
const size_t start= src_len/sizeof(float);
for (size_t i= start; i < MY_ALIGN(start, SIMD_floats); i++)
dst[i]=0.0f;
return dst;
}
FVector::FVector(MHNSW_Context *ctx, MEM_ROOT *root, const void *vec_)
{
vec= make_vec(alloc_root(root, ctx->vec_len * sizeof(float) + SIMD_word - 1),
vec_, ctx->byte_len);
}
float *FVectorNode::make_vec(const void *v)
{
return ::make_vec(tref() + tref_len(), v, ctx->byte_len);
}
FVectorNode::FVectorNode(MHNSW_Context *ctx_, const void *gref_)
: FVector(ctx_), tref(nullptr)
: FVector(), ctx(ctx_), stored(true)
{
gref= (uchar*)memdup_root(&ctx->root, gref_, get_gref_len());
memcpy(gref(), gref_, gref_len());
}
FVectorNode::FVectorNode(MHNSW_Context *ctx_, const void *tref_, size_t layer,
const void *vec_)
: FVector(ctx_, vec_), gref(gref_max)
: FVector(), ctx(ctx_), stored(false)
{
tref= (uchar*)memdup_root(&ctx->root, tref_, get_tref_len());
memset(gref(), 0xff, gref_len()); // important: larger than any real gref
memcpy(tref(), tref_, tref_len());
vec= make_vec(vec_);
alloc_neighborhood(layer);
for (size_t i= 0; i <= layer; i++)
closest_neighbor[i]= FLT_MAX;
}
float FVectorNode::distance_to(const FVector &other) const
{
const_cast<FVectorNode*>(this)->load();
#if __GNUC__ > 7
typedef float v8f __attribute__((vector_size(SIMD_word)));
v8f *p1= (v8f*)vec;
......@@ -180,319 +561,363 @@ float FVectorNode::distance_to(const FVector &other) const
int FVectorNode::alloc_neighborhood(size_t layer)
{
DBUG_ASSERT(!neighbors);
if (neighbors)
return 0;
max_layer= layer;
neighbors= new (&ctx->root) List<FVectorNode>[layer+1];
closest_neighbor= (float*)alloc_root(&ctx->root, (layer+1)*sizeof(*closest_neighbor));
memset(closest_neighbor, 0xff, (layer+1)*sizeof(*closest_neighbor)); // NaN
neighbors= (Neighborhood*)ctx->alloc_neighborhood(layer);
auto ptr= (FVectorNode**)(neighbors + (layer+1));
for (size_t i= 0; i <= layer; i++)
ptr= neighbors[i].init(ptr, ctx->max_neighbors(i));
return 0;
}
int FVectorNode::load()
int FVectorNode::load(TABLE *graph)
{
DBUG_ASSERT(gref);
if (tref)
if (likely(vec))
return 0;
TABLE *graph= ctx->table->hlindex;
if ((ctx->err= graph->file->ha_rnd_pos(graph->record[0], gref)))
return ctx->err;
return load_from_record();
DBUG_ASSERT(stored);
// trx: consider loading nodes from shared, when it makes sense
// for ann_benchmarks it does not
if (int err= graph->file->ha_rnd_pos(graph->record[0], gref()))
return err;
return load_from_record(graph);
}
int FVectorNode::load_from_record()
int FVectorNode::load_from_record(TABLE *graph)
{
TABLE *graph= ctx->table->hlindex;
DBUG_ASSERT(ctx->byte_len);
uint ticket= ctx->lock_node(this);
SCOPE_EXIT([this, ticket](){ ctx->unlock_node(ticket); });
if (vec)
return 0;
String buf, *v= graph->field[FIELD_TREF]->val_str(&buf);
if (unlikely(!v || v->length() != get_tref_len()))
return ctx->err= HA_ERR_CRASHED;
tref= (uchar*)memdup_root(&ctx->root, v->ptr(), v->length());
if (unlikely(!v || v->length() != tref_len()))
return my_errno= HA_ERR_CRASHED;
memcpy(tref(), v->ptr(), v->length());
v= graph->field[FIELD_VEC]->val_str(&buf);
if (unlikely(!v))
return ctx->err= HA_ERR_CRASHED;
return my_errno= HA_ERR_CRASHED;
DBUG_ASSERT(ctx->byte_len);
if (v->length() != ctx->byte_len)
return ctx->err= HA_ERR_CRASHED;
make_vec(v->ptr());
return my_errno= HA_ERR_CRASHED;
float *vec_ptr= make_vec(v->ptr());
size_t layer= graph->field[FIELD_LAYER]->val_int();
if (layer > 100) // 10e30 nodes at M=2, more at larger M's
return ctx->err= HA_ERR_CRASHED;
return my_errno= HA_ERR_CRASHED;
if (alloc_neighborhood(layer))
return ctx->err;
if (int err= alloc_neighborhood(layer))
return err;
v= graph->field[FIELD_NEIGHBORS]->val_str(&buf);
if (unlikely(!v))
return ctx->err= HA_ERR_CRASHED;
return my_errno= HA_ERR_CRASHED;
// <N> <closest distance> <gref> <gref> ... <N> <closest distance> ...etc...
uchar *ptr= (uchar*)v->ptr(), *end= ptr + v->length();
for (size_t i=0; i <= max_layer; i++)
{
if (unlikely(ptr >= end))
return ctx->err= HA_ERR_CRASHED;
return my_errno= HA_ERR_CRASHED;
size_t grefs= *ptr++;
if (unlikely(ptr + clo_nei_size + grefs * get_gref_len() > end))
return ctx->err= HA_ERR_CRASHED;
clo_nei_read(closest_neighbor[i], ptr);
for (ptr+= clo_nei_size; grefs--; ptr+= get_gref_len())
neighbors[i].push_back(ctx->get_node(ptr), &ctx->root);
if (unlikely(ptr + clo_nei_size + grefs * gref_len() > end))
return my_errno= HA_ERR_CRASHED;
clo_nei_read(neighbors[i].closest, ptr);
ptr+= clo_nei_size;
neighbors[i].num= grefs;
for (size_t j=0; j < grefs; j++, ptr+= gref_len())
neighbors[i].links[j]= ctx->get_node(ptr);
}
vec= vec_ptr; // must be done at the very end
return 0;
}
void FVectorNode::update_closest_neighbor(size_t layer, float dist,
const FVectorNode &other)
/* note that "closest" relation is asymmetric! */
void FVectorNode::push_neighbor(size_t layer, float dist, FVectorNode *other)
{
if (memcmp(gref, other.get_gref(), get_gref_len()) < 0 &&
closest_neighbor[layer] > dist)
closest_neighbor[layer]= dist;
DBUG_ASSERT(neighbors[layer].num < ctx->max_neighbors(layer));
neighbors[layer].links[neighbors[layer].num++]= other;
if (memcmp(gref(), other->gref(), gref_len()) < 0 &&
neighbors[layer].closest > dist)
neighbors[layer].closest= dist;
}
size_t FVectorNode::get_tref_len() const
{
return ctx->table->file->ref_length;
}
size_t FVectorNode::get_gref_len() const
{
return ctx->table->hlindex->file->ref_length;
}
size_t FVectorNode::tref_len() const { return ctx->tref_len; }
size_t FVectorNode::gref_len() const { return ctx->gref_len; }
uchar *FVectorNode::gref() const { return (uchar*)(this+1); }
uchar *FVectorNode::tref() const { return gref() + gref_len(); }
uchar *FVectorNode::get_key(const FVectorNode *elem, size_t *key_len, my_bool)
{
*key_len= elem->get_gref_len();
return elem->gref;
*key_len= elem->gref_len();
return elem->gref();
}
FVectorNode *MHNSW_Context::get_node(const void *gref)
/* one visited node during the search. caches the distance to target */
struct Visited : public Sql_alloc
{
FVectorNode *node= node_cache.find(gref, table->hlindex->file->ref_length);
if (!node)
FVectorNode *node;
const float distance_to_target;
Visited(FVectorNode *n, float d) : node(n), distance_to_target(d) {}
static int cmp(void *, const Visited* a, const Visited *b)
{
node= new (&root) FVectorNode(this, gref);
node_cache.insert(node);
return a->distance_to_target < b->distance_to_target ? -1 :
a->distance_to_target > b->distance_to_target ? 1 : 0;
}
return node;
}
};
/*
a factory to create Visited and keep track of already seen nodes
static int cmp_vec(const FVector *target, const FVectorNode *a, const FVectorNode *b)
note that PatternedSimdBloomFilter works in blocks of 8 elements,
so on insert they're accumulated in nodes[], on search the caller
provides 8 addresses at once. we record 0x0 as "seen" so that
the caller could pad the input with nullptr's
*/
class VisitedSet
{
float a_dist= a->distance_to(*target);
float b_dist= b->distance_to(*target);
MEM_ROOT *root;
const FVector &target;
PatternedSimdBloomFilter<FVectorNode> map;
const FVectorNode *nodes[8]= {0,0,0,0,0,0,0,0};
size_t idx= 1; // to record 0 in the filter
public:
uint count= 0;
VisitedSet(MEM_ROOT *root, const FVector &target, uint size) :
root(root), target(target), map(size, 0.01) {}
Visited *create(FVectorNode *node)
{
auto *v= new (root) Visited(node, node->distance_to(target));
insert(node);
count++;
return v;
}
void insert(const FVectorNode *n)
{
nodes[idx++]= n;
if (idx == 8) flush();
}
void flush() {
if (idx) map.Insert(nodes);
idx=0;
}
uint8_t seen(FVectorNode **nodes) { return map.Query(nodes); }
};
if (a_dist < b_dist)
return -1;
if (a_dist > b_dist)
return 1;
return 0;
}
static int select_neighbors(MHNSW_Context *ctx, size_t layer,
FVectorNode &target,
const List<FVectorNode> &candidates_unsafe,
/*
selects best neighbors from the list of candidates plus one extra candidate
one extra candidate is specified separately to avoid appending it to
the Neighborhood candidates, which might be already at its max size.
*/
static int select_neighbors(MHNSW_Context *ctx, TABLE *graph, size_t layer,
FVectorNode &target, const Neighborhood &candidates,
FVectorNode *extra_candidate,
size_t max_neighbor_connections)
{
Hash_set<FVectorNode> visited(PSI_INSTRUMENT_MEM, FVectorNode::get_key);
Queue<FVectorNode, const FVector> pq; // working queue
Queue<FVectorNode, const FVector> pq_discard; // queue for discarded candidates
/*
make a copy of candidates in case it's target.neighbors[layer].
because we're going to modify the latter below
*/
List<FVectorNode> candidates= candidates_unsafe;
List<FVectorNode> &neighbors= target.neighbors[layer];
const bool do_cn= max_neighbor_connections*ctx->vec_len > clo_nei_threshold;
Queue<Visited> pq; // working queue
neighbors.empty();
target.closest_neighbor[layer]= FLT_MAX;
if (pq.init(10000, false, Visited::cmp))
return my_errno= HA_ERR_OUT_OF_MEM;
if (pq.init(10000, 0, cmp_vec, &target) ||
pq_discard.init(10000, 0, cmp_vec, &target))
return ctx->err= HA_ERR_OUT_OF_MEM;
MEM_ROOT * const root= graph->in_use->mem_root;
auto discarded= (Visited**)my_safe_alloca(sizeof(Visited**)*max_neighbor_connections);
size_t discarded_num= 0;
Neighborhood &neighbors= target.neighbors[layer];
const bool do_cn= max_neighbor_connections*ctx->vec_len > clo_nei_threshold;
for (const FVectorNode &candidate : candidates)
for (size_t i=0; i < candidates.num; i++)
{
visited.insert(&candidate);
pq.push(&candidate);
FVectorNode *node= candidates.links[i];
if (int err= node->load(graph))
return err;
pq.push(new (root) Visited(node, node->distance_to(target)));
}
if (extra_candidate)
pq.push(new (root) Visited(extra_candidate, extra_candidate->distance_to(target)));
DBUG_ASSERT(pq.elements());
neighbors.push_back(pq.pop(), &ctx->root);
neighbors.empty();
while (pq.elements() && neighbors.elements < max_neighbor_connections)
while (pq.elements() && neighbors.num < max_neighbor_connections)
{
const FVectorNode *vec= pq.pop();
const float target_dist= vec->distance_to(target);
const float target_dista= target_dist / alpha;
Visited *vec= pq.pop();
FVectorNode * const node= vec->node;
const float target_dista= vec->distance_to_target / alpha;
bool discard= false;
if (do_cn)
discard= vec->closest_neighbor[layer] < target_dista;
discard= node->neighbors[layer].closest < target_dista;
else
{
for (const FVectorNode &neigh : neighbors)
{
if ((discard= vec->distance_to(neigh) < target_dista))
for (size_t i=0; i < neighbors.num; i++)
if ((discard= node->distance_to(*neighbors.links[i]) < target_dista))
break;
}
}
if (!discard)
{
neighbors.push_back(vec, &ctx->root);
target.update_closest_neighbor(layer, target_dist, *vec);
}
else if (pq_discard.elements() + neighbors.elements < max_neighbor_connections)
pq_discard.push(vec);
target.push_neighbor(layer, vec->distance_to_target, node);
else if (discarded_num + neighbors.num < max_neighbor_connections)
discarded[discarded_num++]= vec;
}
while (pq_discard.elements() && neighbors.elements < max_neighbor_connections)
{
const FVectorNode *vec= pq_discard.pop();
neighbors.push_back(vec, &ctx->root);
target.update_closest_neighbor(layer, vec->distance_to(target), *vec);
}
for (size_t i=0; i < discarded_num && neighbors.num < max_neighbor_connections; i++)
target.push_neighbor(layer, discarded[i]->distance_to_target, discarded[i]->node);
my_safe_afree(discarded, sizeof(Visited**)*max_neighbor_connections);
return 0;
}
int FVectorNode::save()
int FVectorNode::save(TABLE *graph)
{
TABLE *graph= ctx->table->hlindex;
DBUG_ASSERT(tref);
DBUG_ASSERT(vec);
DBUG_ASSERT(neighbors);
graph->field[FIELD_LAYER]->store(max_layer, false);
graph->field[FIELD_TREF]->set_notnull();
graph->field[FIELD_TREF]->store_binary(tref, get_tref_len());
graph->field[FIELD_TREF]->store_binary(tref(), tref_len());
graph->field[FIELD_VEC]->store_binary((uchar*)vec, ctx->byte_len);
size_t total_size= 0;
for (size_t i=0; i <= max_layer; i++)
total_size+= 1 + clo_nei_size + get_gref_len() * neighbors[i].elements;
total_size+= 1 + clo_nei_size + gref_len() * neighbors[i].num;
uchar *neighbor_blob= static_cast<uchar *>(my_safe_alloca(total_size));
uchar *ptr= neighbor_blob;
for (size_t i= 0; i <= max_layer; i++)
{
*ptr++= (uchar)(neighbors[i].elements);
clo_nei_store(ptr, closest_neighbor[i]);
*ptr++= (uchar)(neighbors[i].num);
clo_nei_store(ptr, neighbors[i].closest);
ptr+= clo_nei_size;
for (const auto &neigh: neighbors[i])
{
memcpy(ptr, neigh.get_gref(), get_gref_len());
ptr+= neigh.get_gref_len();
}
for (size_t j= 0; j < neighbors[i].num; j++, ptr+= gref_len())
memcpy(ptr, neighbors[i].links[j]->gref(), gref_len());
}
graph->field[FIELD_NEIGHBORS]->store_binary(neighbor_blob, total_size);
if (gref != gref_max)
int err;
if (stored)
{
ctx->err= graph->file->ha_rnd_pos(graph->record[1], gref);
if (!ctx->err)
if (!(err= graph->file->ha_rnd_pos(graph->record[1], gref())))
{
ctx->err= graph->file->ha_update_row(graph->record[1], graph->record[0]);
if (ctx->err == HA_ERR_RECORD_IS_THE_SAME)
ctx->err= 0;
err= graph->file->ha_update_row(graph->record[1], graph->record[0]);
if (err == HA_ERR_RECORD_IS_THE_SAME)
err= 0;
}
}
else
{
ctx->err= graph->file->ha_write_row(graph->record[0]);
err= graph->file->ha_write_row(graph->record[0]);
graph->file->position(graph->record[0]);
gref= (uchar*)memdup_root(&ctx->root, graph->file->ref, get_gref_len());
memcpy(gref(), graph->file->ref, gref_len());
stored= true;
ctx->cache_node(this);
}
my_safe_afree(neighbor_blob, total_size);
return ctx->err;
return err;
}
static int update_second_degree_neighbors(MHNSW_Context *ctx, size_t layer,
uint max_neighbors,
const FVectorNode &node)
static int update_second_degree_neighbors(MHNSW_Context *ctx, TABLE *graph,
size_t layer, FVectorNode *node)
{
for (FVectorNode &neigh: node.neighbors[layer])
const uint max_neighbors= ctx->max_neighbors(layer);
// it seems that one could update nodes in the gref order
// to avoid InnoDB deadlocks, but it produces no noticeable effect
for (size_t i=0; i < node->neighbors[layer].num; i++)
{
List<FVectorNode> &neighneighbors= neigh.neighbors[layer];
neighneighbors.push_back(&node, &ctx->root);
neigh.update_closest_neighbor(layer, neigh.distance_to(node), node);
if (neighneighbors.elements > max_neighbors)
{
if (select_neighbors(ctx, layer, neigh, neighneighbors, max_neighbors))
return ctx->err;
}
if (neigh.save())
return ctx->err;
FVectorNode *neigh= node->neighbors[layer].links[i];
Neighborhood &neighneighbors= neigh->neighbors[layer];
if (neighneighbors.num < max_neighbors)
neigh->push_neighbor(layer, neigh->distance_to(*node), node);
else
if (int err= select_neighbors(ctx, graph, layer, *neigh, neighneighbors,
node, max_neighbors))
return err;
if (int err= neigh->save(graph))
return err;
}
return 0;
}
static int search_layer(MHNSW_Context *ctx, const FVector &target,
const List<FVectorNode> &start_nodes,
uint max_candidates_return, size_t layer,
List<FVectorNode> *result)
static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector &target,
Neighborhood *start_nodes, uint ef, size_t layer,
Neighborhood *result)
{
DBUG_ASSERT(start_nodes.elements > 0);
DBUG_ASSERT(result->elements == 0);
DBUG_ASSERT(start_nodes->num > 0);
result->empty();
MEM_ROOT * const root= graph->in_use->mem_root;
Queue<FVectorNode, const FVector> candidates;
Queue<FVectorNode, const FVector> best;
Hash_set<FVectorNode> visited(PSI_INSTRUMENT_MEM, FVectorNode::get_key);
Queue<Visited> candidates;
Queue<Visited> best;
candidates.init(10000, false, cmp_vec, &target);
best.init(max_candidates_return, true, cmp_vec, &target);
// WARNING! heuristic here
const double est_heuristic= 8 * std::sqrt(ctx->max_neighbors(layer));
const uint est_size= est_heuristic * std::pow(ef, ctx->get_ef_power());
VisitedSet visited(root, target, est_size);
for (const FVectorNode &node : start_nodes)
candidates.init(10000, false, Visited::cmp);
best.init(ef, true, Visited::cmp);
for (size_t i=0; i < start_nodes->num; i++)
{
candidates.push(&node);
if (best.elements() < max_candidates_return)
best.push(&node);
else if (node.distance_to(target) > best.top()->distance_to(target))
best.replace_top(&node);
visited.insert(&node);
Visited *v= visited.create(start_nodes->links[i]);
candidates.push(v);
if (best.elements() < ef)
best.push(v);
else if (v->distance_to_target < best.top()->distance_to_target)
best.replace_top(v);
}
float furthest_best= best.top()->distance_to(target);
float furthest_best= best.top()->distance_to_target;
while (candidates.elements())
{
const FVectorNode &cur_vec= *candidates.pop();
float cur_distance= cur_vec.distance_to(target);
if (cur_distance > furthest_best && best.elements() == max_candidates_return)
{
break; // All possible candidates are worse than what we have.
// Can't get better.
}
const Visited &cur= *candidates.pop();
if (cur.distance_to_target > furthest_best && best.elements() == ef)
break; // All possible candidates are worse than what we have
visited.flush();
for (const FVectorNode &neigh: cur_vec.neighbors[layer])
Neighborhood &neighbors= cur.node->neighbors[layer];
FVectorNode **links= neighbors.links, **end= links + neighbors.num;
for (; links < end; links+= 8)
{
if (visited.find(&neigh))
uint8_t res= visited.seen(links);
if (res == 0xff)
continue;
visited.insert(&neigh);
if (best.elements() < max_candidates_return)
for (size_t i= 0; i < 8; i++)
{
if (res & (1 << i))
continue;
if (int err= links[i]->load(graph))
return err;
Visited *v= visited.create(links[i]);
if (best.elements() < ef)
{
candidates.push(&neigh);
best.push(&neigh);
furthest_best= best.top()->distance_to(target);
candidates.push(v);
best.push(v);
furthest_best= best.top()->distance_to_target;
}
else if (neigh.distance_to(target) < furthest_best)
else if (v->distance_to_target < furthest_best)
{
best.replace_top(&neigh);
candidates.push(&neigh);
furthest_best= best.top()->distance_to(target);
best.replace_top(v);
candidates.push(v);
furthest_best= best.top()->distance_to_target;
}
}
}
}
if (ef > 1 && visited.count*2 > est_size)
ctx->set_ef_power(std::log(visited.count*2/est_heuristic) / std::log(ef));
while (best.elements())
result->push_front(best.pop(), &ctx->root);
result->num= best.elements();
for (FVectorNode **links= result->links + result->num; best.elements();)
*--links= best.pop()->node;
return 0;
}
......@@ -503,7 +928,7 @@ static int bad_value_on_insert(Field *f)
my_error(ER_TRUNCATED_WRONG_VALUE_FOR_FIELD, MYF(0), "vector", "...",
f->table->s->db.str, f->table->s->table_name.str, f->field_name.str,
f->table->in_use->get_stmt_da()->current_row_for_warning());
return HA_ERR_GENERIC;
return my_errno= HA_ERR_GENERIC;
}
......@@ -514,7 +939,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
MY_BITMAP *old_map= dbug_tmp_use_all_columns(table, &table->read_set);
Field *vec_field= keyinfo->key_part->field;
String buf, *res= vec_field->val_str(&buf);
MHNSW_Context ctx(table, vec_field);
MHNSW_Context *ctx;
/* metadata are checked on open */
DBUG_ASSERT(graph);
......@@ -532,93 +957,78 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
if (res->length() == 0 || res->length() % 4)
return bad_value_on_insert(vec_field);
const double NORMALIZATION_FACTOR= 1 / std::log(thd->variables.mhnsw_max_edges_per_node);
table->file->position(table->record[0]);
if (int err= graph->file->ha_index_init(IDX_LAYER, 1))
return err;
ctx.err= graph->file->ha_index_last(graph->record[0]);
graph->file->ha_index_end();
if (ctx.err)
int err= MHNSW_Context::acquire(&ctx, table, true);
SCOPE_EXIT([ctx, table](){ ctx->release(table); });
if (err)
{
if (ctx.err != HA_ERR_END_OF_FILE)
return ctx.err;
ctx.err= 0;
if (err != HA_ERR_END_OF_FILE)
return err;
// First insert!
ctx.set_lengths(res->length());
FVectorNode target(&ctx, table->file->ref, 0, res->ptr());
return target.save();
ctx->set_lengths(res->length());
FVectorNode *target= new (ctx->alloc_node())
FVectorNode(ctx, table->file->ref, 0, res->ptr());
if (!((err= target->save(graph))))
ctx->start= target;
return err;
}
longlong max_layer= graph->field[FIELD_LAYER]->val_int();
List<FVectorNode> candidates;
List<FVectorNode> start_nodes;
graph->file->position(graph->record[0]);
FVectorNode *start_node= ctx.get_node(graph->file->ref);
if (ctx->byte_len != res->length())
return bad_value_on_insert(vec_field);
if (start_nodes.push_back(start_node, &ctx.root))
return HA_ERR_OUT_OF_MEM;
size_t ef= ctx->max_neighbors(0) * ef_construction_multiplier;
Neighborhood candidates, start_nodes;
candidates.init(thd->alloc<FVectorNode*>(ef + 7), ef);
start_nodes.init(thd->alloc<FVectorNode*>(ef + 7), ef);
start_nodes.links[start_nodes.num++]= ctx->start;
ctx.set_lengths(graph->field[FIELD_VEC]->value_length());
if (int err= start_node->load_from_record())
return err;
const double NORMALIZATION_FACTOR= 1 / std::log(ctx->M);
double log= -std::log(my_rnd(&thd->rand)) * NORMALIZATION_FACTOR;
const longlong max_layer= start_nodes.links[0]->max_layer;
longlong target_layer= std::min<longlong>(std::floor(log), max_layer + 1);
longlong cur_layer;
if (ctx.byte_len != res->length())
return bad_value_on_insert(vec_field);
FVectorNode *target= new (ctx->alloc_node())
FVectorNode(ctx, table->file->ref, target_layer, res->ptr());
if (int err= graph->file->ha_rnd_init(0))
return err;
SCOPE_EXIT([graph](){ graph->file->ha_rnd_end(); });
double new_num= my_rnd(&thd->rand);
double log= -std::log(new_num) * NORMALIZATION_FACTOR;
longlong new_node_layer= std::min<longlong>(std::floor(log), max_layer + 1);
longlong cur_layer;
FVectorNode target(&ctx, table->file->ref, new_node_layer, res->ptr());
for (cur_layer= max_layer; cur_layer > new_node_layer; cur_layer--)
for (cur_layer= max_layer; cur_layer > target_layer; cur_layer--)
{
if (search_layer(&ctx, target, start_nodes, 1, cur_layer, &candidates))
return ctx.err;
start_nodes= candidates;
candidates.empty();
if (int err= search_layer(ctx, graph, *target, &start_nodes, 1, cur_layer,
&candidates))
return err;
std::swap(start_nodes, candidates);
}
for (; cur_layer >= 0; cur_layer--)
{
uint max_neighbors= (cur_layer == 0) // heuristics from the paper
? thd->variables.mhnsw_max_edges_per_node * 2
: thd->variables.mhnsw_max_edges_per_node;
if (search_layer(&ctx, target, start_nodes,
static_cast<uint>(ef_construction_multiplier * max_neighbors),
uint max_neighbors= ctx->max_neighbors(cur_layer);
if (int err= search_layer(ctx, graph, *target, &start_nodes,
ef_construction_multiplier * max_neighbors,
cur_layer, &candidates))
return ctx.err;
return err;
if (select_neighbors(&ctx, cur_layer, target, candidates, max_neighbors))
return ctx.err;
start_nodes= candidates;
candidates.empty();
if (int err= select_neighbors(ctx, graph, cur_layer, *target, candidates,
0, max_neighbors))
return err;
std::swap(start_nodes, candidates);
}
if (target.save())
return ctx.err;
if (int err= target->save(graph))
return err;
if (target_layer > max_layer)
ctx->start= target;
for (longlong cur_layer= new_node_layer; cur_layer >= 0; cur_layer--)
for (cur_layer= target_layer; cur_layer >= 0; cur_layer--)
{
uint max_neighbors= (cur_layer == 0) // heuristics from the paper
? thd->variables.mhnsw_max_edges_per_node * 2
: thd->variables.mhnsw_max_edges_per_node;
// XXX do only one ha_update_row() per node
if (update_second_degree_neighbors(&ctx, cur_layer, max_neighbors, target))
return ctx.err;
if (int err= update_second_degree_neighbors(ctx, graph, cur_layer, target))
return err;
}
dbug_tmp_restore_column_map(&table->read_set, old_map);
......@@ -631,81 +1041,69 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
{
THD *thd= table->in_use;
TABLE *graph= table->hlindex;
Field *vec_field= keyinfo->key_part->field;
Item_func_vec_distance *fun= (Item_func_vec_distance *)dist;
String buf, *res= fun->get_const_arg()->val_str(&buf);
handler *h= table->file;
MHNSW_Context ctx(table, vec_field);
MHNSW_Context *ctx;
if (int err= h->ha_rnd_init(0))
if (int err= table->file->ha_rnd_init(0))
return err;
if (int err= graph->file->ha_index_init(0, 1))
if (int err= MHNSW_Context::acquire(&ctx, table, false))
return err;
ctx.err= graph->file->ha_index_last(graph->record[0]);
graph->file->ha_index_end();
SCOPE_EXIT([ctx, table](){ ctx->release(table); });
if (ctx.err)
return ctx.err;
longlong max_layer= graph->field[FIELD_LAYER]->val_int();
// this auto-scales ef with the limit, providing more adequate
// behavior than a fixed ef
size_t ef= limit * thd->variables.mhnsw_limit_multiplier;
List<FVectorNode> candidates;
List<FVectorNode> start_nodes;
graph->file->position(graph->record[0]);
FVectorNode *start_node= ctx.get_node(graph->file->ref);
Neighborhood candidates, start_nodes;
candidates.init(thd->alloc<FVectorNode*>(ef + 7), ef);
start_nodes.init(thd->alloc<FVectorNode*>(ef + 7), ef);
// one could put all max_layer nodes in start_nodes
// but it has no effect of the recall or speed
if (start_nodes.push_back(start_node, &ctx.root))
return HA_ERR_OUT_OF_MEM;
ctx.set_lengths(graph->field[FIELD_VEC]->value_length());
if (int err= start_node->load_from_record())
return err;
start_nodes.links[start_nodes.num++]= ctx->start;
/*
if the query vector is NULL or invalid, VEC_DISTANCE will return
NULL, so the result is basically unsorted, we can return rows
in any order. For simplicity let's sort by the start_node.
*/
if (!res || ctx.byte_len != res->length())
(res= &buf)->set((char*)start_node->vec, ctx.byte_len, &my_charset_bin);
if (!res || ctx->byte_len != res->length())
(res= &buf)->set((char*)start_nodes.links[0]->vec, ctx->byte_len, &my_charset_bin);
const longlong max_layer= start_nodes.links[0]->max_layer;
FVector target(ctx, thd->mem_root, res->ptr());
if (int err= graph->file->ha_rnd_init(0))
return err;
SCOPE_EXIT([graph](){ graph->file->ha_rnd_end(); });
FVector target(&ctx, res->ptr());
// this auto-scales ef with the limit, providing more adequate
// behavior than a fixed ef_search
uint ef_search= static_cast<uint>(limit * thd->variables.mhnsw_limit_multiplier);
for (size_t cur_layer= max_layer; cur_layer > 0; cur_layer--)
{
if (search_layer(&ctx, target, start_nodes, 1, cur_layer, &candidates))
return ctx.err;
start_nodes= candidates;
candidates.empty();
if (int err= search_layer(ctx, graph, target, &start_nodes, 1, cur_layer,
&candidates))
return err;
std::swap(start_nodes, candidates);
}
if (search_layer(&ctx, target, start_nodes, ef_search, 0, &candidates))
return ctx.err;
if (int err= search_layer(ctx, graph, target, &start_nodes, ef, 0,
&candidates))
return err;
size_t context_size=limit * h->ref_length + sizeof(ulonglong);
if (limit > candidates.num)
limit= candidates.num;
size_t context_size=limit * ctx->tref_len + sizeof(ulonglong);
char *context= thd->alloc(context_size);
graph->context= context;
*(ulonglong*)context= limit;
context+= context_size;
while (limit--)
for (size_t i=0; limit--; i++)
{
context-= h->ref_length;
memcpy(context, candidates.pop()->get_tref(), h->ref_length);
context-= ctx->tref_len;
memcpy(context, candidates.links[i]->tref(), ctx->tref_len);
}
DBUG_ASSERT(context - sizeof(ulonglong) == graph->context);
......@@ -720,7 +1118,17 @@ int mhnsw_next(TABLE *table)
ref+= sizeof(ulonglong) + (--*limit) * table->file->ref_length;
return table->file->ha_rnd_pos(table->record[0], ref);
}
return HA_ERR_END_OF_FILE;
return my_errno= HA_ERR_END_OF_FILE;
}
void mhnsw_free(TABLE_SHARE *share)
{
TABLE_SHARE *graph_share= share->hlindex;
if (!graph_share->hlindex_data)
return;
static_cast<MHNSW_Context*>(graph_share->hlindex_data)->~MHNSW_Context();
graph_share->hlindex_data= 0;
}
const LEX_CSTRING mhnsw_hlindex_table_def(THD *thd, uint ref_length)
......
......@@ -25,3 +25,6 @@ const LEX_CSTRING mhnsw_hlindex_table_def(THD *thd, uint ref_length);
int mhnsw_insert(TABLE *table, KEY *keyinfo);
int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit);
int mhnsw_next(TABLE *table);
void mhnsw_free(TABLE_SHARE *share);
extern ulonglong mhnsw_cache_size;
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment