faiss 0.2.3 → 0.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +23 -21
  5. data/ext/faiss/extconf.rb +11 -0
  6. data/ext/faiss/index.cpp +4 -4
  7. data/ext/faiss/index_binary.cpp +6 -6
  8. data/ext/faiss/product_quantizer.cpp +4 -4
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  11. data/vendor/faiss/faiss/Clustering.cpp +32 -0
  12. data/vendor/faiss/faiss/Clustering.h +14 -0
  13. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  14. data/vendor/faiss/faiss/IVFlib.h +26 -2
  15. data/vendor/faiss/faiss/Index.cpp +36 -3
  16. data/vendor/faiss/faiss/Index.h +43 -6
  17. data/vendor/faiss/faiss/Index2Layer.cpp +24 -93
  18. data/vendor/faiss/faiss/Index2Layer.h +8 -17
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +610 -0
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +253 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  22. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  23. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  24. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  25. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  26. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  28. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  30. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  31. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  32. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  33. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  34. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  35. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  36. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  37. data/vendor/faiss/faiss/IndexFlat.cpp +52 -69
  38. data/vendor/faiss/faiss/IndexFlat.h +16 -19
  39. data/vendor/faiss/faiss/IndexFlatCodes.cpp +101 -0
  40. data/vendor/faiss/faiss/IndexFlatCodes.h +59 -0
  41. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  42. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  43. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  44. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  45. data/vendor/faiss/faiss/IndexIVF.cpp +200 -40
  46. data/vendor/faiss/faiss/IndexIVF.h +59 -22
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +393 -0
  48. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +183 -0
  49. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  50. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  51. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  52. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  53. data/vendor/faiss/faiss/IndexIVFFlat.cpp +43 -26
  54. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  55. data/vendor/faiss/faiss/IndexIVFPQ.cpp +238 -53
  56. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -2
  57. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  58. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  59. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  60. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  61. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +63 -40
  62. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +23 -7
  63. data/vendor/faiss/faiss/IndexLSH.cpp +8 -32
  64. data/vendor/faiss/faiss/IndexLSH.h +4 -16
  65. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  66. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  67. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -5
  68. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  69. data/vendor/faiss/faiss/IndexNSG.cpp +37 -5
  70. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  71. data/vendor/faiss/faiss/IndexPQ.cpp +108 -120
  72. data/vendor/faiss/faiss/IndexPQ.h +21 -22
  73. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  74. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  75. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  76. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  77. data/vendor/faiss/faiss/IndexRefine.cpp +36 -4
  78. data/vendor/faiss/faiss/IndexRefine.h +14 -2
  79. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  80. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +28 -43
  84. data/vendor/faiss/faiss/IndexScalarQuantizer.h +8 -23
  85. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  86. data/vendor/faiss/faiss/IndexShards.h +2 -1
  87. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  88. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  89. data/vendor/faiss/faiss/VectorTransform.cpp +45 -1
  90. data/vendor/faiss/faiss/VectorTransform.h +25 -4
  91. data/vendor/faiss/faiss/clone_index.cpp +26 -3
  92. data/vendor/faiss/faiss/clone_index.h +3 -0
  93. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  94. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  95. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  101. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  102. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  104. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  105. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -6
  106. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  107. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  108. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  109. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  110. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  111. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  112. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  113. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  114. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  115. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  116. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  117. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  120. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  121. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  122. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  123. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  124. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +331 -29
  125. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +110 -19
  126. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  127. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  128. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  129. data/vendor/faiss/faiss/impl/HNSW.cpp +133 -32
  130. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  131. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  132. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  133. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +378 -217
  134. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +106 -29
  135. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  136. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  137. data/vendor/faiss/faiss/impl/NSG.cpp +1 -4
  138. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  139. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  140. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  141. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  142. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  143. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  144. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +521 -55
  145. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +94 -16
  146. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  147. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +108 -191
  148. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  149. data/vendor/faiss/faiss/impl/index_read.cpp +338 -24
  150. data/vendor/faiss/faiss/impl/index_write.cpp +300 -18
  151. data/vendor/faiss/faiss/impl/io.cpp +1 -1
  152. data/vendor/faiss/faiss/impl/io_macros.h +20 -0
  153. data/vendor/faiss/faiss/impl/kmeans1d.cpp +303 -0
  154. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  159. data/vendor/faiss/faiss/index_factory.cpp +772 -412
  160. data/vendor/faiss/faiss/index_factory.h +3 -0
  161. data/vendor/faiss/faiss/index_io.h +5 -0
  162. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  163. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  164. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  165. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  166. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  167. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  168. data/vendor/faiss/faiss/utils/distances.cpp +384 -58
  169. data/vendor/faiss/faiss/utils/distances.h +149 -18
  170. data/vendor/faiss/faiss/utils/distances_simd.cpp +776 -6
  171. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  172. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  173. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  174. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  175. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  176. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  177. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  178. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  179. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  180. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  181. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  182. data/vendor/faiss/faiss/utils/random.h +5 -0
  183. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  184. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  185. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  186. data/vendor/faiss/faiss/utils/utils.h +1 -1
  187. metadata +46 -5
  188. data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
  189. data/vendor/faiss/faiss/IndexResidual.h +0 -152
