faiss 0.3.0 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (171) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +0 -1
  12. data/vendor/faiss/faiss/Clustering.cpp +4 -18
  13. data/vendor/faiss/faiss/Clustering.h +31 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +22 -11
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +20 -5
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
  26. data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
  27. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  28. data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
  29. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  30. data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
  31. data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
  33. data/vendor/faiss/faiss/IndexHNSW.h +12 -48
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  35. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  36. data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
  37. data/vendor/faiss/faiss/IndexIVF.h +37 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  42. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
  43. data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  46. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  48. data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
  49. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  50. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  52. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  53. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  56. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  57. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
  58. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  59. data/vendor/faiss/faiss/IndexNSG.h +10 -10
  60. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  61. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  62. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  64. data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
  65. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  66. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  67. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
  68. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  69. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  70. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  71. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  72. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  73. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  74. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  75. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  76. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  77. data/vendor/faiss/faiss/clone_index.h +3 -0
  78. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
  79. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  82. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  83. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  84. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
  85. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  88. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  89. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  90. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
  91. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  92. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
  93. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  94. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  95. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
  96. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  97. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  98. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  99. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  100. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  101. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  102. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
  103. data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
  104. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  105. data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
  106. data/vendor/faiss/faiss/impl/HNSW.h +9 -8
  107. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  108. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
  109. data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
  110. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  111. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  112. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  113. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
  114. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  115. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  116. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  117. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  118. data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
  119. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
  121. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
  122. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  123. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  124. data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
  125. data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
  126. data/vendor/faiss/faiss/impl/io.cpp +10 -10
  127. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  128. data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
  129. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
  130. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  131. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  132. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
  133. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  134. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  135. data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
  136. data/vendor/faiss/faiss/index_factory.cpp +10 -7
  137. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  138. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
  139. data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
  140. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  141. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  142. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  143. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  144. data/vendor/faiss/faiss/utils/distances.cpp +128 -74
  145. data/vendor/faiss/faiss/utils/distances.h +81 -4
  146. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  147. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  148. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  149. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  150. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  151. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  152. data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
  153. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  154. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  155. data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
  156. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  157. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  158. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  159. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
  160. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
  161. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  162. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  163. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  164. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  165. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  166. data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
  167. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  168. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  169. data/vendor/faiss/faiss/utils/utils.cpp +112 -6
  170. data/vendor/faiss/faiss/utils/utils.h +57 -20
  171. metadata +10 -3
@@ -43,6 +43,8 @@ IndexIVFFastScan::IndexIVFFastScan(
43
43
  size_t code_size,
44
44
  MetricType metric)
45
45
  : IndexIVF(quantizer, d, nlist, code_size, metric) {
46
+ // unlike other indexes, we prefer no residuals for performance reasons.
47
+ by_residual = false;
46
48
  FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
47
49
  }
48
50
 
@@ -50,6 +52,7 @@ IndexIVFFastScan::IndexIVFFastScan() {
50
52
  bbs = 0;
51
53
  M2 = 0;
52
54
  is_trained = false;
55
+ by_residual = false;
53
56
  }
54
57
 
