faiss 0.4.3 → 0.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +10 -0
  3. data/README.md +2 -0
  4. data/ext/faiss/index.cpp +33 -6
  5. data/ext/faiss/index_binary.cpp +17 -4
  6. data/ext/faiss/kmeans.cpp +6 -6
  7. data/lib/faiss/version.rb +1 -1
  8. data/vendor/faiss/faiss/AutoTune.cpp +2 -3
  9. data/vendor/faiss/faiss/AutoTune.h +1 -1
  10. data/vendor/faiss/faiss/Clustering.cpp +2 -2
  11. data/vendor/faiss/faiss/Clustering.h +2 -2
  12. data/vendor/faiss/faiss/IVFlib.cpp +26 -51
  13. data/vendor/faiss/faiss/IVFlib.h +1 -1
  14. data/vendor/faiss/faiss/Index.cpp +11 -0
  15. data/vendor/faiss/faiss/Index.h +34 -11
  16. data/vendor/faiss/faiss/Index2Layer.cpp +1 -1
  17. data/vendor/faiss/faiss/Index2Layer.h +2 -2
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +1 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +9 -4
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +5 -1
  21. data/vendor/faiss/faiss/IndexBinary.h +7 -7
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +1 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +8 -2
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +3 -3
  26. data/vendor/faiss/faiss/IndexBinaryHash.h +5 -5
  27. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +7 -6
  28. data/vendor/faiss/faiss/IndexFastScan.cpp +125 -49
  29. data/vendor/faiss/faiss/IndexFastScan.h +102 -7
  30. data/vendor/faiss/faiss/IndexFlat.cpp +374 -4
  31. data/vendor/faiss/faiss/IndexFlat.h +81 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +93 -2
  33. data/vendor/faiss/faiss/IndexHNSW.h +58 -2
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +14 -13
  35. data/vendor/faiss/faiss/IndexIDMap.h +6 -6
  36. data/vendor/faiss/faiss/IndexIVF.cpp +1 -1
  37. data/vendor/faiss/faiss/IndexIVF.h +5 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +1 -1
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +9 -3
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +3 -1
  41. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +176 -90
  42. data/vendor/faiss/faiss/IndexIVFFastScan.h +173 -18
  43. data/vendor/faiss/faiss/IndexIVFFlat.cpp +1 -0
  44. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +251 -0
  45. data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +64 -0
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -1
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +134 -2
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -1
  50. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +99 -8
  51. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +4 -1
  52. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +828 -0
  53. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +252 -0
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -1
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
  56. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -1
  57. data/vendor/faiss/faiss/IndexNSG.cpp +1 -1
  58. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +1 -1
  59. data/vendor/faiss/faiss/IndexPQ.cpp +4 -1
  60. data/vendor/faiss/faiss/IndexPQ.h +1 -1
  61. data/vendor/faiss/faiss/IndexPQFastScan.cpp +6 -2
  62. data/vendor/faiss/faiss/IndexPQFastScan.h +5 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +14 -0
  64. data/vendor/faiss/faiss/IndexPreTransform.h +9 -0
  65. data/vendor/faiss/faiss/IndexRaBitQ.cpp +96 -13
  66. data/vendor/faiss/faiss/IndexRaBitQ.h +11 -2
  67. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +731 -0
  68. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +175 -0
  69. data/vendor/faiss/faiss/IndexRefine.cpp +49 -0
  70. data/vendor/faiss/faiss/IndexRefine.h +17 -0
  71. data/vendor/faiss/faiss/IndexShards.cpp +1 -1
  72. data/vendor/faiss/faiss/MatrixStats.cpp +3 -3
  73. data/vendor/faiss/faiss/MetricType.h +1 -1
  74. data/vendor/faiss/faiss/VectorTransform.h +2 -2
  75. data/vendor/faiss/faiss/clone_index.cpp +5 -1
  76. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  77. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +3 -1
  78. data/vendor/faiss/faiss/gpu/GpuIndex.h +11 -11
  79. data/vendor/faiss/faiss/gpu/GpuIndexBinaryCagra.h +1 -1
  80. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +1 -1
  81. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +11 -7
  82. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +1 -1
  83. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -0
  84. data/vendor/faiss/faiss/gpu/test/TestGpuIcmEncoder.cpp +7 -0
  85. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +1 -1
  86. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +1 -1
  87. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -1
  88. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +2 -2
  89. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -1
  90. data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
  91. data/vendor/faiss/faiss/impl/DistanceComputer.h +77 -6
  92. data/vendor/faiss/faiss/impl/FastScanDistancePostProcessing.h +53 -0
  93. data/vendor/faiss/faiss/impl/HNSW.cpp +295 -16
  94. data/vendor/faiss/faiss/impl/HNSW.h +35 -6
  95. data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
  96. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  97. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +4 -4
  98. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +1 -1
  99. data/vendor/faiss/faiss/impl/LookupTableScaler.h +1 -1
  100. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  101. data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
  102. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  103. data/vendor/faiss/faiss/impl/Panorama.cpp +193 -0
  104. data/vendor/faiss/faiss/impl/Panorama.h +204 -0
  105. data/vendor/faiss/faiss/impl/PanoramaStats.cpp +33 -0
  106. data/vendor/faiss/faiss/impl/PanoramaStats.h +38 -0
  107. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +5 -5
  108. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +1 -1
  109. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  110. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -0
  111. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  112. data/vendor/faiss/faiss/impl/RaBitQStats.cpp +29 -0
  113. data/vendor/faiss/faiss/impl/RaBitQStats.h +56 -0
  114. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +294 -0
  115. data/vendor/faiss/faiss/impl/RaBitQUtils.h +330 -0
  116. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +304 -223
  117. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +72 -4
  118. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +362 -0
  119. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +112 -0
  120. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
  121. data/vendor/faiss/faiss/impl/ResultHandler.h +4 -4
  122. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +7 -10
  123. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +2 -4
  124. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +7 -4
  125. data/vendor/faiss/faiss/impl/index_read.cpp +238 -10
  126. data/vendor/faiss/faiss/impl/index_write.cpp +212 -19
  127. data/vendor/faiss/faiss/impl/io.cpp +2 -2
  128. data/vendor/faiss/faiss/impl/io.h +4 -4
  129. data/vendor/faiss/faiss/impl/kmeans1d.cpp +1 -1
  130. data/vendor/faiss/faiss/impl/kmeans1d.h +1 -1
  131. data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
  132. data/vendor/faiss/faiss/impl/mapped_io.cpp +2 -2
  133. data/vendor/faiss/faiss/impl/mapped_io.h +4 -3
  134. data/vendor/faiss/faiss/impl/maybe_owned_vector.h +8 -1
  135. data/vendor/faiss/faiss/impl/platform_macros.h +12 -0
  136. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +30 -4
  137. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +14 -8
  138. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +5 -6
  139. data/vendor/faiss/faiss/impl/simd_result_handlers.h +55 -11
  140. data/vendor/faiss/faiss/impl/svs_io.cpp +86 -0
  141. data/vendor/faiss/faiss/impl/svs_io.h +67 -0
  142. data/vendor/faiss/faiss/impl/zerocopy_io.h +1 -1
  143. data/vendor/faiss/faiss/index_factory.cpp +217 -8
  144. data/vendor/faiss/faiss/index_factory.h +1 -1
  145. data/vendor/faiss/faiss/index_io.h +1 -1
  146. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +1 -1
  147. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  148. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +115 -1
  149. data/vendor/faiss/faiss/invlists/InvertedLists.h +46 -0
  150. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
  151. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  152. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +261 -0
  153. data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +117 -0
  154. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +66 -0
  155. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +245 -0
  156. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +137 -0
  157. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +39 -0
  158. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +42 -0
  159. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +149 -0
  160. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +58 -0
  161. data/vendor/faiss/faiss/utils/AlignedTable.h +1 -1
  162. data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
  163. data/vendor/faiss/faiss/utils/Heap.h +3 -3
  164. data/vendor/faiss/faiss/utils/NeuralNet.cpp +1 -1
  165. data/vendor/faiss/faiss/utils/NeuralNet.h +3 -3
  166. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
  167. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
  168. data/vendor/faiss/faiss/utils/approx_topk/mode.h +1 -1
  169. data/vendor/faiss/faiss/utils/distances.cpp +0 -3
  170. data/vendor/faiss/faiss/utils/distances.h +2 -2
  171. data/vendor/faiss/faiss/utils/extra_distances-inl.h +3 -1
  172. data/vendor/faiss/faiss/utils/hamming-inl.h +2 -0
  173. data/vendor/faiss/faiss/utils/hamming.cpp +7 -6
  174. data/vendor/faiss/faiss/utils/hamming.h +1 -1
  175. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -2
  176. data/vendor/faiss/faiss/utils/partitioning.cpp +5 -5
  177. data/vendor/faiss/faiss/utils/partitioning.h +2 -2
  178. data/vendor/faiss/faiss/utils/rabitq_simd.h +222 -336
  179. data/vendor/faiss/faiss/utils/random.cpp +1 -1
  180. data/vendor/faiss/faiss/utils/simdlib_avx2.h +1 -1
  181. data/vendor/faiss/faiss/utils/simdlib_avx512.h +1 -1
  182. data/vendor/faiss/faiss/utils/simdlib_neon.h +2 -2
  183. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +1 -1
  184. data/vendor/faiss/faiss/utils/utils.cpp +9 -2
  185. data/vendor/faiss/faiss/utils/utils.h +2 -2
  186. metadata +29 -1
