vector_mhnsw.cc 44.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
/*
   Copyright (c) 2024, MariaDB plc

   This program is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation; version 2 of the License.

   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with this program; if not, write to the Free Software
   Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1335  USA
*/

#include <my_global.h>
19
#include "key.h"                                // key_copy()
20
#include "create_options.h"
21
#include "vector_mhnsw.h"
22
#include "item_vectorfunc.h"
Sergei Golubchik's avatar
Sergei Golubchik committed
23
#include <scope.h>
24 25 26
#include <my_atomic_wrapper.h>
#include "bloom_filters.h"

Sergei Golubchik's avatar
Sergei Golubchik committed
27 28 29 30
#include <random>
#include <eigen3/Eigen/Dense>
using namespace Eigen;

31 32
// Algorithm parameters
static constexpr float alpha = 1.1f;
33
static constexpr float generosity = 1.1f;
34
static constexpr uint ef_construction= 10;
Sergei Golubchik's avatar
Sergei Golubchik committed
35 36 37 38 39 40 41 42 43 44
static constexpr size_t subdist_part= 192;
static constexpr float subdist_margin= 1.1f;

static inline bool use_subdist_heuristic(uint M, size_t vec_len, ha_rows rows)
{
  if (vec_len < subdist_part * 2)
    return false;
  double logrows= rows < 100000 ? std::log(100000) : std::log(rows); // safety
  return M >= 8e5/logrows/logrows/(vec_len - subdist_part);
}
45

46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
static ulonglong mhnsw_cache_size;
static MYSQL_SYSVAR_ULONGLONG(cache_size, mhnsw_cache_size,
       PLUGIN_VAR_RQCMDARG, "Size of the cache for the MHNSW vector index",
       nullptr, nullptr, 16*1024*1024, 1024*1024, SIZE_T_MAX, 1);
static MYSQL_THDVAR_UINT(min_limit, PLUGIN_VAR_RQCMDARG,
       "Defines the minimal number of result candidates to look for in the "
       "vector index for ORDER BY ... LIMIT N queries. The search will never "
       "search for less rows than that, even if LIMIT is smaller. "
       "This notably improves the search quality at low LIMIT values, "
       "at the expense of search time", nullptr, nullptr, 20, 1, 65535, 1);
static MYSQL_THDVAR_UINT(max_edges_per_node, PLUGIN_VAR_RQCMDARG,
       "Larger values means slower INSERT, larger index size and higher "
       "memory consumption, but better search results",
       nullptr, nullptr, 6, 3, 200, 1);

Sergei Golubchik's avatar
Sergei Golubchik committed
61 62 63 64 65 66 67
enum metric_type : uint { EUCLIDEAN, COSINE };
static const char *distance_function_names[]= { "euclidean", "cosine", nullptr };
static TYPELIB distance_functions= CREATE_TYPELIB_FOR(distance_function_names);
static MYSQL_THDVAR_ENUM(distance_function, PLUGIN_VAR_RQCMDARG,
       "Distance function to build the vector index for",
       nullptr, nullptr, EUCLIDEAN, &distance_functions);

68 69 70
struct ha_index_option_struct
{
  uint M;
Sergei Golubchik's avatar
Sergei Golubchik committed
71
  metric_type metric;
72 73
};

74 75 76 77
enum Graph_table_fields {
  FIELD_LAYER, FIELD_TREF, FIELD_VEC, FIELD_NEIGHBORS
};
enum Graph_table_indices {
Sergei Golubchik's avatar
Sergei Golubchik committed
78
  IDX_TREF, IDX_LAYER
79 80
};

81
class MHNSW_Context;
82 83 84
class FVectorNode;

/*
85
  One vector, an array of coordinates in ctx->vec_len dimensions
86
*/
87 88
#pragma pack(push, 1)
struct FVector
89
{
90
  static constexpr size_t data_header= sizeof(float);
Sergei Golubchik's avatar
Sergei Golubchik committed
91
  static constexpr size_t alloc_header= data_header + sizeof(float)*2;
92

Sergei Golubchik's avatar
Sergei Golubchik committed
93
  float abs2, subabs2, scale;
94 95 96 97 98 99 100 101 102 103
  int16_t dims[4];

  uchar *data() const { return (uchar*)(&scale); }

  static size_t data_size(size_t n)
  { return data_header + n*2; }

  static size_t data_to_value_size(size_t data_size)
  { return (data_size - data_header)*2; }

Sergei Golubchik's avatar
Sergei Golubchik committed
104
  static const FVector *create(const MHNSW_Context *ctx, void *mem, const void *src);
105

Sergei Golubchik's avatar
Sergei Golubchik committed
106
  void postprocess(bool use_subdist, size_t vec_len)
107
  {
Sergei Golubchik's avatar
Sergei Golubchik committed
108
    int16_t *d= dims;
109
    fix_tail(vec_len);
Sergei Golubchik's avatar
Sergei Golubchik committed
110 111 112 113 114 115 116 117 118
    if (use_subdist)
    {
      subabs2= scale * scale * dot_product(d, d, subdist_part) / 2;
      d+= subdist_part;
      vec_len-= subdist_part;
    }
    else
      subabs2= 0;
    abs2= subabs2 + scale * scale * dot_product(d, d, vec_len) / 2;
119 120 121 122 123 124
  }

#ifdef INTEL_SIMD_IMPLEMENTATION
  /************* AVX2 *****************************************************/
  static constexpr size_t AVX2_bytes= 256/8;
  static constexpr size_t AVX2_dims= AVX2_bytes/sizeof(int16_t);
Sergei Golubchik's avatar
Sergei Golubchik committed
125
  static_assert(subdist_part % AVX2_dims == 0);
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156

  INTEL_SIMD_IMPLEMENTATION
  static float dot_product(const int16_t *v1, const int16_t *v2, size_t len)
  {
    typedef float v8f __attribute__((vector_size(AVX2_bytes)));
    union { v8f v; __m256 i; } tmp;
    __m256i *p1= (__m256i*)v1;
    __m256i *p2= (__m256i*)v2;
    v8f d= {0};
    for (size_t i= 0; i < (len + AVX2_dims-1)/AVX2_dims; p1++, p2++, i++)
    {
      tmp.i= _mm256_cvtepi32_ps(_mm256_madd_epi16(*p1, *p2));
      d+= tmp.v;
    }
    return d[0] + d[1] + d[2] + d[3] + d[4] + d[5] + d[6] + d[7];
  }

  INTEL_SIMD_IMPLEMENTATION
  static size_t alloc_size(size_t n)
  { return alloc_header + MY_ALIGN(n*2, AVX2_bytes) + AVX2_bytes - 1; }

  INTEL_SIMD_IMPLEMENTATION
  static FVector *align_ptr(void *ptr)
  { return (FVector*)(MY_ALIGN(((intptr)ptr) + alloc_header, AVX2_bytes)
                      - alloc_header); }

  INTEL_SIMD_IMPLEMENTATION
  void fix_tail(size_t vec_len)
  {
    bzero(dims + vec_len, (MY_ALIGN(vec_len, AVX2_dims) - vec_len)*2);
  }
Sergei Golubchik's avatar
Sergei Golubchik committed
157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187

  /************* AVX512 ****************************************************/
  static constexpr size_t AVX512_bytes= 512/8;
  static constexpr size_t AVX512_dims= AVX512_bytes/sizeof(int16_t);
  static_assert(subdist_part % AVX512_dims == 0);

  __attribute__ ((target ("avx512f,avx512bw")))
  static float dot_product(const int16_t *v1, const int16_t *v2, size_t len)
  {
    __m512i *p1= (__m512i*)v1;
    __m512i *p2= (__m512i*)v2;
    __m512 d= _mm512_setzero_ps();
    for (size_t i= 0; i < (len + AVX512_dims-1)/AVX512_dims; p1++, p2++, i++)
      d= _mm512_add_ps(d, _mm512_cvtepi32_ps(_mm512_madd_epi16(*p1, *p2)));
    return _mm512_reduce_add_ps(d);
  }

  __attribute__ ((target ("avx512f,avx512bw")))
  static size_t alloc_size(size_t n)
  { return alloc_header + MY_ALIGN(n*2, AVX512_bytes) + AVX512_bytes - 1; }

  __attribute__ ((target ("avx512f,avx512bw")))
  static FVector *align_ptr(void *ptr)
  { return (FVector*)(MY_ALIGN(((intptr)ptr) + alloc_header, AVX512_bytes)
                      - alloc_header); }

  __attribute__ ((target ("avx512f,avx512bw")))
  void fix_tail(size_t vec_len)
  {
    bzero(dims + vec_len, (MY_ALIGN(vec_len, AVX512_dims) - vec_len)*2);
  }
188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
#endif

  /************* no-SIMD default ******************************************/
  DEFAULT_IMPLEMENTATION
  static float dot_product(const int16_t *v1, const int16_t *v2, size_t len)
  {
    int64_t d= 0;
    for (size_t i= 0; i < len; i++)
      d+= int32_t(v1[i]) * int32_t(v2[i]);
    return static_cast<float>(d);
  }

  DEFAULT_IMPLEMENTATION
  static size_t alloc_size(size_t n) { return alloc_header + n*2; }

  DEFAULT_IMPLEMENTATION
  static FVector *align_ptr(void *ptr) { return (FVector*)ptr; }

  DEFAULT_IMPLEMENTATION
  void fix_tail(size_t) {  }

Sergei Golubchik's avatar
Sergei Golubchik committed
209 210 211 212 213 214 215 216 217 218 219 220
  float distance_greater_than(const FVector *other, size_t vec_len, float than) const
  {
    float k = scale * other->scale;
    float dp= dot_product(dims, other->dims, subdist_part);
    float subdist= (subabs2 + other->subabs2 - k * dp)/subdist_part*vec_len;
    if (subdist > than*subdist_margin)
      return subdist;
    dp+= dot_product(dims+subdist_part, other->dims+subdist_part,
                     vec_len - subdist_part);
    return abs2 + other->abs2 - k * dp;
  }

221 222 223 224 225
  float distance_to(const FVector *other, size_t vec_len) const
  {
    return abs2 + other->abs2 - scale * other->scale *
           dot_product(dims, other->dims, vec_len);
  }
226
};
227
#pragma pack(pop)
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253

