faiss 0.3.1 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.h +1 -1
  5. data/vendor/faiss/faiss/Clustering.cpp +35 -4
  6. data/vendor/faiss/faiss/Clustering.h +10 -1
  7. data/vendor/faiss/faiss/IVFlib.cpp +4 -1
  8. data/vendor/faiss/faiss/Index.h +21 -6
  9. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  10. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -1
  11. data/vendor/faiss/faiss/IndexFastScan.cpp +22 -4
  12. data/vendor/faiss/faiss/IndexFlat.cpp +11 -7
  13. data/vendor/faiss/faiss/IndexFlatCodes.cpp +159 -5
  14. data/vendor/faiss/faiss/IndexFlatCodes.h +20 -3
  15. data/vendor/faiss/faiss/IndexHNSW.cpp +143 -90
  16. data/vendor/faiss/faiss/IndexHNSW.h +52 -3
  17. data/vendor/faiss/faiss/IndexIVF.cpp +3 -3
  18. data/vendor/faiss/faiss/IndexIVF.h +9 -1
  19. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +15 -0
  20. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -0
  21. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +130 -57
  22. data/vendor/faiss/faiss/IndexIVFFastScan.h +14 -7
  23. data/vendor/faiss/faiss/IndexIVFPQ.cpp +1 -3
  24. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +21 -2
  25. data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
  26. data/vendor/faiss/faiss/IndexLattice.h +3 -22
  27. data/vendor/faiss/faiss/IndexNNDescent.cpp +0 -29
  28. data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
  29. data/vendor/faiss/faiss/IndexNSG.h +1 -1
  30. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  31. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  32. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  33. data/vendor/faiss/faiss/IndexRefine.cpp +5 -5
  34. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +3 -1
  35. data/vendor/faiss/faiss/MetricType.h +7 -2
  36. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
  37. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
  38. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
  39. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
  40. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +36 -4
  41. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +6 -0
  42. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
  43. data/vendor/faiss/faiss/gpu/GpuIndex.h +2 -8
  44. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
  45. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +6 -0
  46. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +2 -0
  47. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +25 -0
  48. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
  49. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +6 -0
  50. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
  51. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +65 -0
  52. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  53. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  54. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
  55. data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
  56. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +25 -0
  57. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +9 -1
  58. data/vendor/faiss/faiss/impl/DistanceComputer.h +46 -0
  59. data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
  60. data/vendor/faiss/faiss/impl/HNSW.cpp +358 -190
  61. data/vendor/faiss/faiss/impl/HNSW.h +43 -22
  62. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +8 -8
  63. data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
  64. data/vendor/faiss/faiss/impl/NNDescent.cpp +13 -8
  65. data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
  66. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +1 -0
  67. data/vendor/faiss/faiss/impl/ProductQuantizer.h +5 -1
  68. data/vendor/faiss/faiss/impl/ResultHandler.h +151 -32
  69. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +719 -102
  70. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -0
  71. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +5 -0
  72. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  73. data/vendor/faiss/faiss/impl/index_read.cpp +29 -15
  74. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  75. data/vendor/faiss/faiss/impl/index_write.cpp +28 -10
  76. data/vendor/faiss/faiss/impl/io.cpp +13 -5
  77. data/vendor/faiss/faiss/impl/io.h +4 -4
  78. data/vendor/faiss/faiss/impl/io_macros.h +6 -0
  79. data/vendor/faiss/faiss/impl/platform_macros.h +22 -0
  80. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +11 -0
  81. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +1 -1
  82. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +448 -1
  83. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +5 -5
  84. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
  85. data/vendor/faiss/faiss/impl/simd_result_handlers.h +143 -59
  86. data/vendor/faiss/faiss/index_factory.cpp +31 -13
  87. data/vendor/faiss/faiss/index_io.h +12 -5
  88. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
  89. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
  90. data/vendor/faiss/faiss/invlists/DirectMap.cpp +9 -1
  91. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +55 -17
  92. data/vendor/faiss/faiss/invlists/InvertedLists.h +18 -9
  93. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +21 -6
  94. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
  95. data/vendor/faiss/faiss/python/python_callbacks.cpp +3 -3
  96. data/vendor/faiss/faiss/utils/Heap.h +105 -0
  97. data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
  98. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  99. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  100. data/vendor/faiss/faiss/utils/distances.cpp +58 -88
  101. data/vendor/faiss/faiss/utils/distances.h +5 -5
  102. data/vendor/faiss/faiss/utils/distances_simd.cpp +997 -9
  103. data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
  104. data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
  105. data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
  106. data/vendor/faiss/faiss/utils/hamming.cpp +1 -1
  107. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +4 -1
  108. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +2 -1
  109. data/vendor/faiss/faiss/utils/random.cpp +43 -0
  110. data/vendor/faiss/faiss/utils/random.h +25 -0
  111. data/vendor/faiss/faiss/utils/simdlib.h +10 -1
  112. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  113. data/vendor/faiss/faiss/utils/simdlib_neon.h +5 -2
  114. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  115. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  116. data/vendor/faiss/faiss/utils/utils.cpp +10 -3
  117. data/vendor/faiss/faiss/utils/utils.h +3 -0
  118. metadata +16 -4
  119. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -7,6 +7,7 @@
