faiss 0.3.0 → 0.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (171) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +0 -1
  12. data/vendor/faiss/faiss/Clustering.cpp +4 -18
  13. data/vendor/faiss/faiss/Clustering.h +31 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +22 -11
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +20 -5
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
  26. data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
  27. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  28. data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
  29. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  30. data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
  31. data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
  33. data/vendor/faiss/faiss/IndexHNSW.h +12 -48
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  35. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  36. data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
  37. data/vendor/faiss/faiss/IndexIVF.h +37 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  42. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
  43. data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  46. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  48. data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
  49. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  50. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  52. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  53. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  56. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  57. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
  58. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  59. data/vendor/faiss/faiss/IndexNSG.h +10 -10
  60. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  61. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  62. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  64. data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
  65. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  66. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  67. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
  68. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  69. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  70. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  71. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  72. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  73. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  74. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  75. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  76. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  77. data/vendor/faiss/faiss/clone_index.h +3 -0
  78. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
  79. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  82. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  83. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  84. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
  85. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  88. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  89. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  90. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
  91. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  92. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
  93. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  94. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  95. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
  96. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  97. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  98. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  99. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  100. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  101. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  102. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
  103. data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
  104. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  105. data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
  106. data/vendor/faiss/faiss/impl/HNSW.h +9 -8
  107. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  108. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
  109. data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
  110. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  111. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  112. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  113. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
  114. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  115. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  116. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  117. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  118. data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
  119. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
  121. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
  122. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  123. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  124. data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
  125. data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
  126. data/vendor/faiss/faiss/impl/io.cpp +10 -10
  127. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  128. data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
  129. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
  130. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  131. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  132. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
  133. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  134. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  135. data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
  136. data/vendor/faiss/faiss/index_factory.cpp +10 -7
  137. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  138. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
  139. data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
  140. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  141. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  142. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  143. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  144. data/vendor/faiss/faiss/utils/distances.cpp +128 -74
  145. data/vendor/faiss/faiss/utils/distances.h +81 -4
  146. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  147. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  148. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  149. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  150. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  151. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  152. data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
  153. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  154. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  155. data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
  156. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  157. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  158. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  159. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
  160. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
  161. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  162. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  163. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  164. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  165. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  166. data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
  167. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  168. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  169. data/vendor/faiss/faiss/utils/utils.cpp +112 -6
  170. data/vendor/faiss/faiss/utils/utils.h +57 -20
  171. metadata +10 -3
@@ -36,7 +36,7 @@ IndexBinaryIVF::IndexBinaryIVF(IndexBinary* quantizer, size_t d, size_t nlist)
36
36
  cp.niter = 10;
37
37
  }
38
38
 
39
- IndexBinaryIVF::IndexBinaryIVF() {}
39
+ IndexBinaryIVF::IndexBinaryIVF() = default;
40
40
 
41
41
  void IndexBinaryIVF::add(idx_t n, const uint8_t* x) {
42
42
  add_with_ids(n, x, nullptr);
@@ -119,16 +119,16 @@ void IndexBinaryIVF::search(
119
119
  FAISS_THROW_IF_NOT(k > 0);
120
120
  FAISS_THROW_IF_NOT(nprobe > 0);
121
121
 
122
- const size_t nprobe = std::min(nlist, this->nprobe);
123
- std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
124
- std::unique_ptr<int32_t[]> coarse_dis(new int32_t[n * nprobe]);
122
+ const size_t nprobe_2 = std::min(nlist, this->nprobe);
123
+ std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe_2]);
124
+ std::unique_ptr<int32_t[]> coarse_dis(new int32_t[n * nprobe_2]);
125
125
 
126
126
  double t0 = getmillisecs();
127
- quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());
127
+ quantizer->search(n, x, nprobe_2, coarse_dis.get(), idx.get());
128
128
  indexIVF_stats.quantization_time += getmillisecs() - t0;
129
129
 
130
130
  t0 = getmillisecs();
131
- invlists->prefetch_lists(idx.get(), n * nprobe);
131
+ invlists->prefetch_lists(idx.get(), n * nprobe_2);
132
132
 
133
133
  search_preassigned(
134
134
  n, x, k, idx.get(), coarse_dis.get(), distances, labels, false);