@@ -23,6 +23,7 @@
23
23
  #include <faiss/IndexFlat.h>
24
24
  #include <faiss/impl/AuxIndexStructures.h>
25
25
  #include <faiss/impl/FaissAssert.h>
26
+ #include <faiss/impl/IDSelector.h>
26
27
 
27
28
  namespace faiss {
28
29
 
@@ -107,8 +108,15 @@ void Level1Quantizer::train_q1(
107
108
  } else {
108
109
  clus.train(n, x, *clustering_index);
109
110
  }
110
- if (verbose)
111
+ if (verbose) {
111
112
  printf("Adding centroids to quantizer\n");
113
+ }
114
+ if (!quantizer->is_trained) {
115
+ if (verbose) {
116
+ printf("But training it first on centroids table...\n");
117
+ }
118
+ quantizer->train(nlist, clus.centroids.data());
119
+ }
112
120
  quantizer->add(nlist, clus.centroids.data());
113
121
  }
114
122
  }
@@ -190,6 +198,20 @@ void IndexIVF::add_with_ids(idx_t n, const float* x, const idx_t* xids) {
190
198
  add_core(n, x, xids, coarse_idx.get());
191
199
  }
192
200
 
201
+ void IndexIVF::add_sa_codes(idx_t n, const uint8_t* codes, const idx_t* xids) {
202
+ size_t coarse_size = coarse_code_size();
203
+ DirectMapAdd dm_adder(direct_map, n, xids);
204
+
205
+ for (idx_t i = 0; i < n; i++) {
206
+ const uint8_t* code = codes + (code_size + coarse_size) * i;
207
+ idx_t list_no = decode_listno(code);
208
+ idx_t id = xids ? xids[i] : ntotal + i;
209
+ size_t ofs = invlists->add_entry(list_no, id, code + coarse_size);
210
+ dm_adder.add(i, list_no, ofs);
211
+ }
212
+ ntotal += n;
213
+ }
214
+
193
215
  void IndexIVF::add_core(
194
216
  idx_t n,
195
217
  const float* x,
@@ -282,14 +304,20 @@ void IndexIVF::search(
282
304
  const float* x,
283
305
  idx_t k,
284
306
  float* distances,
285
- idx_t* labels) const {
307
+ idx_t* labels,
308
+ const SearchParameters* params_in) const {
286
309
  FAISS_THROW_IF_NOT(k > 0);
287
-
288
- const size_t nprobe = std::min(nlist, this->nprobe);
310
+ const IVFSearchParameters* params = nullptr;
311
+ if (params_in) {
312
+ params = dynamic_cast<const IVFSearchParameters*>(params_in);
313
+ FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
314
+ }
315
+ const size_t nprobe =
316
+ std::min(nlist, params ? params->nprobe : this->nprobe);
289
317
  FAISS_THROW_IF_NOT(nprobe > 0);
290
318
 
291
319
  // search function for a subset of queries
292
- auto sub_search_func = [this, k, nprobe](
320
+ auto sub_search_func = [this, k, nprobe, params](
293
321
  idx_t n,
294
322
  const float* x,
295
323
  float* distances,
@@ -299,7 +327,13 @@ void IndexIVF::search(
299
327
  std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
300
328
 
301
329
  double t0 = getmillisecs();
302
- quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());
330
+ quantizer->search(
331
+ n,
332
+ x,
333
+ nprobe,
334
+ coarse_dis.get(),
335
+ idx.get(),
336
+ params ? params->quantizer_params : nullptr);
303
337
 
304
338
  double t1 = getmillisecs();
305
339
  invlists->prefetch_lists(idx.get(), n * nprobe);
@@ -313,7 +347,7 @@ void IndexIVF::search(
313
347
  distances,
314
348
  labels,
315
349
  false,
316
- nullptr,
350
+ params,
317
351
  ivf_stats);
318
352
  double t2 = getmillisecs();
319
353
  ivf_stats->quantization_time += t1 - t0;
@@ -379,6 +413,19 @@ void IndexIVF::search_preassigned(
379
413
  FAISS_THROW_IF_NOT(nprobe > 0);
380
414
 
381
415
  idx_t max_codes = params ? params->max_codes : this->max_codes;
416
+ IDSelector* sel = params ? params->sel : nullptr;
417
+ const IDSelectorRange* selr = dynamic_cast<const IDSelectorRange*>(sel);
418
+ if (selr) {
419
+ if (selr->assume_sorted) {
420
+ sel = nullptr; // use special IDSelectorRange processing
421
+ } else {
422
+ selr = nullptr; // use generic processing
423
+ }
424
+ }
425
+
426
+ FAISS_THROW_IF_NOT_MSG(
427
+ !(sel && store_pairs),
428
+ "selector and store_pairs cannot be combined");
382
429
 
383
430
  size_t nlistv = 0, ndis = 0, nheap = 0;
384
431
 
@@ -400,7 +447,8 @@ void IndexIVF::search_preassigned(
400
447
 
401
448
  #pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis, nheap)
402
449
  {
403
- InvertedListScanner* scanner = get_InvertedListScanner(store_pairs);
450
+ InvertedListScanner* scanner =
451
+ get_InvertedListScanner(store_pairs, sel);
404
452
  ScopeDeleter1<InvertedListScanner> del(scanner);
405
453
 
406
454
  /*****************************************************
@@ -471,6 +519,7 @@ void IndexIVF::search_preassigned(
471
519
 
472
520
  try {
473
521
  InvertedLists::ScopedCodes scodes(invlists, key);
522
+ const uint8_t* codes = scodes.get();
474
523
 
475
524
  std::unique_ptr<InvertedLists::ScopedIds> sids;
476
525
  const Index::idx_t* ids = nullptr;
@@ -480,8 +529,20 @@ void IndexIVF::search_preassigned(
480
529
  ids = sids->get();
481
530
  }
482
531
 
532
+ if (selr) { // IDSelectorRange
533
+ // restrict search to a section of the inverted list
534
+ size_t jmin, jmax;
535
+ selr->find_sorted_ids_bounds(list_size, ids, &jmin, &jmax);
536
+ list_size = jmax - jmin;
537
+ if (list_size == 0) {
538
+ return (size_t)0;
539
+ }
540
+ codes += jmin * code_size;
541
+ ids += jmin;
542
+ }
543
+
483
544
  nheap += scanner->scan_codes(
484
- list_size, scodes.get(), ids, simi, idxi, k);
545
+ list_size, codes, ids, simi, idxi, k);
485
546
 
486
547
  } catch (const std::exception& e) {
487
548
  std::lock_guard<std::mutex> lock(exception_mutex);
@@ -630,13 +691,23 @@ void IndexIVF::range_search(
630
691
  idx_t nx,
631
692
  const float* x,
632
693
  float radius,
633
- RangeSearchResult* result) const {
634
- const size_t nprobe = std::min(nlist, this->nprobe);
694
+ RangeSearchResult* result,
695
+ const SearchParameters* params_in) const {
696
+ const IVFSearchParameters* params = nullptr;
697
+ const SearchParameters* quantizer_params = nullptr;
698
+ if (params_in) {
699
+ params = dynamic_cast<const IVFSearchParameters*>(params_in);
700
+ FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
701
+ quantizer_params = params->quantizer_params;
702
+ }
703
+ const size_t nprobe =
704
+ std::min(nlist, params ? params->nprobe : this->nprobe);
635
705
  std::unique_ptr<idx_t[]> keys(new idx_t[nx * nprobe]);
636
706
  std::unique_ptr<float[]> coarse_dis(new float[nx * nprobe]);
637
707
 
638
708
  double t0 = getmillisecs();
639
- quantizer->search(nx, x, nprobe, coarse_dis.get(), keys.get());
709
+ quantizer->search(
710
+ nx, x, nprobe, coarse_dis.get(), keys.get(), quantizer_params);
640
711
  indexIVF_stats.quantization_time += getmillisecs() - t0;
641
712
 
642
713
  t0 = getmillisecs();
@@ -650,7 +721,7 @@ void IndexIVF::range_search(
650
721
  coarse_dis.get(),
651
722
  result,
652
723
  false,
653
- nullptr,
724
+ params,
654
725
  &indexIVF_stats);
655
726
 
656
727
  indexIVF_stats.search_time += getmillisecs() - t0;
@@ -668,7 +739,10 @@ void IndexIVF::range_search_preassigned(
668
739
  IndexIVFStats* stats) const {
669
740
  idx_t nprobe = params ? params->nprobe : this->nprobe;
670
741
  nprobe = std::min((idx_t)nlist, nprobe);
742
+ FAISS_THROW_IF_NOT(nprobe > 0);
743
+
671
744
  idx_t max_codes = params ? params->max_codes : this->max_codes;
745
+ IDSelector* sel = params ? params->sel : nullptr;
672
746
 
673
747
  size_t nlistv = 0, ndis = 0;
674
748
 
@@ -690,7 +764,7 @@ void IndexIVF::range_search_preassigned(
690
764
  {
691
765
  RangeSearchPartialResult pres(result);
692
766
  std::unique_ptr<InvertedListScanner> scanner(
693
- get_InvertedListScanner(store_pairs));
767
+ get_InvertedListScanner(store_pairs, sel));
694
768
  FAISS_THROW_IF_NOT(scanner.get());
695
769
  all_pres[omp_get_thread_num()] = &pres;
696
770
 
@@ -753,7 +827,6 @@ void IndexIVF::range_search_preassigned(
753
827
  }
754
828
  }
755
829
  } else if (parallel_mode == 2) {
756
- std::vector<RangeQueryResult*> all_qres(nx);
757
830
  RangeQueryResult* qres = nullptr;
758
831
 
759
832
  #pragma omp for schedule(dynamic)
@@ -761,7 +834,6 @@ void IndexIVF::range_search_preassigned(
761
834
  idx_t i = iik / (idx_t)nprobe;
762
835
  idx_t ik = iik % (idx_t)nprobe;
763
836
  if (qres == nullptr || qres->qno != i) {
764
- FAISS_ASSERT(!qres || i > qres->qno);
765
837
  qres = &pres.new_result(i);
766
838
  scanner->set_query(x + i * d);
767
839
  }
@@ -797,7 +869,8 @@ void IndexIVF::range_search_preassigned(
797
869
  }
798
870
 
799
871
  InvertedListScanner* IndexIVF::get_InvertedListScanner(
800
- bool /*store_pairs*/) const {
872
+ bool /*store_pairs*/,
873
+ const IDSelector* /* sel */) const {
801
874
  return nullptr;
802
875
  }
803
876
 
@@ -825,6 +898,21 @@ void IndexIVF::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
825
898
  }
826
899
  }
827
900
 
901
+ bool IndexIVF::check_ids_sorted() const {
902
+ size_t nflip = 0;
903
+
904
+ for (size_t i = 0; i < nlist; i++) {
905
+ size_t list_size = invlists->list_size(i);
906
+ InvertedLists::ScopedIds ids(invlists, i);
907
+ for (size_t j = 0; j + 1 < list_size; j++) {
908
+ if (ids[j + 1] < ids[j]) {
909
+ nflip++;
910
+ }
911
+ }
912
+ }
913
+ return nflip == 0;
914
+ }
915
+
828
916
  /* standalone codec interface */
829
917
  size_t IndexIVF::sa_code_size() const {
830
918
  size_t coarse_size = coarse_code_size();
@@ -844,10 +932,15 @@ void IndexIVF::search_and_reconstruct(
844
932
  idx_t k,
845
933
  float* distances,
846
934
  idx_t* labels,
847
- float* recons) const {
848
- FAISS_THROW_IF_NOT(k > 0);
849
-
850
- const size_t nprobe = std::min(nlist, this->nprobe);
935
+ float* recons,
936
+ const SearchParameters* params_in) const {
937
+ const IVFSearchParameters* params = nullptr;
938
+ if (params_in) {
939
+ params = dynamic_cast<const IVFSearchParameters*>(params_in);
940
+ FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
941
+ }
942
+ const size_t nprobe =
943
+ std::min(nlist, params ? params->nprobe : this->nprobe);
851
944
  FAISS_THROW_IF_NOT(nprobe > 0);
852
945
 
853
946
  idx_t* idx = new idx_t[n * nprobe];
@@ -869,7 +962,8 @@ void IndexIVF::search_and_reconstruct(
869
962
  coarse_dis,
870
963
  distances,
871
964
  labels,
872
- true /* store_pairs */);
965
+ true /* store_pairs */,
966
+ params);
873
967
  for (idx_t i = 0; i < n; ++i) {
874
968
  for (idx_t j = 0; j < k; ++j) {
875
969
  idx_t ij = i * k + j;
@@ -955,26 +1049,41 @@ void IndexIVF::train_residual(idx_t /*n*/, const float* /*x*/) {
955
1049
  // does nothing by default
956
1050
  }
957
1051
 
958
- void IndexIVF::check_compatible_for_merge(const IndexIVF& other) const {
1052
+ bool check_compatible_for_merge_expensive_check = true;
1053
+
1054
+ void IndexIVF::check_compatible_for_merge(const Index& otherIndex) const {
959
1055
  // minimal sanity checks
960
- FAISS_THROW_IF_NOT(other.d == d);
961
- FAISS_THROW_IF_NOT(other.nlist == nlist);
962
- FAISS_THROW_IF_NOT(other.code_size == code_size);
1056
+ const IndexIVF* other = dynamic_cast<const IndexIVF*>(&otherIndex);
1057
+ FAISS_THROW_IF_NOT(other);
1058
+ FAISS_THROW_IF_NOT(other->d == d);
1059
+ FAISS_THROW_IF_NOT(other->nlist == nlist);
1060
+ FAISS_THROW_IF_NOT(quantizer->ntotal == other->quantizer->ntotal);
1061
+ FAISS_THROW_IF_NOT(other->code_size == code_size);
963
1062
  FAISS_THROW_IF_NOT_MSG(
964
- typeid(*this) == typeid(other),
1063
+ typeid(*this) == typeid(*other),
965
1064
  "can only merge indexes of the same type");
966
1065
  FAISS_THROW_IF_NOT_MSG(
967
- this->direct_map.no() && other.direct_map.no(),
1066
+ this->direct_map.no() && other->direct_map.no(),
968
1067
  "merge direct_map not implemented");
969
- }
970
1068
 
971
- void IndexIVF::merge_from(IndexIVF& other, idx_t add_id) {
972
- check_compatible_for_merge(other);
1069
+ if (check_compatible_for_merge_expensive_check) {
1070
+ std::vector<float> v(d), v2(d);
1071
+ for (size_t i = 0; i < nlist; i++) {
1072
+ quantizer->reconstruct(i, v.data());
1073
+ other->quantizer->reconstruct(i, v2.data());
1074
+ FAISS_THROW_IF_NOT_MSG(
1075
+ v == v2, "coarse quantizers should be the same");
1076
+ }
1077
+ }
1078
+ }
973
1079
 
974
- invlists->merge_from(other.invlists, add_id);
1080
+ void IndexIVF::merge_from(Index& otherIndex, idx_t add_id) {
1081
+ check_compatible_for_merge(otherIndex);
1082
+ IndexIVF* other = static_cast<IndexIVF*>(&otherIndex);
1083
+ invlists->merge_from(other->invlists, add_id);
975
1084
 
976
- ntotal += other.ntotal;
977
- other.ntotal = 0;
1085
+ ntotal += other->ntotal;
1086
+ other->ntotal = 0;
978
1087
  }
979
1088
 
980
1089
  void IndexIVF::replace_invlists(InvertedLists* il, bool own) {
@@ -1068,6 +1177,10 @@ IndexIVF::~IndexIVF() {
1068
1177
  }
1069
1178
  }
1070
1179
 
1180
+ /*************************************************************************
1181
+ * IndexIVFStats
1182
+ *************************************************************************/
1183
+
1071
1184
  void IndexIVFStats::reset() {
1072
1185
  memset((void*)this, 0, sizeof(*this));
1073
1186
  }
@@ -1083,13 +1196,60 @@ void IndexIVFStats::add(const IndexIVFStats& other) {
1083
1196
 
1084
1197
  IndexIVFStats indexIVF_stats;
1085
1198
 
1199
+ /*************************************************************************
1200
+ * InvertedListScanner
1201
+ *************************************************************************/
1202
+
1203
+ size_t InvertedListScanner::scan_codes(
1204
+ size_t list_size,
1205
+ const uint8_t* codes,
1206
+ const idx_t* ids,
1207
+ float* simi,
1208
+ idx_t* idxi,
1209
+ size_t k) const {
1210
+ size_t nup = 0;
1211
+
1212
+ if (!keep_max) {
1213
+ for (size_t j = 0; j < list_size; j++) {
1214
+ float dis = distance_to_code(codes);
1215
+ if (dis < simi[0]) {
1216
+ int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
1217
+ maxheap_replace_top(k, simi, idxi, dis, id);
1218
+ nup++;
1219
+ }
1220
+ codes += code_size;
1221
+ }
1222
+ } else {
1223
+ for (size_t j = 0; j < list_size; j++) {
1224
+ float dis = distance_to_code(codes);
1225
+ if (dis > simi[0]) {
1226
+ int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
1227
+ minheap_replace_top(k, simi, idxi, dis, id);
1228
+ nup++;
1229
+ }
1230
+ codes += code_size;
1231
+ }
1232
+ }
1233
+ return nup;
1234
+ }
1235
+
1086
1236
  void InvertedListScanner::scan_codes_range(
1087
- size_t,
1088
- const uint8_t*,
1089
- const idx_t*,
1090
- float,
1091
- RangeQueryResult&) const {
1092
- FAISS_THROW_MSG("scan_codes_range not implemented");
1237
+ size_t list_size,
1238
+ const uint8_t* codes,
1239
+ const idx_t* ids,
1240
+ float radius,
1241
+ RangeQueryResult& res) const {
1242
+ for (size_t j = 0; j < list_size; j++) {
1243
+ float dis = distance_to_code(codes);
1244
+ bool keep = !keep_max
1245
+ ? dis < radius
1246
+ : dis > radius; // TODO templatize to remove this test
1247
+ if (keep) {
1248
+ int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
1249
+ res.add(dis, id);
1250
+ }
1251
+ codes += code_size;
1252
+ }
1093
1253
  }
1094
1254
 
1095
1255
  } // namespace faiss
@@ -11,11 +11,13 @@
11
11
  #define FAISS_INDEX_IVF_H
12
12
 
13
13
  #include <stdint.h>
14
+ #include <memory>
14
15
  #include <unordered_map>
15
16
  #include <vector>
16
17
 
17
18
  #include <faiss/Clustering.h>
18
19
  #include <faiss/Index.h>
20
+ #include <faiss/impl/IDSelector.h>
19
21
  #include <faiss/impl/platform_macros.h>
20
22
  #include <faiss/invlists/DirectMap.h>
21
23
  #include <faiss/invlists/InvertedLists.h>
@@ -38,7 +40,7 @@ struct Level1Quantizer {
38
40
  * = 2: kmeans training on a flat index + add the centroids to the quantizer
39
41
  */
40
42
  char quantizer_trains_alone;
41
- bool own_fields; ///< whether object owns the quantizer
43
+ bool own_fields; ///< whether object owns the quantizer (false by default)
42
44
 
43
45
  ClusteringParameters cp; ///< to override default clustering params
44
46
  Index* clustering_index; ///< to override index used during clustering
@@ -62,13 +64,18 @@ struct Level1Quantizer {
62
64
  ~Level1Quantizer();
63
65
  };
64
66
 
65
- struct IVFSearchParameters {
67
+ struct SearchParametersIVF : SearchParameters {
66
68
  size_t nprobe; ///< number of probes at query time
67
69
  size_t max_codes; ///< max nb of codes to visit to do a query
68
- IVFSearchParameters() : nprobe(1), max_codes(0) {}
69
- virtual ~IVFSearchParameters() {}
70
+ SearchParameters* quantizer_params = nullptr;
71
+
72
+ SearchParametersIVF() : nprobe(1), max_codes(0) {}
73
+ virtual ~SearchParametersIVF() {}
70
74
  };
71
75
 
76
+ // the new convention puts the index type after SearchParameters
77
+ using IVFSearchParameters = SearchParametersIVF;
78
+
72
79
  struct InvertedListScanner;
73
80
  struct IndexIVFStats;
74
81
 
@@ -121,8 +128,7 @@ struct IndexIVF : Index, Level1Quantizer {
121
128
 
122
129
  /** The Inverted file takes a quantizer (an Index) on input,
123
130
  * which implements the function mapping a vector to a list
124
- * identifier. The pointer is borrowed: the quantizer should not
125
- * be deleted while the IndexIVF is in use.
131
+ * identifier.
126
132
  */
127
133
  IndexIVF(
128
134
  Index* quantizer,
@@ -171,6 +177,13 @@ struct IndexIVF : Index, Level1Quantizer {
171
177
  uint8_t* codes,
172
178
  bool include_listno = false) const = 0;
173
179
 
180
+ /** Add vectors that are computed with the standalone codec
181
+ *
182
+ * @param codes codes to add size n * sa_code_size()
183
+ * @param xids corresponding ids, size n
184
+ */
185
+ void add_sa_codes(idx_t n, const uint8_t* codes, const idx_t* xids);
186
+
174
187
  /// Sub-classes that encode the residuals can train their encoders here
175
188
  /// does nothing by default
176
189
  virtual void train_residual(idx_t n, const float* x);
@@ -212,13 +225,15 @@ struct IndexIVF : Index, Level1Quantizer {
212
225
  const float* x,
213
226
  idx_t k,
214
227
  float* distances,
215
- idx_t* labels) const override;
228
+ idx_t* labels,
229
+ const SearchParameters* params = nullptr) const override;
216
230
 
217
231
  void range_search(
218
232
  idx_t n,
219
233
  const float* x,
220
234
  float radius,
221
- RangeSearchResult* result) const override;
235
+ RangeSearchResult* result,
236
+ const SearchParameters* params = nullptr) const override;
222
237
 
223
238
  void range_search_preassigned(
224
239
  idx_t nx,
@@ -231,9 +246,13 @@ struct IndexIVF : Index, Level1Quantizer {
231
246
  const IVFSearchParameters* params = nullptr,
232
247
  IndexIVFStats* stats = nullptr) const;
233
248
 
234
- /// get a scanner for this index (store_pairs means ignore labels)
249
+ /** Get a scanner for this index (store_pairs means ignore labels)
250
+ *
251
+ * The default search implementation uses this to compute the distances
252
+ */
235
253
  virtual InvertedListScanner* get_InvertedListScanner(
236
- bool store_pairs = false) const;
254
+ bool store_pairs = false,
255
+ const IDSelector* sel = nullptr) const;
237
256
 
238
257
  /** reconstruct a vector. Works only if maintain_direct_map is set to 1 or 2
239
258
  */
@@ -275,7 +294,8 @@ struct IndexIVF : Index, Level1Quantizer {
275
294
  idx_t k,
276
295
  float* distances,
277
296
  idx_t* labels,
278
- float* recons) const override;
297
+ float* recons,
298
+ const SearchParameters* params = nullptr) const override;
279
299
 
280
300
  /** Reconstruct a vector given the location in terms of (inv list index +
281
301
  * inv list offset) instead of the id.
@@ -293,15 +313,9 @@ struct IndexIVF : Index, Level1Quantizer {
293
313
 
294
314
  size_t remove_ids(const IDSelector& sel) override;
295
315
 
296
- /** check that the two indexes are compatible (ie, they are
297
- * trained in the same way and have the same
298
- * parameters). Otherwise throw. */
299
- void check_compatible_for_merge(const IndexIVF& other) const;
316
+ void check_compatible_for_merge(const Index& otherIndex) const override;
300
317
 
301
- /** moves the entries from another dataset to self. On output,
302
- * other is empty. add_id is added to all moved ids (for
303
- * sequential ids, this would be this->ntotal */
304
- virtual void merge_from(IndexIVF& other, idx_t add_id);
318
+ virtual void merge_from(Index& otherIndex, idx_t add_id) override;
305
319
 
306
320
  /** copy a subset of the entries index to the other index
307
321
  *
@@ -322,6 +336,9 @@ struct IndexIVF : Index, Level1Quantizer {
322
336
  return invlists->list_size(list_no);
323
337
  }
324
338
 
339
+ /// are the ids sorted?
340
+ bool check_ids_sorted() const;
341
+
325
342
  /** intialize a direct map
326
343
  *
327
344
  * @param new_maintain_direct_map if true, create a direct map,
@@ -351,6 +368,22 @@ struct RangeQueryResult;
351
368
  struct InvertedListScanner {
352
369
  using idx_t = Index::idx_t;
353
370
 
371
+ idx_t list_no = -1; ///< remember current list
372
+ bool keep_max = false; ///< keep maximum instead of minimum
373
+ /// store positions in invlists rather than labels
374
+ bool store_pairs;
375
+
376
+ /// search in this subset of ids
377
+ const IDSelector* sel;
378
+
379
+ InvertedListScanner(
380
+ bool store_pairs = false,
381
+ const IDSelector* sel = nullptr)
382
+ : store_pairs(store_pairs), sel(sel) {}
383
+
384
+ /// used in default implementation of scan_codes
385
+ size_t code_size = 0;
386
+
354
387
  /// from now on we handle this query.
355
388
  virtual void set_query(const float* query_vector) = 0;
356
389
 
@@ -361,7 +394,8 @@ struct InvertedListScanner {
361
394
  virtual float distance_to_code(const uint8_t* code) const = 0;
362
395
 
363
396
  /** scan a set of codes, compute distances to current query and
364
- * update heap of results if necessary.
397
+ * update heap of results if necessary. Default implemetation
398
+ * calls distance_to_code.
365
399
  *
366
400
  * @param n number of codes to scan
367
401
  * @param codes codes to scan (n * code_size)
@@ -377,7 +411,7 @@ struct InvertedListScanner {
377
411
  const idx_t* ids,
378
412
  float* distances,
379
413
  idx_t* labels,
380
- size_t k) const = 0;
414
+ size_t k) const;
381
415
 
382
416
  /** scan a set of codes, compute distances to current query and
383
417
  * update results if distances are below radius
@@ -393,10 +427,13 @@ struct InvertedListScanner {
393
427
  virtual ~InvertedListScanner() {}
394
428
  };
395
429
 
430
+ // whether to check that coarse quantizers are the same
431
+ FAISS_API extern bool check_compatible_for_merge_expensive_check;
432
+
396
433
  struct IndexIVFStats {
397
434
  size_t nq; // nb of queries run
398
435
  size_t nlist; // nb of inverted lists scanned
399
- size_t ndis; // nb of distancs computed
436
+ size_t ndis; // nb of distances computed
400
437
  size_t nheap_updates; // nb of times the heap was updated
401
438
  double quantization_time; // time spent quantizing vectors (in ms)
402
439
  double search_time; // time spent searching lists (in ms)