faiss 0.5.2 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +5 -6
  6. data/ext/faiss/index_binary.cpp +76 -17
  7. data/ext/faiss/{index.cpp → index_rb.cpp} +108 -35
  8. data/ext/faiss/kmeans.cpp +12 -9
  9. data/ext/faiss/numo.hpp +11 -9
  10. data/ext/faiss/pca_matrix.cpp +10 -8
  11. data/ext/faiss/product_quantizer.cpp +14 -12
  12. data/ext/faiss/{utils.cpp → utils_rb.cpp} +10 -3
  13. data/ext/faiss/{utils.h → utils_rb.h} +6 -0
  14. data/lib/faiss/version.rb +1 -1
  15. data/lib/faiss.rb +1 -1
  16. data/vendor/faiss/faiss/AutoTune.cpp +130 -11
  17. data/vendor/faiss/faiss/AutoTune.h +14 -1
  18. data/vendor/faiss/faiss/Clustering.cpp +59 -10
  19. data/vendor/faiss/faiss/Clustering.h +12 -0
  20. data/vendor/faiss/faiss/IVFlib.cpp +31 -28
  21. data/vendor/faiss/faiss/Index.cpp +20 -8
  22. data/vendor/faiss/faiss/Index.h +25 -3
  23. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +19 -24
  24. data/vendor/faiss/faiss/IndexBinary.cpp +1 -0
  25. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +9 -4
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +45 -11
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +35 -22
  28. data/vendor/faiss/faiss/IndexFastScan.h +10 -1
  29. data/vendor/faiss/faiss/IndexFlat.cpp +193 -136
  30. data/vendor/faiss/faiss/IndexFlat.h +16 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +46 -22
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +24 -50
  34. data/vendor/faiss/faiss/IndexHNSW.h +14 -12
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +1 -1
  36. data/vendor/faiss/faiss/IndexIVF.cpp +76 -49
  37. data/vendor/faiss/faiss/IndexIVF.h +14 -4
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +11 -8
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -2
  40. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +25 -14
  41. data/vendor/faiss/faiss/IndexIVFFastScan.h +26 -22
  42. data/vendor/faiss/faiss/IndexIVFFlat.cpp +10 -61
  43. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +39 -111
  44. data/vendor/faiss/faiss/IndexIVFPQ.cpp +89 -147
  45. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +37 -5
  46. data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -1
  47. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +42 -30
  48. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -2
  49. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +246 -97
  50. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +32 -29
  51. data/vendor/faiss/faiss/IndexLSH.cpp +8 -6
  52. data/vendor/faiss/faiss/IndexLattice.cpp +29 -24
  53. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -0
  54. data/vendor/faiss/faiss/IndexNSG.cpp +2 -1
  55. data/vendor/faiss/faiss/IndexNSG.h +0 -2
  56. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +1 -1
  57. data/vendor/faiss/faiss/IndexPQ.cpp +19 -10
  58. data/vendor/faiss/faiss/IndexRaBitQ.cpp +26 -13
  59. data/vendor/faiss/faiss/IndexRaBitQ.h +2 -2
  60. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +132 -78
  61. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +14 -12
  62. data/vendor/faiss/faiss/IndexRefine.cpp +0 -30
  63. data/vendor/faiss/faiss/IndexShards.cpp +3 -4
  64. data/vendor/faiss/faiss/MetricType.h +16 -0
  65. data/vendor/faiss/faiss/VectorTransform.cpp +120 -0
  66. data/vendor/faiss/faiss/VectorTransform.h +23 -0
  67. data/vendor/faiss/faiss/clone_index.cpp +7 -4
  68. data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +1 -1
  69. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  70. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +37 -11
  71. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -28
  72. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
  73. data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
  74. data/vendor/faiss/faiss/impl/CodePacker.cpp +4 -0
  75. data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
  76. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
  77. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
  78. data/vendor/faiss/faiss/impl/FaissAssert.h +60 -2
  79. data/vendor/faiss/faiss/impl/HNSW.cpp +25 -34
  80. data/vendor/faiss/faiss/impl/HNSW.h +8 -6
  81. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +34 -27
  82. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  83. data/vendor/faiss/faiss/impl/NSG.cpp +6 -5
  84. data/vendor/faiss/faiss/impl/NSG.h +17 -7
  85. data/vendor/faiss/faiss/impl/Panorama.cpp +53 -46
  86. data/vendor/faiss/faiss/impl/Panorama.h +22 -6
  87. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +16 -5
  88. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +70 -58
  89. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +92 -0
  90. data/vendor/faiss/faiss/impl/RaBitQUtils.h +93 -31
  91. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +12 -28
  92. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
  93. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
  94. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
  95. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +14 -9
  96. data/vendor/faiss/faiss/impl/ResultHandler.h +131 -50
  97. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +67 -2358
  98. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -2
  99. data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
  100. data/vendor/faiss/faiss/impl/VisitedTable.h +69 -0
  101. data/vendor/faiss/faiss/impl/expanded_scanners.h +158 -0
  102. data/vendor/faiss/faiss/impl/index_read.cpp +829 -471
  103. data/vendor/faiss/faiss/impl/index_read_utils.h +0 -1
  104. data/vendor/faiss/faiss/impl/index_write.cpp +17 -8
  105. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +47 -20
  106. data/vendor/faiss/faiss/impl/mapped_io.cpp +9 -2
  107. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +7 -2
  108. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +11 -3
  109. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +19 -13
  110. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +29 -21
  111. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.cpp} +42 -215
  112. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.cpp} +68 -107
  113. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +141 -0
  114. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +23 -0
  115. data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -144
  116. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +9 -6
  117. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
  118. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +136 -0
  119. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +280 -0
  120. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +164 -0
  121. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
  122. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +455 -0
  123. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +430 -0
  124. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +329 -0
  125. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +467 -0
  126. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +203 -0
  127. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +42 -0
  128. data/vendor/faiss/faiss/impl/simd_dispatch.h +139 -0
  129. data/vendor/faiss/faiss/impl/simd_result_handlers.h +18 -18
  130. data/vendor/faiss/faiss/index_factory.cpp +35 -16
  131. data/vendor/faiss/faiss/index_io.h +29 -3
  132. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +7 -4
  133. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
  134. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
  135. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
  136. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +2 -1
  137. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +9 -1
  138. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +9 -0
  139. data/vendor/faiss/faiss/utils/Heap.cpp +46 -0
  140. data/vendor/faiss/faiss/utils/Heap.h +21 -0
  141. data/vendor/faiss/faiss/utils/NeuralNet.cpp +10 -7
  142. data/vendor/faiss/faiss/utils/distances.cpp +141 -23
  143. data/vendor/faiss/faiss/utils/distances.h +98 -0
  144. data/vendor/faiss/faiss/utils/distances_dispatch.h +170 -0
  145. data/vendor/faiss/faiss/utils/distances_simd.cpp +74 -3511
  146. data/vendor/faiss/faiss/utils/extra_distances-inl.h +164 -157
  147. data/vendor/faiss/faiss/utils/extra_distances.cpp +52 -95
  148. data/vendor/faiss/faiss/utils/extra_distances.h +47 -1
  149. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -1
  150. data/vendor/faiss/faiss/utils/partitioning.cpp +1 -1
  151. data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
  152. data/vendor/faiss/faiss/utils/rabitq_simd.h +260 -0
  153. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +150 -0
  154. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +568 -0
  155. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +153 -0
  156. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1185 -0
  157. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1092 -0
  158. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +391 -0
  159. data/vendor/faiss/faiss/utils/simd_levels.cpp +322 -0
  160. data/vendor/faiss/faiss/utils/simd_levels.h +91 -0
  161. data/vendor/faiss/faiss/utils/simdlib_avx2.h +12 -1
  162. data/vendor/faiss/faiss/utils/simdlib_avx512.h +69 -0
  163. data/vendor/faiss/faiss/utils/simdlib_neon.h +6 -0
  164. data/vendor/faiss/faiss/utils/sorting.cpp +4 -4
  165. data/vendor/faiss/faiss/utils/utils.cpp +16 -9
  166. metadata +47 -18
  167. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
  168. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
  169. /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