@@ -169,16 +169,16 @@ void IndexBinaryIVF::search_and_reconstruct(
169
169
  const SearchParameters* params) const {
170
170
  FAISS_THROW_IF_NOT_MSG(
171
171
  !params, "search params not supported for this index");
172
- const size_t nprobe = std::min(nlist, this->nprobe);
172
+ const size_t nprobe_2 = std::min(nlist, this->nprobe);
173
173
  FAISS_THROW_IF_NOT(k > 0);
174
- FAISS_THROW_IF_NOT(nprobe > 0);
174
+ FAISS_THROW_IF_NOT(nprobe_2 > 0);
175
175
 
176
- std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
177
- std::unique_ptr<int32_t[]> coarse_dis(new int32_t[n * nprobe]);
176
+ std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe_2]);
177
+ std::unique_ptr<int32_t[]> coarse_dis(new int32_t[n * nprobe_2]);
178
178
 
179
- quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());
179
+ quantizer->search(n, x, nprobe_2, coarse_dis.get(), idx.get());
180
180
 
181
- invlists->prefetch_lists(idx.get(), n * nprobe);
181
+ invlists->prefetch_lists(idx.get(), n * nprobe_2);
182
182
 
183
183
  // search_preassigned() with `store_pairs` enabled to obtain the list_no
184
184
  // and offset into `codes` for reconstruction
@@ -321,8 +321,8 @@ struct IVFBinaryScannerL2 : BinaryInvertedListScanner {
321
321
  }
322
322
 
323
323
  idx_t list_no;
