faiss 0.6.1 → 0.6.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/Index.h +1 -1
  5. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +6 -7
  6. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +3 -3
  7. data/vendor/faiss/faiss/IndexHNSW.cpp +173 -143
  8. data/vendor/faiss/faiss/IndexIVF.cpp +2 -2
  9. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
  10. data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -1
  11. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +3 -3
  12. data/vendor/faiss/faiss/IndexIVFPQ.cpp +2 -3
  13. data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -3
  14. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +4 -13
  15. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -1
  16. data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
  17. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +68 -6
  18. data/vendor/faiss/faiss/IndexScalarQuantizer.h +10 -0
  19. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +1 -1
  20. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +902 -12
  21. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +702 -10
  22. data/vendor/faiss/faiss/factory_tools.cpp +4 -0
  23. data/vendor/faiss/faiss/gpu/GpuResources.h +3 -2
  24. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +11 -12
  25. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +3 -3
  26. data/vendor/faiss/faiss/gpu_metal/MetalDistance.h +87 -0
  27. data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +7 -0
  28. data/vendor/faiss/faiss/gpu_metal/MetalIndexIVFFlat.h +181 -0
  29. data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +48 -3
  30. data/vendor/faiss/faiss/gpu_metal/MetalPythonBridge.h +45 -0
  31. data/vendor/faiss/faiss/gpu_metal/impl/MetalIVFFlat.h +193 -0
  32. data/vendor/faiss/faiss/impl/HNSW.cpp +556 -199
  33. data/vendor/faiss/faiss/impl/HNSW.h +51 -13
  34. data/vendor/faiss/faiss/impl/NSG.cpp +15 -11
  35. data/vendor/faiss/faiss/impl/Panorama.h +11 -0
  36. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -2
  37. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +1 -1
  38. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +7 -1
  39. data/vendor/faiss/faiss/impl/ResultHandler.h +1 -0
  40. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +271 -8
  41. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +50 -0
  42. data/vendor/faiss/faiss/impl/VisitedTable.cpp +10 -10
  43. data/vendor/faiss/faiss/impl/VisitedTable.h +69 -34
  44. data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +3 -1
  45. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +35 -43
  46. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -15
  47. data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +86 -40
  48. data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +81 -50
  49. data/vendor/faiss/faiss/impl/index_read.cpp +100 -39
  50. data/vendor/faiss/faiss/impl/index_write.cpp +1 -0
  51. data/vendor/faiss/faiss/impl/io_macros.h +25 -0
  52. data/vendor/faiss/faiss/impl/platform_macros.h +12 -8
  53. data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +2 -0
  54. data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +2 -0
  55. data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +2 -0
  56. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +20 -0
  57. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +36 -0
  58. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-sve.cpp +5 -0
  59. data/vendor/faiss/faiss/impl/pq_code_distance/pq_scan_impl.h +105 -0
  60. data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +2 -0
  61. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +6 -0
  62. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +327 -18
  63. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +264 -27
  64. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512-impl.h +553 -0
  65. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512-spr.cpp +559 -0
  66. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +199 -27
  67. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +366 -3
  68. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +144 -19
  69. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +26 -0
  70. data/vendor/faiss/faiss/impl/simd_dispatch.h +65 -8
  71. data/vendor/faiss/faiss/index_factory.cpp +5 -1
  72. data/vendor/faiss/faiss/index_io.h +16 -0
  73. data/vendor/faiss/faiss/invlists/DirectMap.cpp +4 -1
  74. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +13 -13
  75. data/vendor/faiss/faiss/invlists/InvertedLists.h +2 -2
  76. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +119 -22
  77. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +15 -5
  78. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +3 -2
  79. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +2 -1
  80. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +65 -24
  81. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +3 -2
  82. data/vendor/faiss/faiss/utils/bf16.h +34 -0
  83. data/vendor/faiss/faiss/utils/distances_simd.cpp +0 -1
  84. data/vendor/faiss/faiss/utils/hamming.cpp +8 -8
  85. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +2 -1
  86. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512_spr.cpp +15 -0
  87. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +6 -30
  88. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512_spr.h +171 -0
  89. data/vendor/faiss/faiss/utils/partitioning.cpp +0 -2
  90. data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +14 -68
  91. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512_spr.cpp +343 -0
  92. data/vendor/faiss/faiss/utils/simd_levels.cpp +12 -2
  93. metadata +12 -2
@@ -10,6 +10,7 @@
10
10
  #include <cinttypes>
11
11
  #include <cstddef>
12
12
  #include <cstdlib>
13
+ #include <type_traits>
13
14
 
14
15
  #include <faiss/IndexHNSW.h>
15
16
 