/*
  An array of pointers to graph nodes

  It's mainly used to store all neighbors of a given node on a given layer.

  An array is fixed size, 2*M for the zero layer, M for other layers
  see MHNSW_Context::max_neighbors().

  Number of neighbors is zero-padded to multiples of 8 (for SIMD Bloom filter).

  Also used as a simply array of nodes in search_layer, the array size
  then is defined by ef or efConstruction.
*/
struct Neighborhood: public Sql_alloc
{
  FVectorNode **links;
  size_t num;
  FVectorNode **init(FVectorNode **ptr, size_t n)
  {
    num= 0;
    links= ptr;
    n= MY_ALIGN(n, 8);
    bzero(ptr, n*sizeof(*ptr));
    return ptr + n;
  }
254 255
};

256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278

/*
  One node in a graph = one row in the graph table

  stores a vector itself, ref (= position) in the graph (= hlindex)
  table, a ref in the main table, and an array of Neighborhood's, one
  per layer.

  It's lazily initialized, may know only gref, everything else is
  loaded on demand.

  On the other hand, on INSERT the new node knows everything except
  gref - which only becomes known after ha_write_row.

  Allocated on memroot in two chunks. One is the same size for all nodes
  and stores FVectorNode object, gref, tref, and vector. The second
  stores neighbors, all Neighborhood's together, its size depends
  on the number of layers this node is on.

  There can be millions of nodes in the cache and the cache size
  is constrained by mhnsw_cache_size, so every byte matters here
*/
#pragma pack(push, 1)
279
class FVectorNode
280 281
{
private:
282
  MHNSW_Context *ctx;
283

284
  const FVector *make_vec(const void *v);
285
  int alloc_neighborhood(uint8_t layer);
286
public:
287
  const FVector *vec= nullptr;
288 289
  Neighborhood *neighbors= nullptr;
  uint8_t max_layer;
290
  bool stored:1, deleted:1;
291 292 293 294

  FVectorNode(MHNSW_Context *ctx_, const void *gref_);
  FVectorNode(MHNSW_Context *ctx_, const void *tref_, uint8_t layer,
              const void *vec_);
295
  float distance_to(const FVector *other) const;
Sergei Golubchik's avatar
Sergei Golubchik committed
296
  float distance_greater_than(const FVector *other, float than) const;
297 298 299 300 301 302 303 304
  int load(TABLE *graph);
  int load_from_record(TABLE *graph);
  int save(TABLE *graph);
  size_t tref_len() const;
  size_t gref_len() const;
  uchar *gref() const;
  uchar *tref() const;
  void push_neighbor(size_t layer, FVectorNode *v);
305 306

  static uchar *get_key(const FVectorNode *elem, size_t *key_len, my_bool);
307
};
308
#pragma pack(pop)
309

