faiss 0.4.3 → 0.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +10 -0
  3. data/README.md +2 -0
  4. data/ext/faiss/index.cpp +33 -6
  5. data/ext/faiss/index_binary.cpp +17 -4
  6. data/ext/faiss/kmeans.cpp +6 -6
  7. data/lib/faiss/version.rb +1 -1
  8. data/vendor/faiss/faiss/AutoTune.cpp +2 -3
  9. data/vendor/faiss/faiss/AutoTune.h +1 -1
  10. data/vendor/faiss/faiss/Clustering.cpp +2 -2
  11. data/vendor/faiss/faiss/Clustering.h +2 -2
  12. data/vendor/faiss/faiss/IVFlib.cpp +26 -51
  13. data/vendor/faiss/faiss/IVFlib.h +1 -1
  14. data/vendor/faiss/faiss/Index.cpp +11 -0
  15. data/vendor/faiss/faiss/Index.h +34 -11
  16. data/vendor/faiss/faiss/Index2Layer.cpp +1 -1
  17. data/vendor/faiss/faiss/Index2Layer.h +2 -2
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +1 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +9 -4
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +5 -1
  21. data/vendor/faiss/faiss/IndexBinary.h +7 -7
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +1 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +8 -2
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +3 -3
  26. data/vendor/faiss/faiss/IndexBinaryHash.h +5 -5
  27. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +7 -6
  28. data/vendor/faiss/faiss/IndexFastScan.cpp +125 -49
  29. data/vendor/faiss/faiss/IndexFastScan.h +102 -7
  30. data/vendor/faiss/faiss/IndexFlat.cpp +374 -4
  31. data/vendor/faiss/faiss/IndexFlat.h +81 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +93 -2
  33. data/vendor/faiss/faiss/IndexHNSW.h +58 -2
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +14 -13
  35. data/vendor/faiss/faiss/IndexIDMap.h +6 -6
  36. data/vendor/faiss/faiss/IndexIVF.cpp +1 -1
  37. data/vendor/faiss/faiss/IndexIVF.h +5 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +1 -1
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +9 -3
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +3 -1
  41. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +176 -90
  42. data/vendor/faiss/faiss/IndexIVFFastScan.h +173 -18
  43. data/vendor/faiss/faiss/IndexIVFFlat.cpp +1 -0
  44. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +251 -0
  45. data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +64 -0
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -1
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +134 -2
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -1
  50. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +99 -8
  51. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +4 -1
  52. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +828 -0
  53. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +252 -0
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -1
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
  56. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -1
  57. data/vendor/faiss/faiss/IndexNSG.cpp +1 -1
  58. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +1 -1
  59. data/vendor/faiss/faiss/IndexPQ.cpp +4 -1
  60. data/vendor/faiss/faiss/IndexPQ.h +1 -1
  61. data/vendor/faiss/faiss/IndexPQFastScan.cpp +6 -2
  62. data/vendor/faiss/faiss/IndexPQFastScan.h +5 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +14 -0
  64. data/vendor/faiss/faiss/IndexPreTransform.h +9 -0
  65. data/vendor/faiss/faiss/IndexRaBitQ.cpp +96 -13
  66. data/vendor/faiss/faiss/IndexRaBitQ.h +11 -2
  67. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +731 -0
  68. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +175 -0
  69. data/vendor/faiss/faiss/IndexRefine.cpp +49 -0
  70. data/vendor/faiss/faiss/IndexRefine.h +17 -0
  71. data/vendor/faiss/faiss/IndexShards.cpp +1 -1
  72. data/vendor/faiss/faiss/MatrixStats.cpp +3 -3
  73. data/vendor/faiss/faiss/MetricType.h +1 -1
  74. data/vendor/faiss/faiss/VectorTransform.h +2 -2
  75. data/vendor/faiss/faiss/clone_index.cpp +5 -1
  76. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  77. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +3 -1
  78. data/vendor/faiss/faiss/gpu/GpuIndex.h +11 -11
  79. data/vendor/faiss/faiss/gpu/GpuIndexBinaryCagra.h +1 -1
  80. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +1 -1
  81. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +11 -7
  82. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +1 -1
  83. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -0
  84. data/vendor/faiss/faiss/gpu/test/TestGpuIcmEncoder.cpp +7 -0
  85. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +1 -1
  86. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +1 -1
  87. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -1
  88. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +2 -2
  89. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -1
  90. data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
  91. data/vendor/faiss/faiss/impl/DistanceComputer.h +77 -6
  92. data/vendor/faiss/faiss/impl/FastScanDistancePostProcessing.h +53 -0
  93. data/vendor/faiss/faiss/impl/HNSW.cpp +295 -16
  94. data/vendor/faiss/faiss/impl/HNSW.h +35 -6
  95. data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
  96. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  97. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +4 -4
  98. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +1 -1
  99. data/vendor/faiss/faiss/impl/LookupTableScaler.h +1 -1
  100. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  101. data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
  102. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  103. data/vendor/faiss/faiss/impl/Panorama.cpp +193 -0
  104. data/vendor/faiss/faiss/impl/Panorama.h +204 -0
  105. data/vendor/faiss/faiss/impl/PanoramaStats.cpp +33 -0
  106. data/vendor/faiss/faiss/impl/PanoramaStats.h +38 -0
  107. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +5 -5
  108. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +1 -1
  109. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  110. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -0
  111. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  112. data/vendor/faiss/faiss/impl/RaBitQStats.cpp +29 -0
  113. data/vendor/faiss/faiss/impl/RaBitQStats.h +56 -0
  114. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +294 -0
  115. data/vendor/faiss/faiss/impl/RaBitQUtils.h +330 -0
  116. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +304 -223
  117. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +72 -4
  118. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +362 -0
  119. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +112 -0
  120. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
  121. data/vendor/faiss/faiss/impl/ResultHandler.h +4 -4
  122. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +7 -10
  123. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +2 -4
  124. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +7 -4
  125. data/vendor/faiss/faiss/impl/index_read.cpp +238 -10
  126. data/vendor/faiss/faiss/impl/index_write.cpp +212 -19
  127. data/vendor/faiss/faiss/impl/io.cpp +2 -2
  128. data/vendor/faiss/faiss/impl/io.h +4 -4
  129. data/vendor/faiss/faiss/impl/kmeans1d.cpp +1 -1
  130. data/vendor/faiss/faiss/impl/kmeans1d.h +1 -1
  131. data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
  132. data/vendor/faiss/faiss/impl/mapped_io.cpp +2 -2
  133. data/vendor/faiss/faiss/impl/mapped_io.h +4 -3
  134. data/vendor/faiss/faiss/impl/maybe_owned_vector.h +8 -1
  135. data/vendor/faiss/faiss/impl/platform_macros.h +12 -0
  136. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +30 -4
  137. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +14 -8
  138. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +5 -6
  139. data/vendor/faiss/faiss/impl/simd_result_handlers.h +55 -11
  140. data/vendor/faiss/faiss/impl/svs_io.cpp +86 -0
  141. data/vendor/faiss/faiss/impl/svs_io.h +67 -0
  142. data/vendor/faiss/faiss/impl/zerocopy_io.h +1 -1
  143. data/vendor/faiss/faiss/index_factory.cpp +217 -8
  144. data/vendor/faiss/faiss/index_factory.h +1 -1
  145. data/vendor/faiss/faiss/index_io.h +1 -1
  146. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +1 -1
  147. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  148. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +115 -1
  149. data/vendor/faiss/faiss/invlists/InvertedLists.h +46 -0
  150. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
  151. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  152. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +261 -0
  153. data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +117 -0
  154. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +66 -0
  155. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +245 -0
  156. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +137 -0
  157. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +39 -0
  158. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +42 -0
  159. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +149 -0
  160. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +58 -0
  161. data/vendor/faiss/faiss/utils/AlignedTable.h +1 -1
  162. data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
  163. data/vendor/faiss/faiss/utils/Heap.h +3 -3
  164. data/vendor/faiss/faiss/utils/NeuralNet.cpp +1 -1
  165. data/vendor/faiss/faiss/utils/NeuralNet.h +3 -3
  166. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
  167. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
  168. data/vendor/faiss/faiss/utils/approx_topk/mode.h +1 -1
  169. data/vendor/faiss/faiss/utils/distances.cpp +0 -3
  170. data/vendor/faiss/faiss/utils/distances.h +2 -2
  171. data/vendor/faiss/faiss/utils/extra_distances-inl.h +3 -1
  172. data/vendor/faiss/faiss/utils/hamming-inl.h +2 -0
  173. data/vendor/faiss/faiss/utils/hamming.cpp +7 -6
  174. data/vendor/faiss/faiss/utils/hamming.h +1 -1
  175. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -2
  176. data/vendor/faiss/faiss/utils/partitioning.cpp +5 -5
  177. data/vendor/faiss/faiss/utils/partitioning.h +2 -2
  178. data/vendor/faiss/faiss/utils/rabitq_simd.h +222 -336
  179. data/vendor/faiss/faiss/utils/random.cpp +1 -1
  180. data/vendor/faiss/faiss/utils/simdlib_avx2.h +1 -1
  181. data/vendor/faiss/faiss/utils/simdlib_avx512.h +1 -1
  182. data/vendor/faiss/faiss/utils/simdlib_neon.h +2 -2
  183. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +1 -1
  184. data/vendor/faiss/faiss/utils/utils.cpp +9 -2
  185. data/vendor/faiss/faiss/utils/utils.h +2 -2
  186. metadata +29 -1