@@ -233,28 +234,32 @@ int HNSW::prepare_level_tab(size_t n, bool preset_levels) {
233
234
  * neighbor only if there is no previous neighbor that is closer to
234
235
  * that vertex than the query.
235
236
  */
237
+ template <class Comp>
236
238
  void HNSW::shrink_neighbor_list(
237
239
  DistanceComputer& qdis,
238
- std::priority_queue<NodeDistFarther>& input,
239
- std::vector<NodeDistFarther>& output,
240
+ std::priority_queue<NodeDistFartherT<Comp>>& input,
241
+ std::vector<NodeDistFartherT<Comp>>& output,
240
242
  size_t max_size,
241
243
  bool keep_max_size_level0) {
242
244
  // This prevents number of neighbors at
243
245
  // level 0 from being shrunk to less than 2 * M.
244
246
  // This is essential in making sure
245
247
  // `faiss::gpu::GpuIndexCagra::copyFrom(IndexHNSWCagra*)` is functional
246
- std::vector<NodeDistFarther> outsiders;
248
+ std::vector<NodeDistFartherT<Comp>> outsiders;
247
249
 
248
250
  while (input.size() > 0) {
249
- NodeDistFarther v1 = input.top();
251
+ NodeDistFartherT<Comp> v1 = input.top();
250
252
  input.pop();
251
253
  float dist_v1_q = v1.d;
252
254
 
253
255
  bool good = true;
254
- for (NodeDistFarther v2 : output) {
256
+ for (NodeDistFartherT<Comp> v2 : output) {
255
257
  float dist_v1_v2 = qdis.symmetric_dis(v2.id, v1.id);
256
258
 
257
- if (dist_v1_v2 < dist_v1_q) {
259
+ // "v1 is bad" if some previously-kept neighbor v2 is closer
260
+ // (more similar, under CMin) to v1 than the query is. Encoded
261
+ // generically as: v1v2 is "better than" v1q under Comp.
262
+ if (Comp::cmp(dist_v1_q, dist_v1_v2)) {
258
263
  good = false;
259
264
  break;
260
265
  }
@@ -277,44 +282,88 @@ void HNSW::shrink_neighbor_list(
277
282
  }
278
283
  }
279
284
 
285
+ // Explicit instantiations for the two supported comparators.
286
+ template void HNSW::shrink_neighbor_list<HNSW::C_distance>(
287
+ DistanceComputer&,
288
+ std::priority_queue<HNSW::NodeDistFartherT<HNSW::C_distance>>&,
289
+ std::vector<HNSW::NodeDistFartherT<HNSW::C_distance>>&,
290
+ size_t,
291
+ bool);
292
+ template void HNSW::shrink_neighbor_list<HNSW::C_similarity>(
293
+ DistanceComputer&,
294
+ std::priority_queue<HNSW::NodeDistFartherT<HNSW::C_similarity>>&,
295
+ std::vector<HNSW::NodeDistFartherT<HNSW::C_similarity>>&,
296
+ size_t,
297
+ bool);
298
+
280
299
  namespace {
281
300
 
282
301
  using storage_idx_t = HNSW::storage_idx_t;
283
- using NodeDistCloser = HNSW::NodeDistCloser;
284
- using NodeDistFarther = HNSW::NodeDistFarther;
302
+
303
+ // Map a (high-level) HNSW comparator C — which uses int64_t IDs — to the
304
+ // (low-level) MinimaxHeap comparator HC, which uses int32_t IDs.
305
+ template <class C>
306
+ using HC_for = std::
307
+ conditional_t<C::is_max, CMax<float, int32_t>, CMin<float, int32_t>>;
308
+
309
+ // Priority queue types used by the unbounded search variant. For CMax
310
+ // (distance) "top_candidates" is a max-heap of the kept-so-far results
311
+ // (top is the farthest) and "candidates" is a min-heap of the next nodes
312
+ // to explore (top is the closest). For CMin (similarity) the orderings are
313
+ // swapped: top_candidates is a min-heap (top is the least similar) and
314
+ // candidates is a max-heap (top is the most similar).
315
+ template <class C>
316
+ using TopCandidatesQueue = std::conditional_t<
317
+ C::is_max,
318
+ std::priority_queue<HNSW::Node>,
319
+ std::priority_queue<
320
+ HNSW::Node,
321
+ std::vector<HNSW::Node>,
322
+ std::greater<HNSW::Node>>>;
323
+
324
+ template <class C>
325
+ using CandidatesQueue = std::conditional_t<
326
+ C::is_max,
327
+ std::priority_queue<
328
+ HNSW::Node,
329
+ std::vector<HNSW::Node>,
330
+ std::greater<HNSW::Node>>,
331
+ std::priority_queue<HNSW::Node>>;
285
332
 
286
333
  /**************************************************************
287
334
  * Addition subroutines
288
335
  **************************************************************/
289
336
 
290
337
  /// remove neighbors from the list to make it smaller than max_size
291
- void shrink_neighbor_list(
338
+ template <class C>
339
+ void shrink_neighbor_list_inner(
292
340
  DistanceComputer& qdis,
293
- std::priority_queue<NodeDistCloser>& resultSet1,
341
+ std::priority_queue<HNSW::NodeDistCloserT<C>>& resultSet1,
294
342
  size_t max_size,
295
343
  bool keep_max_size_level0 = false) {
296
344
  if (resultSet1.size() < static_cast<size_t>(max_size)) {
297
345
  return;
298
346
  }
299
- std::priority_queue<NodeDistFarther> resultSet;
300
- std::vector<NodeDistFarther> returnlist;
347
+ std::priority_queue<HNSW::NodeDistFartherT<C>> resultSet;
348
+ std::vector<HNSW::NodeDistFartherT<C>> returnlist;
301
349
 
302
350
  while (resultSet1.size() > 0) {
303
351
  resultSet.emplace(resultSet1.top().d, resultSet1.top().id);
304
352
  resultSet1.pop();
305
353
  }
306
354
 
307
- HNSW::shrink_neighbor_list(
355
+ HNSW::shrink_neighbor_list<C>(
308
356
  qdis, resultSet, returnlist, max_size, keep_max_size_level0);
309
357
 
310
- for (NodeDistFarther curen2 : returnlist) {
358
+ for (HNSW::NodeDistFartherT<C> curen2 : returnlist) {
311
359
  resultSet1.emplace(curen2.d, curen2.id);
312
360
  }
313
361
  }
314
362
 
315
363
  /// add a link between two elements, possibly shrinking the list
316
364
  /// of links to make room for it.
317
- void add_link(
365
+ template <class C>
366
+ void add_link_tpl(
318
367
  HNSW& hnsw,
319
368
  DistanceComputer& qdis,
320
369
  storage_idx_t src,
@@ -339,16 +388,17 @@ void add_link(
339
388
  // otherwise we let them fight out which to keep
340
389
 
341
390
  // copy to resultSet...
342
- std::priority_queue<NodeDistCloser> resultSet;
391
+ std::priority_queue<HNSW::NodeDistCloserT<C>> resultSet;
343
392
  resultSet.emplace(qdis.symmetric_dis(src, dest), dest);
344
- for (size_t i = begin; i < end; i++) { // HERE WAS THE BUG
393
+ for (size_t i = begin; i < end; i++) {
345
394
  storage_idx_t neigh = hnsw.neighbors[i];
346
395
  resultSet.emplace(qdis.symmetric_dis(src, neigh), neigh);
347
396
  }
348
397
 
349
398
  size_t max_size = end - begin;
350
399
  max_size -= max_size * std::clamp(hnsw.prune_headroom, 0.0f, 0.5f);
351
- shrink_neighbor_list(qdis, resultSet, max_size, keep_max_size_level0);
400
+ shrink_neighbor_list_inner<C>(
401
+ qdis, resultSet, max_size, keep_max_size_level0);
352
402
 
353
403
  // ...and back
354
404
  size_t i = begin;
@@ -362,31 +412,33 @@ void add_link(
362
412
  }
363
413
  }
364
414
 
365
- } // namespace
366
-
367
- /// search neighbors on a single level, starting from an entry point
368
- void search_neighbors_to_add(
415
+ /** Templated body of `search_neighbors_to_add` — instantiated once per final
416
+ * VisitedTable subclass × comparator so that `vt.set/advance` are inlined
417
+ * and the cost of virtual dispatch is paid only once at the top of the call.
418
+ */
419
+ template <typename VTType, class C>
420
+ static void search_neighbors_to_add_fixVT(
369
421
  HNSW& hnsw,
370
422
  DistanceComputer& qdis,
371
- std::priority_queue<NodeDistCloser>& results,
423
+ std::priority_queue<HNSW::NodeDistCloserT<C>>& results,
372
424
  int entry_point,
373
425
  float d_entry_point,
374
426
  int level,
375
- VisitedTable& vt,
427
+ VTType& vt,
376
428
  bool reference_version) {
377
429
  // top is nearest candidate
378
- std::priority_queue<NodeDistFarther> candidates;
430
+ std::priority_queue<HNSW::NodeDistFartherT<C>> candidates;
379
431
 
380
- NodeDistFarther ev(d_entry_point, entry_point);
432
+ HNSW::NodeDistFartherT<C> ev(d_entry_point, entry_point);
381
433
  candidates.push(ev);
382
434
  results.emplace(d_entry_point, entry_point);
383
435
  vt.set(entry_point);
384
436
 
385
437
  while (!candidates.empty()) {
386
438
  // get nearest
387
- const NodeDistFarther& currEv = candidates.top();
439
+ const HNSW::NodeDistFartherT<C>& currEv = candidates.top();
388
440
 
389
- if (currEv.d > results.top().d) {
441
+ if (C::cmp(currEv.d, results.top().d)) {
390
442
  break;
391
443
  }
392
444
  int currNode = currEv.id;
@@ -407,7 +459,7 @@ void search_neighbors_to_add(
407
459
  if (reference_version) {
408
460
  // a reference version
409
461
  for (size_t i = begin; i < end; i++) {
410
- storage_idx_t nodeId = hnsw.neighbors[i];
462
+ HNSW::storage_idx_t nodeId = hnsw.neighbors[i];
411
463
  if (nodeId < 0) {
412
464
  break;
413
465
  }
@@ -416,10 +468,10 @@ void search_neighbors_to_add(
416
468
  }
417
469
 
418
470
  float dis = qdis(nodeId);
419
- NodeDistFarther evE1(dis, nodeId);
471
+ HNSW::NodeDistFartherT<C> evE1(dis, nodeId);
420
472
 
421
473
  if (results.size() < static_cast<size_t>(hnsw.efConstruction) ||
422
- results.top().d > dis) {
474
+ C::cmp(results.top().d, dis)) {
423
475
  results.emplace(dis, nodeId);
424
476
  candidates.emplace(dis, nodeId);
425
477
  if (results.size() >
@@ -432,10 +484,10 @@ void search_neighbors_to_add(
432
484
  // a faster version
433
485
 
434
486
  // the following version processes 4 neighbors at a time
435
- auto update_with_candidate = [&](const storage_idx_t idx,
487
+ auto update_with_candidate = [&](const HNSW::storage_idx_t idx,
436
488
  const float dis) {
437
489
  if (results.size() < static_cast<size_t>(hnsw.efConstruction) ||
438
- results.top().d > dis) {
490
+ C::cmp(results.top().d, dis)) {
439
491
  results.emplace(dis, idx);
440
492
  candidates.emplace(dis, idx);
441
493
  if (results.size() >
@@ -446,10 +498,10 @@ void search_neighbors_to_add(
446
498
  };
447
499
 
448
500
  int n_buffered = 0;
449
- storage_idx_t buffered_ids[4];
501
+ HNSW::storage_idx_t buffered_ids[4];
450
502
 
451
503
  for (size_t j = begin; j < end; j++) {
452
- storage_idx_t nodeId = hnsw.neighbors[j];
504
+ HNSW::storage_idx_t nodeId = hnsw.neighbors[j];
453
505
  if (nodeId < 0) {
454
506
  break;
455
507
  }
@@ -491,9 +543,41 @@ void search_neighbors_to_add(
491
543
  vt.advance();
492
544
  }
493
545
 
494
- /// Finds neighbors and builds links with them, starting from an entry
495
- /// point. The own neighbor list is assumed to be locked.
496
- void HNSW::add_links_starting_from(
546
+ /// Dispatches the VisitedTable concrete type for a given C, then calls
547
+ /// the templated `search_neighbors_to_add_fixVT<VTType, C>`.
548
+ template <class C>
549
+ void search_neighbors_to_add_dispatch(
550
+ HNSW& hnsw,
551
+ DistanceComputer& qdis,
552
+ std::priority_queue<HNSW::NodeDistCloserT<C>>& results,
553
+ int entry_point,
554
+ float d_entry_point,
555
+ int level,
556
+ VisitedTable& vt,
557
+ bool reference_version) {
558
+ auto call = [&]<typename VTType>(VTType& vt_concrete) {
559
+ search_neighbors_to_add_fixVT<VTType, C>(
560
+ hnsw,
561
+ qdis,
562
+ results,
563
+ entry_point,
564
+ d_entry_point,
565
+ level,
566
+ vt_concrete,
567
+ reference_version);
568
+ };
569
+ if (VisitedTableVector* vtv = dynamic_cast<VisitedTableVector*>(&vt)) {
570
+ call(*vtv);
571
+ return;
572
+ }
573
+ VisitedTableSet& vts = dynamic_cast<VisitedTableSet&>(vt);
574
+ call(vts);
575
+ }
576
+
577
+ /// Templated implementation of `HNSW::add_links_starting_from`.
578
+ template <class C>
579
+ void add_links_starting_from_impl(
580
+ HNSW& hnsw,
497
581
  DistanceComputer& ptdis,
498
582
  storage_idx_t pt_id,
499
583
  storage_idx_t nearest,
@@ -502,21 +586,22 @@ void HNSW::add_links_starting_from(
502
586
  LockVector& locks,
503
587
  VisitedTable& vt,
504
588
  bool keep_max_size_level0) {
505
- std::priority_queue<NodeDistCloser> link_targets;
589
+ std::priority_queue<HNSW::NodeDistCloserT<C>> link_targets;
506
590
 
507
- search_neighbors_to_add(
508
- *this, ptdis, link_targets, nearest, d_nearest, level, vt);
591
+ search_neighbors_to_add_dispatch<C>(
592
+ hnsw, ptdis, link_targets, nearest, d_nearest, level, vt, false);
509
593
 
510
594
  // but we can afford only this many neighbors
511
- int M = nb_neighbors(level);
595
+ int M = hnsw.nb_neighbors(level);
512
596
 
513
- ::faiss::shrink_neighbor_list(ptdis, link_targets, M, keep_max_size_level0);
597
+ shrink_neighbor_list_inner<C>(ptdis, link_targets, M, keep_max_size_level0);
514
598
 
515
599
  std::vector<storage_idx_t> neighbors_to_add;
516
600
  neighbors_to_add.reserve(link_targets.size());
517
601
  while (!link_targets.empty()) {
518
602
  storage_idx_t other_id = link_targets.top().id;
519
- add_link(*this, ptdis, pt_id, other_id, level, keep_max_size_level0);
603
+ add_link_tpl<C>(
604
+ hnsw, ptdis, pt_id, other_id, level, keep_max_size_level0);
520
605
  neighbors_to_add.push_back(other_id);
521
606
  link_targets.pop();
522
607
  }
@@ -524,33 +609,197 @@ void HNSW::add_links_starting_from(
524
609
  locks.unlock(pt_id);
525
610
  for (storage_idx_t other_id : neighbors_to_add) {
526
611
  locks.lock(other_id);
527
- add_link(*this, ptdis, other_id, pt_id, level, keep_max_size_level0);
612
+ add_link_tpl<C>(
613
+ hnsw, ptdis, other_id, pt_id, level, keep_max_size_level0);
528
614
  locks.unlock(other_id);
529
615
  }
530
616
  locks.lock(pt_id);
531
617
  }
532
618
 
619
+ } // namespace
620
+
621
+ /// Finds neighbors and builds links with them, starting from an entry
622
+ /// point. The own neighbor list is assumed to be locked.
623
+ void HNSW::add_links_starting_from(
624
+ DistanceComputer& ptdis,
625
+ storage_idx_t pt_id,
626
+ storage_idx_t nearest,
627
+ float d_nearest,
628
+ int level,
629
+ LockVector& locks,
630
+ VisitedTable& vt,
631
+ bool keep_max_size_level0) {
632
+ if (is_similarity) {
633
+ add_links_starting_from_impl<C_similarity>(
634
+ *this,
635
+ ptdis,
636
+ pt_id,
637
+ nearest,
638
+ d_nearest,
639
+ level,
640
+ locks,
641
+ vt,
642
+ keep_max_size_level0);
643
+ } else {
644
+ add_links_starting_from_impl<C_distance>(
645
+ *this,
646
+ ptdis,
647
+ pt_id,
648
+ nearest,
649
+ d_nearest,
650
+ level,
651
+ locks,
652
+ vt,
653
+ keep_max_size_level0);
654
+ }
655
+ }
656
+
657
+ /// search neighbors on a single level, starting from an entry point.
658
+ /// Public dispatcher: always operates in distance (CMax) mode because its
659
+ /// `priority_queue<HNSW::NodeDistCloser>` signature is the back-compat
660
+ /// distance flavor. Internal callers that need similarity mode reach the
661
+ /// templated implementation directly via `search_neighbors_to_add_dispatch`.
662
+ void hnsw_detail::search_neighbors_to_add(
663
+ HNSW& hnsw,
664
+ DistanceComputer& qdis,
665
+ std::priority_queue<HNSW::NodeDistCloser>& results,
666
+ int entry_point,
667
+ float d_entry_point,
668
+ int level,
669
+ VisitedTable& vt,
670
+ bool reference_version) {
671
+ search_neighbors_to_add_dispatch<HNSW::C_distance>(
672
+ hnsw,
673
+ qdis,
674
+ results,
675
+ entry_point,
676
+ d_entry_point,
677
+ level,
678
+ vt,
679
+ reference_version);
680
+ }
681
+
533
682
  /**************************************************************
534
683
  * Building, parallel
535
684
  **************************************************************/
536
685
 
537
- void HNSW::add_with_locks(
686
+ namespace {
687
+
688
+ /// Greedy update of the nearest entry point at a given level.
689
+ template <class C>
690
+ HNSWStats greedy_update_nearest_impl(
691
+ const HNSW& hnsw,
692
+ DistanceComputer& qdis,
693
+ int level,
694
+ storage_idx_t& nearest,
695
+ float& d_nearest) {
696
+ HNSWStats stats;
697
+
698
+ for (;;) {
699
+ storage_idx_t prev_nearest = nearest;
700
+
701
+ size_t begin, end;
702
+ hnsw.neighbor_range(nearest, level, &begin, &end);
703
+
704
+ size_t ndis = 0;
705
+
706
+ // a faster version: reference version in unit test test_hnsw.cpp
707
+ // the following version processes 4 neighbors at a time
708
+ auto update_with_candidate = [&](const storage_idx_t idx,
709
+ const float dis) {
710
+ if (C::cmp(d_nearest, dis)) {
711
+ nearest = idx;
712
+ d_nearest = dis;
713
+ }
714
+ };
715
+
716
+ int n_buffered = 0;
717
+ storage_idx_t buffered_ids[4];
718
+
719
+ for (size_t j = begin; j < end; j++) {
720
+ storage_idx_t v = hnsw.neighbors[j];
721
+ if (v < 0) {
722
+ break;
723
+ }
724
+ ndis += 1;
725
+
726
+ buffered_ids[n_buffered] = v;
727
+ n_buffered += 1;
728
+
729
+ if (n_buffered == 4) {
730
+ float dis[4];
731
+ qdis.distances_batch_4(
732
+ buffered_ids[0],
733
+ buffered_ids[1],
734
+ buffered_ids[2],
735
+ buffered_ids[3],
736
+ dis[0],
737
+ dis[1],
738
+ dis[2],
739
+ dis[3]);
740
+
741
+ for (size_t id4 = 0; id4 < 4; id4++) {
742
+ update_with_candidate(buffered_ids[id4], dis[id4]);
743
+ }
744
+
745
+ n_buffered = 0;
746
+ }
747
+ }
748
+
749
+ // process leftovers
750
+ for (int icnt = 0; icnt < n_buffered; icnt++) {
751
+ float dis = qdis(buffered_ids[icnt]);
752
+ update_with_candidate(buffered_ids[icnt], dis);
753
+ }
754
+
755
+ // update stats
756
+ stats.ndis += ndis;
757
+ stats.nhops += 1;
758
+
759
+ if (nearest == prev_nearest) {
760
+ return stats;
761
+ }
762
+ }
763
+ }
764
+
765
+ } // namespace
766
+
767
+ /// greedily update a nearest vector at a given level
768
+ HNSWStats hnsw_detail::greedy_update_nearest(
769
+ const HNSW& hnsw,
770
+ DistanceComputer& qdis,
771
+ int level,
772
+ storage_idx_t& nearest,
773
+ float& d_nearest) {
774
+ if (hnsw.is_similarity) {
775
+ return greedy_update_nearest_impl<HNSW::C_similarity>(
776
+ hnsw, qdis, level, nearest, d_nearest);
777
+ }
778
+ return greedy_update_nearest_impl<HNSW::C_distance>(
779
+ hnsw, qdis, level, nearest, d_nearest);
780
+ }
781
+
782
+ namespace {
783
+
784
+ template <class C>
785
+ void add_with_locks_impl(
786
+ HNSW& hnsw,
538
787
  DistanceComputer& ptdis,
539
788
  int pt_level,
540
789
  int pt_id,
541
790
  LockVector& locks,
542
791
  VisitedTable& vt,
543
792
  bool keep_max_size_level0) {
544
- storage_idx_t nearest = entry_point;
793
+ storage_idx_t nearest = hnsw.entry_point;
545
794
  if (nearest == -1) { // avoid locking after the first point.
546
795
  #pragma omp critical
547
- if (entry_point == -1) { // double-check under lock.
548
- max_level = pt_level;
549
- entry_point = pt_id;
796
+ if (hnsw.entry_point == -1) { // double-check under lock.
797
+ hnsw.max_level = pt_level;
798
+ hnsw.entry_point = pt_id;
550
799
  // leave nearest = -1 to trigger early exit after critical block.
551
800
  } else {
552
801
  // else: Another thread set the entry point.
553
- nearest = entry_point;
802
+ nearest = hnsw.entry_point;
554
803
  }
555
804
  }
556
805
 
@@ -560,16 +809,17 @@ void HNSW::add_with_locks(
560
809
 
561
810
  locks.lock(pt_id);
562
811
 
563
- int level = max_level; // level at which we start adding neighbors
812
+ int level = hnsw.max_level; // level at which we start adding neighbors
564
813
  float d_nearest = ptdis(nearest);
565
814
 
566
815
  // greedy search on upper levels
567
816
  for (; level > pt_level; level--) {
568
- greedy_update_nearest(*this, ptdis, level, nearest, d_nearest);
817
+ greedy_update_nearest_impl<C>(hnsw, ptdis, level, nearest, d_nearest);
569
818
  }
570
819
 
571
820
  for (; level >= 0; level--) {
572
- add_links_starting_from(
821
+ add_links_starting_from_impl<C>(
822
+ hnsw,
573
823
  ptdis,
574
824
  pt_id,
575
825
  nearest,
@@ -584,22 +834,39 @@ void HNSW::add_with_locks(
584
834
 
585
835
  #pragma omp critical
586
836
  {
587
- if (pt_level > max_level) {
588
- max_level = pt_level;
589
- entry_point = pt_id;
837
+ if (pt_level > hnsw.max_level) {
838
+ hnsw.max_level = pt_level;
839
+ hnsw.entry_point = pt_id;
590
840
  }
591
841
  }
592
842
  }
593
843
 
844
+ } // namespace
845
+
846
+ void HNSW::add_with_locks(
847
+ DistanceComputer& ptdis,
848
+ int pt_level,
849
+ int pt_id,
850
+ LockVector& locks,
851
+ VisitedTable& vt,
852
+ bool keep_max_size_level0) {
853
+ if (is_similarity) {
854
+ add_with_locks_impl<C_similarity>(
855
+ *this, ptdis, pt_level, pt_id, locks, vt, keep_max_size_level0);
856
+ } else {
857
+ add_with_locks_impl<C_distance>(
858
+ *this, ptdis, pt_level, pt_id, locks, vt, keep_max_size_level0);
859
+ }
860
+ }
861
+
594
862
  /**************************************************************
595
863
  * Searching
596
864
  **************************************************************/
597
865
 
598
- using Node = HNSW::Node;
599
- using C = HNSW::C;
866
+ namespace {
600
867
 
601
868
  /** Helper to extract search parameters from HNSW and SearchParameters */
602
- static inline void extract_search_params(
869
+ inline void extract_search_params(
603
870
  const HNSW& hnsw,
604
871
  const SearchParameters* params,
605
872
  bool& do_dis_check,
@@ -619,13 +886,16 @@ static inline void extract_search_params(
619
886
  }
620
887
  }
621
888
 
622
- /** Do a BFS on the candidates list */
623
- int search_from_candidates(
889
+ /** Templated body of `search_from_candidates` instantiated once per
890
+ * VisitedTable subclass × comparator.
891
+ */
892
+ template <typename VTType, class C>
893
+ int search_from_candidates_fixVT(
624
894
  const HNSW& hnsw,
625
895
  DistanceComputer& qdis,
626
896
  ResultHandler& res,
627
- MinimaxHeap& candidates,
628
- VisitedTable& vt,
897
+ MinimaxHeapT<HC_for<C>>& candidates,
898
+ VTType& vt,
629
899
  HNSWStats& stats,
630
900
  int level,
631
901
  int nres_in,
@@ -638,13 +908,15 @@ int search_from_candidates(
638
908
  const IDSelector* sel;
639
909
  extract_search_params(hnsw, params, do_dis_check, efSearch, sel);
640
910
 
641
- C::T threshold = res.threshold;
911
+ vt.reserve(efSearch);
912
+
913
+ typename C::T threshold = res.threshold;
642
914
  for (int i = 0; i < candidates.size(); i++) {
643
915
  idx_t v1 = candidates.ids[i];
644
916
  float d = candidates.dis[i];
645
917
  FAISS_ASSERT(v1 >= 0);
646
918
  if (!sel || sel->is_member(v1)) {
647
- if (d < threshold) {
919
+ if (C::cmp(threshold, d)) {
648
920
  if (res.add_result(d, v1)) {
649
921
  threshold = res.threshold;
650
922
  }
@@ -693,7 +965,7 @@ int search_from_candidates(
693
965
 
694
966
  auto add_to_heap = [&](const size_t idx, const float dis) {
695
967
  if (!sel || sel->is_member(idx)) {
696
- if (dis < threshold) {
968
+ if (C::cmp(threshold, dis)) {
697
969
  if (res.add_result(dis, idx)) {
698
970
  threshold = res.threshold;
699
971
  nres += 1;
@@ -756,7 +1028,58 @@ int search_from_candidates(
756
1028
  return nres;
757
1029
  }
758
1030
 
759
- int search_from_candidates_panorama(
1031
+ /// Dispatches the VisitedTable concrete type for a given C, then calls
1032
+ /// the templated `search_from_candidates_fixVT<VTType, C>`.
1033
+ template <class C>
1034
+ int search_from_candidates_dispatch(
1035
+ const HNSW& hnsw,
1036
+ DistanceComputer& qdis,
1037
+ ResultHandler& res,
1038
+ MinimaxHeapT<HC_for<C>>& candidates,
1039
+ VisitedTable& vt,
1040
+ HNSWStats& stats,
1041
+ int level,
1042
+ int nres_in,
1043
+ const SearchParameters* params) {
1044
+ auto call = [&]<typename VTType>(VTType& vt_concrete) -> int {
1045
+ return search_from_candidates_fixVT<VTType, C>(
1046
+ hnsw,
1047
+ qdis,
1048
+ res,
1049
+ candidates,
1050
+ vt_concrete,
1051
+ stats,
1052
+ level,
1053
+ nres_in,
1054
+ params);
1055
+ };
1056
+ if (VisitedTableVector* vtv = dynamic_cast<VisitedTableVector*>(&vt)) {
1057
+ return call(*vtv);
1058
+ }
1059
+ VisitedTableSet& vts = dynamic_cast<VisitedTableSet&>(vt);
1060
+ return call(vts);
1061
+ }
1062
+
1063
+ } // namespace
1064
+
1065
+ /** Do a BFS on the candidates list. Public dispatcher: only handles the
1066
+ * distance (CMax) flavor because its `MinimaxHeap` parameter is the
1067
+ * CMax instantiation. */
1068
+ int hnsw_detail::search_from_candidates(
1069
+ const HNSW& hnsw,
1070
+ DistanceComputer& qdis,
1071
+ ResultHandler& res,
1072
+ MinimaxHeap& candidates,
1073
+ VisitedTable& vt,
1074
+ HNSWStats& stats,
1075
+ int level,
1076
+ int nres_in,
1077
+ const SearchParameters* params) {
1078
+ return search_from_candidates_dispatch<HNSW::C_distance>(
1079
+ hnsw, qdis, res, candidates, vt, stats, level, nres_in, params);
1080
+ }
1081
+
1082
+ int hnsw_detail::search_from_candidates_panorama(
760
1083
  const HNSW& hnsw,
761
1084
  const IndexHNSW* index,
762
1085
  DistanceComputer& qdis,
@@ -767,6 +1090,14 @@ int search_from_candidates_panorama(
767
1090
  int level,
768
1091
  int nres_in,
769
1092
  const SearchParameters* params) {
1093
+ // Panorama's progressive-bound math is L2-specific: refuse to run in
1094
+ // similarity mode.
1095
+ FAISS_THROW_IF_NOT_MSG(
1096
+ !hnsw.is_similarity,
1097
+ "search_from_candidates_panorama does not support is_similarity=true");
1098
+
1099
+ using C = HNSW::C_distance;
1100
+
770
1101
  int nres = nres_in;
771
1102
  int ndis = 0;
772
1103
 
@@ -781,7 +1112,7 @@ int search_from_candidates_panorama(
781
1112
  float d = candidates.dis[i];
782
1113
  FAISS_ASSERT(v1 >= 0);
783
1114
  if (!sel || sel->is_member(v1)) {
784
- if (d < threshold) {
1115
+ if (C::cmp(threshold, d)) {
785
1116
  if (res.add_result(d, v1)) {
786
1117
  threshold = res.threshold;
787
1118
  }
@@ -917,28 +1248,28 @@ int search_from_candidates_panorama(
917
1248
  // the maintenance of the candidate heap), but micro-benchmarks
918
1249
  // have shown that it is not worth it to write horrible code to
919
1250
  // squeeze out those cycles.
920
- if (lower_bound_0 <= threshold) {
1251
+ if (!C::cmp(lower_bound_0, threshold)) {
921
1252
  exact_distances[next_batch_size] = new_exact_0;
922
1253
  index_array[next_batch_size] = idx_0;
923
1254
  next_batch_size += 1;
924
1255
  } else {
925
1256
  candidates.push(idx_0, new_exact_0);
926
1257
  }
927
- if (lower_bound_1 <= threshold) {
1258
+ if (!C::cmp(lower_bound_1, threshold)) {
928
1259
  exact_distances[next_batch_size] = new_exact_1;
929
1260
  index_array[next_batch_size] = idx_1;
930
1261
  next_batch_size += 1;
931
1262
  } else {
932
1263
  candidates.push(idx_1, new_exact_1);
933
1264
  }
934
- if (lower_bound_2 <= threshold) {
1265
+ if (!C::cmp(lower_bound_2, threshold)) {
935
1266
  exact_distances[next_batch_size] = new_exact_2;
936
1267
  index_array[next_batch_size] = idx_2;
937
1268
  next_batch_size += 1;
938
1269
  } else {
939
1270
  candidates.push(idx_2, new_exact_2);
940
1271
  }
941
- if (lower_bound_3 <= threshold) {
1272
+ if (!C::cmp(lower_bound_3, threshold)) {
942
1273
  exact_distances[next_batch_size] = new_exact_3;
943
1274
  index_array[next_batch_size] = idx_3;
944
1275
  next_batch_size += 1;
@@ -961,7 +1292,7 @@ int search_from_candidates_panorama(
961
1292
  float cs_bound = 2.0f * cum_sum * query_cum_norm;
962
1293
  float lower_bound = new_exact - cs_bound;
963
1294
 
964
- if (lower_bound <= threshold) {
1295
+ if (!C::cmp(lower_bound, threshold)) {
965
1296
  exact_distances[next_batch_size] = new_exact;
966
1297
  index_array[next_batch_size] = idx;
967
1298
  next_batch_size += 1;
@@ -1005,6 +1336,8 @@ int search_from_candidates_panorama(
1005
1336
  return nres;
1006
1337
  }
1007
1338
 
1339
+ namespace {
1340
+
1008
1341
  template <typename T, typename Container, typename Compare>
1009
1342
  void reservePriorityQueue(
1010
1343
  std::priority_queue<T, Container, Compare>& q,
@@ -1017,31 +1350,35 @@ void reservePriorityQueue(
1017
1350
  q = std::move(access);
1018
1351
  }
1019
1352
 
1020
- std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
1353
+ /// Templated body of `search_from_candidate_unbounded`. The choice of
1354
+ /// max-heap vs min-heap for both `top_candidates` and `candidates` is
1355
+ /// derived from C via `TopCandidatesQueue` / `CandidatesQueue`.
1356
+ template <typename VTType, class C>
1357
+ TopCandidatesQueue<C> search_from_candidate_unbounded_fixVT(
1021
1358
  const HNSW& hnsw,
1022
- const Node& node,
1359
+ const HNSW::Node& node,
1023
1360
  DistanceComputer& qdis,
1024
1361
  int ef,
1025
- VisitedTable* vt,
1362
+ VTType& vt,
1026
1363
  HNSWStats& stats) {
1027
1364
  int ndis = 0;
1028
- std::priority_queue<Node> top_candidates;
1365
+ TopCandidatesQueue<C> top_candidates;
1029
1366
  reservePriorityQueue(top_candidates, ef);
1030
1367
 
1031
- std::priority_queue<Node, std::vector<Node>, std::greater<Node>> candidates;
1368
+ CandidatesQueue<C> candidates;
1032
1369
  reservePriorityQueue(candidates, ef);
1033
1370
 
1034
1371
  top_candidates.push(node);
1035
1372
  candidates.push(node);
1036
1373
 
1037
- vt->set(node.second);
1374
+ vt.set(node.second);
1038
1375
 
1039
1376
  while (!candidates.empty()) {
1040
1377
  float d0;
1041
1378
  storage_idx_t v0;
1042
1379
  std::tie(d0, v0) = candidates.top();
1043
1380
 
1044
- if (d0 > top_candidates.top().first) {
1381
+ if (C::cmp(d0, top_candidates.top().first)) {
1045
1382
  break;
1046
1383
  }
1047
1384
 
@@ -1059,7 +1396,7 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
1059
1396
  break;
1060
1397
  }
1061
1398
 
1062
- vt->prefetch(v1);
1399
+ vt.prefetch(v1);
1063
1400
  jmax += 1;
1064
1401
  }
1065
1402
 
@@ -1067,12 +1404,12 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
1067
1404
  size_t saved_j[4];
1068
1405
 
1069
1406
  auto add_to_heap = [&](const size_t idx, const float dis) {
1070
- if (top_candidates.top().first > dis ||
1071
- top_candidates.size() < ef) {
1407
+ if (C::cmp(top_candidates.top().first, dis) ||
1408
+ top_candidates.size() < static_cast<size_t>(ef)) {
1072
1409
  candidates.emplace(dis, idx);
1073
1410
  top_candidates.emplace(dis, idx);
1074
1411
 
1075
- if (top_candidates.size() > ef) {
1412
+ if (top_candidates.size() > static_cast<size_t>(ef)) {
1076
1413
  top_candidates.pop();
1077
1414
  }
1078
1415
  }
@@ -1082,7 +1419,7 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
1082
1419
  int v1 = hnsw.neighbors[j];
1083
1420
 
1084
1421
  saved_j[counter] = v1;
1085
- counter += vt->set(v1) ? 1 : 0;
1422
+ counter += vt.set(v1) ? 1 : 0;
1086
1423
 
1087
1424
  if (counter == 4) {
1088
1425
  float dis[4];
@@ -1125,111 +1462,59 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
1125
1462
  return top_candidates;
1126
1463
  }
1127
1464
 
1128
- /// greedily update a nearest vector at a given level
1129
- HNSWStats greedy_update_nearest(
1465
+ } // namespace
1466
+
1467
+ /// Public dispatcher: only the distance (CMax) flavor is exposed because
1468
+ /// its return type — `std::priority_queue<HNSW::Node>` — is the CMax
1469
+ /// max-heap. Internal callers that need similarity mode use the same
1470
+ /// dispatch pattern inline.
1471
+ std::priority_queue<HNSW::Node> hnsw_detail::search_from_candidate_unbounded(
1130
1472
  const HNSW& hnsw,
1473
+ const HNSW::Node& node,
1131
1474
  DistanceComputer& qdis,
1132
- int level,
1133
- storage_idx_t& nearest,
1134
- float& d_nearest) {
1135
- HNSWStats stats;
1136
-
1137
- for (;;) {
1138
- storage_idx_t prev_nearest = nearest;
1139
-
1140
- size_t begin, end;
1141
- hnsw.neighbor_range(nearest, level, &begin, &end);
1142
-
1143
- size_t ndis = 0;
1144
-
1145
- // a faster version: reference version in unit test test_hnsw.cpp
1146
- // the following version processes 4 neighbors at a time
1147
- auto update_with_candidate = [&](const storage_idx_t idx,
1148
- const float dis) {
1149
- if (dis < d_nearest) {
1150
- nearest = idx;
1151
- d_nearest = dis;
1152
- }
1153
- };
1154
-
1155
- int n_buffered = 0;
1156
- storage_idx_t buffered_ids[4];
1157
-
1158
- for (size_t j = begin; j < end; j++) {
1159
- storage_idx_t v = hnsw.neighbors[j];
1160
- if (v < 0) {
1161
- break;
1162
- }
1163
- ndis += 1;
1164
-
1165
- buffered_ids[n_buffered] = v;
1166
- n_buffered += 1;
1167
-
1168
- if (n_buffered == 4) {
1169
- float dis[4];
1170
- qdis.distances_batch_4(
1171
- buffered_ids[0],
1172
- buffered_ids[1],
1173
- buffered_ids[2],
1174
- buffered_ids[3],
1175
- dis[0],
1176
- dis[1],
1177
- dis[2],
1178
- dis[3]);
1179
-
1180
- for (size_t id4 = 0; id4 < 4; id4++) {
1181
- update_with_candidate(buffered_ids[id4], dis[id4]);
1182
- }
1183
-
1184
- n_buffered = 0;
1185
- }
1186
- }
1187
-
1188
- // process leftovers
1189
- for (int icnt = 0; icnt < n_buffered; icnt++) {
1190
- float dis = qdis(buffered_ids[icnt]);
1191
- update_with_candidate(buffered_ids[icnt], dis);
1192
- }
1193
-
1194
- // update stats
1195
- stats.ndis += ndis;
1196
- stats.nhops += 1;
1197
-
1198
- if (nearest == prev_nearest) {
1199
- return stats;
1200
- }
1475
+ int ef,
1476
+ VisitedTable* vt,
1477
+ HNSWStats& stats) {
1478
+ using C = HNSW::C_distance;
1479
+ auto call = [&]<typename VTType>(VTType& vt_concrete) {
1480
+ return search_from_candidate_unbounded_fixVT<VTType, C>(
1481
+ hnsw, node, qdis, ef, vt_concrete, stats);
1482
+ };
1483
+ if (VisitedTableVector* vtv = dynamic_cast<VisitedTableVector*>(vt)) {
1484
+ return call(*vtv);
1201
1485
  }
1486
+ VisitedTableSet& vts = dynamic_cast<VisitedTableSet&>(*vt);
1487
+ return call(vts);
1202
1488
  }
1203
1489
 
1204
1490
  namespace {
1205
- using Node = HNSW::Node;
1206
- using C = HNSW::C;
1207
1491
 
1208
1492
  // just used as a lower bound for the minmaxheap, but it is set for heap search
1493
+ template <class C>
1209
1494
  int extract_k_from_ResultHandler(ResultHandler& res) {
1210
1495
  using RH = HeapBlockResultHandler<C>;
1211
- if (auto hres = dynamic_cast<RH::SingleResultHandler*>(&res)) {
1496
+ if (auto hres = dynamic_cast<typename RH::SingleResultHandler*>(&res)) {
1212
1497
  return hres->k;
1213
1498
  }
1214
1499
  return 1;
1215
1500
  }
1216
1501
 
1217
- } // namespace
1218
-
1219
- HNSWStats HNSW::search(
1502
+ template <class C>
1503
+ HNSWStats search_impl(
1504
+ const HNSW& hnsw,
1220
1505
  DistanceComputer& qdis,
1221
1506
  const IndexHNSW* index,
1222
1507
  ResultHandler& res,
1223
1508
  VisitedTable& vt,
1224
- const SearchParameters* params) const {
1509
+ const SearchParameters* params) {
1225
1510
  HNSWStats stats;
1226
- if (entry_point == -1) {
1511
+ if (hnsw.entry_point == -1) {
1227
1512
  return stats;
1228
1513
  }
1229
- int k = extract_k_from_ResultHandler(res);
1514
+ int k = extract_k_from_ResultHandler<C>(res);
1230
1515
 
1231
- bool bounded_queue = this->search_bounded_queue;
1232
- int cur_efSearch = this->efSearch;
1516
+ bool bounded_queue = hnsw.search_bounded_queue;
1517
+ int cur_efSearch = hnsw.efSearch;
1233
1518
  if (params) {
1234
1519
  if (const SearchParametersHNSW* hnsw_params =
1235
1520
  dynamic_cast<const SearchParametersHNSW*>(params)) {
@@ -1239,42 +1524,63 @@ HNSWStats HNSW::search(
1239
1524
  }
1240
1525
 
1241
1526
  // greedy search on upper levels
1242
- storage_idx_t nearest = entry_point;
1527
+ storage_idx_t nearest = hnsw.entry_point;
1243
1528
  float d_nearest = qdis(nearest);
1244
1529
 
1245
- for (int level = max_level; level >= 1; level--) {
1246
- HNSWStats local_stats =
1247
- greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
1530
+ for (int level = hnsw.max_level; level >= 1; level--) {
1531
+ HNSWStats local_stats = greedy_update_nearest_impl<C>(
1532
+ hnsw, qdis, level, nearest, d_nearest);
1248
1533
  stats.combine(local_stats);
1249
1534
  }
1250
1535
 
1251
1536
  int ef = std::max(cur_efSearch, k);
1252
1537
  if (bounded_queue) { // this is the most common branch, for now we only
1253
1538
  // support Panorama search in this branch
1254
- MinimaxHeap candidates(ef);
1539
+ MinimaxHeapT<HC_for<C>> candidates(ef);
1255
1540
 
1256
1541
  candidates.push(nearest, d_nearest);
1257
1542
 
1258
- if (!is_panorama) {
1259
- search_from_candidates(
1260
- *this, qdis, res, candidates, vt, stats, 0, 0, params);
1543
+ if (!hnsw.is_panorama) {
1544
+ search_from_candidates_dispatch<C>(
1545
+ hnsw, qdis, res, candidates, vt, stats, 0, 0, params);
1261
1546
  } else {
1262
- search_from_candidates_panorama(
1263
- *this,
1264
- index,
1265
- qdis,
1266
- res,
1267
- candidates,
1268
- vt,
1269
- stats,
1270
- 0,
1271
- 0,
1272
- params);
1547
+ // Panorama is L2-specific and is only valid for C_distance.
1548
+ // The public dispatch ensures we never reach this code path
1549
+ // with C != C_distance, but assert in debug builds.
1550
+ if constexpr (std::is_same_v<C, HNSW::C_distance>) {
1551
+ hnsw_detail::search_from_candidates_panorama(
1552
+ hnsw,
1553
+ index,
1554
+ qdis,
1555
+ res,
1556
+ candidates,
1557
+ vt,
1558
+ stats,
1559
+ 0,
1560
+ 0,
1561
+ params);
1562
+ } else {
1563
+ FAISS_THROW_MSG(
1564
+ "Panorama search is not supported with is_similarity=true");
1565
+ }
1273
1566
  }
1274
1567
  } else {
1275
- std::priority_queue<Node> top_candidates =
1276
- search_from_candidate_unbounded(
1277
- *this, Node(d_nearest, nearest), qdis, ef, &vt, stats);
1568
+ auto call = [&]<typename VTType>(VTType& vt_concrete) {
1569
+ return search_from_candidate_unbounded_fixVT<VTType, C>(
1570
+ hnsw,
1571
+ HNSW::Node(d_nearest, nearest),
1572
+ qdis,
1573
+ ef,
1574
+ vt_concrete,
1575
+ stats);
1576
+ };
1577
+ TopCandidatesQueue<C> top_candidates;
1578
+ if (VisitedTableVector* vtv = dynamic_cast<VisitedTableVector*>(&vt)) {
1579
+ top_candidates = call(*vtv);
1580
+ } else {
1581
+ VisitedTableSet& vts = dynamic_cast<VisitedTableSet&>(vt);
1582
+ top_candidates = call(vts);
1583
+ }
1278
1584
 
1279
1585
  while (top_candidates.size() > static_cast<size_t>(k)) {
1280
1586
  top_candidates.pop();
@@ -1294,7 +1600,9 @@ HNSWStats HNSW::search(
1294
1600
  return stats;
1295
1601
  }
1296
1602
 
1297
- void HNSW::search_level_0(
1603
+ template <class C>
1604
+ void search_level_0_impl(
1605
+ const HNSW& hnsw,
1298
1606
  DistanceComputer& qdis,
1299
1607
  ResultHandler& res,
1300
1608
  idx_t nprobe,
@@ -1303,9 +1611,7 @@ void HNSW::search_level_0(
1303
1611
  int search_type,
1304
1612
  HNSWStats& search_stats,
1305
1613
  VisitedTable& vt,
1306
- const SearchParameters* params) const {
1307
- const HNSW& hnsw = *this;
1308
-
1614
+ const SearchParameters* params) {
1309
1615
  auto cur_efSearch = hnsw.efSearch;
1310
1616
  if (params) {
1311
1617
  if (const SearchParametersHNSW* hnsw_params =
@@ -1314,7 +1620,7 @@ void HNSW::search_level_0(
1314
1620
  }
1315
1621
  }
1316
1622
 
1317
- int k = extract_k_from_ResultHandler(res);
1623
+ int k = extract_k_from_ResultHandler<C>(res);
1318
1624
 
1319
1625
  if (search_type == 1) {
1320
1626
  int nres = 0;
@@ -1331,11 +1637,11 @@ void HNSW::search_level_0(
1331
1637
  }
1332
1638
 
1333
1639
  int candidates_size = std::max(cur_efSearch, k);
1334
- MinimaxHeap candidates(candidates_size);
1640
+ MinimaxHeapT<HC_for<C>> candidates(candidates_size);
1335
1641
 
1336
1642
  candidates.push(cj, nearest_d[j]);
1337
1643
 
1338
- nres = search_from_candidates(
1644
+ nres = search_from_candidates_dispatch<C>(
1339
1645
  hnsw,
1340
1646
  qdis,
1341
1647
  res,
@@ -1351,7 +1657,7 @@ void HNSW::search_level_0(
1351
1657
  int candidates_size = std::max(cur_efSearch, int(k));
1352
1658
  candidates_size = std::max(candidates_size, int(nprobe));
1353
1659
 
1354
- MinimaxHeap candidates(candidates_size);
1660
+ MinimaxHeapT<HC_for<C>> candidates(candidates_size);
1355
1661
  for (idx_t j = 0; j < nprobe; j++) {
1356
1662
  storage_idx_t cj = nearest_i[j];
1357
1663
 
@@ -1361,11 +1667,62 @@ void HNSW::search_level_0(
1361
1667
  candidates.push(cj, nearest_d[j]);
1362
1668
  }
1363
1669
 
1364
- search_from_candidates(
1670
+ search_from_candidates_dispatch<C>(
1365
1671
  hnsw, qdis, res, candidates, vt, search_stats, 0, 0, params);
1366
1672
  }
1367
1673
  }
1368
1674
 
1675
+ } // namespace
1676
+
1677
+ HNSWStats HNSW::search(
1678
+ DistanceComputer& qdis,
1679
+ const IndexHNSW* index,
1680
+ ResultHandler& res,
1681
+ VisitedTable& vt,
1682
+ const SearchParameters* params) const {
1683
+ if (is_similarity) {
1684
+ return search_impl<C_similarity>(*this, qdis, index, res, vt, params);
1685
+ }
1686
+ return search_impl<C_distance>(*this, qdis, index, res, vt, params);
1687
+ }
1688
+
1689
+ void HNSW::search_level_0(
1690
+ DistanceComputer& qdis,
1691
+ ResultHandler& res,
1692
+ idx_t nprobe,
1693
+ const storage_idx_t* nearest_i,
1694
+ const float* nearest_d,
1695
+ int search_type,
1696
+ HNSWStats& search_stats,
1697
+ VisitedTable& vt,
1698
+ const SearchParameters* params) const {
1699
+ if (is_similarity) {
1700
+ search_level_0_impl<C_similarity>(
1701
+ *this,
1702
+ qdis,
1703
+ res,
1704
+ nprobe,
1705
+ nearest_i,
1706
+ nearest_d,
1707
+ search_type,
1708
+ search_stats,
1709
+ vt,
1710
+ params);
1711
+ } else {
1712
+ search_level_0_impl<C_distance>(
1713
+ *this,
1714
+ qdis,
1715
+ res,
1716
+ nprobe,
1717
+ nearest_i,
1718
+ nearest_d,
1719
+ search_type,
1720
+ search_stats,
1721
+ vt,
1722
+ params);
1723
+ }
1724
+ }
1725
+
1369
1726
  void HNSW::permute_entries(const idx_t* map) {
1370
1727
  // remap levels
1371
1728
  storage_idx_t ntotal = levels.size();