310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327
/*
  Shared algorithm context. The graph.

  Stored in TABLE_SHARE and on TABLE_SHARE::mem_root.
  Stores the complete graph in MHNSW_Context::root,
  The mapping gref->FVectorNode is in the node_cache.
  Both root and node_cache are protected by a cache_lock, but it's
  needed when loading nodes and is not used when the whole graph is in memory.
  Graph can be traversed concurrently by different threads, as traversal
  changes neither nodes nor the ctx.
  Nodes can be loaded concurrently by different threads, this is protected
  by a partitioned node_lock.
  reference counter allows flushing the graph without interrupting
  concurrent searches.
  MyISAM automatically gets exclusive write access because of the TL_WRITE,
  but InnoDB has to use a dedicated ctx->commit_lock for that
*/
class MHNSW_Context : public Sql_alloc
328
{
329 330 331 332 333 334 335 336 337 338 339 340
  std::atomic<uint> refcnt{0};
  mysql_mutex_t cache_lock;
  mysql_mutex_t node_lock[8];

  void cache_internal(FVectorNode *node)
  {
    DBUG_ASSERT(node->stored);
    node_cache.insert(node);
  }
  void *alloc_node_internal()
  {
    return alloc_root(&root, sizeof(FVectorNode) + gref_len + tref_len
341
                      + FVector::alloc_size(vec_len));
342 343
  }

Sergei Golubchik's avatar
Sergei Golubchik committed
344 345 346 347 348 349 350 351 352 353 354 355 356
  /*
    Despite the name, the matrix isn't random, it's deterministic, because
    the random value generator is seeded with Q.rows().
  */
  static void generate_random_orthogonal_matrix(Map<MatrixXf> &Q)
  {
    std::mt19937 rnd((unsigned int)Q.rows());
    std::normal_distribution<float> gauss(0, 1);
    MatrixXf A(MatrixXf::NullaryExpr(Q.rows(), Q.rows(), [&](){ return gauss(rnd); }));
    HouseholderQR<MatrixXf> qr(A);
    Q = qr.householderQ();
  }

357
protected:
358
  MEM_ROOT root;
359 360 361 362
  Hash_set<FVectorNode> node_cache{PSI_INSTRUMENT_MEM, FVectorNode::get_key};

public:
  mysql_rwlock_t commit_lock;
363
  size_t vec_len= 0;
364
  size_t byte_len= 0;
365 366
  Atomic_relaxed<double> ef_power{0.6}; // for the bloom filter size heuristic
  FVectorNode *start= 0;
Sergei Golubchik's avatar
Sergei Golubchik committed
367
  Map<MatrixXf> randomizer;
368 369 370
  const uint tref_len;
  const uint gref_len;
  const uint M;
Sergei Golubchik's avatar
Sergei Golubchik committed
371
  metric_type metric;
Sergei Golubchik's avatar
Sergei Golubchik committed
372
  bool use_subdist;
373 374

  MHNSW_Context(TABLE *t)
Sergei Golubchik's avatar
Sergei Golubchik committed
375 376
    : randomizer(nullptr, 1, 1),
      tref_len(t->file->ref_length),
377
      gref_len(t->hlindex->file->ref_length),
Sergei Golubchik's avatar
Sergei Golubchik committed
378 379
      M(t->s->key_info[t->s->keys].option_struct->M),
      metric(t->s->key_info[t->s->keys].option_struct->metric)
380 381 382 383 384 385 386
  {
    mysql_rwlock_init(PSI_INSTRUMENT_ME, &commit_lock);
    mysql_mutex_init(PSI_INSTRUMENT_ME, &cache_lock, MY_MUTEX_INIT_FAST);
    for (uint i=0; i < array_elements(node_lock); i++)
      mysql_mutex_init(PSI_INSTRUMENT_ME, node_lock + i, MY_MUTEX_INIT_SLOW);
    init_alloc_root(PSI_INSTRUMENT_MEM, &root, 1024*1024, 0, MYF(0));
  }
387

388 389 390 391 392 393 394 395
  virtual ~MHNSW_Context()
  {
    free_root(&root, MYF(0));
    mysql_rwlock_destroy(&commit_lock);
    mysql_mutex_destroy(&cache_lock);
    for (size_t i=0; i < array_elements(node_lock); i++)
      mysql_mutex_destroy(node_lock + i);
  }
396

397
  uint lock_node(FVectorNode *ptr)
398
  {
399 400 401 402 403
    ulong nr1= 1, nr2= 4;
    my_hash_sort_bin(0, (uchar*)&ptr, sizeof(ptr), &nr1, &nr2);
    uint ticket= nr1 % array_elements(node_lock);
    mysql_mutex_lock(node_lock + ticket);
    return ticket;
404
  }
405

406
  void unlock_node(uint ticket)
407
  {
408 409 410 411 412 413
    mysql_mutex_unlock(node_lock + ticket);
  }

  uint max_neighbors(size_t layer) const
  {
    return (layer ? 1 : 2) * M; // heuristic from the paper
414 415
  }

Sergei Golubchik's avatar
Sergei Golubchik committed
416
  void set_lengths(size_t len, ha_rows min_rows)
417
  {
Sergei Golubchik's avatar
Sergei Golubchik committed
418 419 420 421 422 423 424 425 426 427 428 429 430 431 432
    mysql_mutex_lock(node_lock); // let's hijack this mutex, just once
    if (!byte_len)
    {
      byte_len= len;
      vec_len= len / sizeof(float);
      if ((use_subdist= use_subdist_heuristic(M, vec_len, min_rows)))
      {
        mysql_mutex_lock(&cache_lock);
        void *data= alloc_root(&root, sizeof(float)*vec_len*vec_len);
        mysql_mutex_unlock(&cache_lock);
        new (&randomizer) Map<MatrixXf>((float*)data, vec_len, vec_len);
        generate_random_orthogonal_matrix(randomizer);
      }
    }
    mysql_mutex_unlock(node_lock);
433
  }
434 435 436 437

  static int acquire(MHNSW_Context **ctx, TABLE *table, bool for_update);
  static MHNSW_Context *get_from_share(TABLE_SHARE *share, TABLE *table);

Sergei Golubchik's avatar
Sergei Golubchik committed
438
  virtual void reset(TABLE_SHARE *share)
439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458
  {
    mysql_mutex_lock(&share->LOCK_share);
    if (static_cast<MHNSW_Context*>(share->hlindex->hlindex_data) == this)
    {
      share->hlindex->hlindex_data= nullptr;
      --refcnt;
    }
    mysql_mutex_unlock(&share->LOCK_share);
  }

  void release(TABLE *table)
  {
    return release(table->file->has_transactions(), table->s);
  }

  virtual void release(bool can_commit, TABLE_SHARE *share)
  {
    if (can_commit)
      mysql_rwlock_unlock(&commit_lock);
    if (root_size(&root) > mhnsw_cache_size)
Sergei Golubchik's avatar
Sergei Golubchik committed
459
      reset(share);
460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510
    if (--refcnt == 0)
      this->~MHNSW_Context(); // XXX reuse
  }

  FVectorNode *get_node(const void *gref)
  {
    mysql_mutex_lock(&cache_lock);
    FVectorNode *node= node_cache.find(gref, gref_len);
    if (!node)
    {
      node= new (alloc_node_internal()) FVectorNode(this, gref);
      cache_internal(node);
    }
    mysql_mutex_unlock(&cache_lock);
    return node;
  }

  /* used on INSERT, gref isn't known, so cannot cache the node yet */
  void *alloc_node()
  {
    mysql_mutex_lock(&cache_lock);
    auto p= alloc_node_internal();
    mysql_mutex_unlock(&cache_lock);
    return p;
  }

  /* explicitly cache the node after alloc_node() */
  void cache_node(FVectorNode *node)
  {
    mysql_mutex_lock(&cache_lock);
    cache_internal(node);
    mysql_mutex_unlock(&cache_lock);
  }

  /* find the node without creating, only used on merging trx->ctx */
  FVectorNode *find_node(const void *gref)
  {
    mysql_mutex_lock(&cache_lock);
    FVectorNode *node= node_cache.find(gref, gref_len);
    mysql_mutex_unlock(&cache_lock);
    return node;
  }

  void *alloc_neighborhood(size_t max_layer)
  {
    mysql_mutex_lock(&cache_lock);
    auto p= alloc_root(&root, sizeof(Neighborhood)*(max_layer+1) +
             sizeof(FVectorNode*)*(MY_ALIGN(M, 4)*2 + MY_ALIGN(M,8)*max_layer));
    mysql_mutex_unlock(&cache_lock);
    return p;
  }
511
};
512

513 514 515 516 517 518 519 520 521 522 523
/*
  This is a non-shared context that exists within one transaction.

  At the end of the transaction it's either discarded (on rollback)
  or merged into the shared ctx (on commit).

  trx's are stored in thd->ha_data[] in a single-linked list,
  one instance of trx per TABLE_SHARE and allocated on the
  thd->transaction->mem_root
*/
class MHNSW_Trx : public MHNSW_Context
524
{
525 526 527 528 529 530
public:
  TABLE_SHARE *table_share;
  bool list_of_nodes_is_lost= false;
  MHNSW_Trx *next= nullptr;

  MHNSW_Trx(TABLE *table) : MHNSW_Context(table), table_share(table->s) {}
Sergei Golubchik's avatar
Sergei Golubchik committed
531
  void reset(TABLE_SHARE *) override
532 533 534 535 536 537 538 539 540
  {
    node_cache.clear();
    free_root(&root, MYF(0));
    start= 0;
    list_of_nodes_is_lost= true;
  }
  void release(bool, TABLE_SHARE *) override
  {
    if (root_size(&root) > mhnsw_cache_size)
Sergei Golubchik's avatar
Sergei Golubchik committed
541
      reset(nullptr);
542 543
  }

544
  static MHNSW_Trx *get_from_thd(TABLE *table, bool for_update);
545 546 547 548

  // it's okay in a transaction-local cache, there's no concurrent access
  Hash_set<FVectorNode> &get_cache() { return node_cache; }

549 550 551 552
  static transaction_participant tp;
  static int do_commit(THD *thd, bool);
  static int do_savepoint_rollback(THD *thd, void *);
  static int do_rollback(THD *thd, bool);
553 554
};

555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570
struct transaction_participant MHNSW_Trx::tp=
{
  0, 0, 0,
  nullptr,                        /* close_connection */
  [](THD *, void *){ return 0; }, /* savepoint_set */
  MHNSW_Trx::do_savepoint_rollback,
  [](THD *thd){ return true; },   /*savepoint_rollback_can_release_mdl*/
  nullptr,                        /*savepoint_release*/
  MHNSW_Trx::do_commit, MHNSW_Trx::do_rollback,
  nullptr,                        /* prepare */
  nullptr,                        /* recover */
  nullptr, nullptr,               /* commit/rollback_by_xid */
  nullptr, nullptr,               /* recover_rollback_by_xid/recovery_done */
  nullptr, nullptr, nullptr,      /* snapshot, commit/prepare_ordered */
  nullptr, nullptr                /* checkpoint, versioned */
};
571

572
int MHNSW_Trx::do_savepoint_rollback(THD *thd, void *)
573
{
574
  for (auto trx= static_cast<MHNSW_Trx*>(thd_get_ha_data(thd, &tp));
575
       trx; trx= trx->next)
Sergei Golubchik's avatar
Sergei Golubchik committed
576
    trx->reset(nullptr);
577 578 579
  return 0;
}

580
int MHNSW_Trx::do_rollback(THD *thd, bool)
581 582
{
  MHNSW_Trx *trx_next;
583
  for (auto trx= static_cast<MHNSW_Trx*>(thd_get_ha_data(thd, &tp));
584 585 586 587 588
       trx; trx= trx_next)
  {
    trx_next= trx->next;
    trx->~MHNSW_Trx();
  }
589
  thd_set_ha_data(current_thd, &tp, nullptr);
590
  return 0;
591 592
}

