faiss 0.5.3 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (167) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +8 -0
  3. data/ext/faiss/ext.cpp +1 -1
  4. data/ext/faiss/extconf.rb +5 -6
  5. data/ext/faiss/index_binary.cpp +38 -28
  6. data/ext/faiss/{index.cpp → index_rb.cpp} +64 -46
  7. data/ext/faiss/kmeans.cpp +10 -9
  8. data/ext/faiss/pca_matrix.cpp +10 -8
  9. data/ext/faiss/product_quantizer.cpp +14 -12
  10. data/ext/faiss/{utils.cpp → utils_rb.cpp} +5 -3
  11. data/ext/faiss/{utils.h → utils_rb.h} +4 -0
  12. data/lib/faiss/version.rb +1 -1
  13. data/lib/faiss.rb +1 -1
  14. data/vendor/faiss/faiss/AutoTune.cpp +130 -11
  15. data/vendor/faiss/faiss/AutoTune.h +14 -1
  16. data/vendor/faiss/faiss/Clustering.cpp +59 -10
  17. data/vendor/faiss/faiss/Clustering.h +12 -0
  18. data/vendor/faiss/faiss/IVFlib.cpp +31 -28
  19. data/vendor/faiss/faiss/Index.cpp +20 -8
  20. data/vendor/faiss/faiss/Index.h +25 -3
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +19 -24
  22. data/vendor/faiss/faiss/IndexBinary.cpp +1 -0
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +9 -4
  24. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +45 -11
  25. data/vendor/faiss/faiss/IndexFastScan.cpp +35 -22
  26. data/vendor/faiss/faiss/IndexFastScan.h +10 -1
  27. data/vendor/faiss/faiss/IndexFlat.cpp +193 -136
  28. data/vendor/faiss/faiss/IndexFlat.h +16 -1
  29. data/vendor/faiss/faiss/IndexFlatCodes.cpp +46 -22
  30. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
  31. data/vendor/faiss/faiss/IndexHNSW.cpp +24 -50
  32. data/vendor/faiss/faiss/IndexHNSW.h +14 -12
  33. data/vendor/faiss/faiss/IndexIDMap.cpp +1 -1
  34. data/vendor/faiss/faiss/IndexIVF.cpp +76 -49
  35. data/vendor/faiss/faiss/IndexIVF.h +14 -4
  36. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +11 -8
  37. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -2
  38. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +25 -14
  39. data/vendor/faiss/faiss/IndexIVFFastScan.h +26 -22
  40. data/vendor/faiss/faiss/IndexIVFFlat.cpp +10 -61
  41. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +39 -111
  42. data/vendor/faiss/faiss/IndexIVFPQ.cpp +89 -147
  43. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +37 -5
  44. data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -1
  45. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +42 -30
  46. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -2
  47. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +246 -97
  48. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +32 -29
  49. data/vendor/faiss/faiss/IndexLSH.cpp +8 -6
  50. data/vendor/faiss/faiss/IndexLattice.cpp +29 -24
  51. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -0
  52. data/vendor/faiss/faiss/IndexNSG.cpp +2 -1
  53. data/vendor/faiss/faiss/IndexNSG.h +0 -2
  54. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +1 -1
  55. data/vendor/faiss/faiss/IndexPQ.cpp +19 -10
  56. data/vendor/faiss/faiss/IndexRaBitQ.cpp +26 -13
  57. data/vendor/faiss/faiss/IndexRaBitQ.h +2 -2
  58. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +132 -78
  59. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +14 -12
  60. data/vendor/faiss/faiss/IndexRefine.cpp +0 -30
  61. data/vendor/faiss/faiss/IndexShards.cpp +3 -4
  62. data/vendor/faiss/faiss/MetricType.h +16 -0
  63. data/vendor/faiss/faiss/VectorTransform.cpp +120 -0
  64. data/vendor/faiss/faiss/VectorTransform.h +23 -0
  65. data/vendor/faiss/faiss/clone_index.cpp +7 -4
  66. data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +1 -1
  67. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  68. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +37 -11
  69. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -28
  70. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
  71. data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
  72. data/vendor/faiss/faiss/impl/CodePacker.cpp +4 -0
  73. data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
  74. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
  75. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
  76. data/vendor/faiss/faiss/impl/FaissAssert.h +60 -2
  77. data/vendor/faiss/faiss/impl/HNSW.cpp +25 -34
  78. data/vendor/faiss/faiss/impl/HNSW.h +8 -6
  79. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +34 -27
  80. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  81. data/vendor/faiss/faiss/impl/NSG.cpp +6 -5
  82. data/vendor/faiss/faiss/impl/NSG.h +17 -7
  83. data/vendor/faiss/faiss/impl/Panorama.cpp +53 -46
  84. data/vendor/faiss/faiss/impl/Panorama.h +22 -6
  85. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +16 -5
  86. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +70 -58
  87. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +92 -0
  88. data/vendor/faiss/faiss/impl/RaBitQUtils.h +93 -31
  89. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +12 -28
  90. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
  91. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
  92. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
  93. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +14 -9
  94. data/vendor/faiss/faiss/impl/ResultHandler.h +131 -50
  95. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +67 -2358
  96. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -2
  97. data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
  98. data/vendor/faiss/faiss/impl/VisitedTable.h +69 -0
  99. data/vendor/faiss/faiss/impl/expanded_scanners.h +158 -0
  100. data/vendor/faiss/faiss/impl/index_read.cpp +829 -471
  101. data/vendor/faiss/faiss/impl/index_read_utils.h +0 -1
  102. data/vendor/faiss/faiss/impl/index_write.cpp +17 -8
  103. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +47 -20
  104. data/vendor/faiss/faiss/impl/mapped_io.cpp +9 -2
  105. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +7 -2
  106. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +11 -3
  107. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +19 -13
  108. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +29 -21
  109. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.cpp} +42 -215
  110. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.cpp} +68 -107
  111. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +141 -0
  112. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +23 -0
  113. data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -144
  114. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +9 -6
  115. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
  116. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +136 -0
  117. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +280 -0
  118. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +164 -0
  119. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
  120. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +455 -0
  121. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +430 -0
  122. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +329 -0
  123. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +467 -0
  124. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +203 -0
  125. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +42 -0
  126. data/vendor/faiss/faiss/impl/simd_dispatch.h +139 -0
  127. data/vendor/faiss/faiss/impl/simd_result_handlers.h +18 -18
  128. data/vendor/faiss/faiss/index_factory.cpp +35 -16
  129. data/vendor/faiss/faiss/index_io.h +29 -3
  130. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +7 -4
  131. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
  132. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
  133. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
  134. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +2 -1
  135. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +9 -1
  136. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +9 -0
  137. data/vendor/faiss/faiss/utils/Heap.cpp +46 -0
  138. data/vendor/faiss/faiss/utils/Heap.h +21 -0
  139. data/vendor/faiss/faiss/utils/NeuralNet.cpp +10 -7
  140. data/vendor/faiss/faiss/utils/distances.cpp +141 -23
  141. data/vendor/faiss/faiss/utils/distances.h +98 -0
  142. data/vendor/faiss/faiss/utils/distances_dispatch.h +170 -0
  143. data/vendor/faiss/faiss/utils/distances_simd.cpp +74 -3511
  144. data/vendor/faiss/faiss/utils/extra_distances-inl.h +164 -157
  145. data/vendor/faiss/faiss/utils/extra_distances.cpp +52 -95
  146. data/vendor/faiss/faiss/utils/extra_distances.h +47 -1
  147. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -1
  148. data/vendor/faiss/faiss/utils/partitioning.cpp +1 -1
  149. data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
  150. data/vendor/faiss/faiss/utils/rabitq_simd.h +260 -0
  151. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +150 -0
  152. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +568 -0
  153. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +153 -0
  154. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1185 -0
  155. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1092 -0
  156. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +391 -0
  157. data/vendor/faiss/faiss/utils/simd_levels.cpp +322 -0
  158. data/vendor/faiss/faiss/utils/simd_levels.h +91 -0
  159. data/vendor/faiss/faiss/utils/simdlib_avx2.h +12 -1
  160. data/vendor/faiss/faiss/utils/simdlib_avx512.h +69 -0
  161. data/vendor/faiss/faiss/utils/simdlib_neon.h +6 -0
  162. data/vendor/faiss/faiss/utils/sorting.cpp +4 -4
  163. data/vendor/faiss/faiss/utils/utils.cpp +16 -9
  164. metadata +47 -18
  165. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
  166. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
  167. /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
