faiss 0.2.3 → 0.2.5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +23 -21
  5. data/ext/faiss/extconf.rb +11 -0
  6. data/ext/faiss/index.cpp +4 -4
  7. data/ext/faiss/index_binary.cpp +6 -6
  8. data/ext/faiss/product_quantizer.cpp +4 -4
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  11. data/vendor/faiss/faiss/Clustering.cpp +32 -0
  12. data/vendor/faiss/faiss/Clustering.h +14 -0
  13. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  14. data/vendor/faiss/faiss/IVFlib.h +26 -2
  15. data/vendor/faiss/faiss/Index.cpp +36 -3
  16. data/vendor/faiss/faiss/Index.h +43 -6
  17. data/vendor/faiss/faiss/Index2Layer.cpp +24 -93
  18. data/vendor/faiss/faiss/Index2Layer.h +8 -17
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +610 -0
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +253 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  22. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  23. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  24. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  25. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  26. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  28. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  30. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  31. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  32. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  33. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  34. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  35. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  36. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  37. data/vendor/faiss/faiss/IndexFlat.cpp +52 -69
  38. data/vendor/faiss/faiss/IndexFlat.h +16 -19
  39. data/vendor/faiss/faiss/IndexFlatCodes.cpp +101 -0
  40. data/vendor/faiss/faiss/IndexFlatCodes.h +59 -0
  41. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  42. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  43. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  44. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  45. data/vendor/faiss/faiss/IndexIVF.cpp +200 -40
  46. data/vendor/faiss/faiss/IndexIVF.h +59 -22
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +393 -0
  48. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +183 -0
  49. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  50. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  51. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  52. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  53. data/vendor/faiss/faiss/IndexIVFFlat.cpp +43 -26
  54. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  55. data/vendor/faiss/faiss/IndexIVFPQ.cpp +238 -53
  56. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -2
  57. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  58. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  59. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  60. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  61. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +63 -40
  62. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +23 -7
  63. data/vendor/faiss/faiss/IndexLSH.cpp +8 -32
  64. data/vendor/faiss/faiss/IndexLSH.h +4 -16
  65. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  66. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  67. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -5
  68. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  69. data/vendor/faiss/faiss/IndexNSG.cpp +37 -5
  70. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  71. data/vendor/faiss/faiss/IndexPQ.cpp +108 -120
  72. data/vendor/faiss/faiss/IndexPQ.h +21 -22
  73. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  74. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  75. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  76. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  77. data/vendor/faiss/faiss/IndexRefine.cpp +36 -4
  78. data/vendor/faiss/faiss/IndexRefine.h +14 -2
  79. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  80. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +28 -43
  84. data/vendor/faiss/faiss/IndexScalarQuantizer.h +8 -23
  85. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  86. data/vendor/faiss/faiss/IndexShards.h +2 -1
  87. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  88. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  89. data/vendor/faiss/faiss/VectorTransform.cpp +45 -1
  90. data/vendor/faiss/faiss/VectorTransform.h +25 -4
  91. data/vendor/faiss/faiss/clone_index.cpp +26 -3
  92. data/vendor/faiss/faiss/clone_index.h +3 -0
  93. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  94. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  95. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  101. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  102. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  104. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  105. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -6
  106. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  107. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  108. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  109. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  110. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  111. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  112. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  113. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  114. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  115. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  116. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  117. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  120. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  121. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  122. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  123. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  124. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +331 -29
  125. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +110 -19
  126. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  127. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  128. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  129. data/vendor/faiss/faiss/impl/HNSW.cpp +133 -32
  130. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  131. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  132. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  133. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +378 -217
  134. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +106 -29
  135. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  136. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  137. data/vendor/faiss/faiss/impl/NSG.cpp +1 -4
  138. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  139. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  140. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  141. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  142. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  143. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  144. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +521 -55
  145. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +94 -16
  146. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  147. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +108 -191
  148. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  149. data/vendor/faiss/faiss/impl/index_read.cpp +338 -24
  150. data/vendor/faiss/faiss/impl/index_write.cpp +300 -18
  151. data/vendor/faiss/faiss/impl/io.cpp +1 -1
  152. data/vendor/faiss/faiss/impl/io_macros.h +20 -0
  153. data/vendor/faiss/faiss/impl/kmeans1d.cpp +303 -0
  154. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  159. data/vendor/faiss/faiss/index_factory.cpp +772 -412
  160. data/vendor/faiss/faiss/index_factory.h +3 -0
  161. data/vendor/faiss/faiss/index_io.h +5 -0
  162. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  163. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  164. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  165. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  166. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  167. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  168. data/vendor/faiss/faiss/utils/distances.cpp +384 -58
  169. data/vendor/faiss/faiss/utils/distances.h +149 -18
  170. data/vendor/faiss/faiss/utils/distances_simd.cpp +776 -6
  171. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  172. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  173. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  174. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  175. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  176. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  177. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  178. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  179. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  180. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  181. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  182. data/vendor/faiss/faiss/utils/random.h +5 -0
  183. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  184. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  185. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  186. data/vendor/faiss/faiss/utils/utils.h +1 -1
  187. metadata +46 -5
  188. data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
  189. data/vendor/faiss/faiss/IndexResidual.h +0 -152