593
int MHNSW_Trx::do_commit(THD *thd, bool)
594
{
595
  MHNSW_Trx *trx_next;
596
  for (auto trx= static_cast<MHNSW_Trx*>(thd_get_ha_data(thd, &tp));
597 598 599 600 601 602 603 604
       trx; trx= trx_next)
  {
    trx_next= trx->next;
    auto ctx= MHNSW_Context::get_from_share(trx->table_share, nullptr);
    if (ctx)
    {
      mysql_rwlock_wrlock(&ctx->commit_lock);
      if (trx->list_of_nodes_is_lost)
Sergei Golubchik's avatar
Sergei Golubchik committed
605
        ctx->reset(trx->table_share);
606 607 608 609 610 611 612 613 614 615 616 617 618 619
      else
      {
        // consider copying nodes from trx to shared cache when it makes sense
        // for ann_benchmarks it does not
        // also, consider flushing only changed nodes (a flag in the node)
        for (FVectorNode &from : trx->get_cache())
          if (FVectorNode *node= ctx->find_node(from.gref()))
            node->vec= nullptr;
        ctx->start= nullptr;
      }
      ctx->release(true, trx->table_share);
    }
    trx->~MHNSW_Trx();
  }
620
  thd_set_ha_data(current_thd, &tp, nullptr);
621 622 623
  return 0;
}

624
MHNSW_Trx *MHNSW_Trx::get_from_thd(TABLE *table, bool for_update)
625
{
626 627 628 629 630 631 632 633
  if (!table->file->has_transactions())
      return NULL;

  THD *thd= table->in_use;
  auto trx= static_cast<MHNSW_Trx*>(thd_get_ha_data(thd, &tp));
  if (!for_update && !trx)
    return NULL;

634 635 636 637
  while (trx && trx->table_share != table->s) trx= trx->next;
  if (!trx)
  {
    trx= new (&thd->transaction->mem_root) MHNSW_Trx(table);
638 639
    trx->next= static_cast<MHNSW_Trx*>(thd_get_ha_data(thd, &tp));
    thd_set_ha_data(thd, &tp, trx);
640 641 642
    if (!trx->next)
    {
      bool all= thd_test_options(thd, OPTION_NOT_AUTOCOMMIT | OPTION_BEGIN);
643
      trans_register_ha(thd, all, &tp, 0);
644 645 646 647 648 649 650
    }
  }
  return trx;
}

MHNSW_Context *MHNSW_Context::get_from_share(TABLE_SHARE *share, TABLE *table)
{
651 652
  if (share->tmp_table == NO_TMP_TABLE)
    mysql_mutex_lock(&share->LOCK_share);
653 654 655 656 657 658 659 660 661 662
  auto ctx= static_cast<MHNSW_Context*>(share->hlindex->hlindex_data);
  if (!ctx && table)
  {
    ctx= new (&share->hlindex->mem_root) MHNSW_Context(table);
    if (!ctx) return nullptr;
    share->hlindex->hlindex_data= ctx;
    ctx->refcnt++;
  }
  if (ctx)
    ctx->refcnt++;
663 664
  if (share->tmp_table == NO_TMP_TABLE)
    mysql_mutex_unlock(&share->LOCK_share);
665 666 667 668 669 670 671
  return ctx;
}

int MHNSW_Context::acquire(MHNSW_Context **ctx, TABLE *table, bool for_update)
{
  TABLE *graph= table->hlindex;

672
  if (!(*ctx= MHNSW_Trx::get_from_thd(table, for_update)))
673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690
  {
    *ctx= MHNSW_Context::get_from_share(table->s, table);
    if (table->file->has_transactions())
      mysql_rwlock_rdlock(&(*ctx)->commit_lock);
  }

  if ((*ctx)->start)
    return 0;

  if (int err= graph->file->ha_index_init(IDX_LAYER, 1))
    return err;

  int err= graph->file->ha_index_last(graph->record[0]);
  graph->file->ha_index_end();
  if (err)
    return err;

  graph->file->position(graph->record[0]);
Sergei Golubchik's avatar
Sergei Golubchik committed
691 692
  (*ctx)->set_lengths(FVector::data_to_value_size(graph->field[FIELD_VEC]->value_length()),
                      table->s->min_rows);
693 694 695 696
  (*ctx)->start= (*ctx)->get_node(graph->file->ref);
  return (*ctx)->start->load_from_record(graph);
}

697
/* copy the vector, preprocessed as needed */
Sergei Golubchik's avatar
Sergei Golubchik committed
698 699 700 701 702 703 704 705 706 707 708 709 710 711
const FVector *FVector::create(const MHNSW_Context *ctx, void *mem, const void *src)
{
  const void *vdata= ctx->use_subdist ? alloca(ctx->byte_len) : src;
  Map<const VectorXf> in((const float*)src, ctx->vec_len);
  Map<VectorXf> v((float*)vdata, ctx->vec_len);
  if (ctx->use_subdist)
    v= ctx->randomizer * in;

  FVector *vec= align_ptr(mem);
  float scale= std::max(-v.minCoeff(), v.maxCoeff());
  vec->scale= scale ? scale/32767 : 1;
  for (size_t i= 0; i < ctx->vec_len; i++)
    vec->dims[i] = static_cast<int16_t>(std::round(v(i) / vec->scale));
  vec->postprocess(ctx->use_subdist, ctx->vec_len);
Sergei Golubchik's avatar
Sergei Golubchik committed
712 713 714 715 716 717
  if (ctx->metric == COSINE && vec->abs2)
  {
    vec->scale/= std::sqrt(vec->abs2);
    vec->subabs2/= vec->abs2;
    vec->abs2= 1.0f;
  }
Sergei Golubchik's avatar
Sergei Golubchik committed
718 719 720
  return vec;
}

721
const FVector *FVectorNode::make_vec(const void *v)
722
{
Sergei Golubchik's avatar
Sergei Golubchik committed
723
  return FVector::create(ctx, tref() + tref_len(), v);
724
}
725

726
FVectorNode::FVectorNode(MHNSW_Context *ctx_, const void *gref_)
727
  : ctx(ctx_), stored(true), deleted(false)
728
{
729
  memcpy(gref(), gref_, gref_len());
730
}
731

732 733
FVectorNode::FVectorNode(MHNSW_Context *ctx_, const void *tref_, uint8_t layer,
                         const void *vec_)
734
  : ctx(ctx_), stored(false), deleted(false)
735
{
736
  DBUG_ASSERT(tref_);
737 738 739 740
  memset(gref(), 0xff, gref_len()); // important: larger than any real gref
  memcpy(tref(), tref_, tref_len());
  vec= make_vec(vec_);

741
  alloc_neighborhood(layer);
742
}
743

744
float FVectorNode::distance_to(const FVector *other) const
745
{
746
  return vec->distance_to(other, ctx->vec_len);
747
}
748

Sergei Golubchik's avatar
Sergei Golubchik committed
749 750 751 752 753 754 755
float FVectorNode::distance_greater_than(const FVector *other, float than) const
{
  if (ctx->use_subdist)
    return vec->distance_greater_than(other, ctx->vec_len, than);
  return distance_to(other);
}

756
int FVectorNode::alloc_neighborhood(uint8_t layer)
757
{
758 759
  if (neighbors)
    return 0;
760
  max_layer= layer;
761 762 763 764
  neighbors= (Neighborhood*)ctx->alloc_neighborhood(layer);
  auto ptr= (FVectorNode**)(neighbors + (layer+1));
  for (size_t i= 0; i <= layer; i++)
    ptr= neighbors[i].init(ptr, ctx->max_neighbors(i));
765 766 767
  return 0;
}

768
int FVectorNode::load(TABLE *graph)
769
{
770
  if (likely(vec))
771
    return 0;
772

773 774 775 776 777 778
  DBUG_ASSERT(stored);
  // trx: consider loading nodes from shared, when it makes sense
  // for ann_benchmarks it does not
  if (int err= graph->file->ha_rnd_pos(graph->record[0], gref()))
    return err;
  return load_from_record(graph);
779 780
}

