faiss 0.5.2 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +5 -6
  6. data/ext/faiss/index_binary.cpp +76 -17
  7. data/ext/faiss/{index.cpp → index_rb.cpp} +108 -35
  8. data/ext/faiss/kmeans.cpp +12 -9
  9. data/ext/faiss/numo.hpp +11 -9
  10. data/ext/faiss/pca_matrix.cpp +10 -8
  11. data/ext/faiss/product_quantizer.cpp +14 -12
  12. data/ext/faiss/{utils.cpp → utils_rb.cpp} +10 -3
  13. data/ext/faiss/{utils.h → utils_rb.h} +6 -0
  14. data/lib/faiss/version.rb +1 -1
  15. data/lib/faiss.rb +1 -1
  16. data/vendor/faiss/faiss/AutoTune.cpp +130 -11
  17. data/vendor/faiss/faiss/AutoTune.h +14 -1
  18. data/vendor/faiss/faiss/Clustering.cpp +59 -10
  19. data/vendor/faiss/faiss/Clustering.h +12 -0
  20. data/vendor/faiss/faiss/IVFlib.cpp +31 -28
  21. data/vendor/faiss/faiss/Index.cpp +20 -8
  22. data/vendor/faiss/faiss/Index.h +25 -3
  23. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +19 -24
  24. data/vendor/faiss/faiss/IndexBinary.cpp +1 -0
  25. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +9 -4
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +45 -11
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +35 -22
  28. data/vendor/faiss/faiss/IndexFastScan.h +10 -1
  29. data/vendor/faiss/faiss/IndexFlat.cpp +193 -136
  30. data/vendor/faiss/faiss/IndexFlat.h +16 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +46 -22
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +24 -50
  34. data/vendor/faiss/faiss/IndexHNSW.h +14 -12
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +1 -1
  36. data/vendor/faiss/faiss/IndexIVF.cpp +76 -49
  37. data/vendor/faiss/faiss/IndexIVF.h +14 -4
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +11 -8
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -2
  40. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +25 -14
  41. data/vendor/faiss/faiss/IndexIVFFastScan.h +26 -22
  42. data/vendor/faiss/faiss/IndexIVFFlat.cpp +10 -61
  43. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +39 -111
  44. data/vendor/faiss/faiss/IndexIVFPQ.cpp +89 -147
  45. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +37 -5
  46. data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -1
  47. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +42 -30
  48. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -2
  49. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +246 -97
  50. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +32 -29
  51. data/vendor/faiss/faiss/IndexLSH.cpp +8 -6
  52. data/vendor/faiss/faiss/IndexLattice.cpp +29 -24
  53. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -0
  54. data/vendor/faiss/faiss/IndexNSG.cpp +2 -1
  55. data/vendor/faiss/faiss/IndexNSG.h +0 -2
  56. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +1 -1
  57. data/vendor/faiss/faiss/IndexPQ.cpp +19 -10
  58. data/vendor/faiss/faiss/IndexRaBitQ.cpp +26 -13
  59. data/vendor/faiss/faiss/IndexRaBitQ.h +2 -2
  60. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +132 -78
  61. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +14 -12
  62. data/vendor/faiss/faiss/IndexRefine.cpp +0 -30
  63. data/vendor/faiss/faiss/IndexShards.cpp +3 -4
  64. data/vendor/faiss/faiss/MetricType.h +16 -0
  65. data/vendor/faiss/faiss/VectorTransform.cpp +120 -0
  66. data/vendor/faiss/faiss/VectorTransform.h +23 -0
  67. data/vendor/faiss/faiss/clone_index.cpp +7 -4
  68. data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +1 -1
  69. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  70. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +37 -11
  71. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -28
  72. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
  73. data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
  74. data/vendor/faiss/faiss/impl/CodePacker.cpp +4 -0
  75. data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
  76. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
  77. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
  78. data/vendor/faiss/faiss/impl/FaissAssert.h +60 -2
  79. data/vendor/faiss/faiss/impl/HNSW.cpp +25 -34
  80. data/vendor/faiss/faiss/impl/HNSW.h +8 -6
  81. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +34 -27
  82. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  83. data/vendor/faiss/faiss/impl/NSG.cpp +6 -5
  84. data/vendor/faiss/faiss/impl/NSG.h +17 -7
  85. data/vendor/faiss/faiss/impl/Panorama.cpp +53 -46
  86. data/vendor/faiss/faiss/impl/Panorama.h +22 -6
  87. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +16 -5
  88. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +70 -58
  89. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +92 -0
  90. data/vendor/faiss/faiss/impl/RaBitQUtils.h +93 -31
  91. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +12 -28
  92. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
  93. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
  94. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
  95. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +14 -9
  96. data/vendor/faiss/faiss/impl/ResultHandler.h +131 -50
  97. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +67 -2358
  98. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -2
  99. data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
  100. data/vendor/faiss/faiss/impl/VisitedTable.h +69 -0
  101. data/vendor/faiss/faiss/impl/expanded_scanners.h +158 -0
  102. data/vendor/faiss/faiss/impl/index_read.cpp +829 -471
  103. data/vendor/faiss/faiss/impl/index_read_utils.h +0 -1
  104. data/vendor/faiss/faiss/impl/index_write.cpp +17 -8
  105. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +47 -20
  106. data/vendor/faiss/faiss/impl/mapped_io.cpp +9 -2
  107. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +7 -2
  108. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +11 -3
  109. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +19 -13
  110. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +29 -21
  111. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.cpp} +42 -215
  112. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.cpp} +68 -107
  113. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +141 -0
  114. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +23 -0
  115. data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -144
  116. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +9 -6
  117. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
  118. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +136 -0
  119. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +280 -0
  120. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +164 -0
  121. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
  122. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +455 -0
  123. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +430 -0
  124. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +329 -0
  125. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +467 -0
  126. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +203 -0
  127. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +42 -0
  128. data/vendor/faiss/faiss/impl/simd_dispatch.h +139 -0
  129. data/vendor/faiss/faiss/impl/simd_result_handlers.h +18 -18
  130. data/vendor/faiss/faiss/index_factory.cpp +35 -16
  131. data/vendor/faiss/faiss/index_io.h +29 -3
  132. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +7 -4
  133. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
  134. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
  135. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
  136. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +2 -1
  137. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +9 -1
  138. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +9 -0
  139. data/vendor/faiss/faiss/utils/Heap.cpp +46 -0
  140. data/vendor/faiss/faiss/utils/Heap.h +21 -0
  141. data/vendor/faiss/faiss/utils/NeuralNet.cpp +10 -7
  142. data/vendor/faiss/faiss/utils/distances.cpp +141 -23
  143. data/vendor/faiss/faiss/utils/distances.h +98 -0
  144. data/vendor/faiss/faiss/utils/distances_dispatch.h +170 -0
  145. data/vendor/faiss/faiss/utils/distances_simd.cpp +74 -3511
  146. data/vendor/faiss/faiss/utils/extra_distances-inl.h +164 -157
  147. data/vendor/faiss/faiss/utils/extra_distances.cpp +52 -95
  148. data/vendor/faiss/faiss/utils/extra_distances.h +47 -1
  149. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -1
  150. data/vendor/faiss/faiss/utils/partitioning.cpp +1 -1
  151. data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
  152. data/vendor/faiss/faiss/utils/rabitq_simd.h +260 -0
  153. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +150 -0
  154. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +568 -0
  155. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +153 -0
  156. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1185 -0
  157. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1092 -0
  158. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +391 -0
  159. data/vendor/faiss/faiss/utils/simd_levels.cpp +322 -0
  160. data/vendor/faiss/faiss/utils/simd_levels.h +91 -0
  161. data/vendor/faiss/faiss/utils/simdlib_avx2.h +12 -1
  162. data/vendor/faiss/faiss/utils/simdlib_avx512.h +69 -0
  163. data/vendor/faiss/faiss/utils/simdlib_neon.h +6 -0
  164. data/vendor/faiss/faiss/utils/sorting.cpp +4 -4
  165. data/vendor/faiss/faiss/utils/utils.cpp +16 -9
  166. metadata +47 -18
  167. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
  168. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
  169. /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