324
- void set_list(idx_t list_no, uint8_t /* coarse_dis */) override {
325
- this->list_no = list_no;
324
+ void set_list(idx_t list_no_2, uint8_t /* coarse_dis */) override {
325
+ this->list_no = list_no_2;
326
326
  }
327
327
 
328
328
  uint32_t distance_to_code(const uint8_t* code) const override {
@@ -357,7 +357,6 @@ struct IVFBinaryScannerL2 : BinaryInvertedListScanner {
357
357
  const idx_t* __restrict ids,
358
358
  int radius,
359
359
  RangeQueryResult& result) const override {
360
- size_t nup = 0;
361
360
  for (size_t j = 0; j < n; j++) {
362
361
  uint32_t dis = hc.hamming(codes);
363
362
  if (dis < radius) {
@@ -370,7 +369,7 @@ struct IVFBinaryScannerL2 : BinaryInvertedListScanner {
370
369
  };
371
370
 
372
371
  void search_knn_hamming_heap(
373
- const IndexBinaryIVF& ivf,
372
+ const IndexBinaryIVF* ivf,
374
373
  size_t n,
375
374
  const uint8_t* __restrict x,
376
375
  idx_t k,
@@ -380,10 +379,10 @@ void search_knn_hamming_heap(
380
379
  idx_t* __restrict labels,
381
380
  bool store_pairs,
382
381
  const IVFSearchParameters* params) {
383
- idx_t nprobe = params ? params->nprobe : ivf.nprobe;
384
- nprobe = std::min((idx_t)ivf.nlist, nprobe);
385
- idx_t max_codes = params ? params->max_codes : ivf.max_codes;
386
- MetricType metric_type = ivf.metric_type;
382
+ idx_t nprobe = params ? params->nprobe : ivf->nprobe;
383
+ nprobe = std::min((idx_t)ivf->nlist, nprobe);
384
+ idx_t max_codes = params ? params->max_codes : ivf->max_codes;
385
+ MetricType metric_type = ivf->metric_type;
387
386
 
388
387
  // almost verbatim copy from IndexIVF::search_preassigned
389
388
 
@@ -394,11 +393,11 @@ void search_knn_hamming_heap(
394
393
  #pragma omp parallel if (n > 1) reduction(+ : nlistv, ndis, nheap)
395
394
  {
396
395
  std::unique_ptr<BinaryInvertedListScanner> scanner(
397
- ivf.get_InvertedListScanner(store_pairs));
396
+ ivf->get_InvertedListScanner(store_pairs));
398
397
 
399
398
  #pragma omp for
400
399
  for (idx_t i = 0; i < n; i++) {
401
- const uint8_t* xi = x + i * ivf.code_size;
400
+ const uint8_t* xi = x + i * ivf->code_size;
402
401
  scanner->set_query(xi);
403
402
 
404
403
  const idx_t* keysi = keys + i * nprobe;
@@ -420,23 +419,24 @@ void search_knn_hamming_heap(
420
419
  continue;
421
420
  }
422
421
  FAISS_THROW_IF_NOT_FMT(
423
- key < (idx_t)ivf.nlist,
422
+ key < (idx_t)ivf->nlist,
424
423
  "Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
425
424
  key,
426
425
  ik,
427
- ivf.nlist);
426
+ ivf->nlist);
428
427
 
429
428
  scanner->set_list(key, coarse_dis[i * nprobe + ik]);
430
429
 
431
430
  nlistv++;
432
431
 
433
- size_t list_size = ivf.invlists->list_size(key);
434
- InvertedLists::ScopedCodes scodes(ivf.invlists, key);
432
+ size_t list_size = ivf->invlists->list_size(key);
433
+ InvertedLists::ScopedCodes scodes(ivf->invlists, key);
435
434
  std::unique_ptr<InvertedLists::ScopedIds> sids;
436
435
  const idx_t* ids = nullptr;
437
436
 
438
437
  if (!store_pairs) {
439
- sids.reset(new InvertedLists::ScopedIds(ivf.invlists, key));
438
+ sids = std::make_unique<InvertedLists::ScopedIds>(
439
+ ivf->invlists, key);
440
440
  ids = sids->get();
441
441
  }
442
442
 
@@ -466,7 +466,7 @@ void search_knn_hamming_heap(
466
466
 
467
467
  template <class HammingComputer, bool store_pairs>
468
468
  void search_knn_hamming_count(
469
- const IndexBinaryIVF& ivf,
469
+ const IndexBinaryIVF* ivf,
470
470
  size_t nx,
471
471
  const uint8_t* __restrict x,
472
472
  const idx_t* __restrict keys,
@@ -474,21 +474,21 @@ void search_knn_hamming_count(
474
474
  int32_t* __restrict distances,
475
475
  idx_t* __restrict labels,
476
476
  const IVFSearchParameters* params) {
477
- const int nBuckets = ivf.d + 1;
477
+ const int nBuckets = ivf->d + 1;
478
478
  std::vector<int> all_counters(nx * nBuckets, 0);
479
479
  std::unique_ptr<idx_t[]> all_ids_per_dis(new idx_t[nx * nBuckets * k]);
480
480
 
481
- idx_t nprobe = params ? params->nprobe : ivf.nprobe;
482
- nprobe = std::min((idx_t)ivf.nlist, nprobe);
483
- idx_t max_codes = params ? params->max_codes : ivf.max_codes;
481
+ idx_t nprobe = params ? params->nprobe : ivf->nprobe;
482
+ nprobe = std::min((idx_t)ivf->nlist, nprobe);
483
+ idx_t max_codes = params ? params->max_codes : ivf->max_codes;
484
484
 
485
485
  std::vector<HCounterState<HammingComputer>> cs;
486
486
  for (size_t i = 0; i < nx; ++i) {
487
487
  cs.push_back(HCounterState<HammingComputer>(
488
488
  all_counters.data() + i * nBuckets,
489
489
  all_ids_per_dis.get() + i * nBuckets * k,
490
- x + i * ivf.code_size,
491
- ivf.d,
490
+ x + i * ivf->code_size,
491
+ ivf->d,
492
492
  k));
493
493
  }
494
494
 
@@ -508,27 +508,28 @@ void search_knn_hamming_count(
508
508
  continue;
509
509
  }
510
510
  FAISS_THROW_IF_NOT_FMT(
511
- key < (idx_t)ivf.nlist,
511
+ key < (idx_t)ivf->nlist,
512
512
  "Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
513
513
  key,
514
514
  ik,
515
- ivf.nlist);
515
+ ivf->nlist);
516
516
 
517
517
  nlistv++;
518
- size_t list_size = ivf.invlists->list_size(key);
519
- InvertedLists::ScopedCodes scodes(ivf.invlists, key);
518
+ size_t list_size = ivf->invlists->list_size(key);
519
+ InvertedLists::ScopedCodes scodes(ivf->invlists, key);
520
520
  const uint8_t* list_vecs = scodes.get();
521
521
  const idx_t* ids =
522
- store_pairs ? nullptr : ivf.invlists->get_ids(key);
522
+ store_pairs ? nullptr : ivf->invlists->get_ids(key);
523
523
 
524
524
  for (size_t j = 0; j < list_size; j++) {
525
- const uint8_t* yj = list_vecs + ivf.code_size * j;
525
+ const uint8_t* yj = list_vecs + ivf->code_size * j;
526
526
 
527
527
  idx_t id = store_pairs ? (key << 32 | j) : ids[j];
528
528
  csi.update_counter(yj, id);
529
529
  }
530
- if (ids)
531
- ivf.invlists->release_ids(key, ids);
530
+ if (ids) {
531
+ ivf->invlists->release_ids(key, ids);
532
+ }
532
533
 
533
534
  nscan += list_size;
534
535
  if (max_codes && nscan >= max_codes)
@@ -634,7 +635,7 @@ struct BlockSearchVariableK {
634
635
 
635
636
  template <class HammingComputer>
636
637
  void search_knn_hamming_per_invlist(
637
- const IndexBinaryIVF& ivf,
638
+ const IndexBinaryIVF* ivf,
638
639
  size_t n,
639
640
  const uint8_t* __restrict x,
640
641
  idx_t k,
@@ -644,12 +645,11 @@ void search_knn_hamming_per_invlist(
644
645
  idx_t* __restrict labels,
645
646
  bool store_pairs,
646
647
  const IVFSearchParameters* params) {
647
- idx_t nprobe = params ? params->nprobe : ivf.nprobe;
648
- nprobe = std::min((idx_t)ivf.nlist, nprobe);
649
- idx_t max_codes = params ? params->max_codes : ivf.max_codes;
648
+ idx_t nprobe = params ? params->nprobe : ivf->nprobe;
649
+ nprobe = std::min((idx_t)ivf->nlist, nprobe);
650
+ idx_t max_codes = params ? params->max_codes : ivf->max_codes;
650
651
  FAISS_THROW_IF_NOT(max_codes == 0);
651
652
  FAISS_THROW_IF_NOT(!store_pairs);
652
- MetricType metric_type = ivf.metric_type;
653
653
 
654
654
  // reorder buckets
655
655
  std::vector<int64_t> lims(n + 1);
@@ -658,18 +658,18 @@ void search_knn_hamming_per_invlist(
658
658
  for (idx_t i = 0; i < n * nprobe; i++) {
659
659
  keys[i] = keys_in[i];
660
660
  }
661
- matrix_bucket_sort_inplace(n, nprobe, keys, ivf.nlist, lims.data(), 0);
661
+ matrix_bucket_sort_inplace(n, nprobe, keys, ivf->nlist, lims.data(), 0);
662
662
 
663
663
  using C = CMax<int32_t, idx_t>;
664
664
  heap_heapify<C>(n * k, distances, labels);
665
- const size_t code_size = ivf.code_size;
665
+ const size_t code_size = ivf->code_size;
666
666
 
667
- for (idx_t l = 0; l < ivf.nlist; l++) {
667
+ for (idx_t l = 0; l < ivf->nlist; l++) {
668
668
  idx_t l0 = lims[l], nq = lims[l + 1] - l0;
669
669
 
670
- InvertedLists::ScopedCodes scodes(ivf.invlists, l);
671
- InvertedLists::ScopedIds sidx(ivf.invlists, l);
672
- idx_t nb = ivf.invlists->list_size(l);
670
+ InvertedLists::ScopedCodes scodes(ivf->invlists, l);
671
+ InvertedLists::ScopedIds sidx(ivf->invlists, l);
672
+ idx_t nb = ivf->invlists->list_size(l);
673
673
  const uint8_t* bcodes = scodes.get();
674
674
  const idx_t* ids = sidx.get();
675
675
 
@@ -735,151 +735,70 @@ void search_knn_hamming_per_invlist(
735
735
  }
736
736
  }
737
737
 
738
+ struct Run_search_knn_hamming_per_invlist {
739
+ using T = void;
740
+
741
+ template <class HammingComputer, class... Types>
742
+ void f(Types... args) {
743
+ search_knn_hamming_per_invlist<HammingComputer>(args...);
744
+ }
745
+ };
746
+
738
747
  template <bool store_pairs>
739
- void search_knn_hamming_count_1(
740
- const IndexBinaryIVF& ivf,
741
- size_t nx,
742
- const uint8_t* x,
743
- const idx_t* keys,
744
- int k,
745
- int32_t* distances,
746
- idx_t* labels,
747
- const IVFSearchParameters* params) {
748
- switch (ivf.code_size) {
749
- #define HANDLE_CS(cs) \
750
- case cs: \
751
- search_knn_hamming_count<HammingComputer##cs, store_pairs>( \
752
- ivf, nx, x, keys, k, distances, labels, params); \
753
- break;
754
- HANDLE_CS(4);
755
- HANDLE_CS(8);
756
- HANDLE_CS(16);
757
- HANDLE_CS(20);
758
- HANDLE_CS(32);
759
- HANDLE_CS(64);
760
- #undef HANDLE_CS
761
- default:
762
- search_knn_hamming_count<HammingComputerDefault, store_pairs>(
763
- ivf, nx, x, keys, k, distances, labels, params);
764
- break;
748
+ struct Run_search_knn_hamming_count {
749
+ using T = void;
750
+
751
+ template <class HammingComputer, class... Types>
752
+ void f(Types... args) {
753
+ search_knn_hamming_count<HammingComputer, store_pairs>(args...);
765
754
  }
766
- }
755
+ };
767
756
 
768
- void search_knn_hamming_per_invlist_1(
769
- const IndexBinaryIVF& ivf,
770
- size_t n,
771
- const uint8_t* x,
772
- idx_t k,
773
- const idx_t* keys,
774
- const int32_t* coarse_dis,
775
- int32_t* distances,
776
- idx_t* labels,
777
- bool store_pairs,
778
- const IVFSearchParameters* params) {
779
- switch (ivf.code_size) {
780
- #define HANDLE_CS(cs) \
781
- case cs: \
782
- search_knn_hamming_per_invlist<HammingComputer##cs>( \
783
- ivf, \
784
- n, \
785
- x, \
786
- k, \
787
- keys, \
788
- coarse_dis, \
789
- distances, \
790
- labels, \
791
- store_pairs, \
792
- params); \
793
- break;
794
- HANDLE_CS(4);
795
- HANDLE_CS(8);
796
- HANDLE_CS(16);
797
- HANDLE_CS(20);
798
- HANDLE_CS(32);
799
- HANDLE_CS(64);
800
- #undef HANDLE_CS
801
- default:
802
- search_knn_hamming_per_invlist<HammingComputerDefault>(
803
- ivf,
804
- n,
805
- x,
806
- k,
807
- keys,
808
- coarse_dis,
809
- distances,
810
- labels,
811
- store_pairs,
812
- params);
813
- break;
757
+ struct BuildScanner {
758
+ using T = BinaryInvertedListScanner*;
759
+
760
+ template <class HammingComputer>
761
+ T f(size_t code_size, bool store_pairs) {
762
+ return new IVFBinaryScannerL2<HammingComputer>(code_size, store_pairs);
814
763
  }
815
- }
764
+ };
816
765
 
817
766
  } // anonymous namespace
818
767
 
819
768
  BinaryInvertedListScanner* IndexBinaryIVF::get_InvertedListScanner(
820
769
  bool store_pairs) const {
821
- #define HC(name) return new IVFBinaryScannerL2<name>(code_size, store_pairs)
822
- switch (code_size) {
823
- case 4:
824
- HC(HammingComputer4);
825
- case 8:
826
- HC(HammingComputer8);
827
- case 16:
828
- HC(HammingComputer16);
829
- case 20:
830
- HC(HammingComputer20);
831
- case 32:
832
- HC(HammingComputer32);
833
- case 64:
834
- HC(HammingComputer64);
835
- default:
836
- HC(HammingComputerDefault);
837
- }
838
- #undef HC
770
+ BuildScanner bs;
771
+ return dispatch_HammingComputer(code_size, bs, code_size, store_pairs);
839
772
  }
840
773
 
841
774
  void IndexBinaryIVF::search_preassigned(
842
775
  idx_t n,
843
776
  const uint8_t* x,
844
777
  idx_t k,
845
- const idx_t* idx,
846
- const int32_t* coarse_dis,
847
- int32_t* distances,
848
- idx_t* labels,
778
+ const idx_t* cidx,
779
+ const int32_t* cdis,
780
+ int32_t* dis,
781
+ idx_t* idx,
849
782
  bool store_pairs,
850
783
  const IVFSearchParameters* params) const {
851
784
  if (per_invlist_search) {
852
- search_knn_hamming_per_invlist_1(
853
- *this,
854
- n,
855
- x,
856
- k,
857
- idx,
858
- coarse_dis,
859
- distances,
860
- labels,
861
- store_pairs,
862
- params);
785
+ Run_search_knn_hamming_per_invlist r;
786
+ // clang-format off
787
+ dispatch_HammingComputer(
788
+ code_size, r, this, n, x, k,
789
+ cidx, cdis, dis, idx, store_pairs, params);
790
+ // clang-format on
863
791
  } else if (use_heap) {
864
792
  search_knn_hamming_heap(
865
- *this,
866
- n,
867
- x,
868
- k,
869
- idx,
870
- coarse_dis,
871
- distances,
872
- labels,
873
- store_pairs,
874
- params);
875
- } else {
876
- if (store_pairs) {
877
- search_knn_hamming_count_1<true>(
878
- *this, n, x, idx, k, distances, labels, params);
879
- } else {
880
- search_knn_hamming_count_1<false>(
881
- *this, n, x, idx, k, distances, labels, params);
882
- }
793
+ this, n, x, k, cidx, cdis, dis, idx, store_pairs, params);
794
+ } else if (store_pairs) { // !use_heap && store_pairs
795
+ Run_search_knn_hamming_count<true> r;
796
+ dispatch_HammingComputer(
797
+ code_size, r, this, n, x, cidx, k, dis, idx, params);
798
+ } else { // !use_heap && !store_pairs
799
+ Run_search_knn_hamming_count<false> r;
800
+ dispatch_HammingComputer(
801
+ code_size, r, this, n, x, cidx, k, dis, idx, params);
883
802
  }
884
803
  }
885
804
 
@@ -891,16 +810,16 @@ void IndexBinaryIVF::range_search(
891
810
  const SearchParameters* params) const {
892
811
  FAISS_THROW_IF_NOT_MSG(
893
812
  !params, "search params not supported for this index");
894
- const size_t nprobe = std::min(nlist, this->nprobe);
895
- std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
896
- std::unique_ptr<int32_t[]> coarse_dis(new int32_t[n * nprobe]);
813
+ const size_t nprobe_2 = std::min(nlist, this->nprobe);
814
+ std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe_2]);
815
+ std::unique_ptr<int32_t[]> coarse_dis(new int32_t[n * nprobe_2]);
897
816
 
898
817
  double t0 = getmillisecs();
899
- quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());
818
+ quantizer->search(n, x, nprobe_2, coarse_dis.get(), idx.get());
900
819
  indexIVF_stats.quantization_time += getmillisecs() - t0;
901
820
 
902
821
  t0 = getmillisecs();
903
- invlists->prefetch_lists(idx.get(), n * nprobe);
822
+ invlists->prefetch_lists(idx.get(), n * nprobe_2);
904
823
 
905
824
  range_search_preassigned(n, x, radius, idx.get(), coarse_dis.get(), res);
906
825
 
@@ -914,7 +833,7 @@ void IndexBinaryIVF::range_search_preassigned(
914
833
  const idx_t* __restrict assign,
915
834
  const int32_t* __restrict centroid_dis,
916
835
  RangeSearchResult* __restrict res) const {
917
- const size_t nprobe = std::min(nlist, this->nprobe);
836
+ const size_t nprobe_2 = std::min(nlist, this->nprobe);
918
837
  bool store_pairs = false;
919
838
  size_t nlistv = 0, ndis = 0;
920
839
 
@@ -930,7 +849,7 @@ void IndexBinaryIVF::range_search_preassigned(
930
849
  all_pres[omp_get_thread_num()] = &pres;
931
850
 
932
851
  auto scan_list_func = [&](size_t i, size_t ik, RangeQueryResult& qres) {
933
- idx_t key = assign[i * nprobe + ik]; /* select the list */
852
+ idx_t key = assign[i * nprobe_2 + ik]; /* select the list */
934
853
  if (key < 0)
935
854
  return;
936
855
  FAISS_THROW_IF_NOT_FMT(
@@ -947,7 +866,7 @@ void IndexBinaryIVF::range_search_preassigned(
947
866
  InvertedLists::ScopedCodes scodes(invlists, key);
948
867
  InvertedLists::ScopedIds ids(invlists, key);
949
868
 
950
- scanner->set_list(key, assign[i * nprobe + ik]);
869
+ scanner->set_list(key, assign[i * nprobe_2 + ik]);
951
870
  nlistv++;
952
871
  ndis += list_size;
953
872
  scanner->scan_codes_range(
@@ -960,7 +879,7 @@ void IndexBinaryIVF::range_search_preassigned(
960
879
 
961
880
  RangeQueryResult& qres = pres.new_result(i);
962
881
 
963
- for (size_t ik = 0; ik < nprobe; ik++) {
882
+ for (size_t ik = 0; ik < nprobe_2; ik++) {
964
883
  scan_list_func(i, ik, qres);
965
884
  }
966
885
  }