781
int FVectorNode::load_from_record(TABLE *graph)
782
{
783 784 785 786 787 788 789 790
  DBUG_ASSERT(ctx->byte_len);

  uint ticket= ctx->lock_node(this);
  SCOPE_EXIT([this, ticket](){ ctx->unlock_node(ticket); });

  if (vec)
    return 0;

791
  String buf, *v= graph->field[FIELD_TREF]->val_str(&buf);
792 793 794 795 796 797 798
  deleted= graph->field[FIELD_TREF]->is_null();
  if (!deleted)
  {
    if (unlikely(v->length() != tref_len()))
      return my_errno= HA_ERR_CRASHED;
    memcpy(tref(), v->ptr(), v->length());
  }
799 800 801

  v= graph->field[FIELD_VEC]->val_str(&buf);
  if (unlikely(!v))
802
    return my_errno= HA_ERR_CRASHED;
803

804
  if (v->length() != FVector::data_size(ctx->vec_len))
805
    return my_errno= HA_ERR_CRASHED;
806 807
  FVector *vec_ptr= FVector::align_ptr(tref() + tref_len());
  memcpy(vec_ptr->data(), v->ptr(), v->length());
Sergei Golubchik's avatar
Sergei Golubchik committed
808
  vec_ptr->postprocess(ctx->use_subdist, ctx->vec_len);
809 810 811

  longlong layer= graph->field[FIELD_LAYER]->val_int();
  if (layer > 100) // 10e30 nodes at M=2, more at larger M's
812
    return my_errno= HA_ERR_CRASHED;
813

814 815
  if (int err= alloc_neighborhood(static_cast<uint8_t>(layer)))
    return err;
816 817 818

  v= graph->field[FIELD_NEIGHBORS]->val_str(&buf);
  if (unlikely(!v))
819
    return my_errno= HA_ERR_CRASHED;
820 821 822 823 824 825

  // <N> <gref> <gref> ... <N> ...etc...
  uchar *ptr= (uchar*)v->ptr(), *end= ptr + v->length();
  for (size_t i=0; i <= max_layer; i++)
  {
    if (unlikely(ptr >= end))
826
      return my_errno= HA_ERR_CRASHED;
827
    size_t grefs= *ptr++;
828 829 830 831 832
    if (unlikely(ptr + grefs * gref_len() > end))
      return my_errno= HA_ERR_CRASHED;
    neighbors[i].num= grefs;
    for (size_t j=0; j < grefs; j++, ptr+= gref_len())
      neighbors[i].links[j]= ctx->get_node(ptr);
833
  }
834
  vec= vec_ptr; // must be done at the very end
835
  return 0;
836 837
}

838
void FVectorNode::push_neighbor(size_t layer, FVectorNode *other)
839
{
840 841
  DBUG_ASSERT(neighbors[layer].num < ctx->max_neighbors(layer));
  neighbors[layer].links[neighbors[layer].num++]= other;
842
}
843

844 845 846 847
size_t FVectorNode::tref_len() const { return ctx->tref_len; }
size_t FVectorNode::gref_len() const { return ctx->gref_len; }
uchar *FVectorNode::gref() const { return (uchar*)(this+1); }
uchar *FVectorNode::tref() const { return gref() + gref_len(); }
848

849 850
uchar *FVectorNode::get_key(const FVectorNode *elem, size_t *key_len, my_bool)
{
851 852
  *key_len= elem->gref_len();
  return elem->gref();
853 854
}

855 856
/* one visited node during the search. caches the distance to target */
struct Visited : public Sql_alloc
857
{
858 859 860 861
  FVectorNode *node;
  const float distance_to_target;
  Visited(FVectorNode *n, float d) : node(n), distance_to_target(d) {}
  static int cmp(void *, const Visited* a, const Visited *b)
862
  {
863 864
    return a->distance_to_target < b->distance_to_target ? -1 :
           a->distance_to_target > b->distance_to_target ?  1 : 0;
865
  }
866 867 868 869
};

/*
  a factory to create Visited and keep track of already seen nodes
870

871 872 873 874 875 876
  note that PatternedSimdBloomFilter works in blocks of 8 elements,
  so on insert they're accumulated in nodes[], on search the caller
  provides 8 addresses at once. we record 0x0 as "seen" so that
  the caller could pad the input with nullptr's
*/
class VisitedSet
877
{
878 879 880 881 882 883
  MEM_ROOT *root;
  PatternedSimdBloomFilter<FVectorNode> map;
  const FVectorNode *nodes[8]= {0,0,0,0,0,0,0,0};
  size_t idx= 1; // to record 0 in the filter
  public:
  uint count= 0;
Sergei Golubchik's avatar
Sergei Golubchik committed
884 885 886
  VisitedSet(MEM_ROOT *root, uint size) :
    root(root), map(size, 0.01f) {}
  Visited *create(FVectorNode *node, float dist)
887
  {
Sergei Golubchik's avatar
Sergei Golubchik committed
888
    auto *v= new (root) Visited(node, dist);
889 890 891 892 893 894 895 896 897 898 899 900 901 902 903
    insert(node);
    count++;
    return v;
  }
  void insert(const FVectorNode *n)
  {
    nodes[idx++]= n;
    if (idx == 8) flush();
  }
  void flush() {
    if (idx) map.Insert(nodes);
    idx=0;
  }
  uint8_t seen(FVectorNode **nodes) { return map.Query(nodes); }
};
904 905


906 907 908 909 910 911 912 913 914
/*
  selects best neighbors from the list of candidates plus one extra candidate

  one extra candidate is specified separately to avoid appending it to
  the Neighborhood candidates, which might be already at its max size.
*/
static int select_neighbors(MHNSW_Context *ctx, TABLE *graph, size_t layer,
                            FVectorNode &target, const Neighborhood &candidates,
                            FVectorNode *extra_candidate,
915
                            size_t max_neighbor_connections)
916
{
917
  Queue<Visited> pq; // working queue
918

919 920
  if (pq.init(10000, false, Visited::cmp))
    return my_errno= HA_ERR_OUT_OF_MEM;
921

922 923 924 925
  MEM_ROOT * const root= graph->in_use->mem_root;
  auto discarded= (Visited**)my_safe_alloca(sizeof(Visited**)*max_neighbor_connections);
  size_t discarded_num= 0;
  Neighborhood &neighbors= target.neighbors[layer];
926

927
  for (size_t i=0; i < candidates.num; i++)
928
  {
929 930 931
    FVectorNode *node= candidates.links[i];
    if (int err= node->load(graph))
      return err;
932
    pq.push(new (root) Visited(node, node->distance_to(target.vec)));
933
  }
934
  if (extra_candidate)
935
    pq.push(new (root) Visited(extra_candidate, extra_candidate->distance_to(target.vec)));
936 937

  DBUG_ASSERT(pq.elements());
938
  neighbors.num= 0;
939

940
  while (pq.elements() && neighbors.num < max_neighbor_connections)
941
  {
942 943 944
    Visited *vec= pq.pop();
    FVectorNode * const node= vec->node;
    const float target_dista= vec->distance_to_target / alpha;
945
    bool discard= false;
946
    for (size_t i=0; i < neighbors.num; i++)
Sergei Golubchik's avatar
Sergei Golubchik committed
947
      if ((discard= node->distance_greater_than(neighbors.links[i]->vec, target_dista) < target_dista))
948
        break;
949
    if (!discard)
950 951 952
      target.push_neighbor(layer, node);
    else if (discarded_num + neighbors.num < max_neighbor_connections)
      discarded[discarded_num++]= vec;
953 954
  }

955 956
  for (size_t i=0; i < discarded_num && neighbors.num < max_neighbor_connections; i++)
    target.push_neighbor(layer, discarded[i]->node);
957

958
  my_safe_afree(discarded, sizeof(Visited**)*max_neighbor_connections);
Sergei Golubchik's avatar
Sergei Golubchik committed
959
  return 0;
960 961
}

Sergei Golubchik's avatar
Sergei Golubchik committed
962