@@ -27,6 +27,7 @@
27
27
  #include <faiss/impl/AuxIndexStructures.h>
28
28
  #include <faiss/impl/FaissAssert.h>
29
29
  #include <faiss/impl/ResultHandler.h>
30
+ #include <faiss/impl/VisitedTable.h>
30
31
  #include <faiss/utils/random.h>
31
32
  #include <faiss/utils/sorting.h>
32
33
 
@@ -142,7 +143,7 @@ void hnsw_add_vertices(
142
143
 
143
144
  #pragma omp parallel if (i1 > i0 + 100)
144
145
  {
145
- VisitedTable vt(ntotal);
146
+ VisitedTable vt(ntotal, hnsw.use_visited_hashset);
146
147
 
147
148
  std::unique_ptr<DistanceComputer> dis(
148
149
  storage_distance_computer(index_hnsw.storage));
@@ -265,7 +266,7 @@ void hnsw_search(
265
266
 
266
267
  #pragma omp parallel if (i1 - i0 > 1)
267
268
  {
268
- VisitedTable vt(index->ntotal);
269
+ VisitedTable vt(index->ntotal, hnsw.use_visited_hashset);
269
270
  typename BlockResultHandler::SingleResultHandler res(bres);
270
271
 
271
272
  std::unique_ptr<DistanceComputer> dis(
@@ -333,6 +334,14 @@ void IndexHNSW::range_search(
333
334
  }
334
335
  }
335
336
 
337
+ void IndexHNSW::search1(
338
+ const float* x,
339
+ ResultHandler& handler,
340
+ SearchParameters* params) const {
341
+ SingleQueryBlockResultHandler<HNSW::C, false> bres(handler);
342
+ hnsw_search(this, 1, x, bres, params);
343
+ }
344
+
336
345
  void IndexHNSW::add(idx_t n, const float* x) {
337
346
  FAISS_THROW_IF_NOT_MSG(
338
347
  storage,
@@ -428,7 +437,7 @@ void IndexHNSW::search_level_0(
428
437
  std::unique_ptr<DistanceComputer> qdis(
429
438
  storage_distance_computer(storage));
430
439
  HNSWStats search_stats;
431
- VisitedTable vt(ntotal);
440
+ VisitedTable vt(ntotal, hnsw.use_visited_hashset);
432
441
  RH::SingleResultHandler res(bres);
433
442
 
434
443
  #pragma omp for
@@ -516,7 +525,7 @@ void IndexHNSW::init_level_0_from_entry_points(
516
525
 
517
526
  #pragma omp parallel
518
527
  {
519
- VisitedTable vt(ntotal);
528
+ VisitedTable vt(ntotal, hnsw.use_visited_hashset);
520
529
 
521
530
  std::unique_ptr<DistanceComputer> dis(
522
531
  storage_distance_computer(storage));
@@ -616,7 +625,7 @@ void IndexHNSW::link_singletons() {
616
625
 
617
626
  std::vector<float> recons(singletons.size() * d);
618
627
  for (int i = 0; i < singletons.size(); i++) {
619
- FAISS_ASSERT(!"not implemented");
628
+ FAISS_ASSERT(false); // not implemented
620
629
  }
621
630
  }
622
631
 
@@ -653,33 +662,8 @@ IndexHNSWFlat::IndexHNSWFlat(int d, int M, MetricType metric)
653
662
  * IndexHNSWFlatPanorama implementation
654
663
  **************************************************************/
655
664
 
656
- void IndexHNSWFlatPanorama::compute_cum_sums(
657
- const float* x,
658
- float* dst_cum_sums,
659
- int d,
660
- int num_panorama_levels,
661
- int panorama_level_width) {
662
- // Iterate backwards through levels, accumulating sum as we go.
663
- // This avoids computing the suffix sum for each vector, which takes
664
- // extra memory.
665
-
666
- float sum = 0.0f;
667
- dst_cum_sums[num_panorama_levels] = 0.0f;
668
- for (int level = num_panorama_levels - 1; level >= 0; level--) {
669
- int start_idx = level * panorama_level_width;
670
- int end_idx = std::min(start_idx + panorama_level_width, d);
671
- for (int j = start_idx; j < end_idx; j++) {
672
- sum += x[j] * x[j];
673
- }
674
- dst_cum_sums[level] = std::sqrt(sum);
675
- }
676
- }
677
-
678
665
  IndexHNSWFlatPanorama::IndexHNSWFlatPanorama()
679
- : IndexHNSWFlat(),
680
- cum_sums(),
681
- panorama_level_width(0),
682
- num_panorama_levels(0) {}
666
+ : IndexHNSWFlat(), cum_sums(), pano(0, 1, 1), num_panorama_levels(0) {}
683
667
 
684
668
  IndexHNSWFlatPanorama::IndexHNSWFlatPanorama(
685
669
  int d,
@@ -688,8 +672,7 @@ IndexHNSWFlatPanorama::IndexHNSWFlatPanorama(
688
672
  MetricType metric)
689
673
  : IndexHNSWFlat(d, M, metric),
690
674
  cum_sums(),
691
- panorama_level_width(
692
- (d + num_panorama_levels - 1) / num_panorama_levels),
675
+ pano(d * sizeof(float), num_panorama_levels, 1),
693
676
  num_panorama_levels(num_panorama_levels) {
694
677
  // For now, we only support L2 distance.
695
678
  // Supporting dot product and cosine distance is a trivial addition
@@ -704,18 +687,8 @@ IndexHNSWFlatPanorama::IndexHNSWFlatPanorama(
704
687
 
705
688
  void IndexHNSWFlatPanorama::add(idx_t n, const float* x) {
706
689
  idx_t n0 = ntotal;
707
- cum_sums.resize((ntotal + n) * (num_panorama_levels + 1));
708
-
709
- for (size_t idx = 0; idx < n; idx++) {
710
- const float* vector = x + idx * d;
711
- compute_cum_sums(
712
- vector,
713
- &cum_sums[(n0 + idx) * (num_panorama_levels + 1)],
714
- d,
715
- num_panorama_levels,
716
- panorama_level_width);
717
- }
718
-
690
+ cum_sums.resize((ntotal + n) * (pano.n_levels + 1));
691
+ pano.compute_cumulative_sums(cum_sums.data(), n0, n, x);
719
692
  IndexHNSWFlat::add(n, x);
720
693
  }
721
694
 
@@ -725,13 +698,13 @@ void IndexHNSWFlatPanorama::reset() {
725
698
  }
726
699
 
727
700
  void IndexHNSWFlatPanorama::permute_entries(const idx_t* perm) {
728
- std::vector<float> new_cum_sums(ntotal * (num_panorama_levels + 1));
701
+ std::vector<float> new_cum_sums(ntotal * (pano.n_levels + 1));
729
702
 
730
703
  for (idx_t i = 0; i < ntotal; i++) {
731
704
  idx_t src = perm[i];
732
- memcpy(&new_cum_sums[i * (num_panorama_levels + 1)],
733
- &cum_sums[src * (num_panorama_levels + 1)],
734
- (num_panorama_levels + 1) * sizeof(float));
705
+ memcpy(&new_cum_sums[i * (pano.n_levels + 1)],
706
+ &cum_sums[src * (pano.n_levels + 1)],
707
+ (pano.n_levels + 1) * sizeof(float));
735
708
  }
736
709
 
737
710
  std::swap(cum_sums, new_cum_sums);
@@ -903,7 +876,8 @@ void IndexHNSW2Level::search(
903
876
 
904
877
  #pragma omp parallel
905
878
  {
906
- VisitedTable vt(ntotal);
879
+ // visited table (not hash set) for tri-state flags.
880
+ VisitedTable vt(ntotal, /*use_hashset=*/false);
907
881
  std::unique_ptr<DistanceComputer> dis(
908
882
  storage_distance_computer(storage));
909
883
 
@@ -9,13 +9,15 @@
9
9
 
10
10
  #pragma once
11
11
 
12
+ #include <optional>
12
13
  #include <vector>
13
- #include "faiss/Index.h"
14
14
 
15
+ #include <faiss/Index.h>
15
16
  #include <faiss/IndexFlat.h>
16
17
  #include <faiss/IndexPQ.h>
17
18
  #include <faiss/IndexScalarQuantizer.h>
18
19
  #include <faiss/impl/HNSW.h>
20
+ #include <faiss/impl/Panorama.h>
19
21
  #include <faiss/utils/utils.h>
20
22
 
21
23
  namespace faiss {
@@ -47,6 +49,9 @@ struct IndexHNSW : Index {
47
49
  // used when GpuIndexCagra::copyFrom(IndexHNSWCagra*) is invoked.
48
50
  bool keep_max_size_level0 = false;
49
51
 
52
+ // See impl/VisitedTable.h.
53
+ std::optional<bool> use_visited_hashset;
54
+
50
55
  explicit IndexHNSW(int d = 0, int M = 32, MetricType metric = METRIC_L2);
51
56
  explicit IndexHNSW(Index* storage, int M = 32);
52
57
 
@@ -73,6 +78,12 @@ struct IndexHNSW : Index {
73
78
  RangeSearchResult* result,
74
79
  const SearchParameters* params = nullptr) const override;
75
80
 
81
+ /** search one vector with a custom result handler */
82
+ void search1(
83
+ const float* x,
84
+ ResultHandler& handler,
85
+ SearchParameters* params = nullptr) const override;
86
+
76
87
  void reconstruct(idx_t key, float* recons) const override;
77
88
 
78
89
  void reset() override;
@@ -164,20 +175,11 @@ struct IndexHNSWFlatPanorama : IndexHNSWFlat {
164
175
 
165
176
  /// Inline for performance - called frequently in search hot path.
166
177
  const float* get_cum_sum(idx_t i) const {
167
- return cum_sums.data() + i * (num_panorama_levels + 1);
178
+ return cum_sums.data() + i * (pano.n_levels + 1);
168
179
  }
169
180
 
170
- /// Compute cumulative sums for a vector (used both for database points and
171
- /// queries).
172
- static void compute_cum_sums(
173
- const float* x,
174
- float* dst_cum_sums,
175
- int d,
176
- int num_panorama_levels,
177
- int panorama_level_width);
178
-
179
181
  std::vector<float> cum_sums;
180
- const size_t panorama_level_width;
182
+ Panorama pano;
181
183
  const size_t num_panorama_levels;
182
184
  };
183
185
 
@@ -12,7 +12,7 @@
12
12
  #include <cinttypes>
13
13
  #include <cstdint>
14
14
  #include <cstdio>
15
- #include "faiss/Index.h"
15
+ #include <stdexcept>
16
16
 
17
17
  #include <faiss/impl/AuxIndexStructures.h>
18
18
  #include <faiss/impl/FaissAssert.h>
@@ -5,8 +5,6 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  #include <faiss/IndexIVF.h>
11
9
 
12
10
  #include <omp.h>
@@ -27,6 +25,8 @@
27
25
  #include <faiss/impl/CodePacker.h>
28
26
  #include <faiss/impl/FaissAssert.h>
29
27
  #include <faiss/impl/IDSelector.h>
28
+ #include <faiss/impl/ResultHandler.h>
29
+ #include <faiss/impl/expanded_scanners.h>
30
30
 
31
31
  namespace faiss {
32
32
 
@@ -920,6 +920,52 @@ void IndexIVF::range_search_preassigned(
920
920
  stats->ndis += ndis;
921
921
  }
922
922
 
923
+ void IndexIVF::search1(
924
+ const float* x,
925
+ ResultHandler& handler,
926
+ SearchParameters* params_in) const {
927
+ const IVFSearchParameters* params = nullptr;
928
+ const SearchParameters* quantizer_params = nullptr;
929
+ if (params_in) {
930
+ params = dynamic_cast<const IVFSearchParameters*>(params_in);
931
+ FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
932
+ quantizer_params = params->quantizer_params;
933
+ }
934
+ const size_t nprobe =
935
+ std::min(nlist, params ? params->nprobe : this->nprobe);
936
+ size_t nx = 1;
937
+ std::unique_ptr<idx_t[]> keys(new idx_t[nx * nprobe]);
938
+ std::unique_ptr<float[]> coarse_dis(new float[nx * nprobe]);
939
+
940
+ double t0 = getmillisecs();
941
+ quantizer->search(
942
+ nx, x, nprobe, coarse_dis.get(), keys.get(), quantizer_params);
943
+ indexIVF_stats.quantization_time += getmillisecs() - t0;
944
+
945
+ t0 = getmillisecs();
946
+ invlists->prefetch_lists(keys.get(), nx * nprobe);
947
+
948
+ std::unique_ptr<InvertedListScanner> scanner(
949
+ get_InvertedListScanner(false, nullptr, params));
950
+ scanner->set_query(x);
951
+
952
+ for (idx_t i = 0; i < nprobe; i++) {
953
+ idx_t key = keys[i];
954
+ if (key < 0 || invlists->is_empty(key)) {
955
+ continue;
956
+ }
957
+
958
+ scanner->set_list(key, coarse_dis[i]);
959
+ InvertedLists::ScopedCodes scodes(invlists, key);
960
+ InvertedLists::ScopedIds ids(invlists, key);
961
+ size_t list_size = invlists->list_size(key);
962
+
963
+ scanner->scan_codes(list_size, scodes.get(), ids.get(), handler);
964
+ }
965
+
966
+ indexIVF_stats.search_time += getmillisecs() - t0;
967
+ }
968
+
923
969
  InvertedListScanner* IndexIVF::get_InvertedListScanner(
924
970
  bool /*store_pairs*/,
925
971
  const IDSelector* /* sel */,
@@ -1298,6 +1344,20 @@ IndexIVFStats indexIVF_stats;
1298
1344
  * InvertedListScanner
1299
1345
  *************************************************************************/
1300
1346
 
1347
+ // this gets expanded in expanded_scanners
1348
+
1349
+ size_t InvertedListScanner::scan_codes(
1350
+ size_t list_size,
1351
+ const uint8_t* codes,
1352
+ const idx_t* ids,
1353
+ ResultHandler& handler) const {
1354
+ return run_scan_codes(*this, list_size, codes, ids, handler);
1355
+ }
1356
+
1357
+ void InvertedListScanner::set_list(idx_t list_no_in, float /* coarse_dis */) {
1358
+ this->list_no = list_no_in;
1359
+ }
1360
+
1301
1361
  size_t InvertedListScanner::scan_codes(
1302
1362
  size_t list_size,
1303
1363
  const uint8_t* codes,
@@ -1305,46 +1365,15 @@ size_t InvertedListScanner::scan_codes(
1305
1365
  float* simi,
1306
1366
  idx_t* idxi,
1307
1367
  size_t k) const {
1308
- size_t nup = 0;
1309
-
1310
1368
  if (!keep_max) {
1311
- for (size_t j = 0; j < list_size; j++) {
1312
- if (sel != nullptr) {
1313
- int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
1314
- if (!sel->is_member(id)) {
1315
- codes += code_size;
1316
- continue;
1317
- }
1318
- }
1319
-
1320
- float dis = distance_to_code(codes);
1321
- if (dis < simi[0]) {
1322
- int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
1323
- maxheap_replace_top(k, simi, idxi, dis, id);
1324
- nup++;
1325
- }
1326
- codes += code_size;
1327
- }
1369
+ using C = CMax<float, idx_t>;
1370
+ HeapResultHandler<C, false> handler(k, simi, idxi);
1371
+ return scan_codes(list_size, codes, ids, handler);
1328
1372
  } else {
1329
- for (size_t j = 0; j < list_size; j++) {
1330
- if (sel != nullptr) {
1331
- int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
1332
- if (!sel->is_member(id)) {
1333
- codes += code_size;
1334
- continue;
1335
- }
1336
- }
1337
-
1338
- float dis = distance_to_code(codes);
1339
- if (dis > simi[0]) {
1340
- int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
1341
- minheap_replace_top(k, simi, idxi, dis, id);
1342
- nup++;
1343
- }
1344
- codes += code_size;
1345
- }
1373
+ using C = CMin<float, idx_t>;
1374
+ HeapResultHandler<C, false> handler(k, simi, idxi);
1375
+ return scan_codes(list_size, codes, ids, handler);
1346
1376
  }
1347
- return nup;
1348
1377
  }
1349
1378
 
1350
1379
  size_t InvertedListScanner::iterate_codes(
@@ -1386,16 +1415,14 @@ void InvertedListScanner::scan_codes_range(
1386
1415
  const idx_t* ids,
1387
1416
  float radius,
1388
1417
  RangeQueryResult& res) const {
1389
- for (size_t j = 0; j < list_size; j++) {
1390
- float dis = distance_to_code(codes);
1391
- bool keep = !keep_max
1392
- ? dis < radius
1393
- : dis > radius; // TODO templatize to remove this test
1394
- if (keep) {
1395
- int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
1396
- res.add(dis, id);
1397
- }
1398
- codes += code_size;
1418
+ if (!keep_max) {
1419
+ using C = CMax<float, idx_t>;
1420
+ RangeResultHandler<C, false> handler(&res, radius);
1421
+ scan_codes(list_size, codes, ids, handler);
1422
+ } else {
1423
+ using C = CMin<float, idx_t>;
1424
+ RangeResultHandler<C, false> handler(&res, radius);
1425
+ scan_codes(list_size, codes, ids, handler);
1399
1426
  }
1400
1427
  }
1401
1428
 
@@ -11,9 +11,6 @@
11
11
  #define FAISS_INDEX_IVF_H
12
12
 
13
13
  #include <stdint.h>
14
- #include <memory>
15
- #include <unordered_map>
16
- #include <vector>
17
14
 
18
15
  #include <faiss/Clustering.h>
19
16
  #include <faiss/Index.h>
@@ -325,6 +322,12 @@ struct IndexIVF : Index, IndexIVFInterface {
325
322
  RangeSearchResult* result,
326
323
  const SearchParameters* params = nullptr) const override;
327
324
 
325
+ /** search one vector with a custom result handler */
326
+ void search1(
327
+ const float* x,
328
+ ResultHandler& handler,
329
+ SearchParameters* params = nullptr) const override;
330
+
328
331
  /** Get a scanner for this index (store_pairs means ignore labels)
329
332
  *
330
333
  * The default search implementation uses this to compute the distances.
@@ -492,7 +495,7 @@ struct InvertedListScanner {
492
495
  virtual void set_query(const float* query_vector) = 0;
493
496
 
494
497
  /// following codes come from this inverted list
495
- virtual void set_list(idx_t list_no, float coarse_dis) = 0;
498
+ virtual void set_list(idx_t list_no, float coarse_dis);
496
499
 
497
500
  /// compute a single query-to-code distance
498
501
  virtual float distance_to_code(const uint8_t* code) const = 0;
@@ -543,6 +546,13 @@ struct InvertedListScanner {
543
546
  RangeQueryResult& result,
544
547
  size_t& list_size) const;
545
548
 
549
+ // accumulate results with a ResultHandler
550
+ virtual size_t scan_codes(
551
+ size_t n,
552
+ const uint8_t* codes,
553
+ const idx_t* ids,
554
+ ResultHandler& handler) const;
555
+
546
556
  virtual ~InvertedListScanner() {}
547
557
  };
548
558
 
@@ -17,6 +17,7 @@
17
17
  #include <faiss/impl/FastScanDistancePostProcessing.h>
18
18
  #include <faiss/impl/LookupTableScaler.h>
19
19
  #include <faiss/impl/pq4_fast_scan.h>
20
+ #include <faiss/impl/simd_dispatch.h>
20
21
  #include <faiss/invlists/BlockInvertedLists.h>
21
22
  #include <faiss/utils/distances.h>
22
23
  #include <faiss/utils/quantize_lut.h>
@@ -405,18 +406,20 @@ void IndexIVFAdditiveQuantizerFastScan::compute_LUT(
405
406
  // bias = coef * <q, c>
406
407
  // NOTE: q^2 is not added to `biases`
407
408
  biases.resize(n * nprobe);
409
+ with_simd_level([&]<SIMDLevel SL>() {
408
410
  #pragma omp parallel
409
- {
410
- std::vector<float> centroid(d);
411
- float* c = centroid.data();
411
+ {
412
+ std::vector<float> centroid(d);
413
+ float* c = centroid.data();
412
414
 
413
415
  #pragma omp for
414
- for (idx_t ij = 0; ij < n * nprobe; ij++) {
415
- int i = ij / nprobe;
416
- quantizer->reconstruct(cq.ids[ij], c);
417
- biases[ij] = coef * fvec_inner_product(c, x + i * d, d);
416
+ for (idx_t ij = 0; ij < n * nprobe; ij++) {
417
+ int i = ij / nprobe;
418
+ quantizer->reconstruct(cq.ids[ij], c);
419
+ biases[ij] = coef * fvec_inner_product<SL>(c, x + i * d, d);
420
+ }
418
421
  }
419
- }
422
+ });
420
423
  }
421
424
 
422
425
  if (metric_type == METRIC_L2) {
@@ -37,13 +37,13 @@ namespace faiss {
37
37
  struct IndexIVFAdditiveQuantizerFastScan : IndexIVFFastScan {
38
38
  using Search_type_t = AdditiveQuantizer::Search_type_t;
39
39
 
40
- AdditiveQuantizer* aq;
40
+ AdditiveQuantizer* aq{};
41
41
 
42
42
  bool rescale_norm = false;
43
43
  int norm_scale = 1;
44
44
 
45
45
  // max number of training vectors
46
- size_t max_train_points;
46
+ size_t max_train_points{};
47
47
 
48
48
  IndexIVFAdditiveQuantizerFastScan(
49
49
  Index* quantizer,
@@ -95,18 +95,19 @@ IndexIVFFastScan::~IndexIVFFastScan() = default;
95
95
  * Code management functions
96
96
  *********************************************************/
97
97
 
98
- void IndexIVFFastScan::preprocess_code_metadata(
99
- idx_t /* n */,
100
- const uint8_t* /* flat_codes */,
101
- idx_t /* start_global_idx */) {
102
- // Default: no-op
103
- }
104
-
105
98
  size_t IndexIVFFastScan::code_packing_stride() const {
106
99
  // Default: use standard M-byte stride
107
100
  return 0;
108
101
  }
109
102
 
103
+ size_t IndexIVFFastScan::get_block_stride() const {
104
+ std::unique_ptr<CodePacker> packer(get_CodePacker());
105
+ FAISS_THROW_IF_NOT_MSG(
106
+ packer->nvec == static_cast<size_t>(bbs),
107
+ "CodePacker must pack bbs vectors per block for fast-scan");
108
+ return packer->block_size;
109
+ }
110
+
110
111
  void IndexIVFFastScan::add_with_ids(
111
112
  idx_t n,
112
113
  const float* x,
@@ -148,9 +149,6 @@ void IndexIVFFastScan::add_with_ids(
148
149
  AlignedTable<uint8_t> flat_codes(n * code_size);
149
150
  encode_vectors(n, x, idx.get(), flat_codes.get());
150
151
 
151
- // Allow subclasses to preprocess metadata before packing
152
- preprocess_code_metadata(n, flat_codes.get(), ntotal);
153
-
154
152
  DirectMapAdd dm_adder(direct_map, n, xids);
155
153
  BlockInvertedLists* bil = dynamic_cast<BlockInvertedLists*>(invlists);
156
154
  FAISS_THROW_IF_NOT_MSG(bil, "only block inverted lists supported");
@@ -206,7 +204,11 @@ void IndexIVFFastScan::add_with_ids(
206
204
  bbs,
207
205
  M2,
208
206
  bil->codes[list_no].data(),
209
- pack_stride);
207
+ pack_stride,
208
+ get_block_stride());
209
+
210
+ postprocess_packed_codes(
211
+ list_no, list_size, i1 - i0, list_codes.data());
210
212
 
211
213
  i0 = i1;
212
214
  }
@@ -1029,7 +1031,8 @@ void IndexIVFFastScan::search_implem_10(
1029
1031
  codes.get(),
1030
1032
  LUT,
1031
1033
  handler,
1032
- context.norm_scaler);
1034
+ context.norm_scaler,
1035
+ get_block_stride());
1033
1036
 
1034
1037
  ndis += ls;
1035
1038
  nlist_visited++;
@@ -1180,7 +1183,8 @@ void IndexIVFFastScan::search_implem_12(
1180
1183
  codes.get(),
1181
1184
  LUT.get(),
1182
1185
  handler,
1183
- context.norm_scaler);
1186
+ context.norm_scaler,
1187
+ get_block_stride());
1184
1188
  // prepare for next loop
1185
1189
  i0 = i1;
1186
1190
  }
@@ -1403,7 +1407,8 @@ void IndexIVFFastScan::search_implem_14(
1403
1407
  codes.get(),
1404
1408
  LUT.get(),
1405
1409
  *handler.get(),
1406
- context.norm_scaler);
1410
+ context.norm_scaler,
1411
+ get_block_stride());
1407
1412
  }
1408
1413
 
1409
1414
  // labels is in-place for HeapHC
@@ -1519,6 +1524,12 @@ void IndexIVFFastScan::sa_decode(idx_t n, const uint8_t* codes, float* x)
1519
1524
  }
1520
1525
  }
1521
1526
 
1527
+ void IndexIVFFastScan::postprocess_packed_codes(
1528
+ idx_t /*list_no*/,
1529
+ size_t /*list_offset*/,
1530
+ size_t /*n_added*/,
1531
+ const uint8_t* /*flat_codes*/) {}
1532
+
1522
1533
  IVFFastScanStats IVFFastScan_stats;
1523
1534
 
1524
1535
  } // namespace faiss
@@ -360,28 +360,6 @@ struct IndexIVFFastScan : IndexIVF {
360
360
  void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
361
361
 
362
362
  protected:
363
- /** Preprocess metadata from encoded vectors before packing.
364
- *
365
- * Called during add_with_ids after encode_vectors but before codes
366
- * are packed into SIMD-friendly blocks. Subclasses can override to
367
- * extract and store metadata embedded in codes or perform other
368
- * pre-packing operations.
369
- *
370
- * Default implementation: no-op
371
- *
372
- * Example use case:
373
- * - IndexIVFRaBitQFastScan extracts factor data from codes for use
374
- * during search-time distance corrections
375
- *
376
- * @param n number of vectors encoded
377
- * @param flat_codes encoded vectors (n * code_size bytes)
378
- * @param start_global_idx starting global index (ntotal before add)
379
- */
380
- virtual void preprocess_code_metadata(
381
- idx_t n,
382
- const uint8_t* flat_codes,
383
- idx_t start_global_idx);
384
-
385
363
  /** Get stride for interpreting codes during SIMD packing.
386
364
  *
387
365
  * The stride determines how to read codes when packing them into
@@ -399,6 +377,32 @@ struct IndexIVFFastScan : IndexIVF {
399
377
  * - >0: use custom stride (e.g., code_size for embedded metadata)
400
378
  */
401
379
  virtual size_t code_packing_stride() const;
380
+
381
+ public:
382
+ /** Get stride in bytes between consecutive SIMD blocks.
383
+ *
384
+ * Derived from get_CodePacker()->block_size so that there is a
385
+ * single source of truth for the block layout.
386
+ *
387
+ * @return stride in bytes
388
+ */
389
+ size_t get_block_stride() const;
390
+
391
+ /** Post-process packed codes after pq4_pack_codes_range.
392
+ *
393
+ * Called during add_with_ids after codes have been packed into
394
+ * SIMD-friendly blocks.
395
+ *
396
+ * @param list_no inverted list number
397
+ * @param list_offset starting offset within the list (pre-existing size)
398
+ * @param n_added number of vectors added in this batch
399
+ * @param flat_codes encoded vectors for this batch (n_added * code_size)
400
+ */
401
+ virtual void postprocess_packed_codes(
402
+ idx_t list_no,
403
+ size_t list_offset,
404
+ size_t n_added,
405
+ const uint8_t* flat_codes);
402
406
  };
403
407
 
404
408
  struct IVFFastScanStats {