faiss 0.4.2 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/ext/faiss/index.cpp +36 -10
  4. data/ext/faiss/index_binary.cpp +19 -6
  5. data/ext/faiss/kmeans.cpp +6 -6
  6. data/ext/faiss/numo.hpp +273 -123
  7. data/lib/faiss/version.rb +1 -1
  8. data/vendor/faiss/faiss/AutoTune.cpp +2 -3
  9. data/vendor/faiss/faiss/AutoTune.h +1 -1
  10. data/vendor/faiss/faiss/Clustering.cpp +2 -2
  11. data/vendor/faiss/faiss/Clustering.h +2 -2
  12. data/vendor/faiss/faiss/IVFlib.cpp +1 -2
  13. data/vendor/faiss/faiss/IVFlib.h +1 -1
  14. data/vendor/faiss/faiss/Index.h +10 -10
  15. data/vendor/faiss/faiss/Index2Layer.cpp +1 -1
  16. data/vendor/faiss/faiss/Index2Layer.h +2 -2
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +9 -4
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +5 -1
  19. data/vendor/faiss/faiss/IndexBinary.h +7 -7
  20. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +1 -1
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +3 -1
  22. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  23. data/vendor/faiss/faiss/IndexBinaryHash.cpp +3 -3
  24. data/vendor/faiss/faiss/IndexBinaryHash.h +5 -5
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +7 -6
  26. data/vendor/faiss/faiss/IndexFastScan.cpp +125 -49
  27. data/vendor/faiss/faiss/IndexFastScan.h +107 -7
  28. data/vendor/faiss/faiss/IndexFlat.h +1 -1
  29. data/vendor/faiss/faiss/IndexHNSW.cpp +3 -1
  30. data/vendor/faiss/faiss/IndexHNSW.h +1 -1
  31. data/vendor/faiss/faiss/IndexIDMap.cpp +14 -13
  32. data/vendor/faiss/faiss/IndexIDMap.h +6 -6
  33. data/vendor/faiss/faiss/IndexIVF.cpp +1 -1
  34. data/vendor/faiss/faiss/IndexIVF.h +5 -5
  35. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +1 -1
  36. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +9 -3
  37. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +3 -1
  38. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +176 -90
  39. data/vendor/faiss/faiss/IndexIVFFastScan.h +173 -18
  40. data/vendor/faiss/faiss/IndexIVFFlat.cpp +1 -0
  41. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +366 -0
  42. data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +64 -0
  43. data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -1
  44. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
  45. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +134 -2
  46. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -1
  47. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +13 -6
  48. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +1 -0
  49. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +650 -0
  50. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +216 -0
  51. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -1
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
  53. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -1
  54. data/vendor/faiss/faiss/IndexNSG.cpp +1 -1
  55. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +1 -1
  56. data/vendor/faiss/faiss/IndexPQ.h +1 -1
  57. data/vendor/faiss/faiss/IndexPQFastScan.cpp +6 -2
  58. data/vendor/faiss/faiss/IndexPQFastScan.h +5 -1
  59. data/vendor/faiss/faiss/IndexRaBitQ.cpp +13 -10
  60. data/vendor/faiss/faiss/IndexRaBitQ.h +7 -2
  61. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +586 -0
  62. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +149 -0
  63. data/vendor/faiss/faiss/IndexShards.cpp +1 -1
  64. data/vendor/faiss/faiss/MatrixStats.cpp +3 -3
  65. data/vendor/faiss/faiss/MetricType.h +1 -1
  66. data/vendor/faiss/faiss/VectorTransform.h +2 -2
  67. data/vendor/faiss/faiss/clone_index.cpp +3 -1
  68. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  69. data/vendor/faiss/faiss/gpu/GpuIndex.h +11 -11
  70. data/vendor/faiss/faiss/gpu/GpuIndexBinaryCagra.h +1 -1
  71. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +1 -1
  72. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +10 -6
  73. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -0
  74. data/vendor/faiss/faiss/gpu/test/TestGpuIcmEncoder.cpp +7 -0
  75. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +1 -1
  76. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +1 -1
  77. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -1
  78. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +2 -2
  79. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -1
  80. data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
  81. data/vendor/faiss/faiss/impl/DistanceComputer.h +3 -3
  82. data/vendor/faiss/faiss/impl/FastScanDistancePostProcessing.h +53 -0
  83. data/vendor/faiss/faiss/impl/HNSW.cpp +1 -1
  84. data/vendor/faiss/faiss/impl/HNSW.h +4 -4
  85. data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
  86. data/vendor/faiss/faiss/impl/IDSelector.h +1 -1
  87. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +4 -4
  88. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +1 -1
  89. data/vendor/faiss/faiss/impl/LookupTableScaler.h +1 -1
  90. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  91. data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
  92. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  93. data/vendor/faiss/faiss/impl/PanoramaStats.cpp +33 -0
  94. data/vendor/faiss/faiss/impl/PanoramaStats.h +38 -0
  95. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +5 -5
  96. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +1 -1
  97. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  98. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -0
  99. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  100. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +246 -0
  101. data/vendor/faiss/faiss/impl/RaBitQUtils.h +153 -0
  102. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +54 -158
  103. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +2 -1
  104. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
  105. data/vendor/faiss/faiss/impl/ResultHandler.h +4 -4
  106. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1 -1
  107. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +1 -1
  108. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +7 -4
  109. data/vendor/faiss/faiss/impl/index_read.cpp +87 -3
  110. data/vendor/faiss/faiss/impl/index_write.cpp +73 -3
  111. data/vendor/faiss/faiss/impl/io.cpp +2 -2
  112. data/vendor/faiss/faiss/impl/io.h +4 -4
  113. data/vendor/faiss/faiss/impl/kmeans1d.cpp +1 -1
  114. data/vendor/faiss/faiss/impl/kmeans1d.h +1 -1
  115. data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
  116. data/vendor/faiss/faiss/impl/mapped_io.cpp +2 -2
  117. data/vendor/faiss/faiss/impl/mapped_io.h +4 -3
  118. data/vendor/faiss/faiss/impl/maybe_owned_vector.h +8 -1
  119. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +30 -4
  120. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +14 -8
  121. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +5 -6
  122. data/vendor/faiss/faiss/impl/simd_result_handlers.h +55 -11
  123. data/vendor/faiss/faiss/impl/zerocopy_io.h +1 -1
  124. data/vendor/faiss/faiss/index_factory.cpp +43 -1
  125. data/vendor/faiss/faiss/index_factory.h +1 -1
  126. data/vendor/faiss/faiss/index_io.h +1 -1
  127. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +205 -0
  128. data/vendor/faiss/faiss/invlists/InvertedLists.h +62 -0
  129. data/vendor/faiss/faiss/utils/AlignedTable.h +1 -1
  130. data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
  131. data/vendor/faiss/faiss/utils/Heap.h +3 -3
  132. data/vendor/faiss/faiss/utils/NeuralNet.cpp +1 -1
  133. data/vendor/faiss/faiss/utils/NeuralNet.h +3 -3
  134. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
  135. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
  136. data/vendor/faiss/faiss/utils/approx_topk/mode.h +1 -1
  137. data/vendor/faiss/faiss/utils/distances.h +2 -2
  138. data/vendor/faiss/faiss/utils/extra_distances-inl.h +3 -1
  139. data/vendor/faiss/faiss/utils/hamming-inl.h +2 -0
  140. data/vendor/faiss/faiss/utils/hamming.cpp +7 -6
  141. data/vendor/faiss/faiss/utils/hamming.h +1 -1
  142. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -2
  143. data/vendor/faiss/faiss/utils/partitioning.cpp +5 -5
  144. data/vendor/faiss/faiss/utils/partitioning.h +2 -2
  145. data/vendor/faiss/faiss/utils/rabitq_simd.h +222 -336
  146. data/vendor/faiss/faiss/utils/random.cpp +1 -1
  147. data/vendor/faiss/faiss/utils/simdlib_avx2.h +1 -1
  148. data/vendor/faiss/faiss/utils/simdlib_avx512.h +1 -1
  149. data/vendor/faiss/faiss/utils/simdlib_neon.h +2 -2
  150. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +1 -1
  151. data/vendor/faiss/faiss/utils/utils.cpp +5 -2
  152. data/vendor/faiss/faiss/utils/utils.h +2 -2
  153. metadata +14 -3