@@ -37,11 +37,28 @@ struct RaBitQuantizer : Quantizer {
37
37
  // possible. Thus, a quantizer has to introduce a metric.
38
38
  MetricType metric_type = MetricType::METRIC_L2;
39
39
 
40
- RaBitQuantizer(size_t d = 0, MetricType metric = MetricType::METRIC_L2);
40
+ // Number of bits per dimension (1-9). Default is 1 for backward
41
+ // compatibility.
42
+ // - nb_bits = 1: standard 1-bit RaBitQ (sign bits only)
43
+ // - nb_bits = 2-9: multi-bit RaBitQ (1 sign bit + ex_bits extra bits)
44
+ size_t nb_bits = 1;
45
+
46
+ RaBitQuantizer(
47
+ size_t d = 0,
48
+ MetricType metric = MetricType::METRIC_L2,
49
+ size_t nb_bits = 1);
50
+
51
+ // Compute code size based on dimensionality and number of bits
52
+ // Returns: size in bytes for one encoded vector
53
+ // - nb_bits=1: (d+7)/8 + 8 bytes (1-bit codes + base factors)
54
+ // - nb_bits>1: (d+7)/8 + 8 + d*ex_bits/8 + 8 bytes
55
+ // (1-bit codes + base factors + ex-bit codes + ex factors)
56
+ size_t compute_code_size(size_t d, size_t num_bits) const;
41
57
 
42
58
  void train(size_t n, const float* x) override;
43
59
 
44
- // every vector is expected to take (d + 7) / 8 + sizeof(FactorsData) bytes,
60
+ // every vector is expected to take (d + 7) / 8 + sizeof(SignBitFactors)
61
+ // bytes,
45
62
  void compute_codes(const float* x, uint8_t* codes, size_t n) const override;
46
63
 
47
64
  void compute_codes_core(
@@ -71,8 +88,59 @@ struct RaBitQuantizer : Quantizer {
71
88
  // specify qb = 0 to get an DC that does not quantize a query
72
89
  // specify qb > 0 to have SQ qb-bits query
73
90
  FlatCodesDistanceComputer* get_distance_computer(
74
- uint8_t qb,
75
- const float* centroid_in = nullptr) const;
91
+ uint8_t qb = 0,
92
+ const float* centroid = nullptr,
93
+ bool centered = false) const;
94
+ };
95
+
96
+ // RaBitQDistanceComputer: Base class for RaBitQ distance computers
97
+ //
98
+ // This intermediate class exists to provide a unified interface for
99
+ // two-stage multi-bit search. While most Faiss quantizers extend
100
+ // FlatCodesDistanceComputer directly, RaBitQ requires this additional
101
+ // abstraction layer due to its unique split encoding strategy
102
+ // (1 sign bit + magnitude bits) which enables:
103
+ //
104
+ // 1. distance_to_code_1bit() - Fast 1-bit filtering using only sign bits
105
+ // 2. distance_to_code_full() - Accurate multi-bit refinement using all bits
106
+ // 3. lower_bound_distance() - Error-bounded adaptive filtering
107
+ // (based on 1-bit estimator)
108
+ //
109
+ // These three methods implement RaBitQ's two-stage search pattern and are
110
+ // shared between the quantized (Q) and non-quantized (NotQ) query variants.
111
+ // The intermediate class allows two-stage search code to work with both
112
+ // variants via a single dynamic_cast.
113
+ struct RaBitQDistanceComputer : FlatCodesDistanceComputer {
114
+ size_t d = 0;
115
+ const float* centroid = nullptr;
116
+ MetricType metric_type = MetricType::METRIC_L2;
117
+ size_t nb_bits = 1;
118
+
119
+ // Query norm for lower bound computation (g_error in rabitq-library)
120
+ // This is the L2 norm of the rotated query: ||query - centroid||
121
+ float g_error = 0.0f;
122
+
123
+ float symmetric_dis(idx_t /*i*/, idx_t /*j*/) override {
124
+ // Not used for RaBitQ
125
+ FAISS_THROW_MSG("Not implemented");
126
+ }
127
+
128
+ // Compute 1-bit distance estimate (fast)
129
+ virtual float distance_to_code_1bit(const uint8_t* code) = 0;
130
+
131
+ // Compute full multi-bit distance (accurate)
132
+ virtual float distance_to_code_full(const uint8_t* code) = 0;
133
+
134
+ // Compute lower bound of distance using error bounds
135
+ // Guarantees: actual_distance >= lower_bound_distance
136
+ // Used for adaptive filtering in two-stage search
137
+ virtual float lower_bound_distance(const uint8_t* code);
138
+
139
+ // Override from FlatCodesDistanceComputer
140
+ // Delegates to distance_to_code_full() for multi-bit distance computation
141
+ float distance_to_code(const uint8_t* code) final {
142
+ return distance_to_code_full(code);
143
+ }
76
144
  };
77
145
 
78
146
  } // namespace faiss
