faiss 0.4.2 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/ext/faiss/index.cpp +36 -10
  4. data/ext/faiss/index_binary.cpp +19 -6
  5. data/ext/faiss/kmeans.cpp +6 -6
  6. data/ext/faiss/numo.hpp +273 -123
  7. data/lib/faiss/version.rb +1 -1
  8. data/vendor/faiss/faiss/AutoTune.cpp +2 -3
  9. data/vendor/faiss/faiss/AutoTune.h +1 -1
  10. data/vendor/faiss/faiss/Clustering.cpp +2 -2
  11. data/vendor/faiss/faiss/Clustering.h +2 -2
  12. data/vendor/faiss/faiss/IVFlib.cpp +1 -2
  13. data/vendor/faiss/faiss/IVFlib.h +1 -1
  14. data/vendor/faiss/faiss/Index.h +10 -10
  15. data/vendor/faiss/faiss/Index2Layer.cpp +1 -1
  16. data/vendor/faiss/faiss/Index2Layer.h +2 -2
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +9 -4
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +5 -1
  19. data/vendor/faiss/faiss/IndexBinary.h +7 -7
  20. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +1 -1
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +3 -1
  22. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  23. data/vendor/faiss/faiss/IndexBinaryHash.cpp +3 -3
  24. data/vendor/faiss/faiss/IndexBinaryHash.h +5 -5
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +7 -6
  26. data/vendor/faiss/faiss/IndexFastScan.cpp +125 -49
  27. data/vendor/faiss/faiss/IndexFastScan.h +107 -7
  28. data/vendor/faiss/faiss/IndexFlat.h +1 -1
  29. data/vendor/faiss/faiss/IndexHNSW.cpp +3 -1
  30. data/vendor/faiss/faiss/IndexHNSW.h +1 -1
  31. data/vendor/faiss/faiss/IndexIDMap.cpp +14 -13
  32. data/vendor/faiss/faiss/IndexIDMap.h +6 -6
  33. data/vendor/faiss/faiss/IndexIVF.cpp +1 -1
  34. data/vendor/faiss/faiss/IndexIVF.h +5 -5
  35. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +1 -1
  36. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +9 -3
  37. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +3 -1
  38. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +176 -90
  39. data/vendor/faiss/faiss/IndexIVFFastScan.h +173 -18
  40. data/vendor/faiss/faiss/IndexIVFFlat.cpp +1 -0
  41. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +366 -0
  42. data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +64 -0
  43. data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -1
  44. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
  45. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +134 -2
  46. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -1
  47. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +13 -6
  48. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +1 -0
  49. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +650 -0
  50. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +216 -0
  51. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -1
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
  53. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -1
  54. data/vendor/faiss/faiss/IndexNSG.cpp +1 -1
  55. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +1 -1
  56. data/vendor/faiss/faiss/IndexPQ.h +1 -1
  57. data/vendor/faiss/faiss/IndexPQFastScan.cpp +6 -2
  58. data/vendor/faiss/faiss/IndexPQFastScan.h +5 -1
  59. data/vendor/faiss/faiss/IndexRaBitQ.cpp +13 -10
  60. data/vendor/faiss/faiss/IndexRaBitQ.h +7 -2
  61. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +586 -0
  62. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +149 -0
  63. data/vendor/faiss/faiss/IndexShards.cpp +1 -1
  64. data/vendor/faiss/faiss/MatrixStats.cpp +3 -3
  65. data/vendor/faiss/faiss/MetricType.h +1 -1
  66. data/vendor/faiss/faiss/VectorTransform.h +2 -2
  67. data/vendor/faiss/faiss/clone_index.cpp +3 -1
  68. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  69. data/vendor/faiss/faiss/gpu/GpuIndex.h +11 -11
  70. data/vendor/faiss/faiss/gpu/GpuIndexBinaryCagra.h +1 -1
  71. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +1 -1
  72. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +10 -6
  73. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -0
  74. data/vendor/faiss/faiss/gpu/test/TestGpuIcmEncoder.cpp +7 -0
  75. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +1 -1
  76. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +1 -1
  77. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -1
  78. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +2 -2
  79. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -1
  80. data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
  81. data/vendor/faiss/faiss/impl/DistanceComputer.h +3 -3
  82. data/vendor/faiss/faiss/impl/FastScanDistancePostProcessing.h +53 -0
  83. data/vendor/faiss/faiss/impl/HNSW.cpp +1 -1
  84. data/vendor/faiss/faiss/impl/HNSW.h +4 -4
  85. data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
  86. data/vendor/faiss/faiss/impl/IDSelector.h +1 -1
  87. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +4 -4
  88. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +1 -1
  89. data/vendor/faiss/faiss/impl/LookupTableScaler.h +1 -1
  90. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  91. data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
  92. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  93. data/vendor/faiss/faiss/impl/PanoramaStats.cpp +33 -0
  94. data/vendor/faiss/faiss/impl/PanoramaStats.h +38 -0
  95. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +5 -5
  96. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +1 -1
  97. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  98. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -0
  99. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  100. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +246 -0
  101. data/vendor/faiss/faiss/impl/RaBitQUtils.h +153 -0
  102. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +54 -158
  103. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +2 -1
  104. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
  105. data/vendor/faiss/faiss/impl/ResultHandler.h +4 -4
  106. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1 -1
  107. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +1 -1
  108. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +7 -4
  109. data/vendor/faiss/faiss/impl/index_read.cpp +87 -3
  110. data/vendor/faiss/faiss/impl/index_write.cpp +73 -3
  111. data/vendor/faiss/faiss/impl/io.cpp +2 -2
  112. data/vendor/faiss/faiss/impl/io.h +4 -4
  113. data/vendor/faiss/faiss/impl/kmeans1d.cpp +1 -1
  114. data/vendor/faiss/faiss/impl/kmeans1d.h +1 -1
  115. data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
  116. data/vendor/faiss/faiss/impl/mapped_io.cpp +2 -2
  117. data/vendor/faiss/faiss/impl/mapped_io.h +4 -3
  118. data/vendor/faiss/faiss/impl/maybe_owned_vector.h +8 -1
  119. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +30 -4
  120. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +14 -8
  121. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +5 -6
  122. data/vendor/faiss/faiss/impl/simd_result_handlers.h +55 -11
  123. data/vendor/faiss/faiss/impl/zerocopy_io.h +1 -1
  124. data/vendor/faiss/faiss/index_factory.cpp +43 -1
  125. data/vendor/faiss/faiss/index_factory.h +1 -1
  126. data/vendor/faiss/faiss/index_io.h +1 -1
  127. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +205 -0
  128. data/vendor/faiss/faiss/invlists/InvertedLists.h +62 -0
  129. data/vendor/faiss/faiss/utils/AlignedTable.h +1 -1
  130. data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
  131. data/vendor/faiss/faiss/utils/Heap.h +3 -3
  132. data/vendor/faiss/faiss/utils/NeuralNet.cpp +1 -1
  133. data/vendor/faiss/faiss/utils/NeuralNet.h +3 -3
  134. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
  135. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
  136. data/vendor/faiss/faiss/utils/approx_topk/mode.h +1 -1
  137. data/vendor/faiss/faiss/utils/distances.h +2 -2
  138. data/vendor/faiss/faiss/utils/extra_distances-inl.h +3 -1
  139. data/vendor/faiss/faiss/utils/hamming-inl.h +2 -0
  140. data/vendor/faiss/faiss/utils/hamming.cpp +7 -6
  141. data/vendor/faiss/faiss/utils/hamming.h +1 -1
  142. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -2
  143. data/vendor/faiss/faiss/utils/partitioning.cpp +5 -5
  144. data/vendor/faiss/faiss/utils/partitioning.h +2 -2
  145. data/vendor/faiss/faiss/utils/rabitq_simd.h +222 -336
  146. data/vendor/faiss/faiss/utils/random.cpp +1 -1
  147. data/vendor/faiss/faiss/utils/simdlib_avx2.h +1 -1
  148. data/vendor/faiss/faiss/utils/simdlib_avx512.h +1 -1
  149. data/vendor/faiss/faiss/utils/simdlib_neon.h +2 -2
  150. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +1 -1
  151. data/vendor/faiss/faiss/utils/utils.cpp +5 -2
  152. data/vendor/faiss/faiss/utils/utils.h +2 -2
  153. metadata +14 -3