55
58
  void IndexIVFFastScan::init_fastscan(
@@ -79,7 +82,7 @@ void IndexIVFFastScan::init_code_packer() {
79
82
  bil->packer = get_CodePacker();
80
83
  }
81
84
 
82
- IndexIVFFastScan::~IndexIVFFastScan() {}
85
+ IndexIVFFastScan::~IndexIVFFastScan() = default;
83
86
 
84
87
  /*********************************************************
85
88
  * Code management functions
@@ -195,7 +198,7 @@ CodePacker* IndexIVFFastScan::get_CodePacker() const {
195
198
 
196
199
  namespace {
197
200
 
198
- template <class C, typename dis_t, class Scaler>
201
+ template <class C, typename dis_t>
199
202
  void estimators_from_tables_generic(
200
203
  const IndexIVFFastScan& index,
201
204
  const uint8_t* codes,
@@ -206,22 +209,26 @@ void estimators_from_tables_generic(
206
209
  size_t k,
207
210
  typename C::T* heap_dis,
208
211
  int64_t* heap_ids,
209
- const Scaler& scaler) {
212
+ const NormTableScaler* scaler) {
210
213
  using accu_t = typename C::T;
214
+ int nscale = scaler ? scaler->nscale : 0;
211
215
  for (size_t j = 0; j < ncodes; ++j) {
212
216
  BitstringReader bsr(codes + j * index.code_size, index.code_size);
213
217
  accu_t dis = bias;
214
218
  const dis_t* __restrict dt = dis_table;
215
- for (size_t m = 0; m < index.M - scaler.nscale; m++) {
219
+
220
+ for (size_t m = 0; m < index.M - nscale; m++) {
216
221
  uint64_t c = bsr.read(index.nbits);
217
222
  dis += dt[c];
218
223
  dt += index.ksub;
219
224
  }
220
225
 
221
- for (size_t m = 0; m < scaler.nscale; m++) {
222
- uint64_t c = bsr.read(index.nbits);
223
- dis += scaler.scale_one(dt[c]);
224
- dt += index.ksub;
226
+ if (scaler) {
227
+ for (size_t m = 0; m < nscale; m++) {
228
+ uint64_t c = bsr.read(index.nbits);
229
+ dis += scaler->scale_one(dt[c]);
230
+ dt += index.ksub;
231
+ }
225
232
  }
226
233
 
227
234
  if (C::cmp(heap_dis[0], dis)) {
@@ -242,18 +249,15 @@ using namespace quantize_lut;
242
249
  void IndexIVFFastScan::compute_LUT_uint8(
243
250
  size_t n,
244
251
  const float* x,
245
- const idx_t* coarse_ids,
246
- const float* coarse_dis,
252
+ const CoarseQuantized& cq,
247
253
  AlignedTable<uint8_t>& dis_tables,
248
254
  AlignedTable<uint16_t>& biases,
249
255
  float* normalizers) const {
250
256
  AlignedTable<float> dis_tables_float;
251
257
  AlignedTable<float> biases_float;
252
258
 
253
- uint64_t t0 = get_cy();
254
- compute_LUT(n, x, coarse_ids, coarse_dis, dis_tables_float, biases_float);
255
- IVFFastScan_stats.t_compute_distance_tables += get_cy() - t0;
256
-
259
+ compute_LUT(n, x, cq, dis_tables_float, biases_float);
260
+ size_t nprobe = cq.nprobe;
257
261
  bool lut_is_3d = lookup_table_is_3d();
258
262
  size_t dim123 = ksub * M;
259
263
  size_t dim123_2 = ksub * M2;
@@ -265,7 +269,6 @@ void IndexIVFFastScan::compute_LUT_uint8(
265
269
  if (biases_float.get()) {
266
270
  biases.resize(n * nprobe);
267
271
  }
268
- uint64_t t1 = get_cy();
269
272
 
270
273
  #pragma omp parallel for if (n > 100)
271
274
  for (int64_t i = 0; i < n; i++) {
@@ -291,7 +294,6 @@ void IndexIVFFastScan::compute_LUT_uint8(
291
294
  normalizers + 2 * i,
292
295
  normalizers + 2 * i + 1);
293
296
  }
294
- IVFFastScan_stats.t_round += get_cy() - t1;
295
297
  }
296
298
 
297
299
  /*********************************************************
@@ -305,44 +307,161 @@ void IndexIVFFastScan::search(
305
307
  float* distances,
306
308
  idx_t* labels,
307
309
  const SearchParameters* params) const {
310
+ auto paramsi = dynamic_cast<const SearchParametersIVF*>(params);
311
+ FAISS_THROW_IF_NOT_MSG(!params || paramsi, "need IVFSearchParameters");
312
+ search_preassigned(
313
+ n, x, k, nullptr, nullptr, distances, labels, false, paramsi);
314
+ }
315
+
316
+ void IndexIVFFastScan::search_preassigned(
317
+ idx_t n,
318
+ const float* x,
319
+ idx_t k,
320
+ const idx_t* assign,
321
+ const float* centroid_dis,
322
+ float* distances,
323
+ idx_t* labels,
324
+ bool store_pairs,
325
+ const IVFSearchParameters* params,
326
+ IndexIVFStats* stats) const {
327
+ size_t nprobe = this->nprobe;
328
+ if (params) {
329
+ FAISS_THROW_IF_NOT_MSG(
330
+ !params->quantizer_params, "quantizer params not supported");
331
+ FAISS_THROW_IF_NOT(params->max_codes == 0);
332
+ nprobe = params->nprobe;
333
+ }
308
334
  FAISS_THROW_IF_NOT_MSG(
309
- !params, "search params not supported for this index");
335
+ !store_pairs, "store_pairs not supported for this index");
336
+ FAISS_THROW_IF_NOT_MSG(!stats, "stats not supported for this index");
310
337
  FAISS_THROW_IF_NOT(k > 0);
311
338
 
312
- DummyScaler scaler;
313
- if (metric_type == METRIC_L2) {
314
- search_dispatch_implem<true>(n, x, k, distances, labels, scaler);
339
+ const CoarseQuantized cq = {nprobe, centroid_dis, assign};
340
+ search_dispatch_implem(n, x, k, distances, labels, cq, nullptr);
341
+ }
342
+
343
+ void IndexIVFFastScan::range_search(
344
+ idx_t n,
345
+ const float* x,
346
+ float radius,
347
+ RangeSearchResult* result,
348
+ const SearchParameters* params) const {
349
+ FAISS_THROW_IF_NOT(!params);
350
+ const CoarseQuantized cq = {nprobe, nullptr, nullptr};
351
+ range_search_dispatch_implem(n, x, radius, *result, cq, nullptr);
352
+ }
353
+
354
+ namespace {
355
+
356
+ template <class C>
357
+ ResultHandlerCompare<C, true>* make_knn_handler_fixC(
358
+ int impl,
359
+ idx_t n,
360
+ idx_t k,
361
+ float* distances,
362
+ idx_t* labels) {
363
+ using HeapHC = HeapHandler<C, true>;
364
+ using ReservoirHC = ReservoirHandler<C, true>;
365
+ using SingleResultHC = SingleResultHandler<C, true>;
366
+
367
+ if (k == 1) {
368
+ return new SingleResultHC(n, 0, distances, labels);
369
+ } else if (impl % 2 == 0) {
370
+ return new HeapHC(n, 0, k, distances, labels);
371
+ } else /* if (impl % 2 == 1) */ {
372
+ return new ReservoirHC(n, 0, k, 2 * k, distances, labels);
373
+ }
374
+ }
375
+
376
+ SIMDResultHandlerToFloat* make_knn_handler(
377
+ bool is_max,
378
+ int impl,
379
+ idx_t n,
380
+ idx_t k,
381
+ float* distances,
382
+ idx_t* labels) {
383
+ if (is_max) {
384
+ return make_knn_handler_fixC<CMax<uint16_t, int64_t>>(
385
+ impl, n, k, distances, labels);
315
386
  } else {
316
- search_dispatch_implem<false>(n, x, k, distances, labels, scaler);
387
+ return make_knn_handler_fixC<CMin<uint16_t, int64_t>>(
388
+ impl, n, k, distances, labels);
317
389
  }
318
390
  }
319
391
 