data/lib/faiss/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Faiss
2
- VERSION = "0.5.3"
2
+ VERSION = "0.6.0"
3
3
  end
data/lib/faiss.rb CHANGED
@@ -1,5 +1,5 @@
1
1
  # dependencies
2
- require "numo/narray"
2
+ require "numo/narray/alt"
3
3
 
4
4
  # ext
5
5
  require "faiss/ext"
@@ -19,6 +19,8 @@
19
19
  #include <faiss/utils/random.h>
20
20
  #include <faiss/utils/utils.h>
21
21
 
22
+ #include <faiss/IndexBinaryHNSW.h>
23
+ #include <faiss/IndexBinaryIVF.h>
22
24
  #include <faiss/IndexHNSW.h>
23
25
  #include <faiss/IndexIDMap.h>
24
26
  #include <faiss/IndexIVF.h>
@@ -408,21 +410,24 @@ void ParameterSpace::initialize(const Index* index) {
408
410
 
409
411
  #undef DC
410
412
 
411
- /// set a combination of parameters on an index
412
- void ParameterSpace::set_index_parameters(Index* index, size_t cno) const {
413
- for (int i = 0; i < parameter_ranges.size(); i++) {
414
- const ParameterRange& pr = parameter_ranges[i];
413
+ template <typename SetParamFunc>
414
+ static void set_index_parameters_common(
415
+ const ParameterSpace* ps,
416
+ size_t cno,
417
+ SetParamFunc set_param) {
418
+ for (int i = 0; i < ps->parameter_ranges.size(); i++) {
419
+ const ParameterRange& pr = ps->parameter_ranges[i];
415
420
  size_t j = cno % pr.values.size();
416
421
  cno /= pr.values.size();
417
422
  double val = pr.values[j];
418
- set_index_parameter(index, pr.name, val);
423
+ set_param(pr.name, val);
419
424
  }
420
425
  }
421
426
 
422
- /// set a combination of parameters on an index
423
- void ParameterSpace::set_index_parameters(
424
- Index* index,
425
- const char* description_in) const {
427
+ template <typename SetParamFunc>
428
+ static void set_index_parameters_string_common(
429
+ const char* description_in,
430
+ SetParamFunc set_param) {
426
431
  std::string description(description_in);
427
432
  char* ptr;
428
433
 
@@ -433,10 +438,47 @@ void ParameterSpace::set_index_parameters(
433
438
  int ret = sscanf(tok, "%99[^=]=%lf", name, &val);
434
439
  FAISS_THROW_IF_NOT_FMT(
435
440
  ret == 2, "could not interpret parameters %s", tok);
436
- set_index_parameter(index, name, val);
441
+ set_param(name, val);
437
442
  }
438
443
  }
439
444
 
445
+ /// set a combination of parameters on an index
446
+ void ParameterSpace::set_index_parameters(Index* index, size_t cno) const {
447
+ set_index_parameters_common(
448
+ this, cno, [this, index](const std::string& name, double val) {
449
+ this->set_index_parameter(index, name, val);
450
+ });
451
+ }
452
+
453
+ /// set a combination of parameters on an index
454
+ void ParameterSpace::set_index_parameters(
455
+ Index* index,
456
+ const char* description_in) const {
457
+ set_index_parameters_string_common(
458
+ description_in, [this, index](const std::string& name, double val) {
459
+ this->set_index_parameter(index, name, val);
460
+ });
461
+ }
462
+
463
+ /// set a combination of parameters on a binary index
464
+ void ParameterSpace::set_index_parameters(IndexBinary* index, size_t cno)
465
+ const {
466
+ set_index_parameters_common(
467
+ this, cno, [this, index](const std::string& name, double val) {
468
+ this->set_index_parameter(index, name, val);
469
+ });
470
+ }
471
+
472
+ /// set a combination of parameters on a binary index
473
+ void ParameterSpace::set_index_parameters(
474
+ IndexBinary* index,
475
+ const char* description_in) const {
476
+ set_index_parameters_string_common(
477
+ description_in, [this, index](const std::string& name, double val) {
478
+ this->set_index_parameter(index, name, val);
479
+ });
480
+ }
481
+
440
482
  // non-const version
441
483
  // Do not use this macro if ix will be unused
442
484
  #define DC(classname) classname* ix = dynamic_cast<classname*>(index)
@@ -490,7 +532,6 @@ void ParameterSpace::set_index_parameter(
490
532
  }
491
533
 
492
534
  if (name == "verbose") {
493
- index->verbose = int(val);
494
535
  return; // last verbose that we could find
495
536
  }
496
537
 
@@ -573,6 +614,84 @@ void ParameterSpace::set_index_parameter(
573
614
  name.c_str());
574
615
  }
575
616
 
617
+ void ParameterSpace::set_index_parameter(
618
+ IndexBinary* index,
619
+ const std::string& name,
620
+ double val) const {
621
+ if (verbose > 1) {
622
+ printf(" set_index_parameter (binary) %s=%g\n", name.c_str(), val);
623
+ }
624
+
625
+ if (name == "verbose") {
626
+ index->verbose = int(val);
627
+ // and fall through to also enable it on sub-indexes
628
+ }
629
+
630
+ if (DC(IndexBinaryIDMap)) {
631
+ set_index_parameter(ix->index, name, val);
632
+ return;
633
+ }
634
+
635
+ if (name == "verbose") {
636
+ return; // last verbose that we could find
637
+ }
638
+
639
+ if (name == "nprobe") {
640
+ if (DC(IndexBinaryIVF)) {
641
+ ix->nprobe = int(val);
642
+ return;
643
+ }
644
+ }
645
+
646
+ if (name == "max_codes") {
647
+ if (DC(IndexBinaryIVF)) {
648
+ ix->max_codes = std::isfinite(val) ? size_t(val) : 0;
649
+ return;
650
+ }
651
+ }
652
+
653
+ if (name == "efConstruction") {
654
+ if (DC(IndexBinaryHNSW)) {
655
+ ix->hnsw.efConstruction = int(val);
656
+ return;
657
+ }
658
+ if (DC(IndexBinaryIVF)) {
659
+ if (IndexBinaryHNSW* cq =
660
+ dynamic_cast<IndexBinaryHNSW*>(ix->quantizer)) {
661
+ cq->hnsw.efConstruction = int(val);
662
+ return;
663
+ }
664
+ }
665
+ }
666
+
667
+ if (name == "efSearch") {
668
+ if (DC(IndexBinaryHNSW)) {
669
+ ix->hnsw.efSearch = int(val);
670
+ return;
671
+ }
672
+ if (DC(IndexBinaryIVF)) {
673
+ if (IndexBinaryHNSW* cq =
674
+ dynamic_cast<IndexBinaryHNSW*>(ix->quantizer)) {
675
+ cq->hnsw.efSearch = int(val);
676
+ return;
677
+ }
678
+ }
679
+ }
680
+
681
+ if (name.find("quantizer_") == 0) {
682
+ if (DC(IndexBinaryIVF)) {
683
+ std::string sub_name = name.substr(strlen("quantizer_"));
684
+ set_index_parameter(ix->quantizer, sub_name, val);
685
+ return;
686
+ }
687
+ }
688
+
689
+ FAISS_THROW_FMT(
690
+ "ParameterSpace::set_index_parameter:"
691
+ "could not set parameter %s on binary index",
692
+ name.c_str());
693
+ }
694
+
576
695
  #undef DC
577
696
 
578
697
  void ParameterSpace::display() const {
@@ -177,12 +177,25 @@ struct ParameterSpace {
177
177
  /// set a combination of parameters described by a string
178
178
  void set_index_parameters(Index* index, const char* param_string) const;
179
179
 
180
- /// set one of the parameters, returns whether setting was successful
180
+ /// set one of the parameters
181
181
  virtual void set_index_parameter(
182
182
  Index* index,
183
183
  const std::string& name,
184
184
  double val) const;
185
185
 
186
+ /// set a combination of parameters on a binary index
187
+ void set_index_parameters(IndexBinary* index, size_t cno) const;
188
+
189
+ /// set a combination of parameters described by a string on a binary index
190
+ void set_index_parameters(IndexBinary* index, const char* param_string)
191
+ const;
192
+
193
+ /// set one of the parameters on a binary index
194
+ virtual void set_index_parameter(
195
+ IndexBinary* index,
196
+ const std::string& name,
197
+ double val) const;
198
+
186
199
  /** find an upper bound on the performance and a lower bound on t
187
200
  * for configuration cno given another operating point op */
188
201
  void update_bounds(
@@ -407,19 +407,52 @@ void Clustering::train_encoded(
407
407
  printf("Outer iteration %d / %d\n", redo, nredo);
408
408
  }
409
409
 
410
- // initialize (remaining) centroids with random points from the dataset
410
+ // initialize centroids using the selected method
411
411
  centroids.resize(d * k);
412
- std::vector<int> perm(nx);
413
412
 
414
- rand_perm(perm.data(), nx, actual_seed + 1 + redo * 15486557L);
413
+ size_t k_to_init = k - n_input_centroids;
414
+ if (k_to_init > 0) {
415
+ // Fast path for RANDOM initialization - preserves exact original
416
+ // behavior
417
+ if (init_method == ClusteringInitMethod::RANDOM) {
418
+ std::vector<int> perm(nx);
419
+ rand_perm(perm.data(), nx, actual_seed + 1 + redo * 15486557L);
420
+ for (size_t i = 0; i < k_to_init; i++) {
421
+ if (!codec) {
422
+ memcpy(centroids.data() + (n_input_centroids + i) * d,
423
+ x + perm[n_input_centroids + i] * line_size,
424
+ line_size);
425
+ } else {
426
+ codec->sa_decode(
427
+ 1,
428
+ x + perm[n_input_centroids + i] * line_size,
429
+ centroids.data() + (n_input_centroids + i) * d);
430
+ }
431
+ }
432
+ } else {
433
+ // For k-means++ and AFK-MC², we need all vectors decoded
434
+ const float* x_float = nullptr;
435
+ std::vector<float> x_decoded;
415
436
 
416
- if (!codec) {
417
- for (int i = n_input_centroids; i < k; i++) {
418
- memcpy(&centroids[i * d], x + perm[i] * line_size, line_size);
419
- }
420
- } else {
421
- for (int i = n_input_centroids; i < k; i++) {
422
- codec->sa_decode(1, x + perm[i] * line_size, &centroids[i * d]);
437
+ if (!codec) {
438
+ x_float = reinterpret_cast<const float*>(x);
439
+ } else {
440
+ // Decode all vectors for initialization
441
+ x_decoded.resize(nx * d);
442
+ codec->sa_decode(nx, x, x_decoded.data());
443
+ x_float = x_decoded.data();
444
+ }
445
+
446
+ ClusteringInitialization initializer(d, k_to_init);
447
+ initializer.method = init_method;
448
+ initializer.seed = actual_seed + 1 + redo * 15486557L;
449
+ initializer.afkmc2_chain_length = afkmc2_chain_length;
450
+ initializer.init_centroids(
451
+ nx,
452
+ x_float,
453
+ centroids.data() + n_input_centroids * d,
454
+ n_input_centroids,
455
+ n_input_centroids > 0 ? centroids.data() : nullptr);
423
456
  }
424
457
  }
425
458
 
@@ -529,6 +562,22 @@ void Clustering::train_encoded(
529
562
 
530
563
  index.add(k, centroids.data());
531
564
  InterruptCallback::check();
565
+
566
+ // Early stopping: if objective didn't change, we've converged.
567
+ // Safe to access iteration_stats[size - 2] because we push_back
568
+ // above, so size >= i + 1, and when i > 0 we have size >= 2.
569
+ if (i > 0) {
570
+ float prev_obj =
571
+ iteration_stats[iteration_stats.size() - 2].obj;
572
+ if (obj == prev_obj) {
573
+ if (verbose) {
574
+ printf("\n Converged at iteration %d: "
575
+ "objective did not change\n",
576
+ i);
577
+ }
578
+ break;
579
+ }
580
+ }
532
581
  }
533
582
 
534
583
  if (verbose) {
@@ -10,6 +10,7 @@
10
10
  #ifndef FAISS_CLUSTERING_H
11
11
  #define FAISS_CLUSTERING_H
12
12
  #include <faiss/Index.h>
13
+ #include <faiss/impl/ClusteringInitialization.h>
13
14
 
14
15
  #include <vector>
15
16
 
@@ -57,6 +58,17 @@ struct ClusteringParameters {
57
58
  /// Whether to use splitmix64-based random number generator for subsampling,
58
59
  /// which is faster, but may pick duplicate points.
59
60
  bool use_faster_subsampling = false;
61
+
62
+ /// Initialization method for centroids.
63
+ /// RANDOM: uniform random sampling (default, current behavior)
64
+ /// KMEANS_PLUS_PLUS: k-means++ (O(nkd), better quality)
65
+ /// AFK_MC2: Assumption-Free K-MC² (O(nd) + O(mk²d), fast approximation)
66
+ ClusteringInitMethod init_method = ClusteringInitMethod::RANDOM;
67
+
68
+ /// Chain length for AFK-MC² initialization.
69
+ /// Only used when init_method = AFK_MC2.
70
+ /// Longer chains give better approximation but are slower.
71
+ uint16_t afkmc2_chain_length = 50;
60
72
  };
61
73
 
62
74
  struct ClusteringIterationStats {
@@ -18,6 +18,7 @@
18
18
  #include <faiss/MetaIndexes.h>
19
19
  #include <faiss/clone_index.h>
20
20
  #include <faiss/impl/FaissAssert.h>
21
+ #include <faiss/impl/simd_dispatch.h>
21
22
  #include <faiss/index_io.h>
22
23
  #include <faiss/utils/distances.h>
23
24
  #include <faiss/utils/hamming.h>
@@ -512,39 +513,41 @@ void ivf_residual_add_from_flat_codes(
512
513
  const ResidualQuantizer& rq = index->rq;
513
514
 
514
515
  // populate inverted lists
515
- #pragma omp parallel if (nb > 10000)
516
- {
517
- std::vector<uint8_t> tmp_code(index->code_size);
518
- std::vector<float> tmp(rq.d);
519
- int nt = omp_get_num_threads();
520
- int rank = omp_get_thread_num();
516
+ with_simd_level([&]<SIMDLevel SL>() {
517
+ #pragma omp parallel
518
+ {
519
+ std::vector<uint8_t> tmp_code(index->rq.code_size);
520
+ std::vector<float> tmp(rq.d);
521
+ int nt = omp_get_num_threads();
522
+ int rank = omp_get_thread_num();
521
523
 
522
524
  #pragma omp for
523
- for (idx_t i = 0; i < nb; i++) {
524
- const uint8_t* code = &raw_codes[i * code_size];
525
- BitstringReader rd(code, code_size);
526
- idx_t list_no = rd.read(rcq->rq.tot_bits);
527
-
528
- if (list_no % nt ==
529
- rank) { // each thread takes care of 1/nt of the invlists
530
- // copy AQ indexes one by one
531
- BitstringWriter wr(tmp_code.data(), tmp_code.size());
532
- for (int j = 0; j < rq.M; j++) {
533
- int nbit = rq.nbits[j];
534
- wr.write(rd.read(nbit), nbit);
525
+ for (idx_t i = 0; i < nb; i++) {
526
+ const uint8_t* code = &raw_codes[i * code_size];
527
+ BitstringReader rd(code, code_size);
528
+ idx_t list_no = rd.read(rcq->rq.tot_bits);
529
+
530
+ if (list_no % nt ==
531
+ rank) { // each thread takes care of 1/nt of the invlists
532
+ // copy AQ indexes one by one
533
+ BitstringWriter wr(tmp_code.data(), tmp_code.size());
534
+ for (int j = 0; j < rq.M; j++) {
535
+ int nbit = rq.nbits[j];
536
+ wr.write(rd.read(nbit), nbit);
537
+ }
538
+ // we need to recompute the norm
539
+ // decode first, does not use the norm component, so that's
540
+ // ok
541
+ index->rq.decode(tmp_code.data(), tmp.data(), 1);
542
+ float norm = fvec_norm_L2sqr<SL>(tmp.data(), rq.d);
543
+ wr.write(rq.encode_norm(norm), rq.norm_bits);
544
+
545
+ // add code to the inverted list
546
+ invlists.add_entry(list_no, i, tmp_code.data());
535
547
  }
536
- // we need to recompute the norm
537
- // decode first, does not use the norm component, so that's
538
- // ok
539
- index->rq.decode(tmp_code.data(), tmp.data(), 1);
540
- float norm = fvec_norm_L2sqr(tmp.data(), rq.d);
541
- wr.write(rq.encode_norm(norm), rq.norm_bits);
542
-
543
- // add code to the inverted list
544
- invlists.add_entry(list_no, i, tmp_code.data());
545
548
  }
546
549
  }
547
- }
550
+ });
548
551
  index->ntotal += nb;
549
552
  }
550
553
 
@@ -24,12 +24,20 @@ void Index::train(idx_t /*n*/, const float* /*x*/) {
24
24
  // does nothing by default
25
25
  }
26
26
 
27
+ void Index::train(
28
+ idx_t /*n*/,
29
+ const float* /*x*/,
30
+ idx_t /*n_train_q*/,
31
+ const float* /*xq_train*/) {
32
+ // does nothing by default
33
+ }
34
+
27
35
  void Index::range_search(
28
36
  idx_t,
29
37
  const float*,
30
38
  float,
31
39
  RangeSearchResult*,
32
- const SearchParameters* params) const {
40
+ const SearchParameters* /*params*/) const {
33
41
  FAISS_THROW_MSG("range search not implemented");
34
42
  }
35
43
 
@@ -105,16 +113,20 @@ void Index::search_and_reconstruct(
105
113
  }
106
114
 
107
115
  void Index::search_subset(
108
- idx_t n,
109
- const float* x,
110
- idx_t k_base,
111
- const idx_t* base_labels,
112
- idx_t k,
113
- float* distances,
114
- idx_t* labels) const {
116
+ idx_t /*n*/,
117
+ const float* /*x*/,
118
+ idx_t /*k_base*/,
119
+ const idx_t* /*base_labels*/,
120
+ idx_t /*k*/,
121
+ float* /*distances*/,
122
+ idx_t* /*labels*/) const {
115
123
  FAISS_THROW_MSG("search_subset not implemented for this type of index");
116
124
  }
117
125
 
126
+ void Index::search1(const float*, ResultHandler&, SearchParameters*) const {
127
+ FAISS_THROW_MSG("search1 not implemented for this type of index");
128
+ }
129
+
118
130
  void Index::compute_residual(const float* x, float* residual, idx_t key) const {
119
131
  reconstruct(key, residual);
120
132
  for (size_t i = 0; i < d; i++) {
@@ -14,11 +14,10 @@
14
14
  #include <faiss/impl/FaissAssert.h>
15
15
 
16
16
  #include <cstdio>
17
- #include <sstream>
18
17
 
19
18
  #define FAISS_VERSION_MAJOR 1
20
- #define FAISS_VERSION_MINOR 13
21
- #define FAISS_VERSION_PATCH 2
19
+ #define FAISS_VERSION_MINOR 14
20
+ #define FAISS_VERSION_PATCH 1
22
21
 
23
22
  // Macro to combine the version components into a single string
24
23
  #ifndef FAISS_STRINGIFY
@@ -55,6 +54,9 @@ namespace faiss {
55
54
  struct IDSelector;
56
55
  struct RangeSearchResult;
57
56
  struct DistanceComputer;
57
+ template <typename T, typename TI>
58
+ struct ResultHandlerUnordered;
59
+ using ResultHandler = ResultHandlerUnordered<float, idx_t>;
58
60
 
59
61
  enum NumericType {
60
62
  Float32,
@@ -129,6 +131,20 @@ struct Index {
129
131
  */
130
132
  virtual void train(idx_t n, const float* x);
131
133
 
134
+ /** Perfrom training on a representative set of vectors and a representative
135
+ * set of queries
136
+ *
137
+ * @param n nb of training vectors
138
+ * @param x training vectors, size n * d
139
+ * @param n_train_q nb of training queries
140
+ * @param xq_train training queries, size n_train_q * d
141
+ */
142
+ virtual void train(
143
+ idx_t n,
144
+ const float* x,
145
+ idx_t n_train_q,
146
+ const float* xq_train);
147
+
132
148
  virtual void train_ex(idx_t n, const void* x, NumericType numeric_type) {
133
149
  if (numeric_type == NumericType::Float32) {
134
150
  train(n, static_cast<const float*>(x));
@@ -216,6 +232,12 @@ struct Index {
216
232
  }
217
233
  }
218
234
 
235
+ /** search one vector with a custom result handler */
236
+ virtual void search1(
237
+ const float* x,
238
+ ResultHandler& handler,
239
+ SearchParameters* params = nullptr) const;
240
+
219
241
  /** query n vectors of dimension d to the index.
220
242
  *
221
243
  * return all vectors with distance < radius. Note that many
@@ -8,13 +8,11 @@
8
8
  #include <faiss/IndexAdditiveQuantizer.h>
9
9
 
10
10
  #include <algorithm>
11
- #include <cmath>
12
11
  #include <cstring>
13
12
 
14
13
  #include <faiss/impl/FaissAssert.h>
15
14
  #include <faiss/impl/ResidualQuantizer.h>
16
15
  #include <faiss/impl/ResultHandler.h>
17
- #include <faiss/utils/distances.h>
18
16
  #include <faiss/utils/extra_distances.h>
19
17
 
20
18
  namespace faiss {
@@ -189,17 +187,14 @@ void search_with_LUT(
189
187
  FlatCodesDistanceComputer* IndexAdditiveQuantizer::
190
188
  get_FlatCodesDistanceComputer() const {
191
189
  if (aq->search_type == AdditiveQuantizer::ST_decompress) {
192
- if (metric_type == METRIC_L2) {
193
- using VD = VectorDistance<METRIC_L2>;
194
- VD vd = {size_t(d), metric_arg};
195
- return new AQDistanceComputerDecompress<VD>(*this, vd);
196
- } else if (metric_type == METRIC_INNER_PRODUCT) {
197
- using VD = VectorDistance<METRIC_INNER_PRODUCT>;
198
- VD vd = {size_t(d), metric_arg};
199
- return new AQDistanceComputerDecompress<VD>(*this, vd);
200
- } else {
201
- FAISS_THROW_MSG("unsupported metric");
202
- }
190
+ return with_VectorDistance(
191
+ d,
192
+ metric_type,
193
+ metric_arg,
194
+ [&](auto vd) -> FlatCodesDistanceComputer* {
195
+ return new AQDistanceComputerDecompress<decltype(vd)>(
196
+ *this, vd);
197
+ });
203
198
  } else {
204
199
  if (metric_type == METRIC_INNER_PRODUCT) {
205
200
  return new AQDistanceComputerLUT<
@@ -242,17 +237,17 @@ void IndexAdditiveQuantizer::search(
242
237
  !params, "search params not supported for this index");
243
238
 
244
239
  if (aq->search_type == AdditiveQuantizer::ST_decompress) {
245
- if (metric_type == METRIC_L2) {
246
- using VD = VectorDistance<METRIC_L2>;
247
- VD vd = {size_t(d), metric_arg};
248
- HeapBlockResultHandler<VD::C> rh(n, distances, labels, k);
249
- search_with_decompress(*this, x, vd, rh);
250
- } else if (metric_type == METRIC_INNER_PRODUCT) {
251
- using VD = VectorDistance<METRIC_INNER_PRODUCT>;
252
- VD vd = {size_t(d), metric_arg};
253
- HeapBlockResultHandler<VD::C> rh(n, distances, labels, k);
254
- search_with_decompress(*this, x, vd, rh);
255
- }
240
+ with_VectorDistance(d, metric_type, metric_arg, [&](auto vd) {
241
+ if constexpr (decltype(vd)::is_similarity) {
242
+ HeapBlockResultHandler<CMin<float, idx_t>> rh(
243
+ n, distances, labels, k);
244
+ search_with_decompress(*this, x, vd, rh);
245
+ } else {
246
+ HeapBlockResultHandler<CMax<float, idx_t>> rh(
247
+ n, distances, labels, k);
248
+ search_with_decompress(*this, x, vd, rh);
249
+ }
250
+ });
256
251
  } else {
257
252
  if (metric_type == METRIC_INNER_PRODUCT) {
258
253
  HeapBlockResultHandler<CMin<float, idx_t>> rh(
@@ -12,6 +12,7 @@
12
12
 
13
13
  #include <cinttypes>
14
14
  #include <cstring>
15
+ #include <typeinfo>
15
16
 
16
17
  namespace faiss {
17
18
 
@@ -22,6 +22,7 @@
22
22
  #include <faiss/impl/DistanceComputer.h>
23
23
  #include <faiss/impl/FaissAssert.h>
24
24
  #include <faiss/impl/ResultHandler.h>
25
+ #include <faiss/impl/VisitedTable.h>
25
26
  #include <faiss/utils/Heap.h>
26
27
  #include <faiss/utils/hamming.h>
27
28
  #include <faiss/utils/random.h>
@@ -205,10 +206,14 @@ void IndexBinaryHNSW::search(
205
206
  idx_t k,
206
207
  int32_t* distances,
207
208
  idx_t* labels,
208
- const SearchParameters* params) const {
209
- FAISS_THROW_IF_NOT_MSG(
210
- !params, "search params not supported for this index");
209
+ const SearchParameters* params_in) const {
211
210
  FAISS_THROW_IF_NOT(k > 0);
211
+ const SearchParametersHNSW* params = nullptr;
212
+ if (params_in) {
213
+ params = dynamic_cast<const SearchParametersHNSW*>(params_in);
214
+ FAISS_THROW_IF_NOT_MSG(
215
+ params, "IndexBinaryHNSW params have incorrect type");
216
+ }
212
217
 
213
218
  // we use the buffer for distances as float but convert them back
214
219
  // to int in the end
@@ -231,7 +236,7 @@ void IndexBinaryHNSW::search(
231
236
  // as the index parameter. This state does not get used in the
232
237
  // search function, as it is merely there to to enable Panorama
233
238
  // execution for IndexHNSWFlatPanorama.
234
- hnsw.search(*dis, nullptr, res, vt);
239
+ hnsw.search(*dis, nullptr, res, vt, params_in);
235
240
  res.end();
236
241
  }
237
242
  }