@@ -49,10 +49,10 @@ struct IndexBinary {
49
49
  /** Perform training on a representative set of vectors.
50
50
  *
51
51
  * @param n nb of training vectors
52
- * @param x training vecors, size n * d / 8
52
+ * @param x training vectors, size n * d / 8
53
53
  */
54
54
  virtual void train(idx_t n, const uint8_t* x);
55
- virtual void trainEx(idx_t n, const void* x, NumericType numeric_type) {
55
+ virtual void train_ex(idx_t n, const void* x, NumericType numeric_type) {
56
56
  if (numeric_type == NumericType::UInt8) {
57
57
  train(n, static_cast<const uint8_t*>(x));
58
58
  } else {
@@ -66,7 +66,7 @@ struct IndexBinary {
66
66
  * @param x input matrix, size n * d / 8
67
67
  */
68
68
  virtual void add(idx_t n, const uint8_t* x) = 0;
69
- virtual void addEx(idx_t n, const void* x, NumericType numeric_type) {
69
+ virtual void add_ex(idx_t n, const void* x, NumericType numeric_type) {
70
70
  if (numeric_type == NumericType::UInt8) {
71
71
  add(n, static_cast<const uint8_t*>(x));
72
72
  } else {
@@ -82,7 +82,7 @@ struct IndexBinary {
82
82
  * @param xids if non-null, ids to store for the vectors (size n)
83
83
  */
84
84
  virtual void add_with_ids(idx_t n, const uint8_t* x, const idx_t* xids);
85
- virtual void add_with_idsEx(
85
+ virtual void add_with_ids_ex(
86
86
  idx_t n,
87
87
  const void* x,
88
88
  NumericType numeric_type,
@@ -111,7 +111,7 @@ struct IndexBinary {
111
111
  int32_t* distances,
112
112
  idx_t* labels,
113
113
  const SearchParameters* params = nullptr) const = 0;
114
- virtual void searchEx(
114
+ virtual void search_ex(
115
115
  idx_t n,
116
116
  const void* x,
117
117
  NumericType numeric_type,
@@ -172,14 +172,14 @@ struct IndexBinary {
172
172
  *
173
173
  * This function may not be defined for some indexes.
174
174
  * @param key id of the vector to reconstruct
175
- * @param recons reconstucted vector (size d / 8)
175
+ * @param recons reconstructed vector (size d / 8)
176
176
  */
177
177
  virtual void reconstruct(idx_t key, uint8_t* recons) const;
178
178
 
179
179
  /** Reconstruct vectors i0 to i0 + ni - 1.
180
180
  *
181
181
  * This function may not be defined for some indexes.
182
- * @param recons reconstucted vectors (size ni * d / 8)
182
+ * @param recons reconstructed vectors (size ni * d / 8)
183
183
  */
184
184
  virtual void reconstruct_n(idx_t i0, idx_t ni, uint8_t* recons) const;
185
185
 
@@ -32,7 +32,7 @@ struct IndexBinaryFromFloat : IndexBinary {
32
32
 
33
33
  explicit IndexBinaryFromFloat(Index* index);
34
34
 
35
- ~IndexBinaryFromFloat();
35
+ ~IndexBinaryFromFloat() override;
36
36
 
37
37
  void add(idx_t n, const uint8_t* x) override;
38
38
 
@@ -227,7 +227,11 @@ void IndexBinaryHNSW::search(
227
227
  for (idx_t i = 0; i < n; i++) {
228
228
  res.begin(i);
229
229
  dis->set_query((float*)(x + i * code_size));
230
- hnsw.search(*dis, res, vt);
230
+ // Given that IndexBinaryHNSW is not an IndexHNSW, we pass nullptr
231
+ // as the index parameter. This state does not get used in the
232
+ // search function, as it is merely there to to enable Panorama
233
+ // execution for IndexHNSWFlatPanorama.
234
+ hnsw.search(*dis, nullptr, res, vt);
231
235
  res.end();
232
236
  }
233
237
  }
@@ -290,7 +294,9 @@ struct FlatHammingDis : DistanceComputer {
290
294
 
291
295
  ~FlatHammingDis() override {
292
296
  #pragma omp critical
293
- { hnsw_stats.ndis += ndis; }
297
+ {
298
+ hnsw_stats.ndis += ndis;
299
+ }
294
300
  }
295
301
  };
296
302
 
@@ -36,7 +36,7 @@ struct IndexBinaryHNSW : IndexBinary {
36
36
 
37
37
  // When set to true, all neighbors in level 0 are filled up
38
38
  // to the maximum size allowed (2 * M). This option is used by
39
- // IndexBinaryHHNSW to create a full base layer graph that is
39
+ // IndexBinaryHNSW to create a full base layer graph that is
40
40
  // used when GpuIndexBinaryCagra::copyFrom(IndexBinaryHNSW*) is called.
41
41
  bool keep_max_size_level0 = false;
42
42
 
@@ -177,8 +177,8 @@ void search_single_query_template(
177
177
  struct Run_search_single_query {
178
178
  using T = void;
179
179
  template <class HammingComputer, class... Types>
180
- T f(Types... args) {
181
- search_single_query_template<HammingComputer>(args...);
180
+ T f(Types*... args) {
181
+ search_single_query_template<HammingComputer>(*args...);
182
182
  }
183
183
  };
184
184
 
@@ -192,7 +192,7 @@ void search_single_query(
192
192
  size_t& ndis) {
193
193
  Run_search_single_query r;
194
194
  dispatch_HammingComputer(
195
- index.code_size, r, index, q, res, n0, nlist, ndis);
195
+ index.code_size, r, &index, &q, &res, &n0, &nlist, &ndis);
196
196
  }
197
197
 
198
198
  } // anonymous namespace
@@ -66,10 +66,10 @@ struct IndexBinaryHash : IndexBinary {
66
66
  };
67
67
 
68
68
  struct IndexBinaryHashStats {
69
- size_t nq; // nb of queries run
70
- size_t n0; // nb of empty lists
71
- size_t nlist; // nb of non-empty inverted lists scanned
72
- size_t ndis; // nb of distancs computed
69
+ size_t nq; // nb of queries run
70
+ size_t n0; // nb of empty lists
71
+ size_t nlist; // nb of non-empty inverted lists scanned
72
+ size_t ndis{}; // nb of distances computed
73
73
 
74
74
  IndexBinaryHashStats() {
75
75
  reset();
@@ -99,7 +99,7 @@ struct IndexBinaryMultiHash : IndexBinary {
99
99
 
100
100
  IndexBinaryMultiHash();
101
101
 
102
- ~IndexBinaryMultiHash();
102
+ ~IndexBinaryMultiHash() override;
103
103
 
104
104
  void reset() override;
105
105
 
@@ -492,12 +492,13 @@ void search_knn_hamming_count(
492
492
 
493
493
  std::vector<HCounterState<HammingComputer>> cs;
494
494
  for (size_t i = 0; i < nx; ++i) {
495
- cs.push_back(HCounterState<HammingComputer>(
496
- all_counters.data() + i * nBuckets,
497
- all_ids_per_dis.get() + i * nBuckets * k,
498
- x + i * ivf->code_size,
499
- ivf->d,
500
- k));
495
+ cs.push_back(
496
+ HCounterState<HammingComputer>(
497
+ all_counters.data() + i * nBuckets,
498
+ all_ids_per_dis.get() + i * nBuckets * k,
499
+ x + i * ivf->code_size,
500
+ ivf->d,
501
+ k));
501
502
  }
502
503
 
503
504
  size_t nlistv = 0, ndis = 0;
@@ -7,17 +7,20 @@
7
7
 
8
8
  #include <faiss/IndexFastScan.h>
9
9
 
10
- #include <cassert>
11
- #include <climits>
12
- #include <memory>
13
-
14
10
  #include <omp.h>
11
+ #include <cstring>
12
+ #include <memory>
15
13
 
14
+ #include <faiss/impl/CodePacker.h>
16
15
  #include <faiss/impl/FaissAssert.h>
16
+ #include <faiss/impl/FastScanDistancePostProcessing.h>
17
17
  #include <faiss/impl/IDSelector.h>
18
18
  #include <faiss/impl/LookupTableScaler.h>
19
- #include <faiss/impl/ResultHandler.h>
19
+ #include <faiss/impl/RaBitQUtils.h>
20
+ #include <faiss/impl/pq4_fast_scan.h>
21
+ #include <faiss/impl/simd_result_handlers.h>
20
22
  #include <faiss/utils/hamming.h>
23
+ #include <faiss/utils/utils.h>
21
24
 
22
25
  #include <faiss/impl/pq4_fast_scan.h>
23
26
  #include <faiss/impl/simd_result_handlers.h>
@@ -163,14 +166,14 @@ void estimators_from_tables_generic(
163
166
  size_t k,
164
167
  typename C::T* heap_dis,
165
168
  int64_t* heap_ids,
166
- const NormTableScaler* scaler) {
169
+ const FastScanDistancePostProcessing& context) {
167
170
  using accu_t = typename C::T;
168
171
 
169
172
  for (size_t j = 0; j < ncodes; ++j) {
170
173
  BitstringReader bsr(codes + j * index.code_size, index.code_size);
171
174
  accu_t dis = 0;
172
175
  const dis_t* dt = dis_table;
173
- int nscale = scaler ? scaler->nscale : 0;
176
+ int nscale = context.norm_scaler ? context.norm_scaler->nscale : 0;
174
177
 
175
178
  for (size_t m = 0; m < index.M - nscale; m++) {
176
179
  uint64_t c = bsr.read(index.nbits);
@@ -178,10 +181,10 @@ void estimators_from_tables_generic(
178
181
  dt += index.ksub;
179
182
  }
180
183
 
181
- if (nscale) {
184
+ if (nscale && context.norm_scaler) {
182
185
  for (size_t m = 0; m < nscale; m++) {
183
186
  uint64_t c = bsr.read(index.nbits);
184
- dis += scaler->scale_one(dt[c]);
187
+ dis += context.norm_scaler->scale_one(dt[c]);
185
188
  dt += index.ksub;
186
189
  }
187
190
  }
@@ -193,40 +196,58 @@ void estimators_from_tables_generic(
193
196
  }
194
197
  }
195
198
 
196
- template <class C>
197
- ResultHandlerCompare<C, false>* make_knn_handler(
199
+ } // anonymous namespace
200
+
201
+ // Default implementation of make_knn_handler with centralized fallback logic
202
+ SIMDResultHandlerToFloat* IndexFastScan::make_knn_handler(
203
+ bool is_max,
198
204
  int impl,
199
205
  idx_t n,
200
206
  idx_t k,
201
207
  size_t ntotal,
202
208
  float* distances,
203
209
  idx_t* labels,
204
- const IDSelector* sel = nullptr) {
205
- using HeapHC = HeapHandler<C, false>;
206
- using ReservoirHC = ReservoirHandler<C, false>;
207
- using SingleResultHC = SingleResultHandler<C, false>;
208
-
209
- if (k == 1) {
210
- return new SingleResultHC(n, ntotal, distances, labels, sel);
211
- } else if (impl % 2 == 0) {
212
- return new HeapHC(n, ntotal, k, distances, labels, sel);
213
- } else /* if (impl % 2 == 1) */ {
214
- return new ReservoirHC(n, ntotal, k, 2 * k, distances, labels, sel);
210
+ const IDSelector* sel,
211
+ const FastScanDistancePostProcessing&) const {
212
+ // Create default handlers based on k and impl
213
+ if (is_max) {
214
+ using HeapHC = HeapHandler<CMax<uint16_t, int>, false>;
215
+ using ReservoirHC = ReservoirHandler<CMax<uint16_t, int>, false>;
216
+ using SingleResultHC = SingleResultHandler<CMax<uint16_t, int>, false>;
217
+
218
+ if (k == 1) {
219
+ return new SingleResultHC(n, ntotal, distances, labels, sel);
220
+ } else if (impl % 2 == 0) {
221
+ return new HeapHC(n, ntotal, k, distances, labels, sel);
222
+ } else {
223
+ return new ReservoirHC(n, ntotal, k, 2 * k, distances, labels, sel);
224
+ }
225
+ } else {
226
+ using HeapHC = HeapHandler<CMin<uint16_t, int>, false>;
227
+ using ReservoirHC = ReservoirHandler<CMin<uint16_t, int>, false>;
228
+ using SingleResultHC = SingleResultHandler<CMin<uint16_t, int>, false>;
229
+
230
+ if (k == 1) {
231
+ return new SingleResultHC(n, ntotal, distances, labels, sel);
232
+ } else if (impl % 2 == 0) {
233
+ return new HeapHC(n, ntotal, k, distances, labels, sel);
234
+ } else {
235
+ return new ReservoirHC(n, ntotal, k, 2 * k, distances, labels, sel);
236
+ }
215
237
  }
216
238
  }
217
239
 
218
- } // anonymous namespace
219
-
220
240
  using namespace quantize_lut;
221
241
 
222
242
  void IndexFastScan::compute_quantized_LUT(
223
243
  idx_t n,
224
244
  const float* x,
225
245
  uint8_t* lut,
226
- float* normalizers) const {
246
+ float* normalizers,
247
+ const FastScanDistancePostProcessing& context) const {
227
248
  size_t dim12 = ksub * M;
228
249
  std::unique_ptr<float[]> dis_tables(new float[n * dim12]);
229
- compute_float_LUT(dis_tables.get(), n, x);
250
+ compute_float_LUT(dis_tables.get(), n, x, context);
230
251
 
231
252
  for (uint64_t i = 0; i < n; i++) {
232
253
  round_uint8_per_column(
@@ -263,10 +284,12 @@ void IndexFastScan::search(
263
284
  !params, "search params not supported for this index");
264
285
  FAISS_THROW_IF_NOT(k > 0);
265
286
 
287
+ FastScanDistancePostProcessing empty_context{};
266
288
  if (metric_type == METRIC_L2) {
267
- search_dispatch_implem<true>(n, x, k, distances, labels, nullptr);
289
+ search_dispatch_implem<true>(n, x, k, distances, labels, empty_context);
268
290
  } else {
269
- search_dispatch_implem<false>(n, x, k, distances, labels, nullptr);
291
+ search_dispatch_implem<false>(
292
+ n, x, k, distances, labels, empty_context);
270
293
  }
271
294
  }
272
295
 
@@ -277,7 +300,7 @@ void IndexFastScan::search_dispatch_implem(
277
300
  idx_t k,
278
301
  float* distances,
279
302
  idx_t* labels,
280
- const NormTableScaler* scaler) const {
303
+ const FastScanDistancePostProcessing& context) const {
281
304
  using Cfloat = typename std::conditional<
282
305
  is_max,
283
306
  CMax<float, int64_t>,
@@ -308,15 +331,20 @@ void IndexFastScan::search_dispatch_implem(
308
331
  FAISS_THROW_MSG("not implemented");
309
332
  } else if (implem == 2 || implem == 3 || implem == 4) {
310
333
  FAISS_THROW_IF_NOT(orig_codes != nullptr);
311
- search_implem_234<Cfloat>(n, x, k, distances, labels, scaler);
334
+ search_implem_234<Cfloat>(n, x, k, distances, labels, context);
312
335
  } else if (impl >= 12 && impl <= 15) {
313
336
  FAISS_THROW_IF_NOT(ntotal < INT_MAX);
314
337
  int nt = std::min(omp_get_max_threads(), int(n));
338
+ // Fall back to single-threaded implementations when parallelization not
339
+ // beneficial:
340
+ // - Single-core system (omp_get_max_threads() = 1)
341
+ // - Single query (n = 1)
342
+ // - OpenMP disabled (omp_get_max_threads() = 1)
315
343
  if (nt < 2) {
316
344
  if (impl == 12 || impl == 13) {
317
- search_implem_12<C>(n, x, k, distances, labels, impl, scaler);
345
+ search_implem_12<C>(n, x, k, distances, labels, impl, context);
318
346
  } else {
319
- search_implem_14<C>(n, x, k, distances, labels, impl, scaler);
347
+ search_implem_14<C>(n, x, k, distances, labels, impl, context);
320
348
  }
321
349
  } else {
322
350
  // explicitly slice over threads
@@ -324,14 +352,33 @@ void IndexFastScan::search_dispatch_implem(
324
352
  for (int slice = 0; slice < nt; slice++) {
325
353
  idx_t i0 = n * slice / nt;
326
354
  idx_t i1 = n * (slice + 1) / nt;
355
+
356
+ // Create per-thread context with adjusted query_factors pointer
357
+ FastScanDistancePostProcessing thread_context = context;
358
+ if (thread_context.query_factors != nullptr) {
359
+ thread_context.query_factors += i0;
360
+ }
361
+
327
362
  float* dis_i = distances + i0 * k;
328
363
  idx_t* lab_i = labels + i0 * k;
329
364
  if (impl == 12 || impl == 13) {
330
365
  search_implem_12<C>(
331
- i1 - i0, x + i0 * d, k, dis_i, lab_i, impl, scaler);
366
+ i1 - i0,
367
+ x + i0 * d,
368
+ k,
369
+ dis_i,
370
+ lab_i,
371
+ impl,
372
+ thread_context);
332
373
  } else {
333
374
  search_implem_14<C>(
334
- i1 - i0, x + i0 * d, k, dis_i, lab_i, impl, scaler);
375
+ i1 - i0,
376
+ x + i0 * d,
377
+ k,
378
+ dis_i,
379
+ lab_i,
380
+ impl,
381
+ thread_context);
335
382
  }
336
383
  }
337
384
  }
@@ -347,12 +394,12 @@ void IndexFastScan::search_implem_234(
347
394
  idx_t k,
348
395
  float* distances,
349
396
  idx_t* labels,
350
- const NormTableScaler* scaler) const {
397
+ const FastScanDistancePostProcessing& context) const {
351
398
  FAISS_THROW_IF_NOT(implem == 2 || implem == 3 || implem == 4);
352
399
 
353
400
  const size_t dim12 = ksub * M;
354
401
  std::unique_ptr<float[]> dis_tables(new float[n * dim12]);
355
- compute_float_LUT(dis_tables.get(), n, x);
402
+ compute_float_LUT(dis_tables.get(), n, x, context);
356
403
 
357
404
  std::vector<float> normalizers(n * 2);
358
405
 
@@ -384,7 +431,7 @@ void IndexFastScan::search_implem_234(
384
431
  k,
385
432
  heap_dis,
386
433
  heap_ids,
387
- scaler);
434
+ context);
388
435
 
389
436
  heap_reorder<Cfloat>(k, heap_dis, heap_ids);
390
437
 
@@ -407,7 +454,7 @@ void IndexFastScan::search_implem_12(
407
454
  float* distances,
408
455
  idx_t* labels,
409
456
  int impl,
410
- const NormTableScaler* scaler) const {
457
+ const FastScanDistancePostProcessing& context) const {
411
458
  using RH = ResultHandlerCompare<C, false>;
412
459
  FAISS_THROW_IF_NOT(bbs == 32);
413
460
 
@@ -416,6 +463,11 @@ void IndexFastScan::search_implem_12(
416
463
  if (n > qbs2) {
417
464
  for (int64_t i0 = 0; i0 < n; i0 += qbs2) {
418
465
  int64_t i1 = std::min(i0 + qbs2, n);
466
+ // Create sub-context with adjusted query_factors pointer
467
+ FastScanDistancePostProcessing sub_context = context;
468
+ if (sub_context.query_factors != nullptr) {
469
+ sub_context.query_factors += i0;
470
+ }
419
471
  search_implem_12<C>(
420
472
  i1 - i0,
421
473
  x + d * i0,
@@ -423,7 +475,7 @@ void IndexFastScan::search_implem_12(
423
475
  distances + i0 * k,
424
476
  labels + i0 * k,
425
477
  impl,
426
- scaler);
478
+ sub_context);
427
479
  }
428
480
  return;
429
481
  }
@@ -436,7 +488,7 @@ void IndexFastScan::search_implem_12(
436
488
  quantized_dis_tables.clear();
437
489
  } else {
438
490
  compute_quantized_LUT(
439
- n, x, quantized_dis_tables.get(), normalizers.get());
491
+ n, x, quantized_dis_tables.get(), normalizers.get(), context);
440
492
  }
441
493
 
442
494
  AlignedTable<uint8_t> LUT(n * dim12);
@@ -455,7 +507,17 @@ void IndexFastScan::search_implem_12(
455
507
  FAISS_THROW_IF_NOT(LUT_nq == n);
456
508
 
457
509
  std::unique_ptr<RH> handler(
458
- make_knn_handler<C>(impl, n, k, ntotal, distances, labels));
510
+ static_cast<RH*>(make_knn_handler(
511
+ C::is_max,
512
+ impl,
513
+ n,
514
+ k,
515
+ ntotal,
516
+ distances,
517
+ labels,
518
+ nullptr,
519
+ context)));
520
+
459
521
  handler->disable = bool(skip & 2);
460
522
  handler->normalizers = normalizers.get();
461
523
 
@@ -469,7 +531,7 @@ void IndexFastScan::search_implem_12(
469
531
  codes.get(),
470
532
  LUT.get(),
471
533
  *handler.get(),
472
- scaler);
534
+ context.norm_scaler);
473
535
  }
474
536
  if (!(skip & 8)) {
475
537
  handler->end();
@@ -486,7 +548,7 @@ void IndexFastScan::search_implem_14(
486
548
  float* distances,
487
549
  idx_t* labels,
488
550
  int impl,
489
- const NormTableScaler* scaler) const {
551
+ const FastScanDistancePostProcessing& context) const {
490
552
  using RH = ResultHandlerCompare<C, false>;
491
553
  FAISS_THROW_IF_NOT(bbs % 32 == 0);
492
554
 
@@ -496,6 +558,11 @@ void IndexFastScan::search_implem_14(
496
558
  if (n > qbs2) {
497
559
  for (int64_t i0 = 0; i0 < n; i0 += qbs2) {
498
560
  int64_t i1 = std::min(i0 + qbs2, n);
561
+ // Create sub-context with adjusted query_factors pointer
562
+ FastScanDistancePostProcessing sub_context = context;
563
+ if (sub_context.query_factors != nullptr) {
564
+ sub_context.query_factors += i0;
565
+ }
499
566
  search_implem_14<C>(
500
567
  i1 - i0,
501
568
  x + d * i0,
@@ -503,7 +570,7 @@ void IndexFastScan::search_implem_14(
503
570
  distances + i0 * k,
504
571
  labels + i0 * k,
505
572
  impl,
506
- scaler);
573
+ sub_context);
507
574
  }
508
575
  return;
509
576
  }
@@ -516,14 +583,23 @@ void IndexFastScan::search_implem_14(
516
583
  quantized_dis_tables.clear();
517
584
  } else {
518
585
  compute_quantized_LUT(
519
- n, x, quantized_dis_tables.get(), normalizers.get());
586
+ n, x, quantized_dis_tables.get(), normalizers.get(), context);
520
587
  }
521
588
 
522
589
  AlignedTable<uint8_t> LUT(n * dim12);
523
590
  pq4_pack_LUT(n, M2, quantized_dis_tables.get(), LUT.get());
524
591
 
525
592
  std::unique_ptr<RH> handler(
526
- make_knn_handler<C>(impl, n, k, ntotal, distances, labels));
593
+ static_cast<RH*>(make_knn_handler(
594
+ C::is_max,
595
+ impl,
596
+ n,
597
+ k,
598
+ ntotal,
599
+ distances,
600
+ labels,
601
+ nullptr,
602
+ context)));
527
603
  handler->disable = bool(skip & 2);
528
604
  handler->normalizers = normalizers.get();
529
605
 
@@ -538,7 +614,7 @@ void IndexFastScan::search_implem_14(
538
614
  codes.get(),
539
615
  LUT.get(),
540
616
  *handler.get(),
541
- scaler);
617
+ context.norm_scaler);
542
618
  }
543
619
  if (!(skip & 8)) {
544
620
  handler->end();
@@ -551,7 +627,7 @@ template void IndexFastScan::search_dispatch_implem<true>(
551
627
  idx_t k,
552
628
  float* distances,
553
629
  idx_t* labels,
554
- const NormTableScaler* scaler) const;
630
+ const FastScanDistancePostProcessing& context) const;
555
631
 
556
632
  template void IndexFastScan::search_dispatch_implem<false>(
557
633
  idx_t n,
@@ -559,7 +635,7 @@ template void IndexFastScan::search_dispatch_implem<false>(
559
635
  idx_t k,
560
636
  float* distances,
561
637
  idx_t* labels,
562
- const NormTableScaler* scaler) const;
638
+ const FastScanDistancePostProcessing& context) const;
563
639
 
564
640
  void IndexFastScan::reconstruct(idx_t key, float* recons) const {
565
641
  std::vector<uint8_t> code(code_size, 0);