963
int FVectorNode::save(TABLE *graph)
964
{
965 966
  DBUG_ASSERT(vec);
  DBUG_ASSERT(neighbors);
967

968
  graph->field[FIELD_LAYER]->store(max_layer, false);
969 970 971 972 973 974 975
  if (deleted)
    graph->field[FIELD_TREF]->set_null();
  else
  {
    graph->field[FIELD_TREF]->set_notnull();
    graph->field[FIELD_TREF]->store_binary(tref(), tref_len());
  }
976
  graph->field[FIELD_VEC]->store_binary(vec->data(), FVector::data_size(ctx->vec_len));
977

978 979
  size_t total_size= 0;
  for (size_t i=0; i <= max_layer; i++)
980
    total_size+= 1 + gref_len() * neighbors[i].num;
981

982 983 984
  uchar *neighbor_blob= static_cast<uchar *>(my_safe_alloca(total_size));
  uchar *ptr= neighbor_blob;
  for (size_t i= 0; i <= max_layer; i++)
985
  {
986 987 988
    *ptr++= (uchar)(neighbors[i].num);
    for (size_t j= 0; j < neighbors[i].num; j++, ptr+= gref_len())
      memcpy(ptr, neighbors[i].links[j]->gref(), gref_len());
989
  }
990
  graph->field[FIELD_NEIGHBORS]->store_binary(neighbor_blob, total_size);
991

992 993
  int err;
  if (stored)
994
  {
995
    if (!(err= graph->file->ha_rnd_pos(graph->record[1], gref())))
996
    {
997 998 999
      err= graph->file->ha_update_row(graph->record[1], graph->record[0]);
      if (err == HA_ERR_RECORD_IS_THE_SAME)
        err= 0;
1000
    }
1001
  }
1002 1003
  else
  {
1004
    err= graph->file->ha_write_row(graph->record[0]);
1005
    graph->file->position(graph->record[0]);
1006 1007 1008
    memcpy(gref(), graph->file->ref, gref_len());
    stored= true;
    ctx->cache_node(this);
1009 1010
  }
  my_safe_afree(neighbor_blob, total_size);
1011
  return err;
1012 1013
}

1014 1015
static int update_second_degree_neighbors(MHNSW_Context *ctx, TABLE *graph,
                                          size_t layer, FVectorNode *node)
1016
{
1017 1018 1019 1020
  const uint max_neighbors= ctx->max_neighbors(layer);
  // it seems that one could update nodes in the gref order
  // to avoid InnoDB deadlocks, but it produces no noticeable effect
  for (size_t i=0; i < node->neighbors[layer].num; i++)
1021
  {
1022 1023 1024 1025 1026 1027 1028 1029 1030 1031
    FVectorNode *neigh= node->neighbors[layer].links[i];
    Neighborhood &neighneighbors= neigh->neighbors[layer];
    if (neighneighbors.num < max_neighbors)
      neigh->push_neighbor(layer, node);
    else
      if (int err= select_neighbors(ctx, graph, layer, *neigh, neighneighbors,
                                    node, max_neighbors))
        return err;
    if (int err= neigh->save(graph))
      return err;
1032
  }
1033
  return 0;
1034 1035
}

1036
static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target,
Sergei Golubchik's avatar
Sergei Golubchik committed
1037 1038
                        Neighborhood *start_nodes, uint result_size,
                        size_t layer, Neighborhood *result, bool construction)
1039
{
1040 1041 1042 1043
  DBUG_ASSERT(start_nodes->num > 0);
  result->num= 0;

  MEM_ROOT * const root= graph->in_use->mem_root;
Sergei Golubchik's avatar
Sergei Golubchik committed
1044 1045 1046
  Queue<Visited> candidates, best;
  bool skip_deleted;
  uint ef= result_size;
1047

Sergei Golubchik's avatar
Sergei Golubchik committed
1048 1049 1050 1051 1052 1053 1054 1055 1056 1057
  if (construction)
  {
    skip_deleted= false;
    if (ef > 1)
      ef= std::max(ef_construction, ef);
  }
  else
  {
    skip_deleted= layer == 0;
    if (ef > 1 || layer == 0)
1058
      ef= std::max(THDVAR(graph->in_use, min_limit), ef);
Sergei Golubchik's avatar
Sergei Golubchik committed
1059
  }
1060

1061 1062 1063
  // WARNING! heuristic here
  const double est_heuristic= 8 * std::sqrt(ctx->max_neighbors(layer));
  const uint est_size= static_cast<uint>(est_heuristic * std::pow(ef, ctx->ef_power));
Sergei Golubchik's avatar
Sergei Golubchik committed
1064
  VisitedSet visited(root, est_size);
1065

1066 1067 1068
  candidates.init(10000, false, Visited::cmp);
  best.init(ef, true, Visited::cmp);

Sergei Golubchik's avatar
Sergei Golubchik committed
1069
  DBUG_ASSERT(start_nodes->num <= result_size);
1070
  for (size_t i=0; i < start_nodes->num; i++)
1071
  {
Sergei Golubchik's avatar
Sergei Golubchik committed
1072 1073
    auto node= start_nodes->links[i];
    Visited *v= visited.create(node, node->distance_to(target));
1074
    candidates.push(v);
1075 1076
    if (skip_deleted && v->node->deleted)
      continue;
Sergei Golubchik's avatar
Sergei Golubchik committed
1077
    best.push(v);
1078 1079
  }

1080 1081
  float furthest_best= best.is_empty() ? FLT_MAX
                       : best.top()->distance_to_target * generosity;
1082 1083
  while (candidates.elements())
  {
1084
    const Visited &cur= *candidates.pop();
Sergei Golubchik's avatar
Sergei Golubchik committed
1085
    if (cur.distance_to_target > furthest_best && best.is_full())
1086 1087 1088
      break; // All possible candidates are worse than what we have

    visited.flush();
1089

1090 1091 1092
    Neighborhood &neighbors= cur.node->neighbors[layer];
    FVectorNode **links= neighbors.links, **end= links + neighbors.num;
    for (; links < end; links+= 8)
1093
    {
1094 1095
      uint8_t res= visited.seen(links);
      if (res == 0xff)
1096 1097
        continue;

1098
      for (size_t i= 0; i < 8; i++)
1099
      {
1100 1101 1102 1103
        if (res & (1 << i))
          continue;
        if (int err= links[i]->load(graph))
          return err;
Sergei Golubchik's avatar
Sergei Golubchik committed
1104
        if (!best.is_full())
1105
        {
Sergei Golubchik's avatar
Sergei Golubchik committed
1106
          Visited *v= visited.create(links[i], links[i]->distance_to(target));
1107
          candidates.push(v);
1108 1109
          if (skip_deleted && v->node->deleted)
            continue;
1110
          best.push(v);
1111
          furthest_best= best.top()->distance_to_target * generosity;
1112
        }
Sergei Golubchik's avatar
Sergei Golubchik committed
1113
        else
1114
        {
Sergei Golubchik's avatar
Sergei Golubchik committed
1115 1116
          Visited *v= visited.create(links[i], links[i]->distance_greater_than(target, furthest_best));
          if (v->distance_to_target < furthest_best)
1117
          {
Sergei Golubchik's avatar
Sergei Golubchik committed
1118 1119 1120 1121 1122 1123 1124 1125
            candidates.safe_push(v);
            if (skip_deleted && v->node->deleted)
              continue;
            if (v->distance_to_target < best.top()->distance_to_target)
            {
              best.replace_top(v);
              furthest_best= best.top()->distance_to_target * generosity;
            }
1126
          }
1127
        }
1128 1129 1130
      }
    }
  }
1131 1132 1133 1134 1135
  if (ef > 1 && visited.count*2 > est_size)
  {
    double ef_power= std::log(visited.count*2/est_heuristic) / std::log(ef);
    set_if_bigger(ctx->ef_power, ef_power); // not atomic, but it's ok
  }
1136

Sergei Golubchik's avatar
Sergei Golubchik committed
1137 1138 1139
  while (best.elements() > result_size)
    best.pop();

1140 1141 1142
  result->num= best.elements();
  for (FVectorNode **links= result->links + result->num; best.elements();)
    *--links= best.pop()->node;
1143

1144
  return 0;
1145 1146
}

1147

Sergei Golubchik's avatar
Sergei Golubchik committed
1148 1149 1150 1151 1152
static int bad_value_on_insert(Field *f)
{
  my_error(ER_TRUNCATED_WRONG_VALUE_FOR_FIELD, MYF(0), "vector", "...",
           f->table->s->db.str, f->table->s->table_name.str, f->field_name.str,
           f->table->in_use->get_stmt_da()->current_row_for_warning());
1153
  return my_errno= HA_ERR_GENERIC;
Sergei Golubchik's avatar
Sergei Golubchik committed
1154 1155
}

1156