@@ -0,0 +1,362 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // NOTE: Parts of this implementation are adapted from:
9
+ // RaBitQ-Library/include/rabitqlib/quantization/rabitq_impl.hpp
10
+ // https://github.com/VectorDB-NTU/RaBitQ-Library
11
+
12
+ #include <faiss/impl/FaissAssert.h>
13
+ #include <faiss/impl/RaBitQUtils.h>
14
+ #include <faiss/utils/distances.h>
15
+
16
+ #include <algorithm>
17
+ #include <cmath>
18
+ #include <cstring>
19
+ #include <queue>
20
+ #include <vector>
21
+
22
+ namespace faiss {
23
+ namespace rabitq_multibit {
24
+
25
+ using rabitq_utils::ExtraBitsFactors;
26
+ using rabitq_utils::SignBitFactorsWithError;
27
+
28
+ constexpr float kTightStart[9] =
29
+ {0.0f, 0.15f, 0.20f, 0.52f, 0.59f, 0.71f, 0.75f, 0.77f, 0.81f};
30
+
31
+ constexpr double kEps = 1e-5;
32
+
33
+ /**
34
+ * Compute optimal scaling factor for ex-bits quantization using priority
35
+ * queue-based search.
36
+ *
37
+ * This function finds the optimal scaling factor 't' that maximizes the
38
+ * inner product between the normalized quantized vector and the normalized
39
+ * absolute residual. The algorithm uses a priority queue to efficiently
40
+ * explore different quantization levels.
41
+ *
42
+ *
43
+ * @param o_abs Normalized absolute residual vector (must be positive, length
44
+ * d)
45
+ * @param d Dimensionality of the vector
46
+ * @param nb_bits Number of bits per dimension (2-9)
47
+ * @return Optimal scaling factor 't'
48
+ */
49
+ float compute_optimal_scaling_factor(
50
+ const float* o_abs,
51
+ size_t d,
52
+ size_t nb_bits) {
53
+ const size_t ex_bits = nb_bits - 1;
54
+ FAISS_THROW_IF_NOT_MSG(
55
+ ex_bits >= 1 && ex_bits <= 8, "ex_bits must be in range [1, 8]");
56
+
57
+ const int kNEnum = 10;
58
+ const int max_code = (1 << ex_bits) - 1;
59
+
60
+ float max_o = *std::max_element(o_abs, o_abs + d);
61
+
62
+ // Determine search range [t_start, t_end]
63
+ float t_end = static_cast<float>(max_code + kNEnum) / max_o;
64
+ float t_start = t_end * kTightStart[ex_bits];
65
+
66
+ std::vector<float> inv_o_abs(d);
67
+ for (size_t i = 0; i < d; ++i) {
68
+ inv_o_abs[i] = 1.0f / o_abs[i];
69
+ }
70
+
71
+ std::vector<int> cur_o_bar(d);
72
+ float sqr_denominator = static_cast<float>(d) * 0.25f;
73
+ float numerator = 0.0f;
74
+
75
+ for (size_t i = 0; i < d; ++i) {
76
+ int cur = static_cast<int>((t_start * o_abs[i]) + kEps);
77
+ cur_o_bar[i] = cur;
78
+ sqr_denominator += static_cast<float>(cur * cur + cur);
79
+ numerator += (cur + 0.5f) * o_abs[i];
80
+ }
81
+
82
+ float inv_sqrt_denom = 1.0f / std::sqrt(sqr_denominator);
83
+
84
+ // Pair: (next_t, dimension_index)
85
+ // Maximum size is d (one entry per dimension), so reserve exactly d
86
+ std::vector<std::pair<float, size_t>> pq_storage;
87
+ pq_storage.reserve(d);
88
+ std::priority_queue<
89
+ std::pair<float, size_t>,
90
+ std::vector<std::pair<float, size_t>>,
91
+ std::greater<>>
92
+ next_t(std::greater<>(), std::move(pq_storage));
93
+
94
+ // Initialize queue with next quantization level for each dimension
95
+ for (size_t i = 0; i < d; ++i) {
96
+ float t_next = static_cast<float>(cur_o_bar[i] + 1) * inv_o_abs[i];
97
+ if (t_next < t_end) {
98
+ next_t.emplace(t_next, i);
99
+ }
100
+ }
101
+
102
+ float max_ip = 0.0f;
103
+ float t = 0.0f;
104
+
105
+ while (!next_t.empty()) {
106
+ float cur_t = next_t.top().first;
107
+ size_t update_id = next_t.top().second;
108
+ next_t.pop();
109
+
110
+ cur_o_bar[update_id]++;
111
+ int update_o_bar = cur_o_bar[update_id];
112
+
113
+ float delta = 2.0f * update_o_bar;
114
+ sqr_denominator += delta;
115
+ numerator += o_abs[update_id];
116
+
117
+ float old_denom = sqr_denominator - delta;
118
+ inv_sqrt_denom = inv_sqrt_denom *
119
+ (1.0f - 0.5f * delta / (old_denom + delta * 0.5f));
120
+
121
+ float cur_ip = numerator * inv_sqrt_denom;
122
+
123
+ if (cur_ip > max_ip) {
124
+ max_ip = cur_ip;
125
+ t = cur_t;
126
+ }
127
+
128
+ if (update_o_bar < max_code) {
129
+ float t_next =
130
+ static_cast<float>(update_o_bar + 1) * inv_o_abs[update_id];
131
+ if (t_next < t_end) {
132
+ next_t.emplace(t_next, update_id);
133
+ }
134
+ }
135
+ }
136
+
137
+ return t;
138
+ }
139
+
140
+ /**
141
+ * Pack multi-bit codes from integer array to byte array.
142
+ *
143
+ * @param tmp_code Integer codes (length d), each value in [0, 2^ex_bits - 1]
144
+ * @param ex_code Output packed byte array
145
+ * @param d Dimensionality
146
+ * @param nb_bits Number of bits per dimension (2-9)
147
+ */
148
+ void pack_multibit_codes(
149
+ const int* tmp_code,
150
+ uint8_t* ex_code,
151
+ size_t d,
152
+ size_t nb_bits) {
153
+ const size_t ex_bits = nb_bits - 1;
154
+ FAISS_THROW_IF_NOT_MSG(
155
+ ex_bits >= 1 && ex_bits <= 8, "ex_bits must be in range [1, 8]");
156
+
157
+ size_t total_bits = d * ex_bits;
158
+ size_t output_size = (total_bits + 7) / 8;
159
+ memset(ex_code, 0, output_size);
160
+
161
+ size_t bit_pos = 0;
162
+ for (size_t i = 0; i < d; i++) {
163
+ int code_value = tmp_code[i];
164
+
165
+ for (size_t bit = 0; bit < ex_bits; bit++) {
166
+ size_t byte_idx = bit_pos / 8;
167
+ size_t bit_idx = bit_pos % 8;
168
+
169
+ if (code_value & (1 << bit)) {
170
+ ex_code[byte_idx] |= (1 << bit_idx);
171
+ }
172
+
173
+ bit_pos++;
174
+ }
175
+ }
176
+ }
177
+
178
+ /**
179
+ * Compute ex-bits factors for distance computation.
180
+ *
181
+ * @param residual Original residual vector (data - centroid)
182
+ * @param centroid Centroid vector (can be nullptr for zero centroid)
183
+ * @param tmp_code Quantized ex-bit codes (before packing, after bit flipping)
184
+ * @param d Dimensionality
185
+ * @param ex_bits Number of extra bits
186
+ * @param norm L2 norm of residual
187
+ * @param ipnorm Unnormalized inner product between quantized and normalized
188
+ * residual
189
+ * @param ex_factors Output factors structure
190
+ * @param metric_type Distance metric (L2 or Inner Product)
191
+ */
192
+ void compute_ex_factors(
193
+ const float* residual,
194
+ const float* centroid,
195
+ const int* tmp_code,
196
+ size_t d,
197
+ size_t ex_bits,
198
+ float norm,
199
+ double ipnorm,
200
+ ExtraBitsFactors& ex_factors,
201
+ MetricType metric_type) {
202
+ FAISS_THROW_IF_NOT_MSG(
203
+ metric_type == MetricType::METRIC_L2 ||
204
+ metric_type == MetricType::METRIC_INNER_PRODUCT,
205
+ "Unsupported metric type");
206
+
207
+ // Compute ipnorm_inv = 1 / ipnorm
208
+ float ipnorm_inv = static_cast<float>(1.0 / ipnorm);
209
+ if (!std::isnormal(ipnorm_inv)) {
210
+ ipnorm_inv = 1.0f;
211
+ }
212
+
213
+ // Reconstruct xu_cb from total_code
214
+ // total_code was formed from: total_code[i] = (sign << ex_bits) +
215
+ // ex_code[i] Reconstruction: xu_cb[i] = total_code[i] + cb
216
+ const float cb = -(static_cast<float>(1 << ex_bits) - 0.5f);
217
+ std::vector<float> xu_cb(d);
218
+ for (size_t i = 0; i < d; i++) {
219
+ xu_cb[i] = static_cast<float>(tmp_code[i]) + cb;
220
+ }
221
+
222
+ // Compute inner products needed for factors
223
+ float l2_sqr = norm * norm;
224
+ float ip_resi_xucb = fvec_inner_product(residual, xu_cb.data(), d);
225
+
226
+ // Compute factors
227
+ if (metric_type == MetricType::METRIC_L2) {
228
+ // For L2, no centroid correction needed in IVF setting
229
+ // because residual = x - centroid, distance computed in residual space
230
+ ex_factors.f_add_ex = l2_sqr;
231
+ ex_factors.f_rescale_ex = ipnorm_inv * -2.0f * norm;
232
+ } else {
233
+ // For IP, centroid correction is needed
234
+ float ip_resi_cent = 0;
235
+ if (centroid != nullptr) {
236
+ ip_resi_cent = fvec_inner_product(residual, centroid, d);
237
+ }
238
+
239
+ float ip_cent_xucb = 0;
240
+ if (centroid != nullptr) {
241
+ ip_cent_xucb = fvec_inner_product(centroid, xu_cb.data(), d);
242
+ }
243
+
244
+ // When ip_resi_xucb is zero, the correction term should be zero
245
+ float correction_term = 0.0f;
246
+ if (ip_resi_xucb != 0.0f) {
247
+ correction_term = l2_sqr * ip_cent_xucb / ip_resi_xucb;
248
+ }
249
+
250
+ ex_factors.f_add_ex = 1 - ip_resi_cent + correction_term;
251
+ ex_factors.f_rescale_ex = ipnorm_inv * -norm;
252
+ }
253
+ }
254
+
255
+ /**
256
+ * Quantize residual vector to ex-bits.
257
+ *
258
+ * This is the main quantization function that:
259
+ * 1. Normalizes the residual
260
+ * 2. Takes absolute value
261
+ * 3. Finds optimal scaling factor
262
+ * 4. Quantizes to ex_bits
263
+ * 5. Handles negative dimensions by flipping bits
264
+ * 6. Packs codes into byte array
265
+ * 7. Computes factors for distance computation
266
+ *
267
+ * @param residual Input residual vector (data - centroid), length d
268
+ * @param d Dimensionality
269
+ * @param nb_bits Number of bits per dimension (2-9)
270
+ * @param ex_code Output packed ex-bit codes
271
+ * @param ex_factors Output ex-bits factors
272
+ * @param metric_type Distance metric (L2 or Inner Product)
273
+ * @param centroid Optional centroid vector (needed for IP metric)
274
+ */
275
+ void quantize_ex_bits(
276
+ const float* residual,
277
+ size_t d,
278
+ size_t nb_bits,
279
+ uint8_t* ex_code,
280
+ ExtraBitsFactors& ex_factors,
281
+ MetricType metric_type,
282
+ const float* centroid) {
283
+ const size_t ex_bits = nb_bits - 1;
284
+ FAISS_THROW_IF_NOT_MSG(
285
+ ex_bits >= 1 && ex_bits <= 8, "ex_bits must be in range [1, 8]");
286
+ FAISS_THROW_IF_NOT_MSG(residual != nullptr, "residual cannot be null");
287
+ FAISS_THROW_IF_NOT_MSG(ex_code != nullptr, "ex_code cannot be null");
288
+
289
+ // Step 1: Compute L2 norm of residual
290
+ float norm_sqr = fvec_norm_L2sqr(residual, d);
291
+ float norm = std::sqrt(norm_sqr);
292
+
293
+ // Handle degenerate case
294
+ if (norm < 1e-10f) {
295
+ size_t code_size = (d * ex_bits + 7) / 8;
296
+ memset(ex_code, 0, code_size);
297
+ ex_factors.f_add_ex = 0.0f;
298
+ ex_factors.f_rescale_ex = 1.0f;
299
+ return;
300
+ }
301
+
302
+ // Step 2: Normalize residual
303
+ std::vector<float> normalized_residual(d);
304
+ for (size_t i = 0; i < d; i++) {
305
+ normalized_residual[i] = residual[i] / norm;
306
+ }
307
+
308
+ // Step 3: Take absolute value
309
+ std::vector<float> o_abs(d);
310
+ for (size_t i = 0; i < d; i++) {
311
+ o_abs[i] = std::abs(normalized_residual[i]);
312
+ }
313
+
314
+ // Step 4: Find optimal scaling factor
315
+ float t = compute_optimal_scaling_factor(o_abs.data(), d, nb_bits);
316
+
317
+ // Step 5: Quantize to ex_bits
318
+ std::vector<int> tmp_code(d);
319
+ double ipnorm = 0;
320
+ int max_code = (1 << ex_bits) - 1;
321
+
322
+ for (size_t i = 0; i < d; i++) {
323
+ tmp_code[i] = std::min(static_cast<int>(t * o_abs[i] + kEps), max_code);
324
+ // Compute unnormalized inner product
325
+ ipnorm += (tmp_code[i] + 0.5) * o_abs[i];
326
+ }
327
+
328
+ // Step 6: Handle negative dimensions (flip bits)
329
+ // For negative residuals, flip all bits: code' = ~code & max_code
330
+ for (size_t i = 0; i < d; i++) {
331
+ if (residual[i] < 0) {
332
+ tmp_code[i] = (~tmp_code[i]) & max_code;
333
+ }
334
+ }
335
+
336
+ // Step 7: Pack codes into byte array
337
+ pack_multibit_codes(tmp_code.data(), ex_code, d, nb_bits);
338
+
339
+ // Step 8: Compute factors for distance computation
340
+ // Reconstruct total_code for factor computation
341
+ std::vector<int> total_code(d);
342
+ for (size_t i = 0; i < d; i++) {
343
+ // Form total_code = (sign << ex_bits) + ex_code
344
+ bool sign_bit = (residual[i] >= 0);
345
+ total_code[i] = tmp_code[i] + ((sign_bit ? 1 : 0) << ex_bits);
346
+ }
347
+
348
+ // Compute ex-factors; centroid is needed for IP metric correction
349
+ compute_ex_factors(
350
+ residual,
351
+ centroid, // Pass centroid for IP metric factor computation
352
+ total_code.data(),
353
+ d,
354
+ ex_bits,
355
+ norm,
356
+ ipnorm,
357
+ ex_factors,
358
+ metric_type);
359
+ }
360
+
361
+ } // namespace rabitq_multibit
362
+ } // namespace faiss
@@ -0,0 +1,112 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // Reference:
9
+ // "Practical and asymptotically optimal quantization of high-dimensional
10
+ // vectors in euclidean space for approximate nearest neighbor search"
11
+ // Jianyang Gao, Yutong Gou, Yuexuan Xu, Yongyi Yang, Cheng Long, Raymond
12
+ // Chi-Wing Wong https://dl.acm.org/doi/pdf/10.1145/3725413
13
+ //
14
+ // Reference implementation: https://github.com/VectorDB-NTU/RaBitQ-Library
15
+ // NOTE: Parts of this implementation are adapted from
16
+ // rabitqlib/quantization/rabitq_impl.hpp in the above repository.
17
+
18
+ #pragma once
19
+
20
+ #include <faiss/MetricType.h>
21
+ #include <faiss/impl/RaBitQUtils.h>
22
+ #include <cstddef>
23
+ #include <cstdint>
24
+
25
+ namespace faiss {
26
+ namespace rabitq_multibit {
27
+
28
+ /**
29
+ * Compute optimal scaling factor for ex-bits quantization.
30
+ *
31
+ * Uses priority queue-based search to find the scaling factor that
32
+ * maximizes the inner product between quantized and original vectors.
33
+ *
34
+ * @param o_abs Normalized absolute residual vector (positive values)
35
+ * @param d Dimensionality
36
+ * @param nb_bits Number of bits per dimension (2-9)
37
+ * @return Optimal scaling factor 't'
38
+ */
39
+ float compute_optimal_scaling_factor(
40
+ const float* o_abs,
41
+ size_t d,
42
+ size_t nb_bits);
43
+
44
+ /**
45
+ * Pack multi-bit codes from integer array to byte array.
46
+ *
47
+ * @param tmp_code Integer codes (length d), values in [0, 2^ex_bits - 1]
48
+ * @param ex_code Output packed byte array
49
+ * @param d Dimensionality
50
+ * @param nb_bits Number of bits per dimension (2-9)
51
+ */
52
+ void pack_multibit_codes(
53
+ const int* tmp_code,
54
+ uint8_t* ex_code,
55
+ size_t d,
56
+ size_t nb_bits);
57
+
58
+ /**
59
+ * Compute ex-bits factors for distance computation.
60
+ *
61
+ * @param residual Original residual vector (data - centroid)
62
+ * @param centroid Centroid vector (can be nullptr for zero centroid)
63
+ * @param tmp_code Quantized ex-bit codes (unpacked integers)
64
+ * @param d Dimensionality
65
+ * @param ex_bits Number of extra bits
66
+ * @param norm L2 norm of residual
67
+ * @param ipnorm Unnormalized inner product
68
+ * @param ex_factors Output factors structure
69
+ * @param metric_type Distance metric (L2 or IP)
70
+ */
71
+ void compute_ex_factors(
72
+ const float* residual,
73
+ const float* centroid,
74
+ const int* tmp_code,
75
+ size_t d,
76
+ size_t ex_bits,
77
+ float norm,
78
+ double ipnorm,
79
+ rabitq_utils::ExtraBitsFactors& ex_factors,
80
+ MetricType metric_type);
81
+
82
+ /**
83
+ * Main quantization function: quantize residual vector to ex-bits.
84
+ *
85
+ * Performs the complete multi-bit quantization pipeline:
86
+ * 1. Normalize residual
87
+ * 2. Take absolute value
88
+ * 3. Find optimal scaling factor
89
+ * 4. Quantize to ex_bits
90
+ * 5. Handle negative dimensions by bit flipping
91
+ * 6. Pack codes into byte array
92
+ * 7. Compute factors for distance computation
93
+ *
94
+ * @param residual Input residual vector (data - centroid), length d
95
+ * @param d Dimensionality
96
+ * @param nb_bits Number of bits per dimension (2-9)
97
+ * @param ex_code Output packed ex-bit codes
98
+ * @param ex_factors Output ex-bits factors
99
+ * @param metric_type Distance metric (L2 or Inner Product)
100
+ * @param centroid Optional centroid vector (needed for IP metric)
101
+ */
102
+ void quantize_ex_bits(
103
+ const float* residual,
104
+ size_t d,
105
+ size_t nb_bits,
106
+ uint8_t* ex_code,
107
+ rabitq_utils::ExtraBitsFactors& ex_factors,
108
+ MetricType metric_type,
109
+ const float* centroid = nullptr);
110
+
111
+ } // namespace rabitq_multibit
112
+ } // namespace faiss
@@ -49,7 +49,7 @@ struct ResidualQuantizer : AdditiveQuantizer {
49
49
  * first element of the beam (faster but less accurate) */
50
50
  static const int Train_top_beam = 1024;
51
51
 
52
- /** set this bit to *not* autmatically compute the codebook tables
52
+ /** set this bit to *not* automatically compute the codebook tables
53
53
  * after training */
54
54
  static const int Skip_codebook_tables = 2048;
55
55
 
@@ -26,11 +26,11 @@ namespace faiss {
26
26
  * The classes below are intended to be used as template arguments
27
27
  * they handle results for batches of queries (size nq).
28
28
  * They can be called in two ways:
29
- * - by instanciating a SingleResultHandler that tracks results for a single
29
+ * - by instantiating a SingleResultHandler that tracks results for a single
30
30
  * query
31
31
  * - with begin_multiple/add_results/end_multiple calls where a whole block of
32
32
  * results is submitted
33
- * All classes are templated on C which to define wheter the min or the max of
33
+ * All classes are templated on C which to define whether the min or the max of
34
34
  * results is to be kept, and on sel, so that the codepaths for with / without
35
35
  * selector can be separated at compile time.
36
36
  *****************************************************************/
@@ -306,7 +306,7 @@ struct HeapBlockResultHandler : TopkBlockResultHandler<C, use_sel> {
306
306
  *
307
307
  * A reservoir is a result array of size capacity > n (number of requested
308
308
  * results) all results below a threshold are stored in an arbitrary order.
309
- *When the capacity is reached, a new threshold is chosen by partitionning
309
+ *When the capacity is reached, a new threshold is chosen by partitioning
310
310
  *the distance array.
311
311
  *****************************************************************/
312
312
 
@@ -572,7 +572,7 @@ struct RangeSearchBlockResultHandler : BlockResultHandler<C, use_sel> {
572
572
  RangeSearchPartialResult* pres;
573
573
  // there is one RangeSearchPartialResult structure per j0
574
574
  // (= block of columns of the large distance matrix)
575
- // it is a bit tricky to find the poper PartialResult structure
575
+ // it is a bit tricky to find the proper PartialResult structure
576
576
  // because the inner loop is on db not on queries.
577
577
 
578
578
  if (pr < j0s.size() && j0 == j0s[pr]) {
@@ -321,7 +321,7 @@ struct Codec6bit {
321
321
  static FAISS_ALWAYS_INLINE __m256
322
322
  decode_8_components(const uint8_t* code, int i) {
323
323
  // // Faster code for Intel CPUs or AMD Zen3+, just keeping it here
324
- // // for the reference, maybe, it becomes used oned day.
324
+ // // for the reference, maybe, it becomes used one day.
325
325
  // const uint16_t* data16 = (const uint16_t*)(code + (i >> 2) * 3);
326
326
  // const uint32_t* data32 = (const uint32_t*)data16;
327
327
  // const uint64_t val = *data32 + ((uint64_t)data16[2] << 32);
@@ -1009,16 +1009,13 @@ void train_Uniform(
1009
1009
  } else if (rs == ScalarQuantizer::RS_quantiles) {
1010
1010
  std::vector<float> x_copy(n);
1011
1011
  memcpy(x_copy.data(), x, n * sizeof(*x));
1012
- // TODO just do a quickselect
1013
- std::sort(x_copy.begin(), x_copy.end());
1014
- int o = int(rs_arg * n);
1015
- if (o < 0) {
1016
- o = 0;
1017
- }
1018
- if (o > n - o) {
1019
- o = n / 2;
1020
- }
1012
+ int temp = int(rs_arg * n);
1013
+ int o = temp < 0 ? 0 : (temp > n / 2 ? n / 2 : temp);
1014
+
1015
+ std::nth_element(x_copy.begin(), x_copy.begin() + o, x_copy.end());
1021
1016
  vmin = x_copy[o];
1017
+ std::nth_element(
1018
+ x_copy.begin(), x_copy.begin() + (n - 1 - o), x_copy.end());
1022
1019
  vmax = x_copy[n - 1 - o];
1023
1020
 
1024
1021
  } else if (rs == ScalarQuantizer::RS_optim) {
@@ -40,7 +40,7 @@ struct ScalarQuantizer : Quantizer {
40
40
  QuantizerType qtype = QT_8bit;
41
41
 
42
42
  /** The uniform encoder can estimate the range of representable
43
- * values of the unform encoder using different statistics. Here
43
+ * values of the uniform encoder using different statistics. Here
44
44
  * rs = rangestat_arg */
45
45
 
46
46
  // rangestat_arg.
@@ -98,9 +98,7 @@ struct ScalarQuantizer : Quantizer {
98
98
  SQuantizer* select_quantizer() const;
99
99
 
100
100
  struct SQDistanceComputer : FlatCodesDistanceComputer {
101
- const float* q;
102
-
103
- SQDistanceComputer() : q(nullptr) {}
101
+ SQDistanceComputer() : FlatCodesDistanceComputer(nullptr) {}
104
102
 
105
103
  virtual float query_to_code(const uint8_t* code) const = 0;
106
104
 
@@ -5,6 +5,8 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
+ #pragma once
9
+
8
10
  #include <faiss/impl/FaissAssert.h>
9
11
  #include <exception>
10
12
  #include <iostream>
@@ -75,10 +77,11 @@ void ThreadedIndex<IndexT>::addIndex(IndexT* index) {
75
77
  }
76
78
  }
77
79
 
78
- indices_.emplace_back(std::make_pair(
79
- index,
80
- std::unique_ptr<WorkerThread>(
81
- isThreaded_ ? new WorkerThread : nullptr)));
80
+ indices_.emplace_back(
81
+ std::make_pair(
82
+ index,
83
+ std::unique_ptr<WorkerThread>(
84
+ isThreaded_ ? new WorkerThread : nullptr)));
82
85
 
83
86
  onAfterAddIndex(index);
84
87
  }