@@ -7,17 +7,20 @@
7
7
 
8
8
  #include <faiss/IndexFastScan.h>
9
9
 
10
- #include <cassert>
11
- #include <climits>
12
- #include <memory>
13
-
14
10
  #include <omp.h>
11
+ #include <cstring>
12
+ #include <memory>
15
13
 
14
+ #include <faiss/impl/CodePacker.h>
16
15
  #include <faiss/impl/FaissAssert.h>
16
+ #include <faiss/impl/FastScanDistancePostProcessing.h>
17
17
  #include <faiss/impl/IDSelector.h>
18
18
  #include <faiss/impl/LookupTableScaler.h>
19
- #include <faiss/impl/ResultHandler.h>
19
+ #include <faiss/impl/RaBitQUtils.h>
20
+ #include <faiss/impl/pq4_fast_scan.h>
21
+ #include <faiss/impl/simd_result_handlers.h>
20
22
  #include <faiss/utils/hamming.h>
23
+ #include <faiss/utils/utils.h>
21
24
 
22
25
  #include <faiss/impl/pq4_fast_scan.h>
23
26
  #include <faiss/impl/simd_result_handlers.h>
@@ -163,14 +166,14 @@ void estimators_from_tables_generic(
163
166
  size_t k,
164
167
  typename C::T* heap_dis,
165
168
  int64_t* heap_ids,
166
- const NormTableScaler* scaler) {
169
+ const FastScanDistancePostProcessing& context) {
167
170
  using accu_t = typename C::T;
168
171
 
169
172
  for (size_t j = 0; j < ncodes; ++j) {
170
173
  BitstringReader bsr(codes + j * index.code_size, index.code_size);
171
174
  accu_t dis = 0;
172
175
  const dis_t* dt = dis_table;
173
- int nscale = scaler ? scaler->nscale : 0;
176
+ int nscale = context.norm_scaler ? context.norm_scaler->nscale : 0;
174
177
 
175
178
  for (size_t m = 0; m < index.M - nscale; m++) {
176
179
  uint64_t c = bsr.read(index.nbits);
@@ -178,10 +181,10 @@ void estimators_from_tables_generic(
178
181
  dt += index.ksub;
179
182
  }
180
183
 
181
- if (nscale) {
184
+ if (nscale && context.norm_scaler) {
182
185
  for (size_t m = 0; m < nscale; m++) {
183
186
  uint64_t c = bsr.read(index.nbits);
184
- dis += scaler->scale_one(dt[c]);
187
+ dis += context.norm_scaler->scale_one(dt[c]);
185
188
  dt += index.ksub;
186
189
  }
187
190
  }
@@ -193,40 +196,58 @@ void estimators_from_tables_generic(
193
196
  }
194
197
  }
195
198
 
196
- template <class C>
197
- ResultHandlerCompare<C, false>* make_knn_handler(
199
+ } // anonymous namespace
200
+
201
+ // Default implementation of make_knn_handler with centralized fallback logic
202
+ void* IndexFastScan::make_knn_handler(
203
+ bool is_max,
198
204
  int impl,
199
205
  idx_t n,
200
206
  idx_t k,
201
207
  size_t ntotal,
202
208
  float* distances,
203
209
  idx_t* labels,
204
- const IDSelector* sel = nullptr) {
205
- using HeapHC = HeapHandler<C, false>;
206
- using ReservoirHC = ReservoirHandler<C, false>;
207
- using SingleResultHC = SingleResultHandler<C, false>;
208
-
209
- if (k == 1) {
210
- return new SingleResultHC(n, ntotal, distances, labels, sel);
211
- } else if (impl % 2 == 0) {
212
- return new HeapHC(n, ntotal, k, distances, labels, sel);
213
- } else /* if (impl % 2 == 1) */ {
214
- return new ReservoirHC(n, ntotal, k, 2 * k, distances, labels, sel);
210
+ const IDSelector* sel,
211
+ const FastScanDistancePostProcessing&) const {
212
+ // Create default handlers based on k and impl
213
+ if (is_max) {
214
+ using HeapHC = HeapHandler<CMax<uint16_t, int>, false>;
215
+ using ReservoirHC = ReservoirHandler<CMax<uint16_t, int>, false>;
216
+ using SingleResultHC = SingleResultHandler<CMax<uint16_t, int>, false>;
217
+
218
+ if (k == 1) {
219
+ return new SingleResultHC(n, ntotal, distances, labels, sel);
220
+ } else if (impl % 2 == 0) {
221
+ return new HeapHC(n, ntotal, k, distances, labels, sel);
222
+ } else {
223
+ return new ReservoirHC(n, ntotal, k, 2 * k, distances, labels, sel);
224
+ }
225
+ } else {
226
+ using HeapHC = HeapHandler<CMin<uint16_t, int>, false>;
227
+ using ReservoirHC = ReservoirHandler<CMin<uint16_t, int>, false>;
228
+ using SingleResultHC = SingleResultHandler<CMin<uint16_t, int>, false>;
229
+
230
+ if (k == 1) {
231
+ return new SingleResultHC(n, ntotal, distances, labels, sel);
232
+ } else if (impl % 2 == 0) {
233
+ return new HeapHC(n, ntotal, k, distances, labels, sel);
234
+ } else {
235
+ return new ReservoirHC(n, ntotal, k, 2 * k, distances, labels, sel);
236
+ }
215
237
  }
216
238
  }
217
239
 
218
- } // anonymous namespace
219
-
220
240
  using namespace quantize_lut;
221
241
 
222
242
  void IndexFastScan::compute_quantized_LUT(
223
243
  idx_t n,
224
244
  const float* x,
225
245
  uint8_t* lut,
226
- float* normalizers) const {
246
+ float* normalizers,
247
+ const FastScanDistancePostProcessing& context) const {
227
248
  size_t dim12 = ksub * M;
228
249
  std::unique_ptr<float[]> dis_tables(new float[n * dim12]);
229
- compute_float_LUT(dis_tables.get(), n, x);
250
+ compute_float_LUT(dis_tables.get(), n, x, context);
230
251
 
231
252
  for (uint64_t i = 0; i < n; i++) {
232
253
  round_uint8_per_column(
@@ -263,10 +284,12 @@ void IndexFastScan::search(
263
284
  !params, "search params not supported for this index");
264
285
  FAISS_THROW_IF_NOT(k > 0);
265
286
 
287
+ FastScanDistancePostProcessing empty_context{};
266
288
  if (metric_type == METRIC_L2) {
267
- search_dispatch_implem<true>(n, x, k, distances, labels, nullptr);
289
+ search_dispatch_implem<true>(n, x, k, distances, labels, empty_context);
268
290
  } else {
269
- search_dispatch_implem<false>(n, x, k, distances, labels, nullptr);
291
+ search_dispatch_implem<false>(
292
+ n, x, k, distances, labels, empty_context);
270
293
  }
271
294
  }
272
295
 
@@ -277,7 +300,7 @@ void IndexFastScan::search_dispatch_implem(
277
300
  idx_t k,
278
301
  float* distances,
279
302
  idx_t* labels,
280
- const NormTableScaler* scaler) const {
303
+ const FastScanDistancePostProcessing& context) const {
281
304
  using Cfloat = typename std::conditional<
282
305
  is_max,
283
306
  CMax<float, int64_t>,
@@ -308,15 +331,20 @@ void IndexFastScan::search_dispatch_implem(
308
331
  FAISS_THROW_MSG("not implemented");
309
332
  } else if (implem == 2 || implem == 3 || implem == 4) {
310
333
  FAISS_THROW_IF_NOT(orig_codes != nullptr);
311
- search_implem_234<Cfloat>(n, x, k, distances, labels, scaler);
334
+ search_implem_234<Cfloat>(n, x, k, distances, labels, context);
312
335
  } else if (impl >= 12 && impl <= 15) {
313
336
  FAISS_THROW_IF_NOT(ntotal < INT_MAX);
314
337
  int nt = std::min(omp_get_max_threads(), int(n));
338
+ // Fall back to single-threaded implementations when parallelization not
339
+ // beneficial:
340
+ // - Single-core system (omp_get_max_threads() = 1)
341
+ // - Single query (n = 1)
342
+ // - OpenMP disabled (omp_get_max_threads() = 1)
315
343
  if (nt < 2) {
316
344
  if (impl == 12 || impl == 13) {
317
- search_implem_12<C>(n, x, k, distances, labels, impl, scaler);
345
+ search_implem_12<C>(n, x, k, distances, labels, impl, context);
318
346
  } else {
319
- search_implem_14<C>(n, x, k, distances, labels, impl, scaler);
347
+ search_implem_14<C>(n, x, k, distances, labels, impl, context);
320
348
  }
321
349
  } else {
322
350
  // explicitly slice over threads
@@ -324,14 +352,33 @@ void IndexFastScan::search_dispatch_implem(
324
352
  for (int slice = 0; slice < nt; slice++) {
325
353
  idx_t i0 = n * slice / nt;
326
354
  idx_t i1 = n * (slice + 1) / nt;
355
+
356
+ // Create per-thread context with adjusted query_factors pointer
357
+ FastScanDistancePostProcessing thread_context = context;
358
+ if (thread_context.query_factors != nullptr) {
359
+ thread_context.query_factors += i0;
360
+ }
361
+
327
362
  float* dis_i = distances + i0 * k;
328
363
  idx_t* lab_i = labels + i0 * k;
329
364
  if (impl == 12 || impl == 13) {
330
365
  search_implem_12<C>(
331
- i1 - i0, x + i0 * d, k, dis_i, lab_i, impl, scaler);
366
+ i1 - i0,
367
+ x + i0 * d,
368
+ k,
369
+ dis_i,
370
+ lab_i,
371
+ impl,
372
+ thread_context);
332
373
  } else {
333
374
  search_implem_14<C>(
334
- i1 - i0, x + i0 * d, k, dis_i, lab_i, impl, scaler);
375
+ i1 - i0,
376
+ x + i0 * d,
377
+ k,
378
+ dis_i,
379
+ lab_i,
380
+ impl,
381
+ thread_context);
335
382
  }
336
383
  }
337
384
  }
@@ -347,12 +394,12 @@ void IndexFastScan::search_implem_234(
347
394
  idx_t k,
348
395
  float* distances,
349
396
  idx_t* labels,
350
- const NormTableScaler* scaler) const {
397
+ const FastScanDistancePostProcessing& context) const {
351
398
  FAISS_THROW_IF_NOT(implem == 2 || implem == 3 || implem == 4);
352
399
 
353
400
  const size_t dim12 = ksub * M;
354
401
  std::unique_ptr<float[]> dis_tables(new float[n * dim12]);
355
- compute_float_LUT(dis_tables.get(), n, x);
402
+ compute_float_LUT(dis_tables.get(), n, x, context);
356
403
 
357
404
  std::vector<float> normalizers(n * 2);
358
405
 
@@ -384,7 +431,7 @@ void IndexFastScan::search_implem_234(
384
431
  k,
385
432
  heap_dis,
386
433
  heap_ids,
387
- scaler);
434
+ context);
388
435
 
389
436
  heap_reorder<Cfloat>(k, heap_dis, heap_ids);
390
437
 
@@ -407,7 +454,7 @@ void IndexFastScan::search_implem_12(
407
454
  float* distances,
408
455
  idx_t* labels,
409
456
  int impl,
410
- const NormTableScaler* scaler) const {
457
+ const FastScanDistancePostProcessing& context) const {
411
458
  using RH = ResultHandlerCompare<C, false>;
412
459
  FAISS_THROW_IF_NOT(bbs == 32);
413
460
 
@@ -416,6 +463,11 @@ void IndexFastScan::search_implem_12(
416
463
  if (n > qbs2) {
417
464
  for (int64_t i0 = 0; i0 < n; i0 += qbs2) {
418
465
  int64_t i1 = std::min(i0 + qbs2, n);
466
+ // Create sub-context with adjusted query_factors pointer
467
+ FastScanDistancePostProcessing sub_context = context;
468
+ if (sub_context.query_factors != nullptr) {
469
+ sub_context.query_factors += i0;
470
+ }
419
471
  search_implem_12<C>(
420
472
  i1 - i0,
421
473
  x + d * i0,
@@ -423,7 +475,7 @@ void IndexFastScan::search_implem_12(
423
475
  distances + i0 * k,
424
476
  labels + i0 * k,
425
477
  impl,
426
- scaler);
478
+ sub_context);
427
479
  }
428
480
  return;
429
481
  }
@@ -436,7 +488,7 @@ void IndexFastScan::search_implem_12(
436
488
  quantized_dis_tables.clear();
437
489
  } else {
438
490
  compute_quantized_LUT(
439
- n, x, quantized_dis_tables.get(), normalizers.get());
491
+ n, x, quantized_dis_tables.get(), normalizers.get(), context);
440
492
  }
441
493
 
442
494
  AlignedTable<uint8_t> LUT(n * dim12);
@@ -455,7 +507,17 @@ void IndexFastScan::search_implem_12(
455
507
  FAISS_THROW_IF_NOT(LUT_nq == n);
456
508
 
457
509
  std::unique_ptr<RH> handler(
458
- make_knn_handler<C>(impl, n, k, ntotal, distances, labels));
510
+ static_cast<RH*>(make_knn_handler(
511
+ C::is_max,
512
+ impl,
513
+ n,
514
+ k,
515
+ ntotal,
516
+ distances,
517
+ labels,
518
+ nullptr,
519
+ context)));
520
+
459
521
  handler->disable = bool(skip & 2);
460
522
  handler->normalizers = normalizers.get();
461
523
 
@@ -469,7 +531,7 @@ void IndexFastScan::search_implem_12(
469
531
  codes.get(),
470
532
  LUT.get(),
471
533
  *handler.get(),
472
- scaler);
534
+ context.norm_scaler);
473
535
  }
474
536
  if (!(skip & 8)) {
475
537
  handler->end();
@@ -486,7 +548,7 @@ void IndexFastScan::search_implem_14(
486
548
  float* distances,
487
549
  idx_t* labels,
488
550
  int impl,
489
- const NormTableScaler* scaler) const {
551
+ const FastScanDistancePostProcessing& context) const {
490
552
  using RH = ResultHandlerCompare<C, false>;
491
553
  FAISS_THROW_IF_NOT(bbs % 32 == 0);
492
554
 
@@ -496,6 +558,11 @@ void IndexFastScan::search_implem_14(
496
558
  if (n > qbs2) {
497
559
  for (int64_t i0 = 0; i0 < n; i0 += qbs2) {
498
560
  int64_t i1 = std::min(i0 + qbs2, n);
561
+ // Create sub-context with adjusted query_factors pointer
562
+ FastScanDistancePostProcessing sub_context = context;
563
+ if (sub_context.query_factors != nullptr) {
564
+ sub_context.query_factors += i0;
565
+ }
499
566
  search_implem_14<C>(
500
567
  i1 - i0,
501
568
  x + d * i0,
@@ -503,7 +570,7 @@ void IndexFastScan::search_implem_14(
503
570
  distances + i0 * k,
504
571
  labels + i0 * k,
505
572
  impl,
506
- scaler);
573
+ sub_context);
507
574
  }
508
575
  return;
509
576
  }
@@ -516,14 +583,23 @@ void IndexFastScan::search_implem_14(
516
583
  quantized_dis_tables.clear();
517
584
  } else {
518
585
  compute_quantized_LUT(
519
- n, x, quantized_dis_tables.get(), normalizers.get());
586
+ n, x, quantized_dis_tables.get(), normalizers.get(), context);
520
587
  }
521
588
 
522
589
  AlignedTable<uint8_t> LUT(n * dim12);
523
590
  pq4_pack_LUT(n, M2, quantized_dis_tables.get(), LUT.get());
524
591
 
525
592
  std::unique_ptr<RH> handler(
526
- make_knn_handler<C>(impl, n, k, ntotal, distances, labels));
593
+ static_cast<RH*>(make_knn_handler(
594
+ C::is_max,
595
+ impl,
596
+ n,
597
+ k,
598
+ ntotal,
599
+ distances,
600
+ labels,
601
+ nullptr,
602
+ context)));
527
603
  handler->disable = bool(skip & 2);
528
604
  handler->normalizers = normalizers.get();
529
605
 
@@ -538,7 +614,7 @@ void IndexFastScan::search_implem_14(
538
614
  codes.get(),
539
615
  LUT.get(),
540
616
  *handler.get(),
541
- scaler);
617
+ context.norm_scaler);
542
618
  }
543
619
  if (!(skip & 8)) {
544
620
  handler->end();
@@ -551,7 +627,7 @@ template void IndexFastScan::search_dispatch_implem<true>(
551
627
  idx_t k,
552
628
  float* distances,
553
629
  idx_t* labels,
554
- const NormTableScaler* scaler) const;
630
+ const FastScanDistancePostProcessing& context) const;
555
631
 
556
632
  template void IndexFastScan::search_dispatch_implem<false>(
557
633
  idx_t n,
@@ -559,7 +635,7 @@ template void IndexFastScan::search_dispatch_implem<false>(
559
635
  idx_t k,
560
636
  float* distances,
561
637
  idx_t* labels,
562
- const NormTableScaler* scaler) const;
638
+ const FastScanDistancePostProcessing& context) const;
563
639
 
564
640
  void IndexFastScan::reconstruct(idx_t key, float* recons) const {
565
641
  std::vector<uint8_t> code(code_size, 0);
@@ -8,6 +8,7 @@
8
8
  #pragma once
9
9
 
10
10
  #include <faiss/Index.h>
11
+ #include <faiss/impl/FastScanDistancePostProcessing.h>
11
12
  #include <faiss/utils/AlignedTable.h>
12
13
 
13
14
  namespace faiss {
@@ -15,6 +16,13 @@ namespace faiss {
15
16
  struct CodePacker;
16
17
  struct NormTableScaler;
17
18
 
19
+ // Forward declarations for result handlers
20
+ namespace simd_result_handlers {
21
+ template <class C, bool with_id_map>
22
+ struct ResultHandlerCompare;
23
+ }
24
+ struct IDSelector;
25
+
18
26
  /** Fast scan version of IndexPQ and IndexAQ. Works for 4-bit PQ and AQ for now.
19
27
  *
20
28
  * The codes are not stored sequentially but grouped in blocks of size bbs.
@@ -54,6 +62,14 @@ struct IndexFastScan : Index {
54
62
  // (set when initialized by IndexPQ or IndexAQ)
55
63
  const uint8_t* orig_codes = nullptr;
56
64
 
65
+ /** Initialize the fast scan index
66
+ *
67
+ * @param d dimensionality of vectors
68
+ * @param M number of subquantizers
69
+ * @param nbits number of bits per subquantizer
70
+ * @param metric distance metric to use
71
+ * @param bbs block size for SIMD processing
72
+ */
57
73
  void init_fastscan(
58
74
  int d,
59
75
  size_t M,
@@ -65,6 +81,15 @@ struct IndexFastScan : Index {
65
81
 
66
82
  void reset() override;
67
83
 
84
+ /** Search for k nearest neighbors
85
+ *
86
+ * @param n number of query vectors
87
+ * @param x query vectors (n * d)
88
+ * @param k number of nearest neighbors to find
89
+ * @param distances output distances (n * k)
90
+ * @param labels output labels/indices (n * k)
91
+ * @param params optional search parameters
92
+ */
68
93
  void search(
69
94
  idx_t n,
70
95
  const float* x,
@@ -73,20 +98,70 @@ struct IndexFastScan : Index {
73
98
  idx_t* labels,
74
99
  const SearchParameters* params = nullptr) const override;
75
100
 
101
+ /** Add vectors to the index
102
+ *
103
+ * @param n number of vectors to add
104
+ * @param x vectors to add (n * d)
105
+ */
76
106
  void add(idx_t n, const float* x) override;
77
107
 
108
+ /** Compute codes for vectors
109
+ *
110
+ * @param codes output codes
111
+ * @param n number of vectors to encode
112
+ * @param x vectors to encode (n * d)
113
+ */
78
114
  virtual void compute_codes(uint8_t* codes, idx_t n, const float* x)
79
115
  const = 0;
80
116
 
81
- virtual void compute_float_LUT(float* lut, idx_t n, const float* x)
82
- const = 0;
117
+ /** Compute floating-point lookup table for distance computation
118
+ *
119
+ * @param lut output lookup table
120
+ * @param n number of query vectors
121
+ * @param x query vectors (n * d)
122
+ * @param context processing context containing all processors
123
+ */
124
+ virtual void compute_float_LUT(
125
+ float* lut,
126
+ idx_t n,
127
+ const float* x,
128
+ const FastScanDistancePostProcessing& context) const = 0;
129
+
130
+ /** Create a KNN handler for this index type
131
+ *
132
+ * This method can be overridden by derived classes to provide
133
+ * specialized handlers (e.g., RaBitQHeapHandler for RaBitQ indexes).
134
+ * Base implementation creates standard handlers based on k and impl.
135
+ *
136
+ * @param is_max whether to use CMax comparator (true) or CMin (false)
137
+ * @param impl implementation number
138
+ * @param n number of queries
139
+ * @param k number of neighbors to find
140
+ * @param ntotal total number of vectors in database
141
+ * @param distances output distances array
142
+ * @param labels output labels array
143
+ * @param sel optional ID selector
144
+ * @param query_offset query offset for batch processing
145
+ * @return pointer to created handler (never returns nullptr)
146
+ */
147
+ virtual void* make_knn_handler(
148
+ bool is_max,
149
+ int impl,
150
+ idx_t n,
151
+ idx_t k,
152
+ size_t ntotal,
153
+ float* distances,
154
+ idx_t* labels,
155
+ const IDSelector* sel,
156
+ const FastScanDistancePostProcessing& context) const;
83
157
 
84
158
  // called by search function
85
159
  void compute_quantized_LUT(
86
160
  idx_t n,
87
161
  const float* x,
88
162
  uint8_t* lut,
89
- float* normalizers) const;
163
+ float* normalizers,
164
+ const FastScanDistancePostProcessing& context) const;
90
165
 
91
166
  template <bool is_max>
92
167
  void search_dispatch_implem(
@@ -95,7 +170,7 @@ struct IndexFastScan : Index {
95
170
  idx_t k,
96
171
  float* distances,
97
172
  idx_t* labels,
98
- const NormTableScaler* scaler) const;
173
+ const FastScanDistancePostProcessing& context) const;
99
174
 
100
175
  template <class Cfloat>
101
176
  void search_implem_234(
@@ -104,7 +179,7 @@ struct IndexFastScan : Index {
104
179
  idx_t k,
105
180
  float* distances,
106
181
  idx_t* labels,
107
- const NormTableScaler* scaler) const;
182
+ const FastScanDistancePostProcessing& context) const;
108
183
 
109
184
  template <class C>
110
185
  void search_implem_12(
@@ -114,7 +189,7 @@ struct IndexFastScan : Index {
114
189
  float* distances,
115
190
  idx_t* labels,
116
191
  int impl,
117
- const NormTableScaler* scaler) const;
192
+ const FastScanDistancePostProcessing& context) const;
118
193
 
119
194
  template <class C>
120
195
  void search_implem_14(
@@ -124,14 +199,39 @@ struct IndexFastScan : Index {
124
199
  float* distances,
125
200
  idx_t* labels,
126
201
  int impl,
127
- const NormTableScaler* scaler) const;
202
+ const FastScanDistancePostProcessing& context) const;
128
203
 
204
+ /** Reconstruct a vector from its code
205
+ *
206
+ * @param key index of vector to reconstruct
207
+ * @param recons output reconstructed vector
208
+ */
129
209
  void reconstruct(idx_t key, float* recons) const override;
210
+
211
+ /** Remove vectors by ID selector
212
+ *
213
+ * @param sel selector defining which vectors to remove
214
+ * @return number of vectors removed
215
+ */
130
216
  size_t remove_ids(const IDSelector& sel) override;
131
217
 
218
+ /** Get the code packer for this index
219
+ *
220
+ * @return pointer to the code packer
221
+ */
132
222
  CodePacker* get_CodePacker() const;
133
223
 
224
+ /** Merge another index into this one
225
+ *
226
+ * @param otherIndex index to merge from
227
+ * @param add_id ID offset to add to merged vectors
228
+ */
134
229
  void merge_from(Index& otherIndex, idx_t add_id = 0) override;
230
+
231
+ /** Check if another index is compatible for merging
232
+ *
233
+ * @param otherIndex index to check compatibility with
234
+ */
135
235
  void check_compatible_for_merge(const Index& otherIndex) const override;
136
236
 
137
237
  /// standalone codes interface (but the codes are flattened)
@@ -66,7 +66,7 @@ struct IndexFlat : IndexFlatCodes {
66
66
 
67
67
  FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const override;
68
68
 
69
- /* The stanadlone codec interface (just memcopies in this case) */
69
+ /* The standalone codec interface (just memcopies in this case) */
70
70
  void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
71
71
 
72
72
  void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
@@ -450,7 +450,9 @@ void IndexHNSW::search_level_0(
450
450
  vt.advance();
451
451
  }
452
452
  #pragma omp critical
453
- { hnsw_stats.combine(search_stats); }
453
+ {
454
+ hnsw_stats.combine(search_stats);
455
+ }
454
456
  }
455
457
  if (is_similarity_metric(this->metric_type)) {
456
458
  // we need to revert the negated distances
@@ -43,7 +43,7 @@ struct IndexHNSW : Index {
43
43
 
44
44
  // When set to true, all neighbors in level 0 are filled up
45
45
  // to the maximum size allowed (2 * M). This option is used by
46
- // IndexHHNSWCagra to create a full base layer graph that is
46
+ // IndexHNSWCagra to create a full base layer graph that is
47
47
  // used when GpuIndexCagra::copyFrom(IndexHNSWCagra*) is invoked.
48
48
  bool keep_max_size_level0 = false;
49
49