1157 1158
int mhnsw_insert(TABLE *table, KEY *keyinfo)
{
Sergei Golubchik's avatar
Sergei Golubchik committed
1159
  THD *thd= table->in_use;
1160 1161
  TABLE *graph= table->hlindex;
  MY_BITMAP *old_map= dbug_tmp_use_all_columns(table, &table->read_set);
1162 1163
  Field *vec_field= keyinfo->key_part->field;
  String buf, *res= vec_field->val_str(&buf);
1164
  MHNSW_Context *ctx;
1165 1166 1167

  /* metadata are checked on open */
  DBUG_ASSERT(graph);
1168
  DBUG_ASSERT(keyinfo->algorithm == HA_KEY_ALG_VECTOR);
1169
  DBUG_ASSERT(keyinfo->usable_key_parts == 1);
1170 1171
  DBUG_ASSERT(vec_field->binary());
  DBUG_ASSERT(vec_field->cmp_type() == STRING_RESULT);
1172
  DBUG_ASSERT(res); // ER_INDEX_CANNOT_HAVE_NULL
1173
  DBUG_ASSERT(table->file->ref_length <= graph->field[FIELD_TREF]->field_length);
1174

Sergei Golubchik's avatar
Sergei Golubchik committed
1175 1176 1177 1178
  // XXX returning an error here will rollback the insert in InnoDB
  // but in MyISAM the row will stay inserted, making the index out of sync:
  // invalid vector values are present in the table but cannot be found
  // via an index. The easiest way to fix it is with a VECTOR(N) type
1179
  if (res->length() == 0 || res->length() % 4)
Sergei Golubchik's avatar
Sergei Golubchik committed
1180
    return bad_value_on_insert(vec_field);
1181

1182 1183
  table->file->position(table->record[0]);

1184 1185 1186
  int err= MHNSW_Context::acquire(&ctx, table, true);
  SCOPE_EXIT([ctx, table](){ ctx->release(table); });
  if (err)
1187
  {
1188 1189
    if (err != HA_ERR_END_OF_FILE)
      return err;
Sergei Golubchik's avatar
Sergei Golubchik committed
1190

1191
    // First insert!
Sergei Golubchik's avatar
Sergei Golubchik committed
1192
    ctx->set_lengths(res->length(), table->s->min_rows);
1193 1194 1195 1196 1197
    FVectorNode *target= new (ctx->alloc_node())
                   FVectorNode(ctx, table->file->ref, 0, res->ptr());
    if (!((err= target->save(graph))))
      ctx->start= target;
    return err;
1198 1199
  }

1200 1201
  if (ctx->byte_len != res->length())
    return bad_value_on_insert(vec_field);
Sergei Golubchik's avatar
Sergei Golubchik committed
1202

Sergei Golubchik's avatar
Sergei Golubchik committed
1203
  const size_t max_found= ctx->max_neighbors(0);
1204
  Neighborhood candidates, start_nodes;
Sergei Golubchik's avatar
Sergei Golubchik committed
1205 1206
  candidates.init(thd->alloc<FVectorNode*>(max_found + 7), max_found);
  start_nodes.init(thd->alloc<FVectorNode*>(max_found + 7), max_found);
1207
  start_nodes.links[start_nodes.num++]= ctx->start;
1208

1209 1210 1211 1212 1213
  const double NORMALIZATION_FACTOR= 1 / std::log(ctx->M);
  double log= -std::log(my_rnd(&thd->rand)) * NORMALIZATION_FACTOR;
  const uint8_t max_layer= start_nodes.links[0]->max_layer;
  uint8_t target_layer= std::min<uint8_t>(static_cast<uint8_t>(std::floor(log)), max_layer + 1);
  int cur_layer;
Sergei Golubchik's avatar
Sergei Golubchik committed
1214

1215 1216
  FVectorNode *target= new (ctx->alloc_node())
                 FVectorNode(ctx, table->file->ref, target_layer, res->ptr());
Sergei Golubchik's avatar
Sergei Golubchik committed
1217

1218 1219 1220
  if (int err= graph->file->ha_rnd_init(0))
    return err;
  SCOPE_EXIT([graph](){ graph->file->ha_rnd_end(); });
Sergei Golubchik's avatar
Sergei Golubchik committed
1221

1222
  for (cur_layer= max_layer; cur_layer > target_layer; cur_layer--)
1223
  {
1224 1225
    if (int err= search_layer(ctx, graph, target->vec, &start_nodes, 1,
                              cur_layer, &candidates, false))
1226 1227
      return err;
    std::swap(start_nodes, candidates);
1228
  }
1229

1230
  for (; cur_layer >= 0; cur_layer--)
1231
  {
1232
    uint max_neighbors= ctx->max_neighbors(cur_layer);
1233
    if (int err= search_layer(ctx, graph, target->vec, &start_nodes,
Sergei Golubchik's avatar
Sergei Golubchik committed
1234
                              max_neighbors, cur_layer, &candidates, true))
1235 1236 1237 1238 1239 1240
      return err;

    if (int err= select_neighbors(ctx, graph, cur_layer, *target, candidates,
                                  0, max_neighbors))
      return err;
    std::swap(start_nodes, candidates);
1241
  }
1242

1243 1244
  if (int err= target->save(graph))
    return err;
1245

1246 1247 1248 1249
  if (target_layer > max_layer)
    ctx->start= target;

  for (cur_layer= target_layer; cur_layer >= 0; cur_layer--)
1250
  {
1251 1252
    if (int err= update_second_degree_neighbors(ctx, graph, cur_layer, target))
      return err;
1253 1254
  }

1255 1256
  dbug_tmp_restore_column_map(&table->read_set, old_map);

Sergei Golubchik's avatar
Sergei Golubchik committed
1257
  return 0;
1258 1259
}

1260

1261 1262
int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
{
Sergei Golubchik's avatar
Sergei Golubchik committed
1263
  THD *thd= table->in_use;
1264
  TABLE *graph= table->hlindex;
Sergei Golubchik's avatar
Sergei Golubchik committed
1265 1266 1267
  auto *fun= static_cast<Item_func_vec_distance_common*>(dist);
  DBUG_ASSERT(fun);

Sergei Golubchik's avatar
Sergei Golubchik committed
1268
  String buf, *res= fun->get_const_arg()->val_str(&buf);
1269
  MHNSW_Context *ctx;
1270

1271
  if (int err= table->file->ha_rnd_init(0))
1272
    return err;
1273

1274
  if (int err= MHNSW_Context::acquire(&ctx, table, false))
1275
    return err;
1276
  SCOPE_EXIT([ctx, table](){ ctx->release(table); });
1277

1278
  Neighborhood candidates, start_nodes;
Sergei Golubchik's avatar
Sergei Golubchik committed
1279 1280
  candidates.init(thd->alloc<FVectorNode*>(limit + 7), limit);
  start_nodes.init(thd->alloc<FVectorNode*>(limit + 7), limit);
Sergei Golubchik's avatar
Sergei Golubchik committed
1281

Sergei Golubchik's avatar
Sergei Golubchik committed
1282 1283
  // one could put all max_layer nodes in start_nodes
  // but it has no effect of the recall or speed
1284
  start_nodes.links[start_nodes.num++]= ctx->start;
Sergei Golubchik's avatar
Sergei Golubchik committed
1285 1286 1287 1288

  /*
    if the query vector is NULL or invalid, VEC_DISTANCE will return
    NULL, so the result is basically unsorted, we can return rows
1289
    in any order. Let's use some hardcoded value here
Sergei Golubchik's avatar
Sergei Golubchik committed
1290
  */
1291
  if (!res || ctx->byte_len != res->length())
1292 1293 1294 1295 1296 1297
  {
    res= &buf;
    buf.alloc(ctx->byte_len);
    for (size_t i=0; i < ctx->vec_len; i++)
      ((float*)buf.ptr())[i]= i == 0;
  }
1298 1299

  const longlong max_layer= start_nodes.links[0]->max_layer;
Sergei Golubchik's avatar
Sergei Golubchik committed
1300 1301
  auto target= FVector::create(ctx, thd->alloc(FVector::alloc_size(ctx->vec_len)),
                               res->ptr());
1302 1303 1304 1305

  if (int err= graph->file->ha_rnd_init(0))
    return err;
  SCOPE_EXIT([graph](){ graph->file->ha_rnd_end(); });
Sergei Golubchik's avatar
Sergei Golubchik committed
1306

1307 1308
  for (size_t cur_layer= max_layer; cur_layer > 0; cur_layer--)
  {
1309
    if (int err= search_layer(ctx, graph, target, &start_nodes, 1, cur_layer,
1310
                              &candidates, false))
1311 1312
      return err;
    std::swap(start_nodes, candidates);
1313 1314
  }

Sergei Golubchik's avatar
Sergei Golubchik committed
1315 1316
  if (int err= search_layer(ctx, graph, target, &start_nodes,
                            static_cast<uint>(limit), 0, &candidates, false))
1317
    return err;
1318

1319 1320 1321
  if (limit > candidates.num)
    limit= candidates.num;
  size_t context_size=limit * ctx->tref_len + sizeof(ulonglong);
1322
  char *context= thd->alloc(context_size);
1323 1324
  graph->context= context;

1325 1326 1327
  *(ulonglong*)context= limit;
  context+= context_size;

1328
  for (size_t i=0; limit--; i++)
1329
  {
1330 1331
    context-= ctx->tref_len;
    memcpy(context, candidates.links[i]->tref(), ctx->tref_len);
1332 1333
  }
  DBUG_ASSERT(context - sizeof(ulonglong) == graph->context);
1334

Sergei Golubchik's avatar
Sergei Golubchik committed
1335
  return mhnsw_next(table);
1336 1337 1338 1339
}