@@ -23,6 +23,7 @@
23
23
  #include <faiss/IndexFlat.h>
24
24
  #include <faiss/impl/AuxIndexStructures.h>
25
25
  #include <faiss/impl/FaissAssert.h>
26
+ #include <faiss/impl/IDSelector.h>
26
27
 
27
28
  namespace faiss {
28
29
 
@@ -107,8 +108,15 @@ void Level1Quantizer::train_q1(
107
108
  } else {
108
109
  clus.train(n, x, *clustering_index);
109
110
  }
110
- if (verbose)
111
+ if (verbose) {
111
112
  printf("Adding centroids to quantizer\n");
113
+ }
114
+ if (!quantizer->is_trained) {
115
+ if (verbose) {
116
+ printf("But training it first on centroids table...\n");
117
+ }
118
+ quantizer->train(nlist, clus.centroids.data());
119
+ }
112
120
  quantizer->add(nlist, clus.centroids.data());
113
121
  }
114
122
  }
@@ -190,6 +198,20 @@ void IndexIVF::add_with_ids(idx_t n, const float* x, const idx_t* xids) {
190
198
  add_core(n, x, xids, coarse_idx.get());
191
199
  }
192
200
 
201
+ void IndexIVF::add_sa_codes(idx_t n, const uint8_t* codes, const idx_t* xids) {
202
+ size_t coarse_size = coarse_code_size();
203
+ DirectMapAdd dm_adder(direct_map, n, xids);
204
+
205
+ for (idx_t i = 0; i < n; i++) {
206
+ const uint8_t* code = codes + (code_size + coarse_size) * i;
207
+ idx_t list_no = decode_listno(code);
208
+ idx_t id = xids ? xids[i] : ntotal + i;
209
+ size_t ofs = invlists->add_entry(list_no, id, code + coarse_size);
210
+ dm_adder.add(i, list_no, ofs);
211
+ }
212
+ ntotal += n;
213
+ }
214
+
193
215
  void IndexIVF::add_core(
194
216
  idx_t n,
195
217
  const float* x,
@@ -282,14 +304,20 @@ void IndexIVF::search(
282
304
  const float* x,
283
305
  idx_t k,
284
306
  float* distances,
285
- idx_t* labels) const {
307
+ idx_t* labels,
308
+ const SearchParameters* params_in) const {
286
309
  FAISS_THROW_IF_NOT(k > 0);
287
-
288
- const size_t nprobe = std::min(nlist, this->nprobe);
310
+ const IVFSearchParameters* params = nullptr;
311
+ if (params_in) {
312
+ params = dynamic_cast<const IVFSearchParameters*>(params_in);
313
+ FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
314
+ }
315
+ const size_t nprobe =
316
+ std::min(nlist, params ? params->nprobe : this->nprobe);
289
317
  FAISS_THROW_IF_NOT(nprobe > 0);
290
318
 
291
319
  // search function for a subset of queries
292
- auto sub_search_func = [this, k, nprobe](
320
+ auto sub_search_func = [this, k, nprobe, params](
293
321
  idx_t n,
294
322
  const float* x,
295
323
  float* distances,
@@ -299,7 +327,13 @@ void IndexIVF::search(
299
327
  std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
300
328
 
301
329
  double t0 = getmillisecs();
302
- quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());
330
+ quantizer->search(
331
+ n,
332
+ x,
333
+ nprobe,
334
+ coarse_dis.get(),
335
+ idx.get(),
336
+ params ? params->quantizer_params : nullptr);
303
337
 
304
338
  double t1 = getmillisecs();
305
339
  invlists->prefetch_lists(idx.get(), n * nprobe);
@@ -313,7 +347,7 @@ void IndexIVF::search(
313
347
  distances,
314
348
  labels,
315
349
  false,
316
- nullptr,
350
+ params,
317
351
  ivf_stats);
318
352
  double t2 = getmillisecs();
319
353
  ivf_stats->quantization_time += t1 - t0;
@@ -379,6 +413,19 @@ void IndexIVF::search_preassigned(
379
413
  FAISS_THROW_IF_NOT(nprobe > 0);
380
414
 
381
415
  idx_t max_codes = params ? params->max_codes : this->max_codes;
416
+ IDSelector* sel = params ? params->sel : nullptr;
417
+ const IDSelectorRange* selr = dynamic_cast<const IDSelectorRange*>(sel);
418
+ if (selr) {
419
+ if (selr->assume_sorted) {
420
+ sel = nullptr; // use special IDSelectorRange processing
421
+ } else {
422
+ selr = nullptr; // use generic processing
423
+ }
424
+ }
425
+
426
+ FAISS_THROW_IF_NOT_MSG(
427
+ !(sel && store_pairs),
428
+ "selector and store_pairs cannot be combined");
382
429
 
383
430
  size_t nlistv = 0, ndis = 0, nheap = 0;
384
431
 
@@ -400,7 +447,8 @@ void IndexIVF::search_preassigned(
400
447
 
401
448
  #pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis, nheap)
402
449
  {
403
- InvertedListScanner* scanner = get_InvertedListScanner(store_pairs);
450
+ InvertedListScanner* scanner =
451
+ get_InvertedListScanner(store_pairs, sel);
404
452
  ScopeDeleter1<InvertedListScanner> del(scanner);
405
453
 
406
454
  /*****************************************************
@@ -471,6 +519,7 @@ void IndexIVF::search_preassigned(
471
519
 
472
520
  try {
473
521
  InvertedLists::ScopedCodes scodes(invlists, key);
522
+ const uint8_t* codes = scodes.get();
474
523
 
475
524
  std::unique_ptr<InvertedLists::ScopedIds> sids;
476
525
  const Index::idx_t* ids = nullptr;
@@ -480,8 +529,20 @@ void IndexIVF::search_preassigned(
480
529
  ids = sids->get();
481
530
  }
482
531
 
532
+ if (selr) { // IDSelectorRange
533
+ // restrict search to a section of the inverted list
534
+ size_t jmin, jmax;
535
+ selr->find_sorted_ids_bounds(list_size, ids, &jmin, &jmax);
536
+ list_size = jmax - jmin;
537
+ if (list_size == 0) {
538
+ return (size_t)0;
539
+ }
540
+ codes += jmin * code_size;
541
+ ids += jmin;
542
+ }
543
+
483
544
  nheap += scanner->scan_codes(
484
- list_size, scodes.get(), ids, simi, idxi, k);
545
+ list_size, codes, ids, simi, idxi, k);
485
546
 
486
547
  } catch (const std::exception& e) {
487
548
  std::lock_guard<std::mutex> lock(exception_mutex);
@@ -630,13 +691,23 @@ void IndexIVF::range_search(
630
691
  idx_t nx,
631
692
  const float* x,
632
693
  float radius,
633
- RangeSearchResult* result) const {
634
- const size_t nprobe = std::min(nlist, this->nprobe);
694
+ RangeSearchResult* result,
695
+ const SearchParameters* params_in) const {
696
+ const IVFSearchParameters* params = nullptr;
697
+ const SearchParameters* quantizer_params = nullptr;
698
+ if (params_in) {
699
+ params = dynamic_cast<const IVFSearchParameters*>(params_in);
700
+ FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
701
+ quantizer_params = params->quantizer_params;
702
+ }
703
+ const size_t nprobe =
704
+ std::min(nlist, params ? params->nprobe : this->nprobe);
635
705
  std::unique_ptr<idx_t[]> keys(new idx_t[nx * nprobe]);
636
706
  std::unique_ptr<float[]> coarse_dis(new float[nx * nprobe]);
637
707
 
638
708
  double t0 = getmillisecs();
639
- quantizer->search(nx, x, nprobe, coarse_dis.get(), keys.get());
709
+ quantizer->search(
710
+ nx, x, nprobe, coarse_dis.get(), keys.get(), quantizer_params);
640
711
  indexIVF_stats.quantization_time += getmillisecs() - t0;
641
712
 
642
713
  t0 = getmillisecs();
@@ -650,7 +721,7 @@ void IndexIVF::range_search(
650
721
  coarse_dis.get(),
651
722
  result,
652
723
  false,
653
- nullptr,
724
+ params,
654
725
  &indexIVF_stats);
655
726
 
656
727
  indexIVF_stats.search_time += getmillisecs() - t0;
@@ -668,7 +739,10 @@ void IndexIVF::range_search_preassigned(
668
739
  IndexIVFStats* stats) const {
669
740
  idx_t nprobe = params ? params->nprobe : this->nprobe;
670
741
  nprobe = std::min((idx_t)nlist, nprobe);
742
+ FAISS_THROW_IF_NOT(nprobe > 0);
743
+
671
744
  idx_t max_codes = params ? params->max_codes : this->max_codes;
745
+ IDSelector* sel = params ? params->sel : nullptr;
672
746
 
673
747
  size_t nlistv = 0, ndis = 0;
674
748
 
@@ -690,7 +764,7 @@ void IndexIVF::range_search_preassigned(
690
764
  {
691
765
  RangeSearchPartialResult pres(result);
692
766
  std::unique_ptr<InvertedListScanner> scanner(
693
- get_InvertedListScanner(store_pairs));
767
+ get_InvertedListScanner(store_pairs, sel));
694
768
  FAISS_THROW_IF_NOT(scanner.get());
695
769
  all_pres[omp_get_thread_num()] = &pres;
696
770
 
@@ -753,7 +827,6 @@ void IndexIVF::range_search_preassigned(
753
827
  }
754
828
  }
755
829
  } else if (parallel_mode == 2) {
756
- std::vector<RangeQueryResult*> all_qres(nx);
757
830
  RangeQueryResult* qres = nullptr;
758
831
 
759
832
  #pragma omp for schedule(dynamic)
@@ -761,7 +834,6 @@ void IndexIVF::range_search_preassigned(
761
834
  idx_t i = iik / (idx_t)nprobe;
762
835
  idx_t ik = iik % (idx_t)nprobe;
763
836
  if (qres == nullptr || qres->qno != i) {
764
- FAISS_ASSERT(!qres || i > qres->qno);
765
837
  qres = &pres.new_result(i);
766
838
  scanner->set_query(x + i * d);
767
839
  }
@@ -797,7 +869,8 @@ void IndexIVF::range_search_preassigned(
797
869
  }
798
870
 
799
871
  InvertedListScanner* IndexIVF::get_InvertedListScanner(
800
- bool /*store_pairs*/) const {
872
+ bool /*store_pairs*/,
873
+ const IDSelector* /* sel */) const {
801
874
  return nullptr;
802
875
  }
803
876
 
@@ -825,6 +898,21 @@ void IndexIVF::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
825
898
  }
826
899
  }
827
900
 
901
+ bool IndexIVF::check_ids_sorted() const {
902
+ size_t nflip = 0;
903
+
904
+ for (size_t i = 0; i < nlist; i++) {
905
+ size_t list_size = invlists->list_size(i);
906
+ InvertedLists::ScopedIds ids(invlists, i);
907
+ for (size_t j = 0; j + 1 < list_size; j++) {
908
+ if (ids[j + 1] < ids[j]) {
909
+ nflip++;
910
+ }
911
+ }
912
+ }
913
+ return nflip == 0;
914
+ }
915
+
828
916
  /* standalone codec interface */
829
917
  size_t IndexIVF::sa_code_size() const {
830
918
  size_t coarse_size = coarse_code_size();
@@ -844,10 +932,15 @@ void IndexIVF::search_and_reconstruct(
844
932
  idx_t k,
845
933
  float* distances,
846
934
  idx_t* labels,
847
- float* recons) const {
848
- FAISS_THROW_IF_NOT(k > 0);
849
-
850
- const size_t nprobe = std::min(nlist, this->nprobe);
935
+ float* recons,
936
+ const SearchParameters* params_in) const {
937
+ const IVFSearchParameters* params = nullptr;
938
+ if (params_in) {
939
+ params = dynamic_cast<const IVFSearchParameters*>(params_in);
940
+ FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
941
+ }
942
+ const size_t nprobe =
943
+ std::min(nlist, params ? params->nprobe : this->nprobe);
851
944
  FAISS_THROW_IF_NOT(nprobe > 0);
852
945
 
853
946
  idx_t* idx = new idx_t[n * nprobe];
@@ -869,7 +962,8 @@ void IndexIVF::search_and_reconstruct(
869
962
  coarse_dis,
870
963
  distances,
871
964
  labels,
872
- true /* store_pairs */);
965
+ true /* store_pairs */,
966
+ params);
873
967
  for (idx_t i = 0; i < n; ++i) {
874
968
  for (idx_t j = 0; j < k; ++j) {
875
969
  idx_t ij = i * k + j;
@@ -955,26 +1049,41 @@ void IndexIVF::train_residual(idx_t /*n*/, const float* /*x*/) {
955
1049
  // does nothing by default
956
1050
  }
957
1051
 
958
- void IndexIVF::check_compatible_for_merge(const IndexIVF& other) const {
1052
+ bool check_compatible_for_merge_expensive_check = true;
1053
+
1054
+ void IndexIVF::check_compatible_for_merge(const Index& otherIndex) const {
959
1055
  // minimal sanity checks
960
- FAISS_THROW_IF_NOT(other.d == d);
961
- FAISS_THROW_IF_NOT(other.nlist == nlist);
962
- FAISS_THROW_IF_NOT(other.code_size == code_size);
1056
+ const IndexIVF* other = dynamic_cast<const IndexIVF*>(&otherIndex);
1057
+ FAISS_THROW_IF_NOT(other);
1058
+ FAISS_THROW_IF_NOT(other->d == d);
1059
+ FAISS_THROW_IF_NOT(other->nlist == nlist);
1060
+ FAISS_THROW_IF_NOT(quantizer->ntotal == other->quantizer->ntotal);
1061
+ FAISS_THROW_IF_NOT(other->code_size == code_size);
963
1062
  FAISS_THROW_IF_NOT_MSG(
964
- typeid(*this) == typeid(other),
1063
+ typeid(*this) == typeid(*other),
965
1064
  "can only merge indexes of the same type");
966
1065
  FAISS_THROW_IF_NOT_MSG(
967
- this->direct_map.no() && other.direct_map.no(),
1066
+ this->direct_map.no() && other->direct_map.no(),
968
1067
  "merge direct_map not implemented");
969
- }
970
1068
 
971
- void IndexIVF::merge_from(IndexIVF& other, idx_t add_id) {
972
- check_compatible_for_merge(other);
1069
+ if (check_compatible_for_merge_expensive_check) {
1070
+ std::vector<float> v(d), v2(d);
1071
+ for (size_t i = 0; i < nlist; i++) {
1072
+ quantizer->reconstruct(i, v.data());
1073
+ other->quantizer->reconstruct(i, v2.data());
1074
+ FAISS_THROW_IF_NOT_MSG(
1075
+ v == v2, "coarse quantizers should be the same");
1076
+ }
1077
+ }
1078
+ }
973
1079
 
974
- invlists->merge_from(other.invlists, add_id);
1080
+ void IndexIVF::merge_from(Index& otherIndex, idx_t add_id) {
1081
+ check_compatible_for_merge(otherIndex);
1082
+ IndexIVF* other = static_cast<IndexIVF*>(&otherIndex);
1083
+ invlists->merge_from(other->invlists, add_id);
975
1084
 
976
- ntotal += other.ntotal;
977
- other.ntotal = 0;
1085
+ ntotal += other->ntotal;
1086
+ other->ntotal = 0;
978
1087
  }
979
1088
 
980
1089
  void IndexIVF::replace_invlists(InvertedLists* il, bool own) {
@@ -1068,6 +1177,10 @@ IndexIVF::~IndexIVF() {
1068
1177
  }
1069
1178
  }
1070
1179
 
1180
+ /*************************************************************************
1181
+ * IndexIVFStats
1182
+ *************************************************************************/
1183
+
1071
1184
  void IndexIVFStats::reset() {
1072
1185
  memset((void*)this, 0, sizeof(*this));
1073
1186
  }
@@ -1083,13 +1196,60 @@ void IndexIVFStats::add(const IndexIVFStats& other) {
1083
1196
 
1084
1197
  IndexIVFStats indexIVF_stats;
1085
1198
 
1199
+ /*************************************************************************
1200
+ * InvertedListScanner
1201
+ *************************************************************************/
1202
+
1203
+ size_t InvertedListScanner::scan_codes(
1204
+ size_t list_size,
1205
+ const uint8_t* codes,
1206
+ const idx_t* ids,
1207
+ float* simi,
1208
+ idx_t* idxi,
1209
+ size_t k) const {
1210
+ size_t nup = 0;
1211
+
1212
+ if (!keep_max) {
1213
+ for (size_t j = 0; j < list_size; j++) {
1214
+ float dis = distance_to_code(codes);
1215
+ if (dis < simi[0]) {
1216
+ int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
1217
+ maxheap_replace_top(k, simi, idxi, dis, id);
1218
+ nup++;
1219
+ }
1220
+ codes += code_size;
1221
+ }
1222
+ } else {
1223
+ for (size_t j = 0; j < list_size; j++) {
1224
+ float dis = distance_to_code(codes);
1225
+ if (dis > simi[0]) {
1226
+ int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
1227
+ minheap_replace_top(k, simi, idxi, dis, id);
1228
+ nup++;
1229
+ }
1230
+ codes += code_size;
1231
+ }
1232
+ }
1233
+ return nup;
1234
+ }
1235
+
1086
1236
  void InvertedListScanner::scan_codes_range(
1087
- size_t,
1088
- const uint8_t*,
1089
- const idx_t*,
1090
- float,
1091
- RangeQueryResult&) const {
1092
- FAISS_THROW_MSG("scan_codes_range not implemented");
1237
+ size_t list_size,
1238
+ const uint8_t* codes,
1239
+ const idx_t* ids,
1240
+ float radius,
1241
+ RangeQueryResult& res) const {
1242
+ for (size_t j = 0; j < list_size; j++) {
1243
+ float dis = distance_to_code(codes);
1244
+ bool keep = !keep_max
1245
+ ? dis < radius
1246
+ : dis > radius; // TODO templatize to remove this test
1247
+ if (keep) {
1248
+ int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
1249
+ res.add(dis, id);
1250
+ }
1251
+ codes += code_size;
1252
+ }
1093
1253
  }
1094
1254
 
1095
1255
  } // namespace faiss
@@ -11,11 +11,13 @@
11
11
  #define FAISS_INDEX_IVF_H
12
12
 
13
13
  #include <stdint.h>
14
+ #include <memory>
14
15
  #include <unordered_map>
15
16
  #include <vector>
16
17
 
17
18
  #include <faiss/Clustering.h>
18
19
  #include <faiss/Index.h>
20
+ #include <faiss/impl/IDSelector.h>
19
21
  #include <faiss/impl/platform_macros.h>
20
22
  #include <faiss/invlists/DirectMap.h>
21
23
  #include <faiss/invlists/InvertedLists.h>
@@ -38,7 +40,7 @@ struct Level1Quantizer {
38
40
  * = 2: kmeans training on a flat index + add the centroids to the quantizer
39
41
  */
40
42
  char quantizer_trains_alone;
41
- bool own_fields; ///< whether object owns the quantizer
43
+ bool own_fields; ///< whether object owns the quantizer (false by default)
42
44
 
43
45
  ClusteringParameters cp; ///< to override default clustering params
44
46
  Index* clustering_index; ///< to override index used during clustering
@@ -62,13 +64,18 @@ struct Level1Quantizer {
62
64
  ~Level1Quantizer();
63
65
  };
64
66
 
65
- struct IVFSearchParameters {
67
+ struct SearchParametersIVF : SearchParameters {
66
68
  size_t nprobe; ///< number of probes at query time
67
69
  size_t max_codes; ///< max nb of codes to visit to do a query
68
- IVFSearchParameters() : nprobe(1), max_codes(0) {}
69
- virtual ~IVFSearchParameters() {}
70
+ SearchParameters* quantizer_params = nullptr;
71
+
72
+ SearchParametersIVF() : nprobe(1), max_codes(0) {}
73
+ virtual ~SearchParametersIVF() {}
70
74
  };
71
75
 
76
+ // the new convention puts the index type after SearchParameters
77
+ using IVFSearchParameters = SearchParametersIVF;
78
+
72
79
  struct InvertedListScanner;
73
80
  struct IndexIVFStats;
74
81
 
@@ -121,8 +128,7 @@ struct IndexIVF : Index, Level1Quantizer {
121
128
 
122
129
  /** The Inverted file takes a quantizer (an Index) on input,
123
130
  * which implements the function mapping a vector to a list
124
- * identifier. The pointer is borrowed: the quantizer should not
125
- * be deleted while the IndexIVF is in use.
131
+ * identifier.
126
132
  */
127
133
  IndexIVF(
128
134
  Index* quantizer,
@@ -171,6 +177,13 @@ struct IndexIVF : Index, Level1Quantizer {
171
177
  uint8_t* codes,
172
178
  bool include_listno = false) const = 0;
173
179
 
180
+ /** Add vectors that are computed with the standalone codec
181
+ *
182
+ * @param codes codes to add size n * sa_code_size()
183
+ * @param xids corresponding ids, size n
184
+ */
185
+ void add_sa_codes(idx_t n, const uint8_t* codes, const idx_t* xids);
186
+
174
187
  /// Sub-classes that encode the residuals can train their encoders here
175
188
  /// does nothing by default
176
189
  virtual void train_residual(idx_t n, const float* x);
@@ -212,13 +225,15 @@ struct IndexIVF : Index, Level1Quantizer {
212
225
  const float* x,
213
226
  idx_t k,
214
227
  float* distances,
215
- idx_t* labels) const override;
228
+ idx_t* labels,
229
+ const SearchParameters* params = nullptr) const override;
216
230
 
217
231
  void range_search(
218
232
  idx_t n,
219
233
  const float* x,
220
234
  float radius,
221
- RangeSearchResult* result) const override;
235
+ RangeSearchResult* result,
236
+ const SearchParameters* params = nullptr) const override;
222
237
 
223
238
  void range_search_preassigned(
224
239
  idx_t nx,
@@ -231,9 +246,13 @@ struct IndexIVF : Index, Level1Quantizer {
231
246
  const IVFSearchParameters* params = nullptr,
232
247
  IndexIVFStats* stats = nullptr) const;
233
248
 
234
- /// get a scanner for this index (store_pairs means ignore labels)
249
+ /** Get a scanner for this index (store_pairs means ignore labels)
250
+ *
251
+ * The default search implementation uses this to compute the distances
252
+ */
235
253
  virtual InvertedListScanner* get_InvertedListScanner(
236
- bool store_pairs = false) const;
254
+ bool store_pairs = false,
255
+ const IDSelector* sel = nullptr) const;
237
256
 
238
257
  /** reconstruct a vector. Works only if maintain_direct_map is set to 1 or 2
239
258
  */
@@ -275,7 +294,8 @@ struct IndexIVF : Index, Level1Quantizer {
275
294
  idx_t k,
276
295
  float* distances,
277
296
  idx_t* labels,
278
- float* recons) const override;
297
+ float* recons,
298
+ const SearchParameters* params = nullptr) const override;
279
299
 
280
300
  /** Reconstruct a vector given the location in terms of (inv list index +
281
301
  * inv list offset) instead of the id.
@@ -293,15 +313,9 @@ struct IndexIVF : Index, Level1Quantizer {
293
313
 
294
314
  size_t remove_ids(const IDSelector& sel) override;
295
315
 
296
- /** check that the two indexes are compatible (ie, they are
297
- * trained in the same way and have the same
298
- * parameters). Otherwise throw. */
299
- void check_compatible_for_merge(const IndexIVF& other) const;
316
+ void check_compatible_for_merge(const Index& otherIndex) const override;
300
317
 
301
- /** moves the entries from another dataset to self. On output,
302
- * other is empty. add_id is added to all moved ids (for
303
- * sequential ids, this would be this->ntotal */
304
- virtual void merge_from(IndexIVF& other, idx_t add_id);
318
+ virtual void merge_from(Index& otherIndex, idx_t add_id) override;
305
319
 
306
320
  /** copy a subset of the entries index to the other index
307
321
  *
@@ -322,6 +336,9 @@ struct IndexIVF : Index, Level1Quantizer {
322
336
  return invlists->list_size(list_no);
323
337
  }
324
338
 
339
+ /// are the ids sorted?
340
+ bool check_ids_sorted() const;
341
+
325
342
  /** intialize a direct map
326
343
  *
327
344
  * @param new_maintain_direct_map if true, create a direct map,
@@ -351,6 +368,22 @@ struct RangeQueryResult;
351
368
  struct InvertedListScanner {
352
369
  using idx_t = Index::idx_t;
353
370
 
371
+ idx_t list_no = -1; ///< remember current list
372
+ bool keep_max = false; ///< keep maximum instead of minimum
373
+ /// store positions in invlists rather than labels
374
+ bool store_pairs;
375
+
376
+ /// search in this subset of ids
377
+ const IDSelector* sel;
378
+
379
+ InvertedListScanner(
380
+ bool store_pairs = false,
381
+ const IDSelector* sel = nullptr)
382
+ : store_pairs(store_pairs), sel(sel) {}
383
+
384
+ /// used in default implementation of scan_codes
385
+ size_t code_size = 0;
386
+
354
387
  /// from now on we handle this query.
355
388
  virtual void set_query(const float* query_vector) = 0;
356
389
 
@@ -361,7 +394,8 @@ struct InvertedListScanner {
361
394
  virtual float distance_to_code(const uint8_t* code) const = 0;
362
395
 
363
396
  /** scan a set of codes, compute distances to current query and
364
- * update heap of results if necessary.
397
+ * update heap of results if necessary. Default implemetation
398
+ * calls distance_to_code.
365
399
  *
366
400
  * @param n number of codes to scan
367
401
  * @param codes codes to scan (n * code_size)
@@ -377,7 +411,7 @@ struct InvertedListScanner {
377
411
  const idx_t* ids,
378
412
  float* distances,
379
413
  idx_t* labels,
380
- size_t k) const = 0;
414
+ size_t k) const;
381
415
 
382
416
  /** scan a set of codes, compute distances to current query and
383
417
  * update results if distances are below radius
@@ -393,10 +427,13 @@ struct InvertedListScanner {
393
427
  virtual ~InvertedListScanner() {}
394
428
  };
395
429
 
430
+ // whether to check that coarse quantizers are the same
431
+ FAISS_API extern bool check_compatible_for_merge_expensive_check;
432
+
396
433
  struct IndexIVFStats {
397
434
  size_t nq; // nb of queries run
398
435
  size_t nlist; // nb of inverted lists scanned
399
- size_t ndis; // nb of distancs computed
436
+ size_t ndis; // nb of distances computed
400
437
  size_t nheap_updates; // nb of times the heap was updated
401
438
  double quantization_time; // time spent quantizing vectors (in ms)
402
439
  double search_time; // time spent searching lists (in ms)