320
- void IndexIVFFastScan::range_search(
321
- idx_t,
322
- const float*,
323
- float,
324
- RangeSearchResult*,
325
- const SearchParameters*) const {
326
- FAISS_THROW_MSG("not implemented");
392
+ using CoarseQuantized = IndexIVFFastScan::CoarseQuantized;
393
+
394
+ struct CoarseQuantizedWithBuffer : CoarseQuantized {
395
+ explicit CoarseQuantizedWithBuffer(const CoarseQuantized& cq)
396
+ : CoarseQuantized(cq) {}
397
+
398
+ bool done() const {
399
+ return ids != nullptr;
400
+ }
401
+
402
+ std::vector<idx_t> ids_buffer;
403
+ std::vector<float> dis_buffer;
404
+
405
+ void quantize(const Index* quantizer, idx_t n, const float* x) {
406
+ dis_buffer.resize(nprobe * n);
407
+ ids_buffer.resize(nprobe * n);
408
+ quantizer->search(n, x, nprobe, dis_buffer.data(), ids_buffer.data());
409
+ dis = dis_buffer.data();
410
+ ids = ids_buffer.data();
411
+ }
412
+ };
413
+
414
+ struct CoarseQuantizedSlice : CoarseQuantizedWithBuffer {
415
+ size_t i0, i1;
416
+ CoarseQuantizedSlice(const CoarseQuantized& cq, size_t i0, size_t i1)
417
+ : CoarseQuantizedWithBuffer(cq), i0(i0), i1(i1) {
418
+ if (done()) {
419
+ dis += nprobe * i0;
420
+ ids += nprobe * i0;
421
+ }
422
+ }
423
+
424
+ void quantize_slice(const Index* quantizer, const float* x) {
425
+ quantize(quantizer, i1 - i0, x + quantizer->d * i0);
426
+ }
427
+ };
428
+
429
+ int compute_search_nslice(
430
+ const IndexIVFFastScan* index,
431
+ size_t n,
432
+ size_t nprobe) {
433
+ int nslice;
434
+ if (n <= omp_get_max_threads()) {
435
+ nslice = n;
436
+ } else if (index->lookup_table_is_3d()) {
437
+ // make sure we don't make too big LUT tables
438
+ size_t lut_size_per_query = index->M * index->ksub * nprobe *
439
+ (sizeof(float) + sizeof(uint8_t));
440
+
441
+ size_t max_lut_size = precomputed_table_max_bytes;
442
+ // how many queries we can handle within mem budget
443
+ size_t nq_ok = std::max(max_lut_size / lut_size_per_query, size_t(1));
444
+ nslice = roundup(
445
+ std::max(size_t(n / nq_ok), size_t(1)), omp_get_max_threads());
446
+ } else {
447
+ // LUTs unlikely to be a limiting factor
448
+ nslice = omp_get_max_threads();
449
+ }
450
+ return nslice;
327
451
  }
328
452
 
329
- template <bool is_max, class Scaler>
453
+ } // namespace
454
+
330
455
  void IndexIVFFastScan::search_dispatch_implem(
331
456
  idx_t n,
332
457
  const float* x,
333
458
  idx_t k,
334
459
  float* distances,
335
460
  idx_t* labels,
336
- const Scaler& scaler) const {
337
- using Cfloat = typename std::conditional<
338
- is_max,
339
- CMax<float, int64_t>,
340
- CMin<float, int64_t>>::type;
341
-
342
- using C = typename std::conditional<
343
- is_max,
344
- CMax<uint16_t, int64_t>,
345
- CMin<uint16_t, int64_t>>::type;
461
+ const CoarseQuantized& cq_in,
462
+ const NormTableScaler* scaler) const {
463
+ bool is_max = !is_similarity_metric(metric_type);
464
+ using RH = SIMDResultHandlerToFloat;
346
465
 
347
466
  if (n == 0) {
348
467
  return;
@@ -357,70 +476,74 @@ void IndexIVFFastScan::search_dispatch_implem(
357
476
  } else {
358
477
  impl = 10;
359
478
  }
360
- if (k > 20) {
479
+ if (k > 20) { // use reservoir rather than heap
361
480
  impl++;
362
481
  }
363
482
  }
364
483
 
484
+ bool multiple_threads =
485
+ n > 1 && impl >= 10 && impl <= 13 && omp_get_max_threads() > 1;
486
+ if (impl >= 100) {
487
+ multiple_threads = false;
488
+ impl -= 100;
489
+ }
490
+
491
+ CoarseQuantizedWithBuffer cq(cq_in);
492
+
493
+ if (!cq.done() && !multiple_threads) {
494
+ // we do the coarse quantization here execpt when search is
495
+ // sliced over threads (then it is more efficient to have each thread do
496
+ // its own coarse quantization)
497
+ cq.quantize(quantizer, n, x);
498
+ }
499
+
365
500
  if (impl == 1) {
366
- search_implem_1<Cfloat>(n, x, k, distances, labels, scaler);
501
+ if (is_max) {
502
+ search_implem_1<CMax<float, int64_t>>(
503
+ n, x, k, distances, labels, cq, scaler);
504
+ } else {
505
+ search_implem_1<CMin<float, int64_t>>(
506
+ n, x, k, distances, labels, cq, scaler);
507
+ }
367
508
  } else if (impl == 2) {
368
- search_implem_2<C>(n, x, k, distances, labels, scaler);
509
+ if (is_max) {
510
+ search_implem_2<CMax<uint16_t, int64_t>>(
511
+ n, x, k, distances, labels, cq, scaler);
512
+ } else {
513
+ search_implem_2<CMin<uint16_t, int64_t>>(
514
+ n, x, k, distances, labels, cq, scaler);
515
+ }
369
516
 
370
517
  } else if (impl >= 10 && impl <= 15) {
371
518
  size_t ndis = 0, nlist_visited = 0;
372
519
 
373
- if (n < 2) {
520
+ if (!multiple_threads) {
521
+ // clang-format off
374
522
  if (impl == 12 || impl == 13) {
375
- search_implem_12<C>(
376
- n,
377
- x,
378
- k,
379
- distances,
380
- labels,
381
- impl,
382
- &ndis,
383
- &nlist_visited,
384
- scaler);
523
+ std::unique_ptr<RH> handler(make_knn_handler(is_max, impl, n, k, distances, labels));
524
+ search_implem_12(
525
+ n, x, *handler.get(),
526
+ cq, &ndis, &nlist_visited, scaler);
527
+
385
528
  } else if (impl == 14 || impl == 15) {
386
- search_implem_14<C>(n, x, k, distances, labels, impl, scaler);
529
+
530
+ search_implem_14(
531
+ n, x, k, distances, labels,
532
+ cq, impl, scaler);
387
533
  } else {
388
- search_implem_10<C>(
389
- n,
390
- x,
391
- k,
392
- distances,
393
- labels,
394
- impl,
395
- &ndis,
396
- &nlist_visited,
397
- scaler);
534
+ std::unique_ptr<RH> handler(make_knn_handler(is_max, impl, n, k, distances, labels));
535
+ search_implem_10(
536
+ n, x, *handler.get(), cq,
537
+ &ndis, &nlist_visited, scaler);
398
538
  }
539
+ // clang-format on
399
540
  } else {
400
541
  // explicitly slice over threads
401
- int nslice;
402
- if (n <= omp_get_max_threads()) {
403
- nslice = n;
404
- } else if (lookup_table_is_3d()) {
405
- // make sure we don't make too big LUT tables
406
- size_t lut_size_per_query =
407
- M * ksub * nprobe * (sizeof(float) + sizeof(uint8_t));
408
-
409
- size_t max_lut_size = precomputed_table_max_bytes;
410
- // how many queries we can handle within mem budget
411
- size_t nq_ok =
412
- std::max(max_lut_size / lut_size_per_query, size_t(1));
413
- nslice =
414
- roundup(std::max(size_t(n / nq_ok), size_t(1)),
415
- omp_get_max_threads());
416
- } else {
417
- // LUTs unlikely to be a limiting factor
418
- nslice = omp_get_max_threads();
419
- }
420
- if (impl == 14 ||
421
- impl == 15) { // this might require slicing if there are too
422
- // many queries (for now we keep this simple)
423
- search_implem_14<C>(n, x, k, distances, labels, impl, scaler);
542
+ int nslice = compute_search_nslice(this, n, cq.nprobe);
543
+ if (impl == 14 || impl == 15) {
544
+ // this might require slicing if there are too
545
+ // many queries (for now we keep this simple)
546
+ search_implem_14(n, x, k, distances, labels, cq, impl, scaler);
424
547
  } else {
425
548
  #pragma omp parallel for reduction(+ : ndis, nlist_visited)
426
549
  for (int slice = 0; slice < nslice; slice++) {
@@ -428,29 +551,23 @@ void IndexIVFFastScan::search_dispatch_implem(
428
551
  idx_t i1 = n * (slice + 1) / nslice;
429
552
  float* dis_i = distances + i0 * k;
430
553
  idx_t* lab_i = labels + i0 * k;
554
+ CoarseQuantizedSlice cq_i(cq, i0, i1);
555
+ if (!cq_i.done()) {
556
+ cq_i.quantize_slice(quantizer, x);
557
+ }
558
+ std::unique_ptr<RH> handler(make_knn_handler(
559
+ is_max, impl, i1 - i0, k, dis_i, lab_i));
560
+ // clang-format off
431
561
  if (impl == 12 || impl == 13) {
432
- search_implem_12<C>(
433
- i1 - i0,
434
- x + i0 * d,
435
- k,
436
- dis_i,
437
- lab_i,
438
- impl,
439
- &ndis,
440
- &nlist_visited,
441
- scaler);
562
+ search_implem_12(
563
+ i1 - i0, x + i0 * d, *handler.get(),
564
+ cq_i, &ndis, &nlist_visited, scaler);
442
565
  } else {
443
- search_implem_10<C>(
444
- i1 - i0,
445
- x + i0 * d,
446
- k,
447
- dis_i,
448
- lab_i,
449
- impl,
450
- &ndis,
451
- &nlist_visited,
452
- scaler);
566
+ search_implem_10(
567
+ i1 - i0, x + i0 * d, *handler.get(),
568
+ cq_i, &ndis, &nlist_visited, scaler);
453
569
  }
570
+ // clang-format on
454
571
  }
455
572
  }
456
573
  }
@@ -462,31 +579,139 @@ void IndexIVFFastScan::search_dispatch_implem(
462
579
  }
463
580
  }
464
581
 
465
- template <class C, class Scaler>
582
+ void IndexIVFFastScan::range_search_dispatch_implem(
583
+ idx_t n,
584
+ const float* x,
585
+ float radius,
586
+ RangeSearchResult& rres,
587
+ const CoarseQuantized& cq_in,
588
+ const NormTableScaler* scaler) const {
589
+ bool is_max = !is_similarity_metric(metric_type);
590
+
591
+ if (n == 0) {
592
+ return;
593
+ }
594
+
595
+ // actual implementation used
596
+ int impl = implem;
597
+
598
+ if (impl == 0) {
599
+ if (bbs == 32) {
600
+ impl = 12;
601
+ } else {
602
+ impl = 10;
603
+ }
604
+ }
605
+
606
+ CoarseQuantizedWithBuffer cq(cq_in);
607
+
608
+ bool multiple_threads =
609
+ n > 1 && impl >= 10 && impl <= 13 && omp_get_max_threads() > 1;
610
+ if (impl >= 100) {
611
+ multiple_threads = false;
612
+ impl -= 100;
613
+ }
614
+
615
+ if (!multiple_threads && !cq.done()) {
616
+ cq.quantize(quantizer, n, x);
617
+ }
618
+
619
+ size_t ndis = 0, nlist_visited = 0;
620
+
621
+ if (!multiple_threads) { // single thread
622
+ std::unique_ptr<SIMDResultHandlerToFloat> handler;
623
+ if (is_max) {
624
+ handler.reset(new RangeHandler<CMax<uint16_t, int64_t>, true>(
625
+ rres, radius, 0));
626
+ } else {
627
+ handler.reset(new RangeHandler<CMin<uint16_t, int64_t>, true>(
628
+ rres, radius, 0));
629
+ }
630
+ if (impl == 12) {
631
+ search_implem_12(
632
+ n, x, *handler.get(), cq, &ndis, &nlist_visited, scaler);
633
+ } else if (impl == 10) {
634
+ search_implem_10(
635
+ n, x, *handler.get(), cq, &ndis, &nlist_visited, scaler);
636
+ } else {
637
+ FAISS_THROW_FMT("Range search implem %d not impemented", impl);
638
+ }
639
+ } else {
640
+ // explicitly slice over threads
641
+ int nslice = compute_search_nslice(this, n, cq.nprobe);
642
+ #pragma omp parallel
643
+ {
644
+ RangeSearchPartialResult pres(&rres);
645
+
646
+ #pragma omp for reduction(+ : ndis, nlist_visited)
647
+ for (int slice = 0; slice < nslice; slice++) {
648
+ idx_t i0 = n * slice / nslice;
649
+ idx_t i1 = n * (slice + 1) / nslice;
650
+ CoarseQuantizedSlice cq_i(cq, i0, i1);
651
+ if (!cq_i.done()) {
652
+ cq_i.quantize_slice(quantizer, x);
653
+ }
654
+ std::unique_ptr<SIMDResultHandlerToFloat> handler;
655
+ if (is_max) {
656
+ handler.reset(new PartialRangeHandler<
657
+ CMax<uint16_t, int64_t>,
658
+ true>(pres, radius, 0, i0, i1));
659
+ } else {
660
+ handler.reset(new PartialRangeHandler<
661
+ CMin<uint16_t, int64_t>,
662
+ true>(pres, radius, 0, i0, i1));
663
+ }
664
+
665
+ if (impl == 12 || impl == 13) {
666
+ search_implem_12(
667
+ i1 - i0,
668
+ x + i0 * d,
669
+ *handler.get(),
670
+ cq_i,
671
+ &ndis,
672
+ &nlist_visited,
673
+ scaler);
674
+ } else {
675
+ search_implem_10(
676
+ i1 - i0,
677
+ x + i0 * d,
678
+ *handler.get(),
679
+ cq_i,
680
+ &ndis,
681
+ &nlist_visited,
682
+ scaler);
683
+ }
684
+ }
685
+ pres.finalize();
686
+ }
687
+ }
688
+
689
+ indexIVF_stats.nq += n;
690
+ indexIVF_stats.ndis += ndis;
691
+ indexIVF_stats.nlist += nlist_visited;
692
+ }
693
+
694
+ template <class C>
466
695
  void IndexIVFFastScan::search_implem_1(
467
696
  idx_t n,
468
697
  const float* x,
469
698
  idx_t k,
470
699
  float* distances,
471
700
  idx_t* labels,
472
- const Scaler& scaler) const {
701
+ const CoarseQuantized& cq,
702
+ const NormTableScaler* scaler) const {
473
703
  FAISS_THROW_IF_NOT(orig_invlists);
474
704
 
475
- std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
476
- std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
477
-
478
- quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
479
-
480
705
  size_t dim12 = ksub * M;
481
706
  AlignedTable<float> dis_tables;
482
707
  AlignedTable<float> biases;
483
708
 
484
- compute_LUT(n, x, coarse_ids.get(), coarse_dis.get(), dis_tables, biases);
709
+ compute_LUT(n, x, cq, dis_tables, biases);
485
710
 
486
711
  bool single_LUT = !lookup_table_is_3d();
487
712
 
488
713
  size_t ndis = 0, nlist_visited = 0;
489
-
714
+ size_t nprobe = cq.nprobe;
490
715
  #pragma omp parallel for reduction(+ : ndis, nlist_visited)
491
716
  for (idx_t i = 0; i < n; i++) {
492
717
  int64_t* heap_ids = labels + i * k;
@@ -501,7 +726,7 @@ void IndexIVFFastScan::search_implem_1(
501
726
  if (!single_LUT) {
502
727
  LUT = dis_tables.get() + (i * nprobe + j) * dim12;
503
728
  }
504
- idx_t list_no = coarse_ids[i * nprobe + j];
729
+ idx_t list_no = cq.ids[i * nprobe + j];
505
730
  if (list_no < 0)
506
731
  continue;
507
732
  size_t ls = orig_invlists->list_size(list_no);
@@ -533,38 +758,28 @@ void IndexIVFFastScan::search_implem_1(
533
758
  indexIVF_stats.nlist += nlist_visited;
534
759
  }
535
760
 
536
- template <class C, class Scaler>
761
+ template <class C>
537
762
  void IndexIVFFastScan::search_implem_2(
538
763
  idx_t n,
539
764
  const float* x,
540
765
  idx_t k,
541
766
  float* distances,
542
767
  idx_t* labels,
543
- const Scaler& scaler) const {
768
+ const CoarseQuantized& cq,
769
+ const NormTableScaler* scaler) const {
544
770
  FAISS_THROW_IF_NOT(orig_invlists);
545
771
 
546
- std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
547
- std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
548
-
549
- quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
550
-
551
772
  size_t dim12 = ksub * M2;
552
773
  AlignedTable<uint8_t> dis_tables;
553
774
  AlignedTable<uint16_t> biases;
554
775
  std::unique_ptr<float[]> normalizers(new float[2 * n]);
555
776
 
556
- compute_LUT_uint8(
557
- n,
558
- x,
559
- coarse_ids.get(),
560
- coarse_dis.get(),
561
- dis_tables,
562
- biases,
563
- normalizers.get());
777
+ compute_LUT_uint8(n, x, cq, dis_tables, biases, normalizers.get());
564
778
 
565
779
  bool single_LUT = !lookup_table_is_3d();
566
780
 
567
781
  size_t ndis = 0, nlist_visited = 0;
782
+ size_t nprobe = cq.nprobe;
568
783
 
569
784
  #pragma omp parallel for reduction(+ : ndis, nlist_visited)
570
785
  for (idx_t i = 0; i < n; i++) {
@@ -581,7 +796,7 @@ void IndexIVFFastScan::search_implem_2(
581
796
  if (!single_LUT) {
582
797
  LUT = dis_tables.get() + (i * nprobe + j) * dim12;
583
798
  }
584
- idx_t list_no = coarse_ids[i * nprobe + j];
799
+ idx_t list_no = cq.ids[i * nprobe + j];
585
800
  if (list_no < 0)
586
801
  continue;
587
802
  size_t ls = orig_invlists->list_size(list_no);
@@ -626,171 +841,99 @@ void IndexIVFFastScan::search_implem_2(
626
841
  indexIVF_stats.nlist += nlist_visited;
627
842
  }
628
843
 
629
- template <class C, class Scaler>
630
844
  void IndexIVFFastScan::search_implem_10(
631
845
  idx_t n,
632
846
  const float* x,
633
- idx_t k,
634
- float* distances,
635
- idx_t* labels,
636
- int impl,
847
+ SIMDResultHandlerToFloat& handler,
848
+ const CoarseQuantized& cq,
637
849
  size_t* ndis_out,
638
850
  size_t* nlist_out,
639
- const Scaler& scaler) const {
640
- memset(distances, -1, sizeof(float) * k * n);
641
- memset(labels, -1, sizeof(idx_t) * k * n);
642
-
643
- using HeapHC = HeapHandler<C, true>;
644
- using ReservoirHC = ReservoirHandler<C, true>;
645
- using SingleResultHC = SingleResultHandler<C, true>;
646
-
647
- std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
648
- std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
649
-
650
- uint64_t times[10];
651
- memset(times, 0, sizeof(times));
652
- int ti = 0;
653
- #define TIC times[ti++] = get_cy()
654
- TIC;
655
-
656
- quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
657
-
658
- TIC;
659
-
851
+ const NormTableScaler* scaler) const {
660
852
  size_t dim12 = ksub * M2;
661
853
  AlignedTable<uint8_t> dis_tables;
662
854
  AlignedTable<uint16_t> biases;
663
855
  std::unique_ptr<float[]> normalizers(new float[2 * n]);
664
856
 
665
- compute_LUT_uint8(
666
- n,
667
- x,
668
- coarse_ids.get(),
669
- coarse_dis.get(),
670
- dis_tables,
671
- biases,
672
- normalizers.get());
673
-
674
- TIC;
857
+ compute_LUT_uint8(n, x, cq, dis_tables, biases, normalizers.get());
675
858
 
676
859
  bool single_LUT = !lookup_table_is_3d();
677
860
 
678
- TIC;
679
- size_t ndis = 0, nlist_visited = 0;
861
+ size_t ndis = 0;
862
+ int qmap1[1];
680
863
 
681
- {
682
- AlignedTable<uint16_t> tmp_distances(k);
683
- for (idx_t i = 0; i < n; i++) {
684
- const uint8_t* LUT = nullptr;
685
- int qmap1[1] = {0};
686
- std::unique_ptr<SIMDResultHandler<C, true>> handler;
687
-
688
- if (k == 1) {
689
- handler.reset(new SingleResultHC(1, 0));
690
- } else if (impl == 10) {
691
- handler.reset(new HeapHC(
692
- 1, tmp_distances.get(), labels + i * k, k, 0));
693
- } else if (impl == 11) {
694
- handler.reset(new ReservoirHC(1, 0, k, 2 * k));
695
- } else {
696
- FAISS_THROW_MSG("invalid");
697
- }
864
+ handler.q_map = qmap1;
865
+ handler.begin(skip & 16 ? nullptr : normalizers.get());
866
+ size_t nprobe = cq.nprobe;
698
867
 
699
- handler->q_map = qmap1;
868
+ for (idx_t i = 0; i < n; i++) {
869
+ const uint8_t* LUT = nullptr;
870
+ qmap1[0] = i;
700
871
 
701
- if (single_LUT) {
702
- LUT = dis_tables.get() + i * dim12;
872
+ if (single_LUT) {
873
+ LUT = dis_tables.get() + i * dim12;
874
+ }
875
+ for (idx_t j = 0; j < nprobe; j++) {
876
+ size_t ij = i * nprobe + j;
877
+ if (!single_LUT) {
878
+ LUT = dis_tables.get() + ij * dim12;
879
+ }
880
+ if (biases.get()) {
881
+ handler.dbias = biases.get() + ij;
703
882
  }
704
- for (idx_t j = 0; j < nprobe; j++) {
705
- size_t ij = i * nprobe + j;
706
- if (!single_LUT) {
707
- LUT = dis_tables.get() + ij * dim12;
708
- }
709
- if (biases.get()) {
710
- handler->dbias = biases.get() + ij;
711
- }
712
-
713
- idx_t list_no = coarse_ids[ij];
714
- if (list_no < 0)
715
- continue;
716
- size_t ls = invlists->list_size(list_no);
717
- if (ls == 0)
718
- continue;
719
883
 
720
- InvertedLists::ScopedCodes codes(invlists, list_no);
721
- InvertedLists::ScopedIds ids(invlists, list_no);
884
+ idx_t list_no = cq.ids[ij];
885
+ if (list_no < 0) {
886
+ continue;
887
+ }
888
+ size_t ls = invlists->list_size(list_no);
889
+ if (ls == 0) {
890
+ continue;
891
+ }
722
892
 
723
- handler->ntotal = ls;
724
- handler->id_map = ids.get();
893
+ InvertedLists::ScopedCodes codes(invlists, list_no);
894
+ InvertedLists::ScopedIds ids(invlists, list_no);
725
895
 
726
- #define DISPATCH(classHC) \
727
- if (dynamic_cast<classHC*>(handler.get())) { \
728
- auto* res = static_cast<classHC*>(handler.get()); \
729
- pq4_accumulate_loop( \
730
- 1, roundup(ls, bbs), bbs, M2, codes.get(), LUT, *res, scaler); \
731
- }
732
- DISPATCH(HeapHC)
733
- else DISPATCH(ReservoirHC) else DISPATCH(SingleResultHC)
734
- #undef DISPATCH
896
+ handler.ntotal = ls;
897
+ handler.id_map = ids.get();
735
898
 
736
- nlist_visited++;
737
- ndis++;
738
- }
899
+ pq4_accumulate_loop(
900
+ 1,
901
+ roundup(ls, bbs),
902
+ bbs,
903
+ M2,
904
+ codes.get(),
905
+ LUT,
906
+ handler,
907
+ scaler);
739
908
 
740
- handler->to_flat_arrays(
741
- distances + i * k,
742
- labels + i * k,
743
- skip & 16 ? nullptr : normalizers.get() + i * 2);
909
+ ndis++;
744
910
  }
745
911
  }
912
+ handler.end();
746
913
  *ndis_out = ndis;
747
914
  *nlist_out = nlist;
748
915
  }
749
916
 
750
- template <class C, class Scaler>
751
917
  void IndexIVFFastScan::search_implem_12(
752
918
  idx_t n,
753
919
  const float* x,
754
- idx_t k,
755
- float* distances,
756
- idx_t* labels,
757
- int impl,
920
+ SIMDResultHandlerToFloat& handler,
921
+ const CoarseQuantized& cq,
758
922
  size_t* ndis_out,
759
923
  size_t* nlist_out,
760
- const Scaler& scaler) const {
924
+ const NormTableScaler* scaler) const {
761
925
  if (n == 0) { // does not work well with reservoir
762
926
  return;
763
927
  }
764
928
  FAISS_THROW_IF_NOT(bbs == 32);
765
929
 
766
- std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
767
- std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
768
-
769
- uint64_t times[10];
770
- memset(times, 0, sizeof(times));
771
- int ti = 0;
772
- #define TIC times[ti++] = get_cy()
773
- TIC;
774
-
775
- quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
776
-
777
- TIC;
778
-
779
930
  size_t dim12 = ksub * M2;
780
931
  AlignedTable<uint8_t> dis_tables;
781
932
  AlignedTable<uint16_t> biases;
782
933
  std::unique_ptr<float[]> normalizers(new float[2 * n]);
783
934
 
784
- compute_LUT_uint8(
785
- n,
786
- x,
787
- coarse_ids.get(),
788
- coarse_dis.get(),
789
- dis_tables,
790
- biases,
791
- normalizers.get());
792
-
793
- TIC;
935
+ compute_LUT_uint8(n, x, cq, dis_tables, biases, normalizers.get());
936
+ handler.begin(skip & 16 ? nullptr : normalizers.get());
794
937
 
795
938
  struct QC {
796
939
  int qno; // sequence number of the query
@@ -798,14 +941,15 @@ void IndexIVFFastScan::search_implem_12(
798
941
  int rank; // this is the rank'th result of the coarse quantizer
799
942
  };
800
943
  bool single_LUT = !lookup_table_is_3d();
944
+ size_t nprobe = cq.nprobe;
801
945
 
802
946
  std::vector<QC> qcs;
803
947
  {
804
948
  int ij = 0;
805
949
  for (int i = 0; i < n; i++) {
806
950
  for (int j = 0; j < nprobe; j++) {
807
- if (coarse_ids[ij] >= 0) {
808
- qcs.push_back(QC{i, int(coarse_ids[ij]), int(j)});
951
+ if (cq.ids[ij] >= 0) {
952
+ qcs.push_back(QC{i, int(cq.ids[ij]), int(j)});
809
953
  }
810
954
  ij++;
811
955
  }
@@ -814,42 +958,21 @@ void IndexIVFFastScan::search_implem_12(
814
958
  return a.list_no < b.list_no;
815
959
  });
816
960
  }
817
- TIC;
818
-
819
961
  // prepare the result handlers
820
962
 
821
- std::unique_ptr<SIMDResultHandler<C, true>> handler;
822
- AlignedTable<uint16_t> tmp_distances;
823
-
824
- using HeapHC = HeapHandler<C, true>;
825
- using ReservoirHC = ReservoirHandler<C, true>;
826
- using SingleResultHC = SingleResultHandler<C, true>;
827
-
828
- if (k == 1) {
829
- handler.reset(new SingleResultHC(n, 0));
830
- } else if (impl == 12) {
831
- tmp_distances.resize(n * k);
832
- handler.reset(new HeapHC(n, tmp_distances.get(), labels, k, 0));
833
- } else if (impl == 13) {
834
- handler.reset(new ReservoirHC(n, 0, k, 2 * k));
835
- }
836
-
837
963
  int qbs2 = this->qbs2 ? this->qbs2 : 11;
838
964
 
839
965
  std::vector<uint16_t> tmp_bias;
840
966
  if (biases.get()) {
841
967
  tmp_bias.resize(qbs2);
842
- handler->dbias = tmp_bias.data();
968
+ handler.dbias = tmp_bias.data();
843
969
  }
844
- TIC;
845
970
 
846
971
  size_t ndis = 0;
847
972
 
848
973
  size_t i0 = 0;
849
974
  uint64_t t_copy_pack = 0, t_scan = 0;
850
975
  while (i0 < qcs.size()) {
851
- uint64_t tt0 = get_cy();
852
-
853
976
  // find all queries that access this inverted list
854
977
  int list_no = qcs[i0].list_no;
855
978
  size_t i1 = i0 + 1;
@@ -897,93 +1020,47 @@ void IndexIVFFastScan::search_implem_12(
897
1020
 
898
1021
  // prepare the handler
899
1022
 
900
- handler->ntotal = list_size;
901
- handler->q_map = q_map.data();
902
- handler->id_map = ids.get();
903
- uint64_t tt1 = get_cy();
904
-
905
- #define DISPATCH(classHC) \
906
- if (dynamic_cast<classHC*>(handler.get())) { \
907
- auto* res = static_cast<classHC*>(handler.get()); \
908
- pq4_accumulate_loop_qbs( \
909
- qbs, list_size, M2, codes.get(), LUT.get(), *res, scaler); \
910
- }
911
- DISPATCH(HeapHC)
912
- else DISPATCH(ReservoirHC) else DISPATCH(SingleResultHC)
913
-
914
- // prepare for next loop
915
- i0 = i1;
1023
+ handler.ntotal = list_size;
1024
+ handler.q_map = q_map.data();
1025
+ handler.id_map = ids.get();
916
1026
 
917
- uint64_t tt2 = get_cy();
918
- t_copy_pack += tt1 - tt0;
919
- t_scan += tt2 - tt1;
1027
+ pq4_accumulate_loop_qbs(
1028
+ qbs, list_size, M2, codes.get(), LUT.get(), handler, scaler);
1029
+ // prepare for next loop
1030
+ i0 = i1;
920
1031
  }
921
- TIC;
922
-
923
- // labels is in-place for HeapHC
924
- handler->to_flat_arrays(
925
- distances, labels, skip & 16 ? nullptr : normalizers.get());
926
1032
 
927
- TIC;
1033
+ handler.end();
928
1034
 
929
1035
  // these stats are not thread-safe
930
1036
 
931
- for (int i = 1; i < ti; i++) {
932
- IVFFastScan_stats.times[i] += times[i] - times[i - 1];
933
- }
934
1037
  IVFFastScan_stats.t_copy_pack += t_copy_pack;
935
1038
  IVFFastScan_stats.t_scan += t_scan;
936
1039
 
937
- if (auto* rh = dynamic_cast<ReservoirHC*>(handler.get())) {
938
- for (int i = 0; i < 4; i++) {
939
- IVFFastScan_stats.reservoir_times[i] += rh->times[i];
940
- }
941
- }
942
-
943
1040
  *ndis_out = ndis;
944
1041
  *nlist_out = nlist;
945
1042
  }
946
1043
 
947
- template <class C, class Scaler>
948
1044
  void IndexIVFFastScan::search_implem_14(
949
1045
  idx_t n,
950
1046
  const float* x,
951
1047
  idx_t k,
952
1048
  float* distances,
953
1049
  idx_t* labels,
1050
+ const CoarseQuantized& cq,
954
1051
  int impl,
955
- const Scaler& scaler) const {
1052
+ const NormTableScaler* scaler) const {
956
1053
  if (n == 0) { // does not work well with reservoir
957
1054
  return;
958
1055
  }
959
1056
  FAISS_THROW_IF_NOT(bbs == 32);
960
1057
 
961
- std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
962
- std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
963
-
964
- uint64_t ttg0 = get_cy();
965
-
966
- quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
967
-
968
- uint64_t ttg1 = get_cy();
969
- uint64_t coarse_search_tt = ttg1 - ttg0;
970
-
971
1058
  size_t dim12 = ksub * M2;
972
1059
  AlignedTable<uint8_t> dis_tables;
973
1060
  AlignedTable<uint16_t> biases;
974
1061
  std::unique_ptr<float[]> normalizers(new float[2 * n]);
975
1062
 
976
- compute_LUT_uint8(
977
- n,
978
- x,
979
- coarse_ids.get(),
980
- coarse_dis.get(),
981
- dis_tables,
982
- biases,
983
- normalizers.get());
984
-
985
- uint64_t ttg2 = get_cy();
986
- uint64_t lut_compute_tt = ttg2 - ttg1;
1063
+ compute_LUT_uint8(n, x, cq, dis_tables, biases, normalizers.get());
987
1064
 
988
1065
  struct QC {
989
1066
  int qno; // sequence number of the query
@@ -991,14 +1068,15 @@ void IndexIVFFastScan::search_implem_14(
991
1068
  int rank; // this is the rank'th result of the coarse quantizer
992
1069
  };
993
1070
  bool single_LUT = !lookup_table_is_3d();
1071
+ size_t nprobe = cq.nprobe;
994
1072
 
995
1073
  std::vector<QC> qcs;
996
1074
  {
997
1075
  int ij = 0;
998
1076
  for (int i = 0; i < n; i++) {
999
1077
  for (int j = 0; j < nprobe; j++) {
1000
- if (coarse_ids[ij] >= 0) {
1001
- qcs.push_back(QC{i, int(coarse_ids[ij]), int(j)});
1078
+ if (cq.ids[ij] >= 0) {
1079
+ qcs.push_back(QC{i, int(cq.ids[ij]), int(j)});
1002
1080
  }
1003
1081
  ij++;
1004
1082
  }
@@ -1036,14 +1114,13 @@ void IndexIVFFastScan::search_implem_14(
1036
1114
  ses.push_back(SE{i0_l, i1, list_size});
1037
1115
  i0_l = i1;
1038
1116
  }
1039
- uint64_t ttg3 = get_cy();
1040
- uint64_t compute_clusters_tt = ttg3 - ttg2;
1041
1117
 
1042
1118
  // function to handle the global heap
1119
+ bool is_max = !is_similarity_metric(metric_type);
1043
1120
  using HeapForIP = CMin<float, idx_t>;
1044
1121
  using HeapForL2 = CMax<float, idx_t>;
1045
1122
  auto init_result = [&](float* simi, idx_t* idxi) {
1046
- if (metric_type == METRIC_INNER_PRODUCT) {
1123
+ if (!is_max) {
1047
1124
  heap_heapify<HeapForIP>(k, simi, idxi);
1048
1125
  } else {
1049
1126
  heap_heapify<HeapForL2>(k, simi, idxi);
@@ -1054,7 +1131,7 @@ void IndexIVFFastScan::search_implem_14(
1054
1131
  const idx_t* local_idx,
1055
1132
  float* simi,
1056
1133
  idx_t* idxi) {
1057
- if (metric_type == METRIC_INNER_PRODUCT) {
1134
+ if (!is_max) {
1058
1135
  heap_addn<HeapForIP>(k, simi, idxi, local_dis, local_idx, k);
1059
1136
  } else {
1060
1137
  heap_addn<HeapForL2>(k, simi, idxi, local_dis, local_idx, k);
@@ -1062,14 +1139,12 @@ void IndexIVFFastScan::search_implem_14(
1062
1139
  };
1063
1140
 
1064
1141
  auto reorder_result = [&](float* simi, idx_t* idxi) {
1065
- if (metric_type == METRIC_INNER_PRODUCT) {
1142
+ if (!is_max) {
1066
1143
  heap_reorder<HeapForIP>(k, simi, idxi);
1067
1144
  } else {
1068
1145
  heap_reorder<HeapForL2>(k, simi, idxi);
1069
1146
  }
1070
1147
  };
1071
- uint64_t ttg4 = get_cy();
1072
- uint64_t fn_tt = ttg4 - ttg3;
1073
1148
 
1074
1149
  size_t ndis = 0;
1075
1150
  size_t nlist_visited = 0;
@@ -1081,22 +1156,9 @@ void IndexIVFFastScan::search_implem_14(
1081
1156
  std::vector<float> local_dis(k * n);
1082
1157
 
1083
1158
  // prepare the result handlers
1084
- std::unique_ptr<SIMDResultHandler<C, true>> handler;
1085
- AlignedTable<uint16_t> tmp_distances;
1086
-
1087
- using HeapHC = HeapHandler<C, true>;
1088
- using ReservoirHC = ReservoirHandler<C, true>;
1089
- using SingleResultHC = SingleResultHandler<C, true>;
1090
-
1091
- if (k == 1) {
1092
- handler.reset(new SingleResultHC(n, 0));
1093
- } else if (impl == 14) {
1094
- tmp_distances.resize(n * k);
1095
- handler.reset(
1096
- new HeapHC(n, tmp_distances.get(), local_idx.data(), k, 0));
1097
- } else if (impl == 15) {
1098
- handler.reset(new ReservoirHC(n, 0, k, 2 * k));
1099
- }
1159
+ std::unique_ptr<SIMDResultHandlerToFloat> handler(make_knn_handler(
1160
+ is_max, impl, n, k, local_dis.data(), local_idx.data()));
1161
+ handler->begin(normalizers.get());
1100
1162
 
1101
1163
  int qbs2 = this->qbs2 ? this->qbs2 : 11;
1102
1164
 
@@ -1105,15 +1167,10 @@ void IndexIVFFastScan::search_implem_14(
1105
1167
  tmp_bias.resize(qbs2);
1106
1168
  handler->dbias = tmp_bias.data();
1107
1169
  }
1108
-
1109
- uint64_t ttg5 = get_cy();
1110
- uint64_t handler_tt = ttg5 - ttg4;
1111
-
1112
1170
  std::set<int> q_set;
1113
1171
  uint64_t t_copy_pack = 0, t_scan = 0;
1114
1172
  #pragma omp for schedule(dynamic)
1115
1173
  for (idx_t cluster = 0; cluster < ses.size(); cluster++) {
1116
- uint64_t tt0 = get_cy();
1117
1174
  size_t i0 = ses[cluster].start;
1118
1175
  size_t i1 = ses[cluster].end;
1119
1176
  size_t list_size = ses[cluster].list_size;
@@ -1153,28 +1210,21 @@ void IndexIVFFastScan::search_implem_14(
1153
1210
  handler->ntotal = list_size;
1154
1211
  handler->q_map = q_map.data();
1155
1212
  handler->id_map = ids.get();
1156
- uint64_t tt1 = get_cy();
1157
-
1158
- #define DISPATCH(classHC) \
1159
- if (dynamic_cast<classHC*>(handler.get())) { \
1160
- auto* res = static_cast<classHC*>(handler.get()); \
1161
- pq4_accumulate_loop_qbs( \
1162
- qbs, list_size, M2, codes.get(), LUT.get(), *res, scaler); \
1163
- }
1164
- DISPATCH(HeapHC)
1165
- else DISPATCH(ReservoirHC) else DISPATCH(SingleResultHC)
1166
1213
 
1167
- uint64_t tt2 = get_cy();
1168
- t_copy_pack += tt1 - tt0;
1169
- t_scan += tt2 - tt1;
1214
+ pq4_accumulate_loop_qbs(
1215
+ qbs,
1216
+ list_size,
1217
+ M2,
1218
+ codes.get(),
1219
+ LUT.get(),
1220
+ *handler.get(),
1221
+ scaler);
1170
1222
  }
1171
1223
 
1172
1224
  // labels is in-place for HeapHC
1173
- handler->to_flat_arrays(
1174
- local_dis.data(),
1175
- local_idx.data(),
1176
- skip & 16 ? nullptr : normalizers.get());
1225
+ handler->end();
1177
1226
 
1227
+ // merge per-thread results
1178
1228
  #pragma omp single
1179
1229
  {
1180
1230
  // we init the results as a heap
@@ -1197,12 +1247,6 @@ void IndexIVFFastScan::search_implem_14(
1197
1247
 
1198
1248
  IVFFastScan_stats.t_copy_pack += t_copy_pack;
1199
1249
  IVFFastScan_stats.t_scan += t_scan;
1200
-
1201
- if (auto* rh = dynamic_cast<ReservoirHC*>(handler.get())) {
1202
- for (int i = 0; i < 4; i++) {
1203
- IVFFastScan_stats.reservoir_times[i] += rh->times[i];
1204
- }
1205
- }
1206
1250
  }
1207
1251
  #pragma omp barrier
1208
1252
  #pragma omp single
@@ -1272,20 +1316,4 @@ void IndexIVFFastScan::reconstruct_orig_invlists() {
1272
1316
 
1273
1317
  IVFFastScanStats IVFFastScan_stats;
1274
1318
 
1275
- template void IndexIVFFastScan::search_dispatch_implem<true, NormTableScaler>(
1276
- idx_t n,
1277
- const float* x,
1278
- idx_t k,
1279
- float* distances,
1280
- idx_t* labels,
1281
- const NormTableScaler& scaler) const;
1282
-
1283
- template void IndexIVFFastScan::search_dispatch_implem<false, NormTableScaler>(
1284
- idx_t n,
1285
- const float* x,
1286
- idx_t k,
1287
- float* distances,
1288
- idx_t* labels,
1289
- const NormTableScaler& scaler) const;
1290
-
1291
1319
  } // namespace faiss