int mhnsw_next(TABLE *table)
{
1340 1341
  uchar *ref= (uchar*)(table->hlindex->context);
  if (ulonglong *limit= (ulonglong*)ref)
1342
  {
1343 1344
    ref+= sizeof(ulonglong) + (--*limit) * table->file->ref_length;
    return table->file->ha_rnd_pos(table->record[0], ref);
1345
  }
1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356
  return my_errno= HA_ERR_END_OF_FILE;
}

void mhnsw_free(TABLE_SHARE *share)
{
  TABLE_SHARE *graph_share= share->hlindex;
  if (!graph_share->hlindex_data)
    return;

  static_cast<MHNSW_Context*>(graph_share->hlindex_data)->~MHNSW_Context();
  graph_share->hlindex_data= 0;
1357
}
1358

Sergei Golubchik's avatar
Sergei Golubchik committed
1359
int mhnsw_invalidate(TABLE *table, const uchar *rec, KEY *keyinfo)
1360 1361 1362
{
  TABLE *graph= table->hlindex;
  handler *h= table->file;
Sergei Golubchik's avatar
Sergei Golubchik committed
1363 1364
  MHNSW_Context *ctx;
  bool use_ctx= !MHNSW_Context::acquire(&ctx, table, true);
1365 1366 1367 1368 1369

  /* metadata are checked on open */
  DBUG_ASSERT(graph);
  DBUG_ASSERT(keyinfo->algorithm == HA_KEY_ALG_VECTOR);
  DBUG_ASSERT(keyinfo->usable_key_parts == 1);
Sergei Golubchik's avatar
Sergei Golubchik committed
1370
  DBUG_ASSERT(h->ref_length <= graph->field[FIELD_TREF]->field_length);
1371 1372 1373 1374

  // target record:
  h->position(rec);
  graph->field[FIELD_TREF]->set_notnull();
Sergei Golubchik's avatar
Sergei Golubchik committed
1375
  graph->field[FIELD_TREF]->store_binary(h->ref, h->ref_length);
1376

Sergei Golubchik's avatar
Sergei Golubchik committed
1377 1378 1379
  uchar *key= (uchar*)alloca(graph->key_info[IDX_TREF].key_length);
  key_copy(key, graph->record[0], &graph->key_info[IDX_TREF],
           graph->key_info[IDX_TREF].key_length);
1380

Sergei Golubchik's avatar
Sergei Golubchik committed
1381 1382 1383 1384 1385 1386 1387 1388
  if (int err= graph->file->ha_index_read_idx_map(graph->record[1], IDX_TREF,
                                        key, HA_WHOLE_KEY, HA_READ_KEY_EXACT))
   return err;

  restore_record(graph, record[1]);
  graph->field[FIELD_TREF]->set_null();
  if (int err= graph->file->ha_update_row(graph->record[1], graph->record[0]))
    return err;
1389

Sergei Golubchik's avatar
Sergei Golubchik committed
1390
  if (use_ctx)
1391
  {
Sergei Golubchik's avatar
Sergei Golubchik committed
1392 1393 1394 1395
    graph->file->position(graph->record[0]);
    FVectorNode *node= ctx->get_node(graph->file->ref);
    node->deleted= true;
    ctx->release(table);
1396 1397
  }

Sergei Golubchik's avatar
Sergei Golubchik committed
1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408
  return 0;
}

int mhnsw_delete_all(TABLE *table, KEY *keyinfo)
{
  TABLE *graph= table->hlindex;

  /* metadata are checked on open */
  DBUG_ASSERT(graph);
  DBUG_ASSERT(keyinfo->algorithm == HA_KEY_ALG_VECTOR);
  DBUG_ASSERT(keyinfo->usable_key_parts == 1);
1409

Sergei Golubchik's avatar
Sergei Golubchik committed
1410 1411 1412 1413 1414 1415 1416 1417 1418
  if (int err= graph->file->ha_delete_all_rows())
   return err;

  MHNSW_Context *ctx;
  if (!MHNSW_Context::acquire(&ctx, table, true))
  {
    ctx->reset(table->s);
    ctx->release(table);
  }
1419 1420 1421 1422

  return 0;
}

1423 1424 1425
const LEX_CSTRING mhnsw_hlindex_table_def(THD *thd, uint ref_length)
{
  const char templ[]="CREATE TABLE i (                   "
1426 1427 1428 1429
                     "  layer tinyint not null,          "
                     "  tref varbinary(%u),              "
                     "  vec blob not null,               "
                     "  neighbors blob not null,         "
Sergei Golubchik's avatar
Sergei Golubchik committed
1430 1431
                     "  unique (tref),                   "
                     "  key (layer))                     ";
1432 1433
  size_t len= sizeof(templ) + 32;
  char *s= thd->alloc(len);
1434
  len= my_snprintf(s, len, templ, ref_length);
1435 1436
  return {s, len};
}
1437

Sergei Golubchik's avatar
Sergei Golubchik committed
1438 1439 1440 1441 1442 1443 1444
bool mhnsw_uses_distance(const TABLE *table, KEY *keyinfo, const Item *dist)
{
  if (keyinfo->option_struct->metric == EUCLIDEAN)
    return dynamic_cast<const Item_func_vec_distance_euclidean*>(dist) != NULL;
  return dynamic_cast<const Item_func_vec_distance_cosine*>(dist) != NULL;
}

1445 1446 1447 1448 1449 1450 1451
/*
  Declare the plugin and index options
*/

ha_create_table_option mhnsw_index_options[]=
{
  HA_IOPTION_SYSVAR("max_edges_per_node", M, max_edges_per_node),
Sergei Golubchik's avatar
Sergei Golubchik committed
1452
  HA_IOPTION_SYSVAR("distance_function", metric, distance_function),
1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480
  HA_IOPTION_END
};

st_plugin_int *mhnsw_plugin;

static int mhnsw_init(void *p)
{
  mhnsw_plugin= (st_plugin_int *)p;
  mhnsw_plugin->data= &MHNSW_Trx::tp;
  if (setup_transaction_participant(mhnsw_plugin))
    return 1;

  return resolve_sysvar_table_options(mhnsw_index_options);
}

static int mhnsw_deinit(void *)
{
  free_sysvar_table_options(mhnsw_index_options);
  return 0;
}

static struct st_mysql_storage_engine mhnsw_daemon=
{ MYSQL_DAEMON_INTERFACE_VERSION };

static struct st_mysql_sys_var *mhnsw_sys_vars[]=
{
  MYSQL_SYSVAR(cache_size),
  MYSQL_SYSVAR(max_edges_per_node),
Sergei Golubchik's avatar
Sergei Golubchik committed
1481
  MYSQL_SYSVAR(distance_function),
1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494
  MYSQL_SYSVAR(min_limit),
  NULL
};

maria_declare_plugin(mhnsw)
{
  MYSQL_DAEMON_PLUGIN,
  &mhnsw_daemon, "mhnsw", "MariaDB plc",
  "A plugin for mhnsw vector index algorithm",
  PLUGIN_LICENSE_GPL, mhnsw_init, mhnsw_deinit, 0x0100, NULL,
  mhnsw_sys_vars, "1.0", MariaDB_PLUGIN_MATURITY_STABLE
}
maria_declare_plugin_end;