7
7
 
8
8
  #include <faiss/impl/HNSW.h>
9
9
 
10
+ #include <cstddef>
10
11
  #include <string>
11
12
 
12
13
  #include <faiss/impl/AuxIndexStructures.h>
@@ -110,8 +111,8 @@ void HNSW::print_neighbor_stats(int level) const {
110
111
  level,
111
112
  nb_neighbors(level));
112
113
  size_t tot_neigh = 0, tot_common = 0, tot_reciprocal = 0, n_node = 0;
113
- #pragma omp parallel for reduction(+: tot_neigh) reduction(+: tot_common) \
114
- reduction(+: tot_reciprocal) reduction(+: n_node)
114
+ #pragma omp parallel for reduction(+ : tot_neigh) reduction(+ : tot_common) \
115
+ reduction(+ : tot_reciprocal) reduction(+ : n_node)
115
116
  for (int i = 0; i < levels.size(); i++) {
116
117
  if (levels[i] > level) {
117
118
  n_node++;
@@ -215,8 +216,8 @@ int HNSW::prepare_level_tab(size_t n, bool preset_levels) {
215
216
  if (pt_level > max_level)
216
217
  max_level = pt_level;
217
218
  offsets.push_back(offsets.back() + cum_nb_neighbors(pt_level + 1));
218
- neighbors.resize(offsets.back(), -1);
219
219
  }
220
+ neighbors.resize(offsets.back(), -1);
220
221
 
221
222
  return max_level;
222
223
  }
@@ -229,7 +230,14 @@ void HNSW::shrink_neighbor_list(
229
230
  DistanceComputer& qdis,
230
231
  std::priority_queue<NodeDistFarther>& input,
231
232
  std::vector<NodeDistFarther>& output,
232
- int max_size) {
233
+ int max_size,
234
+ bool keep_max_size_level0) {
235
+ // This prevents number of neighbors at
236
+ // level 0 from being shrunk to less than 2 * M.
237
+ // This is essential in making sure
238
+ // `faiss::gpu::GpuIndexCagra::copyFrom(IndexHNSWCagra*)` is functional
239
+ std::vector<NodeDistFarther> outsiders;
240
+
233
241
  while (input.size() > 0) {
234
242
  NodeDistFarther v1 = input.top();
235
243
  input.pop();
@@ -250,8 +258,15 @@ void HNSW::shrink_neighbor_list(
250
258
  if (output.size() >= max_size) {
251
259
  return;
252
260
  }
261
+ } else if (keep_max_size_level0) {
262
+ outsiders.push_back(v1);
253
263
  }
254
264
  }
265
+ size_t idx = 0;
266
+ while (keep_max_size_level0 && (output.size() < max_size) &&
267
+ (idx < outsiders.size())) {
268
+ output.push_back(outsiders[idx++]);
269
+ }
255
270
  }
256
271
 
257
272
  namespace {
@@ -268,7 +283,8 @@ using NodeDistFarther = HNSW::NodeDistFarther;
268
283
  void shrink_neighbor_list(
269
284
  DistanceComputer& qdis,
270
285
  std::priority_queue<NodeDistCloser>& resultSet1,
271
- int max_size) {
286
+ int max_size,
287
+ bool keep_max_size_level0 = false) {
272
288
  if (resultSet1.size() < max_size) {
273
289
  return;
274
290
  }
@@ -280,7 +296,8 @@ void shrink_neighbor_list(
280
296
  resultSet1.pop();
281
297
  }
282
298
 
283
- HNSW::shrink_neighbor_list(qdis, resultSet, returnlist, max_size);
299
+ HNSW::shrink_neighbor_list(
300
+ qdis, resultSet, returnlist, max_size, keep_max_size_level0);
284
301
 
285
302
  for (NodeDistFarther curen2 : returnlist) {
286
303
  resultSet1.emplace(curen2.d, curen2.id);
@@ -294,7 +311,8 @@ void add_link(
294
311
  DistanceComputer& qdis,
295
312
  storage_idx_t src,
296
313
  storage_idx_t dest,
297
- int level) {
314
+ int level,
315
+ bool keep_max_size_level0 = false) {
298
316
  size_t begin, end;
299
317
  hnsw.neighbor_range(src, level, &begin, &end);
300
318
  if (hnsw.neighbors[end - 1] == -1) {
@@ -319,7 +337,7 @@ void add_link(
319
337
  resultSet.emplace(qdis.symmetric_dis(src, neigh), neigh);
320
338
  }
321
339
 
322
- shrink_neighbor_list(qdis, resultSet, end - begin);
340
+ shrink_neighbor_list(qdis, resultSet, end - begin, keep_max_size_level0);
323
341
 
324
342
  // ...and back
325
343
  size_t i = begin;
@@ -342,6 +360,9 @@ void search_neighbors_to_add(
342
360
  float d_entry_point,
343
361
  int level,
344
362
  VisitedTable& vt) {
363
+ // selects a version
364
+ const bool reference_version = false;
365
+
345
366
  // top is nearest candidate
346
367
  std::priority_queue<NodeDistFarther> candidates;
347
368
 
@@ -363,59 +384,90 @@ void search_neighbors_to_add(
363
384
  // loop over neighbors
364
385
  size_t begin, end;
365
386
  hnsw.neighbor_range(currNode, level, &begin, &end);
366
- for (size_t i = begin; i < end; i++) {
367
- storage_idx_t nodeId = hnsw.neighbors[i];
368
- if (nodeId < 0)
369
- break;
370
- if (vt.get(nodeId))
371
- continue;
372
- vt.set(nodeId);
373
387
 
374
- float dis = qdis(nodeId);
375
- NodeDistFarther evE1(dis, nodeId);
376
-
377
- if (results.size() < hnsw.efConstruction || results.top().d > dis) {
378
- results.emplace(dis, nodeId);
379
- candidates.emplace(dis, nodeId);
380
- if (results.size() > hnsw.efConstruction) {
381
- results.pop();
388
+ // select a version, based on a flag
389
+ if (reference_version) {
390
+ // a reference version
391
+ for (size_t i = begin; i < end; i++) {
392
+ storage_idx_t nodeId = hnsw.neighbors[i];
393
+ if (nodeId < 0)
394
+ break;
395
+ if (vt.get(nodeId))
396
+ continue;
397
+ vt.set(nodeId);
398
+
399
+ float dis = qdis(nodeId);
400
+ NodeDistFarther evE1(dis, nodeId);
401
+
402
+ if (results.size() < hnsw.efConstruction ||
403
+ results.top().d > dis) {
404
+ results.emplace(dis, nodeId);
405
+ candidates.emplace(dis, nodeId);
406
+ if (results.size() > hnsw.efConstruction) {
407
+ results.pop();
408
+ }
382
409
  }
383
410
  }
384
- }
385
- }
386
- vt.advance();
387
- }
411
+ } else {
412
+ // a faster version
413
+
414
+ // the following version processes 4 neighbors at a time
415
+ auto update_with_candidate = [&](const storage_idx_t idx,
416
+ const float dis) {
417
+ if (results.size() < hnsw.efConstruction ||
418
+ results.top().d > dis) {
419
+ results.emplace(dis, idx);
420
+ candidates.emplace(dis, idx);
421
+ if (results.size() > hnsw.efConstruction) {
422
+ results.pop();
423
+ }
424
+ }
425
+ };
388
426
 
389
- /**************************************************************
390
- * Searching subroutines
391
- **************************************************************/
427
+ int n_buffered = 0;
428
+ storage_idx_t buffered_ids[4];
392
429
 
393
- /// greedily update a nearest vector at a given level
394
- void greedy_update_nearest(
395
- const HNSW& hnsw,
396
- DistanceComputer& qdis,
397
- int level,
398
- storage_idx_t& nearest,
399
- float& d_nearest) {
400
- for (;;) {
401
- storage_idx_t prev_nearest = nearest;
430
+ for (size_t j = begin; j < end; j++) {
431
+ storage_idx_t nodeId = hnsw.neighbors[j];
432
+ if (nodeId < 0)
433
+ break;
434
+ if (vt.get(nodeId)) {
435
+ continue;
436
+ }
437
+ vt.set(nodeId);
438
+
439
+ buffered_ids[n_buffered] = nodeId;
440
+ n_buffered += 1;
441
+
442
+ if (n_buffered == 4) {
443
+ float dis[4];
444
+ qdis.distances_batch_4(
445
+ buffered_ids[0],
446
+ buffered_ids[1],
447
+ buffered_ids[2],
448
+ buffered_ids[3],
449
+ dis[0],
450
+ dis[1],
451
+ dis[2],
452
+ dis[3]);
453
+
454
+ for (size_t id4 = 0; id4 < 4; id4++) {
455
+ update_with_candidate(buffered_ids[id4], dis[id4]);
456
+ }
402
457
 
403
- size_t begin, end;
404
- hnsw.neighbor_range(nearest, level, &begin, &end);
405
- for (size_t i = begin; i < end; i++) {
406
- storage_idx_t v = hnsw.neighbors[i];
407
- if (v < 0)
408
- break;
409
- float dis = qdis(v);
410
- if (dis < d_nearest) {
411
- nearest = v;
412
- d_nearest = dis;
458
+ n_buffered = 0;
459
+ }
460
+ }
461
+
462
+ // process leftovers
463
+ for (size_t icnt = 0; icnt < n_buffered; icnt++) {
464
+ float dis = qdis(buffered_ids[icnt]);
465
+ update_with_candidate(buffered_ids[icnt], dis);
413
466
  }
414
- }
415
- if (nearest == prev_nearest) {
416
- return;
417
467
  }
418
468
  }
469
+
470
+ vt.advance();
419
471
  }
420
472
 
421
473
  } // namespace
@@ -429,7 +481,8 @@ void HNSW::add_links_starting_from(
429
481
  float d_nearest,
430
482
  int level,
431
483
  omp_lock_t* locks,
432
- VisitedTable& vt) {
484
+ VisitedTable& vt,
485
+ bool keep_max_size_level0) {
433
486
  std::priority_queue<NodeDistCloser> link_targets;
434
487
 
435
488
  search_neighbors_to_add(
@@ -438,13 +491,13 @@ void HNSW::add_links_starting_from(
438
491
  // but we can afford only this many neighbors
439
492
  int M = nb_neighbors(level);
440
493
 
441
- ::faiss::shrink_neighbor_list(ptdis, link_targets, M);
494
+ ::faiss::shrink_neighbor_list(ptdis, link_targets, M, keep_max_size_level0);
442
495
 
443
496
  std::vector<storage_idx_t> neighbors;
444
497
  neighbors.reserve(link_targets.size());
445
498
  while (!link_targets.empty()) {
446
499
  storage_idx_t other_id = link_targets.top().id;
447
- add_link(*this, ptdis, pt_id, other_id, level);
500
+ add_link(*this, ptdis, pt_id, other_id, level, keep_max_size_level0);
448
501
  neighbors.push_back(other_id);
449
502
  link_targets.pop();
450
503
  }
@@ -452,7 +505,7 @@ void HNSW::add_links_starting_from(
452
505
  omp_unset_lock(&locks[pt_id]);
453
506
  for (storage_idx_t other_id : neighbors) {
454
507
  omp_set_lock(&locks[other_id]);
455
- add_link(*this, ptdis, other_id, pt_id, level);
508
+ add_link(*this, ptdis, other_id, pt_id, level, keep_max_size_level0);
456
509
  omp_unset_lock(&locks[other_id]);
457
510
  }
458
511
  omp_set_lock(&locks[pt_id]);
@@ -467,7 +520,8 @@ void HNSW::add_with_locks(
467
520
  int pt_level,
468
521
  int pt_id,
469
522
  std::vector<omp_lock_t>& locks,
470
- VisitedTable& vt) {
523
+ VisitedTable& vt,
524
+ bool keep_max_size_level0) {
471
525
  // greedy search on upper levels
472
526
 
473
527
  storage_idx_t nearest;
@@ -496,7 +550,14 @@ void HNSW::add_with_locks(
496
550
 
497
551
  for (; level >= 0; level--) {
498
552
  add_links_starting_from(
499
- ptdis, pt_id, nearest, d_nearest, level, locks.data(), vt);
553
+ ptdis,
554
+ pt_id,
555
+ nearest,
556
+ d_nearest,
557
+ level,
558
+ locks.data(),
559
+ vt,
560
+ keep_max_size_level0);
500
561
  }
501
562
 
502
563
  omp_unset_lock(&locks[pt_id]);
@@ -511,12 +572,10 @@ void HNSW::add_with_locks(
511
572
  * Searching
512
573
  **************************************************************/
513
574
 
514
- namespace {
515
575
  using MinimaxHeap = HNSW::MinimaxHeap;
516
576
  using Node = HNSW::Node;
517
577
  using C = HNSW::C;
518
578
  /** Do a BFS on the candidates list */
519
-
520
579
  int search_from_candidates(
521
580
  const HNSW& hnsw,
522
581
  DistanceComputer& qdis,
@@ -525,8 +584,8 @@ int search_from_candidates(
525
584
  VisitedTable& vt,
526
585
  HNSWStats& stats,
527
586
  int level,
528
- int nres_in = 0,
529
- const SearchParametersHNSW* params = nullptr) {
587
+ int nres_in,
588
+ const SearchParametersHNSW* params) {
530
589
  int nres = nres_in;
531
590
  int ndis = 0;
532
591
 
@@ -571,27 +630,7 @@ int search_from_candidates(
571
630
  size_t begin, end;
572
631
  hnsw.neighbor_range(v0, level, &begin, &end);
573
632
 
574
- // // baseline version
575
- // for (size_t j = begin; j < end; j++) {
576
- // int v1 = hnsw.neighbors[j];
577
- // if (v1 < 0)
578
- // break;
579
- // if (vt.get(v1)) {
580
- // continue;
581
- // }
582
- // vt.set(v1);
583
- // ndis++;
584
- // float d = qdis(v1);
585
- // if (!sel || sel->is_member(v1)) {
586
- // if (nres < k) {
587
- // faiss::maxheap_push(++nres, D, I, d, v1);
588
- // } else if (d < D[0]) {
589
- // faiss::maxheap_replace_top(nres, D, I, d, v1);
590
- // }
591
- // }
592
- // candidates.push(v1, d);
593
- // }
594
-
633
+ // a faster version: reference version in unit test test_hnsw.cpp
595
634
  // the following version processes 4 neighbors at a time
596
635
  size_t jmax = begin;
597
636
  for (size_t j = begin; j < end; j++) {
@@ -606,7 +645,6 @@ int search_from_candidates(
606
645
  int counter = 0;
607
646
  size_t saved_j[4];
608
647
 
609
- ndis += jmax - begin;
610
648
  threshold = res.threshold;
611
649
 
612
650
  auto add_to_heap = [&](const size_t idx, const float dis) {
@@ -614,6 +652,7 @@ int search_from_candidates(
614
652
  if (dis < threshold) {
615
653
  if (res.add_result(dis, idx)) {
616
654
  threshold = res.threshold;
655
+ nres += 1;
617
656
  }
618
657
  }
619
658
  }
@@ -644,6 +683,8 @@ int search_from_candidates(
644
683
  add_to_heap(saved_j[id4], dis[id4]);
645
684
  }
646
685
 
686
+ ndis += 4;
687
+
647
688
  counter = 0;
648
689
  }
649
690
  }
@@ -651,6 +692,8 @@ int search_from_candidates(
651
692
  for (size_t icnt = 0; icnt < counter; icnt++) {
652
693
  float dis = qdis(saved_j[icnt]);
653
694
  add_to_heap(saved_j[icnt], dis);
695
+
696
+ ndis += 1;
654
697
  }
655
698
 
656
699
  nstep++;
@@ -664,7 +707,8 @@ int search_from_candidates(
664
707
  if (candidates.size() == 0) {
665
708
  stats.n2++;
666
709
  }
667
- stats.n3 += ndis;
710
+ stats.ndis += ndis;
711
+ stats.nhops += nstep;
668
712
  }
669
713
 
670
714
  return nres;
@@ -700,33 +744,7 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
700
744
  size_t begin, end;
701
745
  hnsw.neighbor_range(v0, 0, &begin, &end);
702
746
 
703
- // // baseline version
704
- // for (size_t j = begin; j < end; ++j) {
705
- // int v1 = hnsw.neighbors[j];
706
- //
707
- // if (v1 < 0) {
708
- // break;
709
- // }
710
- // if (vt->get(v1)) {
711
- // continue;
712
- // }
713
- //
714
- // vt->set(v1);
715
- //
716
- // float d1 = qdis(v1);
717
- // ++ndis;
718
- //
719
- // if (top_candidates.top().first > d1 ||
720
- // top_candidates.size() < ef) {
721
- // candidates.emplace(d1, v1);
722
- // top_candidates.emplace(d1, v1);
723
- //
724
- // if (top_candidates.size() > ef) {
725
- // top_candidates.pop();
726
- // }
727
- // }
728
- // }
729
-
747
+ // a faster version: reference version in unit test test_hnsw.cpp
730
748
  // the following version processes 4 neighbors at a time
731
749
  size_t jmax = begin;
732
750
  for (size_t j = begin; j < end; j++) {
@@ -741,8 +759,6 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
741
759
  int counter = 0;
742
760
  size_t saved_j[4];
743
761
 
744
- ndis += jmax - begin;
745
-
746
762
  auto add_to_heap = [&](const size_t idx, const float dis) {
747
763
  if (top_candidates.top().first > dis ||
748
764
  top_candidates.size() < ef) {
@@ -779,6 +795,8 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
779
795
  add_to_heap(saved_j[id4], dis[id4]);
780
796
  }
781
797
 
798
+ ndis += 4;
799
+
782
800
  counter = 0;
783
801
  }
784
802
  }
@@ -786,18 +804,102 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
786
804
  for (size_t icnt = 0; icnt < counter; icnt++) {
787
805
  float dis = qdis(saved_j[icnt]);
788
806
  add_to_heap(saved_j[icnt], dis);
807
+
808
+ ndis += 1;
789
809
  }
810
+
811
+ stats.nhops += 1;
790
812
  }
791
813
 
792
814
  ++stats.n1;
793
815
  if (candidates.size() == 0) {
794
816
  ++stats.n2;
795
817
  }
796
- stats.n3 += ndis;
818
+ stats.ndis += ndis;
797
819
 
798
820
  return top_candidates;
799
821
  }
800
822
 
823
+ /// greedily update a nearest vector at a given level
824
+ HNSWStats greedy_update_nearest(
825
+ const HNSW& hnsw,
826
+ DistanceComputer& qdis,
827
+ int level,
828
+ storage_idx_t& nearest,
829
+ float& d_nearest) {
830
+ HNSWStats stats;
831
+
832
+ for (;;) {
833
+ storage_idx_t prev_nearest = nearest;
834
+
835
+ size_t begin, end;
836
+ hnsw.neighbor_range(nearest, level, &begin, &end);
837
+
838
+ size_t ndis = 0;
839
+
840
+ // a faster version: reference version in unit test test_hnsw.cpp
841
+ // the following version processes 4 neighbors at a time
842
+ auto update_with_candidate = [&](const storage_idx_t idx,
843
+ const float dis) {
844
+ if (dis < d_nearest) {
845
+ nearest = idx;
846
+ d_nearest = dis;
847
+ }
848
+ };
849
+
850
+ int n_buffered = 0;
851
+ storage_idx_t buffered_ids[4];
852
+
853
+ for (size_t j = begin; j < end; j++) {
854
+ storage_idx_t v = hnsw.neighbors[j];
855
+ if (v < 0)
856
+ break;
857
+ ndis += 1;
858
+
859
+ buffered_ids[n_buffered] = v;
860
+ n_buffered += 1;
861
+
862
+ if (n_buffered == 4) {
863
+ float dis[4];
864
+ qdis.distances_batch_4(
865
+ buffered_ids[0],
866
+ buffered_ids[1],
867
+ buffered_ids[2],
868
+ buffered_ids[3],
869
+ dis[0],
870
+ dis[1],
871
+ dis[2],
872
+ dis[3]);
873
+
874
+ for (size_t id4 = 0; id4 < 4; id4++) {
875
+ update_with_candidate(buffered_ids[id4], dis[id4]);
876
+ }
877
+
878
+ n_buffered = 0;
879
+ }
880
+ }
881
+
882
+ // process leftovers
883
+ for (size_t icnt = 0; icnt < n_buffered; icnt++) {
884
+ float dis = qdis(buffered_ids[icnt]);
885
+ update_with_candidate(buffered_ids[icnt], dis);
886
+ }
887
+
888
+ // update stats
889
+ stats.ndis += ndis;
890
+ stats.nhops += 1;
891
+
892
+ if (nearest == prev_nearest) {
893
+ return stats;
894
+ }
895
+ }
896
+ }
897
+
898
+ namespace {
899
+ using MinimaxHeap = HNSW::MinimaxHeap;
900
+ using Node = HNSW::Node;
901
+ using C = HNSW::C;
902
+
801
903
  // just used as a lower bound for the minmaxheap, but it is set for heap search
802
904
  int extract_k_from_ResultHandler(ResultHandler<C>& res) {
803
905
  using RH = HeapBlockResultHandler<C>;
@@ -807,7 +909,7 @@ int extract_k_from_ResultHandler(ResultHandler<C>& res) {
807
909
  return 1;
808
910
  }
809
911
 
810
- } // anonymous namespace
912
+ } // namespace
811
913
 
812
914
  HNSWStats HNSW::search(
813
915
  DistanceComputer& qdis,
@@ -820,85 +922,47 @@ HNSWStats HNSW::search(
820
922
  }
821
923
  int k = extract_k_from_ResultHandler(res);
822
924
 
823
- if (upper_beam == 1) {
824
- // greedy search on upper levels
825
- storage_idx_t nearest = entry_point;
826
- float d_nearest = qdis(nearest);
827
-
828
- for (int level = max_level; level >= 1; level--) {
829
- greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
830
- }
925
+ bool bounded_queue =
926
+ params ? params->bounded_queue : this->search_bounded_queue;
831
927
 
832
- int ef = std::max(params ? params->efSearch : efSearch, k);
833
- if (search_bounded_queue) { // this is the most common branch
834
- MinimaxHeap candidates(ef);
835
-
836
- candidates.push(nearest, d_nearest);
928
+ // greedy search on upper levels
929
+ storage_idx_t nearest = entry_point;
930
+ float d_nearest = qdis(nearest);
837
931
 
838
- search_from_candidates(
839
- *this, qdis, res, candidates, vt, stats, 0, 0, params);
840
- } else {
841
- std::priority_queue<Node> top_candidates =
842
- search_from_candidate_unbounded(
843
- *this,
844
- Node(d_nearest, nearest),
845
- qdis,
846
- ef,
847
- &vt,
848
- stats);
849
-
850
- while (top_candidates.size() > k) {
851
- top_candidates.pop();
852
- }
932
+ for (int level = max_level; level >= 1; level--) {
933
+ HNSWStats local_stats =
934
+ greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
935
+ stats.combine(local_stats);
936
+ }
853
937
 
854
- while (!top_candidates.empty()) {
855
- float d;
856
- storage_idx_t label;
857
- std::tie(d, label) = top_candidates.top();
858
- res.add_result(d, label);
859
- top_candidates.pop();
860
- }
861
- }
938
+ int ef = std::max(params ? params->efSearch : efSearch, k);
939
+ if (bounded_queue) { // this is the most common branch
940
+ MinimaxHeap candidates(ef);
862
941
 
863
- vt.advance();
942
+ candidates.push(nearest, d_nearest);
864
943
 
944
+ search_from_candidates(
945
+ *this, qdis, res, candidates, vt, stats, 0, 0, params);
865
946
  } else {
866
- int candidates_size = upper_beam;
867
- MinimaxHeap candidates(candidates_size);
868
-
869
- std::vector<idx_t> I_to_next(candidates_size);
870
- std::vector<float> D_to_next(candidates_size);
871
-
872
- HeapBlockResultHandler<C> block_resh(
873
- 1, D_to_next.data(), I_to_next.data(), candidates_size);
874
- HeapBlockResultHandler<C>::SingleResultHandler resh(block_resh);
875
-
876
- int nres = 1;
877
- I_to_next[0] = entry_point;
878
- D_to_next[0] = qdis(entry_point);
879
-
880
- for (int level = max_level; level >= 0; level--) {
881
- // copy I, D -> candidates
882
-
883
- candidates.clear();
947
+ std::priority_queue<Node> top_candidates =
948
+ search_from_candidate_unbounded(
949
+ *this, Node(d_nearest, nearest), qdis, ef, &vt, stats);
884
950
 
885
- for (int i = 0; i < nres; i++) {
886
- candidates.push(I_to_next[i], D_to_next[i]);
887
- }
951
+ while (top_candidates.size() > k) {
952
+ top_candidates.pop();
953
+ }
888
954
 
889
- if (level == 0) {
890
- nres = search_from_candidates(
891
- *this, qdis, res, candidates, vt, stats, 0);
892
- } else {
893
- resh.begin(0);
894
- nres = search_from_candidates(
895
- *this, qdis, resh, candidates, vt, stats, level);
896
- resh.end();
897
- }
898
- vt.advance();
955
+ while (!top_candidates.empty()) {
956
+ float d;
957
+ storage_idx_t label;
958
+ std::tie(d, label) = top_candidates.top();
959
+ res.add_result(d, label);
960
+ top_candidates.pop();
899
961
  }
900
962
  }
901
963
 
964
+ vt.advance();
965
+
902
966
  return stats;
903
967
  }
904
968
 
@@ -910,9 +974,12 @@ void HNSW::search_level_0(
910
974
  const float* nearest_d,
911
975
  int search_type,
912
976
  HNSWStats& search_stats,
913
- VisitedTable& vt) const {
977
+ VisitedTable& vt,
978
+ const SearchParametersHNSW* params) const {
914
979
  const HNSW& hnsw = *this;
980
+ auto efSearch = params ? params->efSearch : hnsw.efSearch;
915
981
  int k = extract_k_from_ResultHandler(res);
982
+
916
983
  if (search_type == 1) {
917
984
  int nres = 0;
918
985
 
@@ -925,16 +992,25 @@ void HNSW::search_level_0(
925
992
  if (vt.get(cj))
926
993
  continue;
927
994
 
928
- int candidates_size = std::max(hnsw.efSearch, k);
995
+ int candidates_size = std::max(efSearch, k);
929
996
  MinimaxHeap candidates(candidates_size);
930
997
 
931
998
  candidates.push(cj, nearest_d[j]);
932
999
 
933
1000
  nres = search_from_candidates(
934
- hnsw, qdis, res, candidates, vt, search_stats, 0, nres);
1001
+ hnsw,
1002
+ qdis,
1003
+ res,
1004
+ candidates,
1005
+ vt,
1006
+ search_stats,
1007
+ 0,
1008
+ nres,
1009
+ params);
1010
+ nres = std::min(nres, candidates_size);
935
1011
  }
936
1012
  } else if (search_type == 2) {
937
- int candidates_size = std::max(hnsw.efSearch, int(k));
1013
+ int candidates_size = std::max(efSearch, int(k));
938
1014
  candidates_size = std::max(candidates_size, int(nprobe));
939
1015
 
940
1016
  MinimaxHeap candidates(candidates_size);
@@ -947,7 +1023,7 @@ void HNSW::search_level_0(
947
1023
  }
948
1024
 
949
1025
  search_from_candidates(
950
- hnsw, qdis, res, candidates, vt, search_stats, 0);
1026
+ hnsw, qdis, res, candidates, vt, search_stats, 0, 0, params);
951
1027
  }
952
1028
  }
953
1029
 
@@ -1013,7 +1089,99 @@ void HNSW::MinimaxHeap::clear() {
1013
1089
  nvalid = k = 0;
1014
1090
  }
1015
1091
 
1016
- #ifdef __AVX2__
1092
+ #ifdef __AVX512F__
1093
+
1094
+ int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
1095
+ assert(k > 0);
1096
+ static_assert(
1097
+ std::is_same<storage_idx_t, int32_t>::value,
1098
+ "This code expects storage_idx_t to be int32_t");
1099
+
1100
+ int32_t min_idx = -1;
1101
+ float min_dis = std::numeric_limits<float>::infinity();
1102
+
1103
+ __m512i min_indices = _mm512_set1_epi32(-1);
1104
+ __m512 min_distances =
1105
+ _mm512_set1_ps(std::numeric_limits<float>::infinity());
1106
+ __m512i current_indices = _mm512_setr_epi32(
1107
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
1108
+ __m512i offset = _mm512_set1_epi32(16);
1109
+
1110
+ // The following loop tracks the rightmost index with the min distance.
1111
+ // -1 index values are ignored.
1112
+ const int k16 = (k / 16) * 16;
1113
+ for (size_t iii = 0; iii < k16; iii += 16) {
1114
+ __m512i indices =
1115
+ _mm512_loadu_si512((const __m512i*)(ids.data() + iii));
1116
+ __m512 distances = _mm512_loadu_ps(dis.data() + iii);
1117
+
1118
+ // This mask filters out -1 values among indices.
1119
+ __mmask16 m1mask =
1120
+ _mm512_cmpgt_epi32_mask(_mm512_setzero_si512(), indices);
1121
+
1122
+ __mmask16 dmask =
1123
+ _mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS);
1124
+ __mmask16 finalmask = m1mask | dmask;
1125
+
1126
+ const __m512i min_indices_new = _mm512_mask_blend_epi32(
1127
+ finalmask, current_indices, min_indices);
1128
+ const __m512 min_distances_new =
1129
+ _mm512_mask_blend_ps(finalmask, distances, min_distances);
1130
+
1131
+ min_indices = min_indices_new;
1132
+ min_distances = min_distances_new;
1133
+
1134
+ current_indices = _mm512_add_epi32(current_indices, offset);
1135
+ }
1136
+
1137
+ // leftovers
1138
+ if (k16 != k) {
1139
+ const __mmask16 kmask = (1 << (k - k16)) - 1;
1140
+
1141
+ __m512i indices = _mm512_mask_loadu_epi32(
1142
+ _mm512_set1_epi32(-1), kmask, ids.data() + k16);
1143
+ __m512 distances = _mm512_maskz_loadu_ps(kmask, dis.data() + k16);
1144
+
1145
+ // This mask filters out -1 values among indices.
1146
+ __mmask16 m1mask =
1147
+ _mm512_cmpgt_epi32_mask(_mm512_setzero_si512(), indices);
1148
+
1149
+ __mmask16 dmask =
1150
+ _mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS);
1151
+ __mmask16 finalmask = m1mask | dmask;
1152
+
1153
+ const __m512i min_indices_new = _mm512_mask_blend_epi32(
1154
+ finalmask, current_indices, min_indices);
1155
+ const __m512 min_distances_new =
1156
+ _mm512_mask_blend_ps(finalmask, distances, min_distances);
1157
+
1158
+ min_indices = min_indices_new;
1159
+ min_distances = min_distances_new;
1160
+ }
1161
+
1162
+ // grab min distance
1163
+ min_dis = _mm512_reduce_min_ps(min_distances);
1164
+ // blend
1165
+ __mmask16 mindmask =
1166
+ _mm512_cmpeq_ps_mask(min_distances, _mm512_set1_ps(min_dis));
1167
+ // pick the max one
1168
+ min_idx = _mm512_mask_reduce_max_epi32(mindmask, min_indices);
1169
+
1170
+ if (min_idx == -1) {
1171
+ return -1;
1172
+ }
1173
+
1174
+ if (vmin_out) {
1175
+ *vmin_out = min_dis;
1176
+ }
1177
+ int ret = ids[min_idx];
1178
+ ids[min_idx] = -1;
1179
+ --nvalid;
1180
+ return ret;
1181
+ }
1182
+
1183
+ #elif __AVX2__
1184
+
1017
1185
  int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
1018
1186
  assert(k > 0);
1019
1187
  static_assert(