@@ -7,7 +7,6 @@
7
7
 
8
8
  #include <faiss/IndexIVFFastScan.h>
9
9
 
10
- #include <cassert>
11
10
  #include <cstdio>
12
11
  #include <set>
13
12
 
@@ -18,7 +17,9 @@
18
17
  #include <faiss/IndexIVFPQ.h>
19
18
  #include <faiss/impl/AuxIndexStructures.h>
20
19
  #include <faiss/impl/FaissAssert.h>
20
+ #include <faiss/impl/FastScanDistancePostProcessing.h>
21
21
  #include <faiss/impl/LookupTableScaler.h>
22
+ #include <faiss/impl/RaBitQUtils.h>
22
23
  #include <faiss/impl/pq4_fast_scan.h>
23
24
  #include <faiss/impl/simd_result_handlers.h>
24
25
  #include <faiss/invlists/BlockInvertedLists.h>
@@ -94,6 +95,18 @@ IndexIVFFastScan::~IndexIVFFastScan() = default;
94
95
  * Code management functions
95
96
  *********************************************************/
96
97
 
98
+ void IndexIVFFastScan::preprocess_code_metadata(
99
+ idx_t /* n */,
100
+ const uint8_t* /* flat_codes */,
101
+ idx_t /* start_global_idx */) {
102
+ // Default: no-op
103
+ }
104
+
105
+ size_t IndexIVFFastScan::code_packing_stride() const {
106
+ // Default: use standard M-byte stride
107
+ return 0;
108
+ }
109
+
97
110
  void IndexIVFFastScan::add_with_ids(
98
111
  idx_t n,
99
112
  const float* x,
@@ -135,6 +148,9 @@ void IndexIVFFastScan::add_with_ids(
135
148
  AlignedTable<uint8_t> flat_codes(n * code_size);
136
149
  encode_vectors(n, x, idx.get(), flat_codes.get());
137
150
 
151
+ // Allow subclasses to preprocess metadata before packing
152
+ preprocess_code_metadata(n, flat_codes.get(), ntotal);
153
+
138
154
  DirectMapAdd dm_adder(direct_map, n, xids);
139
155
  BlockInvertedLists* bil = dynamic_cast<BlockInvertedLists*>(invlists);
140
156
  FAISS_THROW_IF_NOT_MSG(bil, "only block inverted lists supported");
@@ -150,6 +166,9 @@ void IndexIVFFastScan::add_with_ids(
150
166
  return idx[a] < idx[b];
151
167
  });
152
168
 
169
+ // Get stride for packing codes with potential embedded metadata
170
+ size_t pack_stride = code_packing_stride();
171
+
153
172
  // TODO parallelize
154
173
  idx_t i0 = 0;
155
174
  while (i0 < n) {
@@ -186,7 +205,8 @@ void IndexIVFFastScan::add_with_ids(
186
205
  list_size + i1 - i0,
187
206
  bbs,
188
207
  M2,
189
- bil->codes[list_no].data());
208
+ bil->codes[list_no].data(),
209
+ pack_stride);
190
210
 
191
211
  i0 = i1;
192
212
  }