@@ -18,21 +18,20 @@
18
18
  #include <algorithm>
19
19
 
20
20
  #include <faiss/utils/Heap.h>
21
- #include <faiss/utils/distances.h>
21
+ #include <faiss/utils/distances_dispatch.h>
22
22
  #include <faiss/utils/utils.h>
23
23
 
24
24
  #include <faiss/Clustering.h>
25
25
 
26
26
  #include <faiss/utils/hamming.h>
27
27
 
28
- #include <faiss/impl/FaissAssert.h>
29
-
30
28
  #include <faiss/impl/AuxIndexStructures.h>
29
+ #include <faiss/impl/FaissAssert.h>
31
30
  #include <faiss/impl/IDSelector.h>
32
-
33
31
  #include <faiss/impl/ProductQuantizer.h>
34
-
35
- #include <faiss/impl/code_distance/code_distance.h>
32
+ #include <faiss/impl/ResultHandler.h>
33
+ #include <faiss/impl/pq_code_distance/pq_code_distance-inl.h>
34
+ #include <faiss/impl/simd_dispatch.h>
36
35
 
37
36
  namespace faiss {
38
37
 
@@ -427,7 +426,7 @@ void initialize_IVFPQ_precomputed_table(
427
426
  for (int m = 0; m < pq.M; m++)
428
427
  for (int j = 0; j < pq.ksub; j++)
429
428
  r_norms[m * pq.ksub + j] =
430
- fvec_norm_L2sqr(pq.get_centroids(m, j), pq.dsub);
429
+ fvec_norm_L2sqr_dispatch(pq.get_centroids(m, j), pq.dsub);
431
430
 
432
431
  if (use_precomputed_table == 1) {
433
432
  precomputed_table.resize(nlist * pq.M * pq.ksub);
@@ -438,7 +437,7 @@ void initialize_IVFPQ_precomputed_table(
438
437
 
439
438
  float* tab = &precomputed_table[i * pq.M * pq.ksub];
440
439
  pq.compute_inner_prod_table(centroid.data(), tab);
441
- fvec_madd(pq.M * pq.ksub, r_norms.data(), 2.0, tab, tab);
440
+ fvec_madd_dispatch(pq.M * pq.ksub, r_norms.data(), 2.0, tab, tab);
442
441
  }
443
442
  } else if (use_precomputed_table == 2) {
444
443
  const MultiIndexQuantizer* miq =
@@ -465,7 +464,7 @@ void initialize_IVFPQ_precomputed_table(
465
464
 
466
465
  for (size_t i = 0; i < cpq.ksub; i++) {
467
466
  float* tab = &precomputed_table[i * pq.M * pq.ksub];
468
- fvec_madd(pq.M * pq.ksub, r_norms.data(), 2.0, tab, tab);
467
+ fvec_madd_dispatch(pq.M * pq.ksub, r_norms.data(), 2.0, tab, tab);
469
468
  }
470
469
  }
471
470
  }
@@ -628,7 +627,7 @@ struct QueryTables {
628
627
  // and dis0, the initial value
629
628
  ivfpq.quantizer->reconstruct(key, decoded_vec);
630
629
  // decoded_vec = centroid
631
- float dis0 = fvec_inner_product(qi, decoded_vec, d);
630
+ float dis0 = fvec_inner_product_dispatch(qi, decoded_vec, d);
632
631
 
633
632
  if (polysemous_ht) {
634
633
  for (int i = 0; i < d; i++) {
@@ -657,7 +656,7 @@ struct QueryTables {
657
656
  } else if (use_precomputed_table == 1) {
658
657
  dis0 = coarse_dis;
659
658
 
660
- fvec_madd(
659
+ fvec_madd_dispatch(
661
660
  pq.M * pq.ksub,
662
661
  ivfpq.precomputed_table.data() + key * pq.ksub * pq.M,
663
662
  -2.0,
@@ -693,12 +692,12 @@ struct QueryTables {
693
692
 
694
693
  if (polysemous_ht == 0) {
695
694
  // sum up with query-specific table
696
- fvec_madd(Mf * pq.ksub, pc, -2.0, qtab, ltab);
695
+ fvec_madd_dispatch(Mf * pq.ksub, pc, -2.0, qtab, ltab);
697
696
  ltab += Mf * pq.ksub;
698
697
  qtab += Mf * pq.ksub;
699
698
  } else {
700
699
  for (int m = cm * Mf; m < (cm + 1) * Mf; m++) {
701
- q_code[m] = fvec_madd_and_argmin(
700
+ q_code[m] = fvec_madd_and_argmin_dispatch(
702
701
  pq.ksub, pc, -2, qtab, ltab);
703
702
  pc += pq.ksub;
704
703
  ltab += pq.ksub;
@@ -762,52 +761,31 @@ struct QueryTables {
762
761
  }
763
762
  };
764
763
 
765
- // This way of handling the selector is not optimal since all distances
766
- // are computed even if the id would filter it out.
767
764
  template <class C, bool use_sel>
768
- struct KnnSearchResults {
769
- idx_t key;
770
- const idx_t* ids;
771
- const IDSelector* sel;
772
-
773
- // heap params
774
- size_t k;
775
- float* heap_sim;
776
- idx_t* heap_ids;
777
-
778
- size_t nup;
779
-
780
- inline bool skip_entry(idx_t j) {
781
- return use_sel && !sel->is_member(ids[j]);
782
- }
783
-
784
- inline void add(idx_t j, float dis) {
785
- if (C::cmp(heap_sim[0], dis)) {
786
- idx_t id = ids ? ids[j] : lo_build(key, j);
787
- heap_replace_top<C>(k, heap_sim, heap_ids, dis, id);
788
- nup++;
789
- }
790
- }
791
- };
765
+ struct WrappedSearchResult {
766
+ ResultHandler& res;
767
+ size_t nup = 0;
768
+ idx_t list_no;
792
769
 
793
- template <class C, bool use_sel>
794
- struct RangeSearchResults {
795
- idx_t key;
796
770
  const idx_t* ids;
797
771
  const IDSelector* sel;
798
772
 
799
- // wrapped result structure
800
- float radius;
801
- RangeQueryResult& rres;
773
+ WrappedSearchResult(
774
+ idx_t list_no,
775
+ const idx_t* ids,
776
+ const IDSelector* sel,
777
+ ResultHandler& res)
778
+ : res(res), list_no(list_no), ids(ids), sel(sel) {}
802
779
 
803
780
  inline bool skip_entry(idx_t j) {
804
781
  return use_sel && !sel->is_member(ids[j]);
805
782
  }
806
783
 
807
784
  inline void add(idx_t j, float dis) {
808
- if (C::cmp(radius, dis)) {
809
- idx_t id = ids ? ids[j] : lo_build(key, j);
810
- rres.add(dis, id);
785
+ if (C::cmp(res.threshold, dis)) {
786
+ idx_t id = ids ? ids[j] : lo_build(this->list_no, j);
787
+ res.add_result(dis, id);
788
+ nup++;
811
789
  }
812
790
  }
813
791
  };
@@ -817,8 +795,9 @@ struct RangeSearchResults {
817
795
  * The scanning functions call their favorite precompute_*
818
796
  * function to precompute the tables they need.
819
797
  *****************************************************/
820
- template <typename IDType, MetricType METRIC_TYPE, class PQDecoder>
798
+ template <typename IDType, MetricType METRIC_TYPE, class PQCodeDist>
821
799
  struct IVFPQScannerT : QueryTables {
800
+ using PQDecoder = typename PQCodeDist::PQDecoder;
822
801
  const uint8_t* list_codes;
823
802
  const IDType* list_ids;
824
803
  size_t list_size;
@@ -859,7 +838,7 @@ struct IVFPQScannerT : QueryTables {
859
838
  // if (res.skip_entry(j)) {
860
839
  // continue;
861
840
  // }
862
- // float dis = dis0 + distance_single_code<PQDecoder>(
841
+ // float dis = dis0 + PQCodeDist::distance_single_code(
863
842
  // pq, sim_table, codes);
864
843
  // res.add(j, dis);
865
844
  // }
@@ -894,7 +873,7 @@ struct IVFPQScannerT : QueryTables {
894
873
  float distance_1 = 0;
895
874
  float distance_2 = 0;
896
875
  float distance_3 = 0;
897
- distance_four_codes<PQDecoder>(
876
+ PQCodeDist::distance_four_codes(
898
877
  pq.M,
899
878
  pq.nbits,
900
879
  sim_table,
@@ -917,7 +896,7 @@ struct IVFPQScannerT : QueryTables {
917
896
 
918
897
  if (counter >= 1) {
919
898
  float dis = dis0 +
920
- distance_single_code<PQDecoder>(
899
+ PQCodeDist::distance_single_code(
921
900
  pq.M,
922
901
  pq.nbits,
923
902
  sim_table,
@@ -926,7 +905,7 @@ struct IVFPQScannerT : QueryTables {
926
905
  }
927
906
  if (counter >= 2) {
928
907
  float dis = dis0 +
929
- distance_single_code<PQDecoder>(
908
+ PQCodeDist::distance_single_code(
930
909
  pq.M,
931
910
  pq.nbits,
932
911
  sim_table,
@@ -935,7 +914,7 @@ struct IVFPQScannerT : QueryTables {
935
914
  }
936
915
  if (counter >= 3) {
937
916
  float dis = dis0 +
938
- distance_single_code<PQDecoder>(
917
+ PQCodeDist::distance_single_code(
939
918
  pq.M,
940
919
  pq.nbits,
941
920
  sim_table,
@@ -979,7 +958,7 @@ struct IVFPQScannerT : QueryTables {
979
958
  if (by_residual) {
980
959
  if (METRIC_TYPE == METRIC_INNER_PRODUCT) {
981
960
  ivfpq.quantizer->reconstruct(key, residual_vec);
982
- dis0 = fvec_inner_product(residual_vec, qi, d);
961
+ dis0 = fvec_inner_product_dispatch(residual_vec, qi, d);
983
962
  } else {
984
963
  ivfpq.quantizer->compute_residual(qi, residual_vec, key);
985
964
  }
@@ -997,9 +976,9 @@ struct IVFPQScannerT : QueryTables {
997
976
 
998
977
  float dis;
999
978
  if (METRIC_TYPE == METRIC_INNER_PRODUCT) {
1000
- dis = dis0 + fvec_inner_product(decoded_vec, qi, d);
979
+ dis = dis0 + fvec_inner_product_dispatch(decoded_vec, qi, d);
1001
980
  } else {
1002
- dis = fvec_L2sqr(decoded_vec, dvec, d);
981
+ dis = fvec_L2sqr_dispatch(decoded_vec, dvec, d);
1003
982
  }
1004
983
  res.add(j, dis);
1005
984
  }
@@ -1035,7 +1014,7 @@ struct IVFPQScannerT : QueryTables {
1035
1014
  //
1036
1015
  // float dis =
1037
1016
  // dis0 +
1038
- // distance_single_code<PQDecoder>(
1017
+ // PQCodeDist::distance_single_code(
1039
1018
  // pq, sim_table, codes);
1040
1019
  //
1041
1020
  // res.add(j, dis);
@@ -1101,7 +1080,7 @@ struct IVFPQScannerT : QueryTables {
1101
1080
  float distance_1 = dis0;
1102
1081
  float distance_2 = dis0;
1103
1082
  float distance_3 = dis0;
1104
- distance_four_codes<PQDecoder>(
1083
+ PQCodeDist::distance_four_codes(
1105
1084
  pq.M,
1106
1085
  pq.nbits,
1107
1086
  sim_table,
@@ -1132,7 +1111,7 @@ struct IVFPQScannerT : QueryTables {
1132
1111
  n_hamming_pass++;
1133
1112
 
1134
1113
  float dis = dis0 +
1135
- distance_single_code<PQDecoder>(
1114
+ PQCodeDist::distance_single_code(
1136
1115
  pq.M,
1137
1116
  pq.nbits,
1138
1117
  sim_table,
@@ -1152,7 +1131,7 @@ struct IVFPQScannerT : QueryTables {
1152
1131
  n_hamming_pass++;
1153
1132
 
1154
1133
  float dis = dis0 +
1155
- distance_single_code<PQDecoder>(
1134
+ PQCodeDist::distance_single_code(
1156
1135
  pq.M,
1157
1136
  pq.nbits,
1158
1137
  sim_table,
@@ -1199,8 +1178,8 @@ struct IVFPQScannerT : QueryTables {
1199
1178
  *
1200
1179
  * use_sel: store or ignore the IDSelector
1201
1180
  */
1202
- template <MetricType METRIC_TYPE, class C, class PQDecoder, bool use_sel>
1203
- struct IVFPQScanner : IVFPQScannerT<idx_t, METRIC_TYPE, PQDecoder>,
1181
+ template <MetricType METRIC_TYPE, class C, class PQCodeDist, bool use_sel>
1182
+ struct IVFPQScanner : IVFPQScannerT<idx_t, METRIC_TYPE, PQCodeDist>,
1204
1183
  InvertedListScanner {
1205
1184
  int precompute_mode;
1206
1185
  const IDSelector* sel;
@@ -1210,7 +1189,7 @@ struct IVFPQScanner : IVFPQScannerT<idx_t, METRIC_TYPE, PQDecoder>,
1210
1189
  bool store_pairs,
1211
1190
  int precompute_mode,
1212
1191
  const IDSelector* sel)
1213
- : IVFPQScannerT<idx_t, METRIC_TYPE, PQDecoder>(ivfpq, nullptr),
1192
+ : IVFPQScannerT<idx_t, METRIC_TYPE, PQCodeDist>(ivfpq, nullptr),
1214
1193
  precompute_mode(precompute_mode),
1215
1194
  sel(sel) {
1216
1195
  this->store_pairs = store_pairs;
@@ -1230,7 +1209,7 @@ struct IVFPQScanner : IVFPQScannerT<idx_t, METRIC_TYPE, PQDecoder>,
1230
1209
  float distance_to_code(const uint8_t* code) const override {
1231
1210
  assert(precompute_mode == 2);
1232
1211
  float dis = this->dis0 +
1233
- distance_single_code<PQDecoder>(
1212
+ PQCodeDist::distance_single_code(
1234
1213
  this->pq.M, this->pq.nbits, this->sim_table, code);
1235
1214
  return dis;
1236
1215
  }
@@ -1239,17 +1218,12 @@ struct IVFPQScanner : IVFPQScannerT<idx_t, METRIC_TYPE, PQDecoder>,
1239
1218
  size_t ncode,
1240
1219
  const uint8_t* codes,
1241
1220
  const idx_t* ids,
1242
- float* heap_sim,
1243
- idx_t* heap_ids,
1244
- size_t k) const override {
1245
- KnnSearchResults<C, use_sel> res = {
1246
- /* key */ this->key,
1247
- /* ids */ this->store_pairs ? nullptr : ids,
1248
- /* sel */ this->sel,
1249
- /* k */ k,
1250
- /* heap_sim */ heap_sim,
1251
- /* heap_ids */ heap_ids,
1252
- /* nup */ 0};
1221
+ ResultHandler& handler) const override {
1222
+ WrappedSearchResult<C, use_sel> res(
1223
+ this->key,
1224
+ this->store_pairs ? nullptr : ids,
1225
+ this->sel,
1226
+ handler);
1253
1227
 
1254
1228
  if (this->polysemous_ht > 0) {
1255
1229
  assert(precompute_mode == 2);
@@ -1265,85 +1239,53 @@ struct IVFPQScanner : IVFPQScannerT<idx_t, METRIC_TYPE, PQDecoder>,
1265
1239
  }
1266
1240
  return res.nup;
1267
1241
  }
1268
-
1269
- void scan_codes_range(
1270
- size_t ncode,
1271
- const uint8_t* codes,
1272
- const idx_t* ids,
1273
- float radius,
1274
- RangeQueryResult& rres) const override {
1275
- RangeSearchResults<C, use_sel> res = {
1276
- /* key */ this->key,
1277
- /* ids */ this->store_pairs ? nullptr : ids,
1278
- /* sel */ this->sel,
1279
- /* radius */ radius,
1280
- /* rres */ rres};
1281
-
1282
- if (this->polysemous_ht > 0) {
1283
- assert(precompute_mode == 2);
1284
- this->scan_list_polysemous(ncode, codes, res);
1285
- } else if (precompute_mode == 2) {
1286
- this->scan_list_with_table(ncode, codes, res);
1287
- } else if (precompute_mode == 1) {
1288
- this->scan_list_with_pointer(ncode, codes, res);
1289
- } else if (precompute_mode == 0) {
1290
- this->scan_on_the_fly_dist(ncode, codes, res);
1291
- } else {
1292
- FAISS_THROW_MSG("bad precomp mode");
1293
- }
1294
- }
1295
1242
  };
1296
1243
 
1297
- template <class PQDecoder, bool use_sel>
1298
- InvertedListScanner* get_InvertedListScanner1(
1299
- const IndexIVFPQ& index,
1300
- bool store_pairs,
1301
- const IDSelector* sel) {
1302
- if (index.metric_type == METRIC_INNER_PRODUCT) {
1303
- return new IVFPQScanner<
1304
- METRIC_INNER_PRODUCT,
1305
- CMin<float, idx_t>,
1306
- PQDecoder,
1307
- use_sel>(index, store_pairs, 2, sel);
1308
- } else if (index.metric_type == METRIC_L2) {
1309
- return new IVFPQScanner<
1310
- METRIC_L2,
1311
- CMax<float, idx_t>,
1312
- PQDecoder,
1313
- use_sel>(index, store_pairs, 2, sel);
1314
- }
1315
- return nullptr;
1316
- }
1317
-
1318
- template <bool use_sel>
1319
- InvertedListScanner* get_InvertedListScanner2(
1320
- const IndexIVFPQ& index,
1321
- bool store_pairs,
1322
- const IDSelector* sel) {
1323
- if (index.pq.nbits == 8) {
1324
- return get_InvertedListScanner1<PQDecoder8, use_sel>(
1325
- index, store_pairs, sel);
1326
- } else if (index.pq.nbits == 16) {
1327
- return get_InvertedListScanner1<PQDecoder16, use_sel>(
1328
- index, store_pairs, sel);
1329
- } else {
1330
- return get_InvertedListScanner1<PQDecoderGeneric, use_sel>(
1331
- index, store_pairs, sel);
1332
- }
1333
- }
1334
-
1335
1244
  } // anonymous namespace
1336
1245
 
1337
1246
  InvertedListScanner* IndexIVFPQ::get_InvertedListScanner(
1338
1247
  bool store_pairs,
1339
1248
  const IDSelector* sel,
1340
1249
  const IVFSearchParameters*) const {
1341
- if (sel) {
1342
- return get_InvertedListScanner2<true>(*this, store_pairs, sel);
1343
- } else {
1344
- return get_InvertedListScanner2<false>(*this, store_pairs, sel);
1345
- }
1346
- return nullptr;
1250
+ return with_simd_level([&]<SIMDLevel SL>() -> InvertedListScanner* {
1251
+ auto make =
1252
+ [&]<class PQCodeDist, bool use_sel>() -> InvertedListScanner* {
1253
+ if (metric_type == METRIC_INNER_PRODUCT) {
1254
+ return new IVFPQScanner<
1255
+ METRIC_INNER_PRODUCT,
1256
+ CMin<float, idx_t>,
1257
+ PQCodeDist,
1258
+ use_sel>(*this, store_pairs, 2, sel);
1259
+ } else if (metric_type == METRIC_L2) {
1260
+ return new IVFPQScanner<
1261
+ METRIC_L2,
1262
+ CMax<float, idx_t>,
1263
+ PQCodeDist,
1264
+ use_sel>(*this, store_pairs, 2, sel);
1265
+ } else {
1266
+ FAISS_THROW_MSG("unsupported metric type");
1267
+ }
1268
+ };
1269
+
1270
+ auto with_decoder = [&]<bool use_sel>() -> InvertedListScanner* {
1271
+ if (pq.nbits == 8) {
1272
+ return make.template
1273
+ operator()<PQCodeDistance<PQDecoder8, SL>, use_sel>();
1274
+ } else if (pq.nbits == 16) {
1275
+ return make.template
1276
+ operator()<PQCodeDistance<PQDecoder16, SL>, use_sel>();
1277
+ } else {
1278
+ return make.template
1279
+ operator()<PQCodeDistance<PQDecoderGeneric, SL>, use_sel>();
1280
+ }
1281
+ };
1282
+
1283
+ if (sel) {
1284
+ return with_decoder.template operator()<true>();
1285
+ } else {
1286
+ return with_decoder.template operator()<false>();
1287
+ }
1288
+ });
1347
1289
  }
1348
1290
 
1349
1291
  IndexIVFPQStats indexIVFPQ_stats;
@@ -17,6 +17,7 @@
17
17
  #include <faiss/impl/FaissAssert.h>
18
18
  #include <faiss/utils/Heap.h>
19
19
  #include <faiss/utils/distances.h>
20
+ #include <faiss/utils/extra_distances.h>
20
21
  #include <faiss/utils/simdlib.h>
21
22
 
22
23
  #include <faiss/invlists/BlockInvertedLists.h>
@@ -307,6 +308,7 @@ struct IVFPQFastScanScanner : InvertedListScanner {
307
308
  const IndexIVFPQFastScan& index;
308
309
  AlignedTable<uint8_t> dis_tables;
309
310
  AlignedTable<uint16_t> biases;
311
+ std::vector<float> residual;
310
312
  std::array<float, 2> normalizers{};
311
313
  const float* xi = nullptr;
312
314
 
@@ -316,6 +318,7 @@ struct IVFPQFastScanScanner : InvertedListScanner {
316
318
  const IDSelector* sel)
317
319
  : InvertedListScanner(store_pairs, sel), index(index) {
318
320
  this->keep_max = is_similarity_metric(index.metric_type);
321
+ residual.resize(index.d);
319
322
  }
320
323
 
321
324
  void set_query(const float* query) override {
@@ -332,12 +335,40 @@ struct IVFPQFastScanScanner : InvertedListScanner {
332
335
  FastScanDistancePostProcessing empty_context{};
333
336
  index.compute_LUT_uint8(
334
337
  1, xi, cq, dis_tables, biases, &normalizers[0], empty_context);
338
+ // used in distance_to_code
339
+ index.quantizer->compute_residual(
340
+ this->xi, residual.data(), this->list_no);
335
341
  }
336
342
 
337
- float distance_to_code(const uint8_t* /* code */) const override {
338
- // It's not really possible to implement a distance_to_code since codes
339
- // for 32 database vectors are intermixed.
340
- FAISS_THROW_MSG("not implemented");
343
+ float distance_to_code(const uint8_t* code) const override {
344
+ // directly use the PQ tables to compute the distance
345
+ const ProductQuantizer& pq = index.pq;
346
+ // when by_residual, codes are residuals so compare against query
347
+ // residual; otherwise codes are raw vectors so compare against raw
348
+ // query
349
+ const float* x = index.by_residual ? residual.data() : this->xi;
350
+ float accu = 0;
351
+ // implemented for all vector distances, although only L2 and IP are
352
+ // suppored by FastScan
353
+ with_VectorDistance(pq.dsub, index.metric_type, 0.0, [&](auto vd) {
354
+ int m;
355
+ for (m = 0; m + 1 < pq.M; m += 2) {
356
+ const float* cent;
357
+ uint8_t c = *code++;
358
+ cent = pq.get_centroids(m, c & 15);
359
+ accu += vd(cent, x);
360
+ x += pq.dsub;
361
+ cent = pq.get_centroids(m + 1, c >> 4);
362
+ accu += vd(cent, x);
363
+ x += pq.dsub;
364
+ }
365
+ if (m < pq.M) { // leftover
366
+ uint8_t c = *code++;
367
+ const float* cent = pq.get_centroids(m, c & 15);
368
+ accu += vd(cent, x);
369
+ }
370
+ });
371
+ return accu;
341
372
  }
342
373
 
343
374
  // Based on IVFFastScan search_implem_10, since it also deals with 1 query
@@ -388,7 +419,8 @@ struct IVFPQFastScanScanner : InvertedListScanner {
388
419
  codes,
389
420
  LUT,
390
421
  *handler,
391
- nullptr);
422
+ nullptr,
423
+ index.get_block_stride());
392
424
 
393
425
  // The handler is for the results of this iteration.
394
426
  // Then we need a second heap to combine across iterations.
@@ -11,6 +11,7 @@
11
11
 
12
12
  #include <cinttypes>
13
13
 
14
+ #include <faiss/impl/simd_dispatch.h>
14
15
  #include <faiss/utils/Heap.h>
15
16
  #include <faiss/utils/distances.h>
16
17
  #include <faiss/utils/utils.h>
@@ -128,7 +129,7 @@ void IndexIVFPQR::search_preassigned(
128
129
  IndexIVFStats* stats) const {
129
130
  uint64_t t0;
130
131
  TIC;
131
- size_t k_coarse = long(k * k_factor);
132
+ size_t k_coarse = long((size_t)k * k_factor);
132
133
  std::unique_ptr<idx_t[]> coarse_labels(new idx_t[k_coarse * n]);
133
134
  {
134
135
  // query with quantizer levels 1 and 2.
@@ -15,10 +15,15 @@
15
15
  #include <vector>
16
16
 
17
17
  #include <faiss/impl/FaissAssert.h>
18
+ #include <faiss/impl/RaBitQUtils.h>
18
19
  #include <faiss/impl/RaBitQuantizer.h>
20
+ #include <faiss/impl/ResultHandler.h>
21
+ #include <faiss/impl/expanded_scanners.h>
19
22
 
20
23
  namespace faiss {
21
24
 
25
+ using rabitq_utils::SignBitFactorsWithError;
26
+
22
27
  IndexIVFRaBitQ::IndexIVFRaBitQ(
23
28
  Index* quantizer,
24
29
  const size_t d,
@@ -44,7 +49,7 @@ IndexIVFRaBitQ::IndexIVFRaBitQ() {
44
49
  void IndexIVFRaBitQ::train_encoder(
45
50
  idx_t n,
46
51
  const float* x,
47
- const idx_t* assign) {
52
+ const idx_t* /*assign*/) {
48
53
  rabitq.train(n, x);
49
54
  }
50
55
 
@@ -182,7 +187,7 @@ struct RaBitInvertedListScanner : InvertedListScanner {
182
187
  }
183
188
 
184
189
  /// following codes come from this inverted list
185
- void set_list(idx_t list_no, float coarse_dis) override {
190
+ void set_list(idx_t list_no, float /*coarse_dis*/) override {
186
191
  this->list_no = list_no;
187
192
 
188
193
  reconstructed_centroid.resize(ivf_rabitq.d);
@@ -193,24 +198,31 @@ struct RaBitInvertedListScanner : InvertedListScanner {
193
198
  }
194
199
 
195
200
  /// compute a single query-to-code distance
196
- float distance_to_code(const uint8_t* code) const override {
201
+ float distance_to_code(const uint8_t* code) const final {
197
202
  return dc->distance_to_code(code);
198
203
  }
199
204
 
205
+ // redefiniing the scan_codes allows to inline the distance_to_code
206
+ // (this is unlikely to matter because it contains a virtual function call)
207
+ size_t scan_codes_1bit(
208
+ size_t list_size,
209
+ const uint8_t* codes,
210
+ const idx_t* ids,
211
+ ResultHandler& handler) const {
212
+ return run_scan_codes(*this, list_size, codes, ids, handler);
213
+ }
214
+
200
215
  /// Override scan_codes to implement adaptive filtering for multi-bit codes
201
216
  size_t scan_codes(
202
217
  size_t list_size,
203
218
  const uint8_t* codes,
204
219
  const idx_t* ids,
205
- float* simi,
206
- idx_t* idxi,
207
- size_t k) const override {
220
+ ResultHandler& handler) const override {
208
221
  size_t ex_bits = ivf_rabitq.rabitq.nb_bits - 1;
209
222
 
210
223
  // For 1-bit codes, use default implementation
211
224
  if (ex_bits == 0 || rabitq_dc == nullptr) {
212
- return InvertedListScanner::scan_codes(
213
- list_size, codes, ids, simi, idxi, k);
225
+ return scan_codes_1bit(list_size, codes, ids, handler);
214
226
  }
215
227
 
216
228
  // Multi-bit: Two-stage search with adaptive filtering
@@ -233,33 +245,33 @@ struct RaBitInvertedListScanner : InvertedListScanner {
233
245
 
234
246
  local_1bit_evaluations++;
235
247
 
236
- // Stage 1: Compute lower bound using 1-bit codes
237
- float lower_bound = rabitq_dc->lower_bound_distance(codes);
238
-
239
- // Stage 2: Adaptive filtering
240
- // L2 (min-heap): filter if lower_bound < simi[0]
241
- // IP (max-heap): filter if lower_bound > simi[0]
242
- // Note: Using simi[0] directly (not cached) enables more aggressive
243
- // filtering as the heap is updated with better candidates
244
- bool should_refine = keep_max ? (lower_bound > simi[0])
245
- : (lower_bound < simi[0]);
246
-
248
+ // Stage 1: Compute distance bound using 1-bit codes
249
+ // For L2 (min-heap): use lower_bound to safely skip if it's
250
+ // already worse than heap worst
251
+ // For IP (max-heap): use upper_bound because with a lower bound,
252
+ // we can't safely skip any candidate
253
+ float est_distance = rabitq_dc->distance_to_code_1bit(codes);
254
+
255
+ // Extract f_error and g_error for filtering
256
+ size_t code_size_base = (ivf_rabitq.d + 7) / 8;
257
+ const rabitq_utils::SignBitFactorsWithError* base_fac =
258
+ reinterpret_cast<
259
+ const rabitq_utils::SignBitFactorsWithError*>(
260
+ codes + code_size_base);
261
+
262
+ bool should_refine = rabitq_utils::should_refine_candidate(
263
+ est_distance,
264
+ base_fac->f_error,
265
+ rabitq_dc->g_error,
266
+ handler.threshold,
267
+ keep_max);
247
268
  if (should_refine) {
248
269
  local_multibit_evaluations++;
249
270
  // Lower bound is promising, compute full distance
250
271
  float dis = distance_to_code(codes);
272
+ int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
251
273
 
252
- // Check if distance improves heap
253
- bool improves_heap =
254
- keep_max ? (dis > simi[0]) : (dis < simi[0]);
255
-
256
- if (improves_heap) {
257
- int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
258
- if (keep_max) {
259
- minheap_replace_top(k, simi, idxi, dis, id);
260
- } else {
261
- maxheap_replace_top(k, simi, idxi, dis, id);
262
- }
274
+ if (handler.add_result(dis, id)) {
263
275
  nup++;
264
276
  }
265
277
  }
@@ -19,7 +19,7 @@
19
19
  namespace faiss {
20
20
 
21
21
  struct IVFRaBitQSearchParameters : IVFSearchParameters {
22
- uint8_t qb = 0;
22
+ uint8_t qb = 4;
23
23
  bool centered = false;
24
24
  };
25
25
 
@@ -29,7 +29,7 @@ struct IndexIVFRaBitQ : IndexIVF {
29
29
 
30
30
  // the default number of bits to quantize a query with.
31
31
  // use '0' to disable quantization and use raw fp32 values.
32
- uint8_t qb = 0;
32
+ uint8_t qb = 4;
33
33
 
34
34
  IndexIVFRaBitQ(
35
35
  Index* quantizer,