@@ -215,9 +235,9 @@ void estimators_from_tables_generic(
215
235
  size_t k,
216
236
  typename C::T* heap_dis,
217
237
  int64_t* heap_ids,
218
- const NormTableScaler* scaler) {
238
+ const FastScanDistancePostProcessing& context) {
219
239
  using accu_t = typename C::T;
220
- size_t nscale = scaler ? scaler->nscale : 0;
240
+ size_t nscale = context.norm_scaler ? context.norm_scaler->nscale : 0;
221
241
  for (size_t j = 0; j < ncodes; ++j) {
222
242
  BitstringReader bsr(codes + j * index.code_size, index.code_size);
223
243
  accu_t dis = bias;
@@ -229,10 +249,10 @@ void estimators_from_tables_generic(
229
249
  dt += index.ksub;
230
250
  }
231
251
 
232
- if (scaler) {
252
+ if (context.norm_scaler) {
233
253
  for (size_t m = 0; m < nscale; m++) {
234
254
  uint64_t c = bsr.read(index.nbits);
235
- dis += scaler->scale_one(dt[c]);
255
+ dis += context.norm_scaler->scale_one(dt[c]);
236
256
  dt += index.ksub;
237
257
  }
238
258
  }
@@ -244,13 +264,12 @@ void estimators_from_tables_generic(
244
264
  }
245
265
  }
246
266
 
247
- using namespace quantize_lut;
248
-
249
267
  } // anonymous namespace
250
268
 
251
269
  /*********************************************************
252
270
  * Look-Up Table functions
253
271
  *********************************************************/
272
+ using namespace quantize_lut;
254
273
 
255
274
  void IndexIVFFastScan::compute_LUT_uint8(
256
275
  size_t n,
@@ -258,11 +277,12 @@ void IndexIVFFastScan::compute_LUT_uint8(
258
277
  const CoarseQuantized& cq,
259
278
  AlignedTable<uint8_t>& dis_tables,
260
279
  AlignedTable<uint16_t>& biases,
261
- float* normalizers) const {
280
+ float* normalizers,
281
+ const FastScanDistancePostProcessing& context) const {
262
282
  AlignedTable<float> dis_tables_float;
263
283
  AlignedTable<float> biases_float;
264
284
 
265
- compute_LUT(n, x, cq, dis_tables_float, biases_float);
285
+ compute_LUT(n, x, cq, dis_tables_float, biases_float, context);
266
286
  size_t nprobe = cq.nprobe;
267
287
  bool lut_is_3d = lookup_table_is_3d();
268
288
  size_t dim123 = ksub * M;
@@ -346,9 +366,11 @@ void IndexIVFFastScan::search_preassigned(
346
366
  !store_pairs, "store_pairs not supported for this index");
347
367
  FAISS_THROW_IF_NOT_MSG(!stats, "stats not supported for this index");
348
368
  FAISS_THROW_IF_NOT(k > 0);
369
+ FastScanDistancePostProcessing empty_context{};
349
370
 
350
371
  const CoarseQuantized cq = {nprobe, centroid_dis, assign};
351
- search_dispatch_implem(n, x, k, distances, labels, cq, nullptr, params);
372
+ search_dispatch_implem(
373
+ n, x, k, distances, labels, cq, empty_context, params);
352
374
  }
353
375
 
354
376
  void IndexIVFFastScan::range_search(
@@ -365,9 +387,11 @@ void IndexIVFFastScan::range_search(
365
387
  params, "IndexIVFFastScan params have incorrect type");
366
388
  nprobe = params->nprobe;
367
389
  }
390
+ FastScanDistancePostProcessing empty_context{};
368
391
 
369
392
  const CoarseQuantized cq = {nprobe, nullptr, nullptr};
370
- range_search_dispatch_implem(n, x, radius, *result, cq, nullptr, params);
393
+ range_search_dispatch_implem(
394
+ n, x, radius, *result, cq, empty_context, params);
371
395
  }
372
396
 
373
397
  namespace {
@@ -379,7 +403,8 @@ ResultHandlerCompare<C, true>* make_knn_handler_fixC(
379
403
  idx_t k,
380
404
  float* distances,
381
405
  idx_t* labels,
382
- const IDSelector* sel) {
406
+ const IDSelector* sel,
407
+ const float* normalizers) {
383
408
  using HeapHC = HeapHandler<C, true>;
384
409
  using ReservoirHC = ReservoirHandler<C, true>;
385
410
  using SingleResultHC = SingleResultHandler<C, true>;
@@ -387,29 +412,12 @@ ResultHandlerCompare<C, true>* make_knn_handler_fixC(
387
412
  if (k == 1) {
388
413
  return new SingleResultHC(n, 0, distances, labels, sel);
389
414
  } else if (impl % 2 == 0) {
390
- return new HeapHC(n, 0, k, distances, labels, sel);
415
+ return new HeapHC(n, 0, k, distances, labels, sel, normalizers);
391
416
  } else /* if (impl % 2 == 1) */ {
392
417
  return new ReservoirHC(n, 0, k, 2 * k, distances, labels, sel);
393
418
  }
394
419
  }
395
420
 
396
- SIMDResultHandlerToFloat* make_knn_handler(
397
- bool is_max,
398
- int impl,
399
- idx_t n,
400
- idx_t k,
401
- float* distances,
402
- idx_t* labels,
403
- const IDSelector* sel) {
404
- if (is_max) {
405
- return make_knn_handler_fixC<CMax<uint16_t, int64_t>>(
406
- impl, n, k, distances, labels, sel);
407
- } else {
408
- return make_knn_handler_fixC<CMin<uint16_t, int64_t>>(
409
- impl, n, k, distances, labels, sel);
410
- }
411
- }
412
-
413
421
  using CoarseQuantized = IndexIVFFastScan::CoarseQuantized;
414
422
 
415
423
  struct CoarseQuantizedWithBuffer : CoarseQuantized {
@@ -443,7 +451,7 @@ struct CoarseQuantizedWithBuffer : CoarseQuantized {
443
451
  };
444
452
 
445
453
  struct CoarseQuantizedSlice : CoarseQuantizedWithBuffer {
446
- size_t i0, i1;
454
+ const size_t i0, i1;
447
455
  CoarseQuantizedSlice(const CoarseQuantized& cq, size_t i0, size_t i1)
448
456
  : CoarseQuantizedWithBuffer(cq), i0(i0), i1(i1) {
449
457
  if (done()) {
@@ -486,6 +494,25 @@ int compute_search_nslice(
486
494
 
487
495
  } // namespace
488
496
 
497
+ SIMDResultHandlerToFloat* IndexIVFFastScan::make_knn_handler(
498
+ bool is_max,
499
+ int impl,
500
+ idx_t n,
501
+ idx_t k,
502
+ float* distances,
503
+ idx_t* labels,
504
+ const IDSelector* sel,
505
+ const FastScanDistancePostProcessing&,
506
+ const float* normalizers) const {
507
+ if (is_max) {
508
+ return make_knn_handler_fixC<CMax<uint16_t, int64_t>>(
509
+ impl, n, k, distances, labels, sel, normalizers);
510
+ } else {
511
+ return make_knn_handler_fixC<CMin<uint16_t, int64_t>>(
512
+ impl, n, k, distances, labels, sel, normalizers);
513
+ }
514
+ }
515
+
489
516
  void IndexIVFFastScan::search_dispatch_implem(
490
517
  idx_t n,
491
518
  const float* x,
@@ -493,7 +520,7 @@ void IndexIVFFastScan::search_dispatch_implem(
493
520
  float* distances,
494
521
  idx_t* labels,
495
522
  const CoarseQuantized& cq_in,
496
- const NormTableScaler* scaler,
523
+ const FastScanDistancePostProcessing& context,
497
524
  const IVFSearchParameters* params) const {
498
525
  const idx_t nprobe = params ? params->nprobe : this->nprobe;
499
526
  const IDSelector* sel = (params) ? params->sel : nullptr;
@@ -542,18 +569,18 @@ void IndexIVFFastScan::search_dispatch_implem(
542
569
  if (impl == 1) {
543
570
  if (is_max) {
544
571
  search_implem_1<CMax<float, int64_t>>(
545
- n, x, k, distances, labels, cq, scaler, params);
572
+ n, x, k, distances, labels, cq, context, params);
546
573
  } else {
547
574
  search_implem_1<CMin<float, int64_t>>(
548
- n, x, k, distances, labels, cq, scaler, params);
575
+ n, x, k, distances, labels, cq, context, params);
549
576
  }
550
577
  } else if (impl == 2) {
551
578
  if (is_max) {
552
579
  search_implem_2<CMax<uint16_t, int64_t>>(
553
- n, x, k, distances, labels, cq, scaler, params);
580
+ n, x, k, distances, labels, cq, context, params);
554
581
  } else {
555
582
  search_implem_2<CMin<uint16_t, int64_t>>(
556
- n, x, k, distances, labels, cq, scaler, params);
583
+ n, x, k, distances, labels, cq, context, params);
557
584
  }
558
585
  } else if (impl >= 10 && impl <= 15) {
559
586
  size_t ndis = 0, nlist_visited = 0;
@@ -562,37 +589,38 @@ void IndexIVFFastScan::search_dispatch_implem(
562
589
  // clang-format off
563
590
  if (impl == 12 || impl == 13) {
564
591
  std::unique_ptr<RH> handler(
565
- make_knn_handler(
566
- is_max,
567
- impl,
568
- n,
569
- k,
570
- distances,
571
- labels, sel
572
- )
592
+ static_cast<RH*>(this->make_knn_handler(
593
+ is_max,
594
+ impl,
595
+ n,
596
+ k,
597
+ distances,
598
+ labels,
599
+ sel,
600
+ context))
573
601
  );
574
602
  search_implem_12(
575
603
  n, x, *handler.get(),
576
- cq, &ndis, &nlist_visited, scaler, params);
604
+ cq, &ndis, &nlist_visited, context, params);
577
605
  } else if (impl == 14 || impl == 15) {
578
606
  search_implem_14(
579
607
  n, x, k, distances, labels,
580
- cq, impl, scaler, params);
608
+ cq, impl, context, params);
581
609
  } else {
582
610
  std::unique_ptr<RH> handler(
583
- make_knn_handler(
584
- is_max,
585
- impl,
586
- n,
587
- k,
588
- distances,
611
+ static_cast<RH*>(this->make_knn_handler(
612
+ is_max,
613
+ impl,
614
+ n,
615
+ k,
616
+ distances,
589
617
  labels,
590
- sel
591
- )
618
+ sel,
619
+ context))
592
620
  );
593
621
  search_implem_10(
594
622
  n, x, *handler.get(), cq,
595
- &ndis, &nlist_visited, scaler, params);
623
+ &ndis, &nlist_visited, context, params);
596
624
  }
597
625
  // clang-format on
598
626
  } else {
@@ -602,7 +630,7 @@ void IndexIVFFastScan::search_dispatch_implem(
602
630
  // this might require slicing if there are too
603
631
  // many queries (for now we keep this simple)
604
632
  search_implem_14(
605
- n, x, k, distances, labels, cq, impl, scaler, params);
633
+ n, x, k, distances, labels, cq, impl, context, params);
606
634
  } else {
607
635
  #pragma omp parallel for reduction(+ : ndis, nlist_visited)
608
636
  for (int slice = 0; slice < nslice; slice++) {
@@ -614,17 +642,33 @@ void IndexIVFFastScan::search_dispatch_implem(
614
642
  if (!cq_i.done()) {
615
643
  cq_i.quantize_slice(quantizer, x, quantizer_params);
616
644
  }
617
- std::unique_ptr<RH> handler(make_knn_handler(
618
- is_max, impl, i1 - i0, k, dis_i, lab_i, sel));
645
+
646
+ // Create per-thread context with adjusted query_factors
647
+ // pointer
648
+ FastScanDistancePostProcessing thread_context = context;
649
+ if (thread_context.query_factors != nullptr) {
650
+ thread_context.query_factors += i0 * nprobe;
651
+ }
652
+
653
+ std::unique_ptr<RH> handler(
654
+ static_cast<RH*>(this->make_knn_handler(
655
+ is_max,
656
+ impl,
657
+ i1 - i0,
658
+ k,
659
+ dis_i,
660
+ lab_i,
661
+ sel,
662
+ thread_context)));
619
663
  // clang-format off
620
664
  if (impl == 12 || impl == 13) {
621
665
  search_implem_12(
622
666
  i1 - i0, x + i0 * d, *handler.get(),
623
- cq_i, &ndis, &nlist_visited, scaler, params);
667
+ cq_i, &ndis, &nlist_visited, thread_context, params);
624
668
  } else {
625
669
  search_implem_10(
626
670
  i1 - i0, x + i0 * d, *handler.get(),
627
- cq_i, &ndis, &nlist_visited, scaler, params);
671
+ cq_i, &ndis, &nlist_visited, thread_context, params);
628
672
  }
629
673
  // clang-format on
630
674
  }
@@ -644,7 +688,7 @@ void IndexIVFFastScan::range_search_dispatch_implem(
644
688
  float radius,
645
689
  RangeSearchResult& rres,
646
690
  const CoarseQuantized& cq_in,
647
- const NormTableScaler* scaler,
691
+ const FastScanDistancePostProcessing& context,
648
692
  const IVFSearchParameters* params) const {
649
693
  // const idx_t nprobe = params ? params->nprobe : this->nprobe;
650
694
  const IDSelector* sel = (params) ? params->sel : nullptr;
@@ -656,7 +700,6 @@ void IndexIVFFastScan::range_search_dispatch_implem(
656
700
  if (n == 0) {
657
701
  return;
658
702
  }
659
-
660
703
  // actual implementation used
661
704
  int impl = implem;
662
705
 
@@ -695,10 +738,10 @@ void IndexIVFFastScan::range_search_dispatch_implem(
695
738
  }
696
739
  if (impl == 12) {
697
740
  search_implem_12(
698
- n, x, *handler.get(), cq, &ndis, &nlist_visited, scaler);
741
+ n, x, *handler.get(), cq, &ndis, &nlist_visited, context);
699
742
  } else if (impl == 10) {
700
743
  search_implem_10(
701
- n, x, *handler.get(), cq, &ndis, &nlist_visited, scaler);
744
+ n, x, *handler.get(), cq, &ndis, &nlist_visited, context);
702
745
  } else {
703
746
  FAISS_THROW_FMT("Range search implem %d not implemented", impl);
704
747
  }
@@ -736,8 +779,7 @@ void IndexIVFFastScan::range_search_dispatch_implem(
736
779
  cq_i,
737
780
  &ndis,
738
781
  &nlist_visited,
739
- scaler,
740
- params);
782
+ context);
741
783
  } else {
742
784
  search_implem_10(
743
785
  i1 - i0,
@@ -746,8 +788,7 @@ void IndexIVFFastScan::range_search_dispatch_implem(
746
788
  cq_i,
747
789
  &ndis,
748
790
  &nlist_visited,
749
- scaler,
750
- params);
791
+ context);
751
792
  }
752
793
  }
753
794
  pres.finalize();
@@ -767,7 +808,7 @@ void IndexIVFFastScan::search_implem_1(
767
808
  float* distances,
768
809
  idx_t* labels,
769
810
  const CoarseQuantized& cq,
770
- const NormTableScaler* scaler,
811
+ const FastScanDistancePostProcessing& context,
771
812
  const IVFSearchParameters* params) const {
772
813
  FAISS_THROW_IF_NOT(orig_invlists);
773
814
 
@@ -775,7 +816,8 @@ void IndexIVFFastScan::search_implem_1(
775
816
  AlignedTable<float> dis_tables;
776
817
  AlignedTable<float> biases;
777
818
 
778
- compute_LUT(n, x, cq, dis_tables, biases);
819
+ FastScanDistancePostProcessing empty_context;
820
+ compute_LUT(n, x, cq, dis_tables, biases, empty_context);
779
821
 
780
822
  bool single_LUT = !lookup_table_is_3d();
781
823
 
@@ -818,7 +860,7 @@ void IndexIVFFastScan::search_implem_1(
818
860
  k,
819
861
  heap_dis,
820
862
  heap_ids,
821
- scaler);
863
+ context);
822
864
  nlist_visited++;
823
865
  ndis += ls;
824
866
  }
@@ -837,7 +879,7 @@ void IndexIVFFastScan::search_implem_2(
837
879
  float* distances,
838
880
  idx_t* labels,
839
881
  const CoarseQuantized& cq,
840
- const NormTableScaler* scaler,
882
+ const FastScanDistancePostProcessing& context,
841
883
  const IVFSearchParameters* params) const {
842
884
  FAISS_THROW_IF_NOT(orig_invlists);
843
885
 
@@ -846,7 +888,7 @@ void IndexIVFFastScan::search_implem_2(
846
888
  AlignedTable<uint16_t> biases;
847
889
  std::unique_ptr<float[]> normalizers(new float[2 * n]);
848
890
 
849
- compute_LUT_uint8(n, x, cq, dis_tables, biases, normalizers.get());
891
+ compute_LUT_uint8(n, x, cq, dis_tables, biases, normalizers.get(), context);
850
892
 
851
893
  bool single_LUT = !lookup_table_is_3d();
852
894
 
@@ -891,7 +933,7 @@ void IndexIVFFastScan::search_implem_2(
891
933
  k,
892
934
  heap_dis,
893
935
  heap_ids,
894
- scaler);
936
+ context);
895
937
 
896
938
  nlist_visited++;
897
939
  ndis += ls;
@@ -922,24 +964,27 @@ void IndexIVFFastScan::search_implem_10(
922
964
  const CoarseQuantized& cq,
923
965
  size_t* ndis_out,
924
966
  size_t* nlist_out,
925
- const NormTableScaler* scaler,
926
- const IVFSearchParameters* params) const {
967
+ const FastScanDistancePostProcessing& context,
968
+ const IVFSearchParameters* /* params */) const {
927
969
  size_t dim12 = ksub * M2;
928
970
  AlignedTable<uint8_t> dis_tables;
929
971
  AlignedTable<uint16_t> biases;
930
972
  std::unique_ptr<float[]> normalizers(new float[2 * n]);
931
973
 
932
- compute_LUT_uint8(n, x, cq, dis_tables, biases, normalizers.get());
974
+ compute_LUT_uint8(n, x, cq, dis_tables, biases, normalizers.get(), context);
933
975
 
934
976
  bool single_LUT = !lookup_table_is_3d();
935
977
 
936
978
  size_t ndis = 0, nlist_visited = 0;
937
979
  int qmap1[1];
938
-
939
980
  handler.q_map = qmap1;
940
981
  handler.begin(skip & 16 ? nullptr : normalizers.get());
941
982
  size_t nprobe = cq.nprobe;
942
983
 
984
+ // Allocate probe_map once and reuse it
985
+ std::vector<int> probe_map;
986
+ probe_map.reserve(1);
987
+
943
988
  for (idx_t i = 0; i < n; i++) {
944
989
  const uint8_t* LUT = nullptr;
945
990
  qmap1[0] = i;
@@ -971,6 +1016,11 @@ void IndexIVFFastScan::search_implem_10(
971
1016
  handler.ntotal = ls;
972
1017
  handler.id_map = ids.get();
973
1018
 
1019
+ // Set context information for handlers that need additional data
1020
+ probe_map.resize(1);
1021
+ probe_map[0] = static_cast<int>(j);
1022
+ handler.set_list_context(list_no, probe_map);
1023
+
974
1024
  pq4_accumulate_loop(
975
1025
  1,
976
1026
  roundup(ls, bbs),
@@ -979,7 +1029,7 @@ void IndexIVFFastScan::search_implem_10(
979
1029
  codes.get(),
980
1030
  LUT,
981
1031
  handler,
982
- scaler);
1032
+ context.norm_scaler);
983
1033
 
984
1034
  ndis += ls;
985
1035
  nlist_visited++;
@@ -998,8 +1048,8 @@ void IndexIVFFastScan::search_implem_12(
998
1048
  const CoarseQuantized& cq,
999
1049
  size_t* ndis_out,
1000
1050
  size_t* nlist_out,
1001
- const NormTableScaler* scaler,
1002
- const IVFSearchParameters* params) const {
1051
+ const FastScanDistancePostProcessing& context,
1052
+ const IVFSearchParameters* /* params */) const {
1003
1053
  if (n == 0) { // does not work well with reservoir
1004
1054
  return;
1005
1055
  }
@@ -1010,7 +1060,7 @@ void IndexIVFFastScan::search_implem_12(
1010
1060
  AlignedTable<uint16_t> biases;
1011
1061
  std::unique_ptr<float[]> normalizers(new float[2 * n]);
1012
1062
 
1013
- compute_LUT_uint8(n, x, cq, dis_tables, biases, normalizers.get());
1063
+ compute_LUT_uint8(n, x, cq, dis_tables, biases, normalizers.get(), context);
1014
1064
 
1015
1065
  handler.begin(skip & 16 ? nullptr : normalizers.get());
1016
1066
 
@@ -1050,6 +1100,10 @@ void IndexIVFFastScan::search_implem_12(
1050
1100
 
1051
1101
  size_t ndis = 0, nlist_visited = 0;
1052
1102
 
1103
+ // Allocate vectors once and reuse them
1104
+ std::vector<int> probe_map;
1105
+ probe_map.reserve(actual_qbs2);
1106
+
1053
1107
  size_t i0 = 0;
1054
1108
  uint64_t t_copy_pack = 0, t_scan = 0;
1055
1109
  while (i0 < qcs.size()) {
@@ -1109,6 +1163,16 @@ void IndexIVFFastScan::search_implem_12(
1109
1163
  handler.q_map = q_map.data();
1110
1164
  handler.id_map = ids.get();
1111
1165
 
1166
+ // Set context information for handlers that need additional data
1167
+ // All queries in this batch access the same list_no, but each
1168
+ // query has its own probe rank (qc.rank)
1169
+ probe_map.resize(nc);
1170
+ for (size_t i = i0; i < i1; i++) {
1171
+ const QC& qc = qcs[i];
1172
+ probe_map[i - i0] = qc.rank;
1173
+ }
1174
+ handler.set_list_context(list_no, probe_map);
1175
+
1112
1176
  pq4_accumulate_loop_qbs(
1113
1177
  qbs_for_list,
1114
1178
  list_size,
@@ -1116,11 +1180,10 @@ void IndexIVFFastScan::search_implem_12(
1116
1180
  codes.get(),
1117
1181
  LUT.get(),
1118
1182
  handler,
1119
- scaler);
1183
+ context.norm_scaler);
1120
1184
  // prepare for next loop
1121
1185
  i0 = i1;
1122
1186
  }
1123
-
1124
1187
  handler.end();
1125
1188
 
1126
1189
  // these stats are not thread-safe
@@ -1140,7 +1203,7 @@ void IndexIVFFastScan::search_implem_14(
1140
1203
  idx_t* labels,
1141
1204
  const CoarseQuantized& cq,
1142
1205
  int impl,
1143
- const NormTableScaler* scaler,
1206
+ const FastScanDistancePostProcessing& context,
1144
1207
  const IVFSearchParameters* params) const {
1145
1208
  if (n == 0) { // does not work well with reservoir
1146
1209
  return;
@@ -1154,7 +1217,7 @@ void IndexIVFFastScan::search_implem_14(
1154
1217
  AlignedTable<uint16_t> biases;
1155
1218
  std::unique_ptr<float[]> normalizers(new float[2 * n]);
1156
1219
 
1157
- compute_LUT_uint8(n, x, cq, dis_tables, biases, normalizers.get());
1220
+ compute_LUT_uint8(n, x, cq, dis_tables, biases, normalizers.get(), context);
1158
1221
 
1159
1222
  struct QC {
1160
1223
  int qno; // sequence number of the query
@@ -1250,8 +1313,16 @@ void IndexIVFFastScan::search_implem_14(
1250
1313
  std::vector<float> local_dis(k * n);
1251
1314
 
1252
1315
  // prepare the result handlers
1253
- std::unique_ptr<SIMDResultHandlerToFloat> handler(make_knn_handler(
1254
- is_max, impl, n, k, local_dis.data(), local_idx.data(), sel));
1316
+ std::unique_ptr<SIMDResultHandlerToFloat> handler(
1317
+ this->make_knn_handler(
1318
+ is_max,
1319
+ impl,
1320
+ n,
1321
+ k,
1322
+ local_dis.data(),
1323
+ local_idx.data(),
1324
+ sel,
1325
+ context));
1255
1326
  handler->begin(normalizers.get());
1256
1327
 
1257
1328
  int actual_qbs2 = this->qbs2 ? this->qbs2 : 11;
@@ -1264,6 +1335,11 @@ void IndexIVFFastScan::search_implem_14(
1264
1335
 
1265
1336
  std::set<int> q_set;
1266
1337
  uint64_t t_copy_pack = 0, t_scan = 0;
1338
+
1339
+ // Allocate probe_map once per thread and reuse it
1340
+ std::vector<int> probe_map;
1341
+ probe_map.reserve(actual_qbs2);
1342
+
1267
1343
  #pragma omp for schedule(dynamic)
1268
1344
  for (idx_t cluster = 0; cluster < ses.size(); cluster++) {
1269
1345
  size_t i0 = ses[cluster].start;
@@ -1310,6 +1386,16 @@ void IndexIVFFastScan::search_implem_14(
1310
1386
  handler->q_map = q_map.data();
1311
1387
  handler->id_map = ids.get();
1312
1388
 
1389
+ // Set context information for handlers that need additional data
1390
+ // All queries in this batch access the same list_no, but each
1391
+ // query has its own probe rank (qc.rank)
1392
+ probe_map.resize(nc);
1393
+ for (size_t i = i0; i < i1; i++) {
1394
+ const QC& qc = qcs[i];
1395
+ probe_map[i - i0] = qc.rank;
1396
+ }
1397
+ handler->set_list_context(list_no, probe_map);
1398
+
1313
1399
  pq4_accumulate_loop_qbs(
1314
1400
  qbs_for_list,
1315
1401
  list_size,
@@ -1317,7 +1403,7 @@ void IndexIVFFastScan::search_implem_14(
1317
1403
  codes.get(),
1318
1404
  LUT.get(),
1319
1405
  *handler.get(),
1320
- scaler);
1406
+ context.norm_scaler);
1321
1407
  }
1322
1408
 
1323
1409
  // labels is in-place for HeapHC