faiss 0.5.0 → 0.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +2 -0
  4. data/ext/faiss/index.cpp +8 -0
  5. data/lib/faiss/version.rb +1 -1
  6. data/vendor/faiss/faiss/IVFlib.cpp +25 -49
  7. data/vendor/faiss/faiss/Index.cpp +11 -0
  8. data/vendor/faiss/faiss/Index.h +24 -1
  9. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +1 -0
  10. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  11. data/vendor/faiss/faiss/IndexFastScan.cpp +1 -1
  12. data/vendor/faiss/faiss/IndexFastScan.h +3 -8
  13. data/vendor/faiss/faiss/IndexFlat.cpp +374 -4
  14. data/vendor/faiss/faiss/IndexFlat.h +80 -0
  15. data/vendor/faiss/faiss/IndexHNSW.cpp +90 -1
  16. data/vendor/faiss/faiss/IndexHNSW.h +57 -1
  17. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +34 -149
  18. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +86 -2
  19. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +3 -1
  20. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +293 -115
  21. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +52 -16
  22. data/vendor/faiss/faiss/IndexPQ.cpp +4 -1
  23. data/vendor/faiss/faiss/IndexPreTransform.cpp +14 -0
  24. data/vendor/faiss/faiss/IndexPreTransform.h +9 -0
  25. data/vendor/faiss/faiss/IndexRaBitQ.cpp +96 -16
  26. data/vendor/faiss/faiss/IndexRaBitQ.h +5 -1
  27. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +238 -93
  28. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +35 -9
  29. data/vendor/faiss/faiss/IndexRefine.cpp +49 -0
  30. data/vendor/faiss/faiss/IndexRefine.h +17 -0
  31. data/vendor/faiss/faiss/clone_index.cpp +2 -0
  32. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +3 -1
  33. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +1 -1
  34. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +1 -1
  35. data/vendor/faiss/faiss/impl/DistanceComputer.h +74 -3
  36. data/vendor/faiss/faiss/impl/HNSW.cpp +294 -15
  37. data/vendor/faiss/faiss/impl/HNSW.h +31 -2
  38. data/vendor/faiss/faiss/impl/IDSelector.h +3 -3
  39. data/vendor/faiss/faiss/impl/Panorama.cpp +193 -0
  40. data/vendor/faiss/faiss/impl/Panorama.h +204 -0
  41. data/vendor/faiss/faiss/impl/RaBitQStats.cpp +29 -0
  42. data/vendor/faiss/faiss/impl/RaBitQStats.h +56 -0
  43. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +54 -6
  44. data/vendor/faiss/faiss/impl/RaBitQUtils.h +183 -6
  45. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +269 -84
  46. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +71 -4
  47. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +362 -0
  48. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +112 -0
  49. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +6 -9
  50. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +1 -3
  51. data/vendor/faiss/faiss/impl/index_read.cpp +156 -12
  52. data/vendor/faiss/faiss/impl/index_write.cpp +142 -19
  53. data/vendor/faiss/faiss/impl/platform_macros.h +12 -0
  54. data/vendor/faiss/faiss/impl/svs_io.cpp +86 -0
  55. data/vendor/faiss/faiss/impl/svs_io.h +67 -0
  56. data/vendor/faiss/faiss/index_factory.cpp +182 -15
  57. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +1 -1
  58. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  59. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +18 -109
  60. data/vendor/faiss/faiss/invlists/InvertedLists.h +2 -18
  61. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
  62. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  63. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +261 -0
  64. data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +117 -0
  65. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +66 -0
  66. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +245 -0
  67. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +137 -0
  68. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +39 -0
  69. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +42 -0
  70. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +149 -0
  71. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +58 -0
  72. data/vendor/faiss/faiss/utils/distances.cpp +0 -3
  73. data/vendor/faiss/faiss/utils/utils.cpp +4 -0
  74. metadata +18 -1
@@ -0,0 +1,362 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // NOTE: Parts of this implementation are adapted from:
9
+ // RaBitQ-Library/include/rabitqlib/quantization/rabitq_impl.hpp
10
+ // https://github.com/VectorDB-NTU/RaBitQ-Library
11
+
12
+ #include <faiss/impl/FaissAssert.h>
13
+ #include <faiss/impl/RaBitQUtils.h>
14
+ #include <faiss/utils/distances.h>
15
+
16
+ #include <algorithm>
17
+ #include <cmath>
18
+ #include <cstring>
19
+ #include <queue>
20
+ #include <vector>
21
+
22
+ namespace faiss {
23
+ namespace rabitq_multibit {
24
+
25
+ using rabitq_utils::ExtraBitsFactors;
26
+ using rabitq_utils::SignBitFactorsWithError;
27
+
28
+ constexpr float kTightStart[9] =
29
+ {0.0f, 0.15f, 0.20f, 0.52f, 0.59f, 0.71f, 0.75f, 0.77f, 0.81f};
30
+
31
+ constexpr double kEps = 1e-5;
32
+
33
+ /**
34
+ * Compute optimal scaling factor for ex-bits quantization using priority
35
+ * queue-based search.
36
+ *
37
+ * This function finds the optimal scaling factor 't' that maximizes the
38
+ * inner product between the normalized quantized vector and the normalized
39
+ * absolute residual. The algorithm uses a priority queue to efficiently
40
+ * explore different quantization levels.
41
+ *
42
+ *
43
+ * @param o_abs Normalized absolute residual vector (must be positive, length
44
+ * d)
45
+ * @param d Dimensionality of the vector
46
+ * @param nb_bits Number of bits per dimension (2-9)
47
+ * @return Optimal scaling factor 't'
48
+ */
49
+ float compute_optimal_scaling_factor(
50
+ const float* o_abs,
51
+ size_t d,
52
+ size_t nb_bits) {
53
+ const size_t ex_bits = nb_bits - 1;
54
+ FAISS_THROW_IF_NOT_MSG(
55
+ ex_bits >= 1 && ex_bits <= 8, "ex_bits must be in range [1, 8]");
56
+
57
+ const int kNEnum = 10;
58
+ const int max_code = (1 << ex_bits) - 1;
59
+
60
+ float max_o = *std::max_element(o_abs, o_abs + d);
61
+
62
+ // Determine search range [t_start, t_end]
63
+ float t_end = static_cast<float>(max_code + kNEnum) / max_o;
64
+ float t_start = t_end * kTightStart[ex_bits];
65
+
66
+ std::vector<float> inv_o_abs(d);
67
+ for (size_t i = 0; i < d; ++i) {
68
+ inv_o_abs[i] = 1.0f / o_abs[i];
69
+ }
70
+
71
+ std::vector<int> cur_o_bar(d);
72
+ float sqr_denominator = static_cast<float>(d) * 0.25f;
73
+ float numerator = 0.0f;
74
+
75
+ for (size_t i = 0; i < d; ++i) {
76
+ int cur = static_cast<int>((t_start * o_abs[i]) + kEps);
77
+ cur_o_bar[i] = cur;
78
+ sqr_denominator += static_cast<float>(cur * cur + cur);
79
+ numerator += (cur + 0.5f) * o_abs[i];
80
+ }
81
+
82
+ float inv_sqrt_denom = 1.0f / std::sqrt(sqr_denominator);
83
+
84
+ // Pair: (next_t, dimension_index)
85
+ // Maximum size is d (one entry per dimension), so reserve exactly d
86
+ std::vector<std::pair<float, size_t>> pq_storage;
87
+ pq_storage.reserve(d);
88
+ std::priority_queue<
89
+ std::pair<float, size_t>,
90
+ std::vector<std::pair<float, size_t>>,
91
+ std::greater<>>
92
+ next_t(std::greater<>(), std::move(pq_storage));
93
+
94
+ // Initialize queue with next quantization level for each dimension
95
+ for (size_t i = 0; i < d; ++i) {
96
+ float t_next = static_cast<float>(cur_o_bar[i] + 1) * inv_o_abs[i];
97
+ if (t_next < t_end) {
98
+ next_t.emplace(t_next, i);
99
+ }
100
+ }
101
+
102
+ float max_ip = 0.0f;
103
+ float t = 0.0f;
104
+
105
+ while (!next_t.empty()) {
106
+ float cur_t = next_t.top().first;
107
+ size_t update_id = next_t.top().second;
108
+ next_t.pop();
109
+
110
+ cur_o_bar[update_id]++;
111
+ int update_o_bar = cur_o_bar[update_id];
112
+
113
+ float delta = 2.0f * update_o_bar;
114
+ sqr_denominator += delta;
115
+ numerator += o_abs[update_id];
116
+
117
+ float old_denom = sqr_denominator - delta;
118
+ inv_sqrt_denom = inv_sqrt_denom *
119
+ (1.0f - 0.5f * delta / (old_denom + delta * 0.5f));
120
+
121
+ float cur_ip = numerator * inv_sqrt_denom;
122
+
123
+ if (cur_ip > max_ip) {
124
+ max_ip = cur_ip;
125
+ t = cur_t;
126
+ }
127
+
128
+ if (update_o_bar < max_code) {
129
+ float t_next =
130
+ static_cast<float>(update_o_bar + 1) * inv_o_abs[update_id];
131
+ if (t_next < t_end) {
132
+ next_t.emplace(t_next, update_id);
133
+ }
134
+ }
135
+ }
136
+
137
+ return t;
138
+ }
139
+
140
+ /**
141
+ * Pack multi-bit codes from integer array to byte array.
142
+ *
143
+ * @param tmp_code Integer codes (length d), each value in [0, 2^ex_bits - 1]
144
+ * @param ex_code Output packed byte array
145
+ * @param d Dimensionality
146
+ * @param nb_bits Number of bits per dimension (2-9)
147
+ */
148
+ void pack_multibit_codes(
149
+ const int* tmp_code,
150
+ uint8_t* ex_code,
151
+ size_t d,
152
+ size_t nb_bits) {
153
+ const size_t ex_bits = nb_bits - 1;
154
+ FAISS_THROW_IF_NOT_MSG(
155
+ ex_bits >= 1 && ex_bits <= 8, "ex_bits must be in range [1, 8]");
156
+
157
+ size_t total_bits = d * ex_bits;
158
+ size_t output_size = (total_bits + 7) / 8;
159
+ memset(ex_code, 0, output_size);
160
+
161
+ size_t bit_pos = 0;
162
+ for (size_t i = 0; i < d; i++) {
163
+ int code_value = tmp_code[i];
164
+
165
+ for (size_t bit = 0; bit < ex_bits; bit++) {
166
+ size_t byte_idx = bit_pos / 8;
167
+ size_t bit_idx = bit_pos % 8;
168
+
169
+ if (code_value & (1 << bit)) {
170
+ ex_code[byte_idx] |= (1 << bit_idx);
171
+ }
172
+
173
+ bit_pos++;
174
+ }
175
+ }
176
+ }
177
+
178
+ /**
179
+ * Compute ex-bits factors for distance computation.
180
+ *
181
+ * @param residual Original residual vector (data - centroid)
182
+ * @param centroid Centroid vector (can be nullptr for zero centroid)
183
+ * @param tmp_code Quantized ex-bit codes (before packing, after bit flipping)
184
+ * @param d Dimensionality
185
+ * @param ex_bits Number of extra bits
186
+ * @param norm L2 norm of residual
187
+ * @param ipnorm Unnormalized inner product between quantized and normalized
188
+ * residual
189
+ * @param ex_factors Output factors structure
190
+ * @param metric_type Distance metric (L2 or Inner Product)
191
+ */
192
+ void compute_ex_factors(
193
+ const float* residual,
194
+ const float* centroid,
195
+ const int* tmp_code,
196
+ size_t d,
197
+ size_t ex_bits,
198
+ float norm,
199
+ double ipnorm,
200
+ ExtraBitsFactors& ex_factors,
201
+ MetricType metric_type) {
202
+ FAISS_THROW_IF_NOT_MSG(
203
+ metric_type == MetricType::METRIC_L2 ||
204
+ metric_type == MetricType::METRIC_INNER_PRODUCT,
205
+ "Unsupported metric type");
206
+
207
+ // Compute ipnorm_inv = 1 / ipnorm
208
+ float ipnorm_inv = static_cast<float>(1.0 / ipnorm);
209
+ if (!std::isnormal(ipnorm_inv)) {
210
+ ipnorm_inv = 1.0f;
211
+ }
212
+
213
+ // Reconstruct xu_cb from total_code
214
+ // total_code was formed from: total_code[i] = (sign << ex_bits) +
215
+ // ex_code[i] Reconstruction: xu_cb[i] = total_code[i] + cb
216
+ const float cb = -(static_cast<float>(1 << ex_bits) - 0.5f);
217
+ std::vector<float> xu_cb(d);
218
+ for (size_t i = 0; i < d; i++) {
219
+ xu_cb[i] = static_cast<float>(tmp_code[i]) + cb;
220
+ }
221
+
222
+ // Compute inner products needed for factors
223
+ float l2_sqr = norm * norm;
224
+ float ip_resi_xucb = fvec_inner_product(residual, xu_cb.data(), d);
225
+
226
+ // Compute factors
227
+ if (metric_type == MetricType::METRIC_L2) {
228
+ // For L2, no centroid correction needed in IVF setting
229
+ // because residual = x - centroid, distance computed in residual space
230
+ ex_factors.f_add_ex = l2_sqr;
231
+ ex_factors.f_rescale_ex = ipnorm_inv * -2.0f * norm;
232
+ } else {
233
+ // For IP, centroid correction is needed
234
+ float ip_resi_cent = 0;
235
+ if (centroid != nullptr) {
236
+ ip_resi_cent = fvec_inner_product(residual, centroid, d);
237
+ }
238
+
239
+ float ip_cent_xucb = 0;
240
+ if (centroid != nullptr) {
241
+ ip_cent_xucb = fvec_inner_product(centroid, xu_cb.data(), d);
242
+ }
243
+
244
+ // When ip_resi_xucb is zero, the correction term should be zero
245
+ float correction_term = 0.0f;
246
+ if (ip_resi_xucb != 0.0f) {
247
+ correction_term = l2_sqr * ip_cent_xucb / ip_resi_xucb;
248
+ }
249
+
250
+ ex_factors.f_add_ex = 1 - ip_resi_cent + correction_term;
251
+ ex_factors.f_rescale_ex = ipnorm_inv * -norm;
252
+ }
253
+ }
254
+
255
+ /**
256
+ * Quantize residual vector to ex-bits.
257
+ *
258
+ * This is the main quantization function that:
259
+ * 1. Normalizes the residual
260
+ * 2. Takes absolute value
261
+ * 3. Finds optimal scaling factor
262
+ * 4. Quantizes to ex_bits
263
+ * 5. Handles negative dimensions by flipping bits
264
+ * 6. Packs codes into byte array
265
+ * 7. Computes factors for distance computation
266
+ *
267
+ * @param residual Input residual vector (data - centroid), length d
268
+ * @param d Dimensionality
269
+ * @param nb_bits Number of bits per dimension (2-9)
270
+ * @param ex_code Output packed ex-bit codes
271
+ * @param ex_factors Output ex-bits factors
272
+ * @param metric_type Distance metric (L2 or Inner Product)
273
+ * @param centroid Optional centroid vector (needed for IP metric)
274
+ */
275
+ void quantize_ex_bits(
276
+ const float* residual,
277
+ size_t d,
278
+ size_t nb_bits,
279
+ uint8_t* ex_code,
280
+ ExtraBitsFactors& ex_factors,
281
+ MetricType metric_type,
282
+ const float* centroid) {
283
+ const size_t ex_bits = nb_bits - 1;
284
+ FAISS_THROW_IF_NOT_MSG(
285
+ ex_bits >= 1 && ex_bits <= 8, "ex_bits must be in range [1, 8]");
286
+ FAISS_THROW_IF_NOT_MSG(residual != nullptr, "residual cannot be null");
287
+ FAISS_THROW_IF_NOT_MSG(ex_code != nullptr, "ex_code cannot be null");
288
+
289
+ // Step 1: Compute L2 norm of residual
290
+ float norm_sqr = fvec_norm_L2sqr(residual, d);
291
+ float norm = std::sqrt(norm_sqr);
292
+
293
+ // Handle degenerate case
294
+ if (norm < 1e-10f) {
295
+ size_t code_size = (d * ex_bits + 7) / 8;
296
+ memset(ex_code, 0, code_size);
297
+ ex_factors.f_add_ex = 0.0f;
298
+ ex_factors.f_rescale_ex = 1.0f;
299
+ return;
300
+ }
301
+
302
+ // Step 2: Normalize residual
303
+ std::vector<float> normalized_residual(d);
304
+ for (size_t i = 0; i < d; i++) {
305
+ normalized_residual[i] = residual[i] / norm;
306
+ }
307
+
308
+ // Step 3: Take absolute value
309
+ std::vector<float> o_abs(d);
310
+ for (size_t i = 0; i < d; i++) {
311
+ o_abs[i] = std::abs(normalized_residual[i]);
312
+ }
313
+
314
+ // Step 4: Find optimal scaling factor
315
+ float t = compute_optimal_scaling_factor(o_abs.data(), d, nb_bits);
316
+
317
+ // Step 5: Quantize to ex_bits
318
+ std::vector<int> tmp_code(d);
319
+ double ipnorm = 0;
320
+ int max_code = (1 << ex_bits) - 1;
321
+
322
+ for (size_t i = 0; i < d; i++) {
323
+ tmp_code[i] = std::min(static_cast<int>(t * o_abs[i] + kEps), max_code);
324
+ // Compute unnormalized inner product
325
+ ipnorm += (tmp_code[i] + 0.5) * o_abs[i];
326
+ }
327
+
328
+ // Step 6: Handle negative dimensions (flip bits)
329
+ // For negative residuals, flip all bits: code' = ~code & max_code
330
+ for (size_t i = 0; i < d; i++) {
331
+ if (residual[i] < 0) {
332
+ tmp_code[i] = (~tmp_code[i]) & max_code;
333
+ }
334
+ }
335
+
336
+ // Step 7: Pack codes into byte array
337
+ pack_multibit_codes(tmp_code.data(), ex_code, d, nb_bits);
338
+
339
+ // Step 8: Compute factors for distance computation
340
+ // Reconstruct total_code for factor computation
341
+ std::vector<int> total_code(d);
342
+ for (size_t i = 0; i < d; i++) {
343
+ // Form total_code = (sign << ex_bits) + ex_code
344
+ bool sign_bit = (residual[i] >= 0);
345
+ total_code[i] = tmp_code[i] + ((sign_bit ? 1 : 0) << ex_bits);
346
+ }
347
+
348
+ // Compute ex-factors; centroid is needed for IP metric correction
349
+ compute_ex_factors(
350
+ residual,
351
+ centroid, // Pass centroid for IP metric factor computation
352
+ total_code.data(),
353
+ d,
354
+ ex_bits,
355
+ norm,
356
+ ipnorm,
357
+ ex_factors,
358
+ metric_type);
359
+ }
360
+
361
+ } // namespace rabitq_multibit
362
+ } // namespace faiss
@@ -0,0 +1,112 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // Reference:
9
+ // "Practical and asymptotically optimal quantization of high-dimensional
10
+ // vectors in euclidean space for approximate nearest neighbor search"
11
+ // Jianyang Gao, Yutong Gou, Yuexuan Xu, Yongyi Yang, Cheng Long, Raymond
12
+ // Chi-Wing Wong https://dl.acm.org/doi/pdf/10.1145/3725413
13
+ //
14
+ // Reference implementation: https://github.com/VectorDB-NTU/RaBitQ-Library
15
+ // NOTE: Parts of this implementation are adapted from
16
+ // rabitqlib/quantization/rabitq_impl.hpp in the above repository.
17
+
18
+ #pragma once
19
+
20
+ #include <faiss/MetricType.h>
21
+ #include <faiss/impl/RaBitQUtils.h>
22
+ #include <cstddef>
23
+ #include <cstdint>
24
+
25
+ namespace faiss {
26
+ namespace rabitq_multibit {
27
+
28
+ /**
29
+ * Compute optimal scaling factor for ex-bits quantization.
30
+ *
31
+ * Uses priority queue-based search to find the scaling factor that
32
+ * maximizes the inner product between quantized and original vectors.
33
+ *
34
+ * @param o_abs Normalized absolute residual vector (positive values)
35
+ * @param d Dimensionality
36
+ * @param nb_bits Number of bits per dimension (2-9)
37
+ * @return Optimal scaling factor 't'
38
+ */
39
+ float compute_optimal_scaling_factor(
40
+ const float* o_abs,
41
+ size_t d,
42
+ size_t nb_bits);
43
+
44
+ /**
45
+ * Pack multi-bit codes from integer array to byte array.
46
+ *
47
+ * @param tmp_code Integer codes (length d), values in [0, 2^ex_bits - 1]
48
+ * @param ex_code Output packed byte array
49
+ * @param d Dimensionality
50
+ * @param nb_bits Number of bits per dimension (2-9)
51
+ */
52
+ void pack_multibit_codes(
53
+ const int* tmp_code,
54
+ uint8_t* ex_code,
55
+ size_t d,
56
+ size_t nb_bits);
57
+
58
+ /**
59
+ * Compute ex-bits factors for distance computation.
60
+ *
61
+ * @param residual Original residual vector (data - centroid)
62
+ * @param centroid Centroid vector (can be nullptr for zero centroid)
63
+ * @param tmp_code Quantized ex-bit codes (unpacked integers)
64
+ * @param d Dimensionality
65
+ * @param ex_bits Number of extra bits
66
+ * @param norm L2 norm of residual
67
+ * @param ipnorm Unnormalized inner product
68
+ * @param ex_factors Output factors structure
69
+ * @param metric_type Distance metric (L2 or IP)
70
+ */
71
+ void compute_ex_factors(
72
+ const float* residual,
73
+ const float* centroid,
74
+ const int* tmp_code,
75
+ size_t d,
76
+ size_t ex_bits,
77
+ float norm,
78
+ double ipnorm,
79
+ rabitq_utils::ExtraBitsFactors& ex_factors,
80
+ MetricType metric_type);
81
+
82
+ /**
83
+ * Main quantization function: quantize residual vector to ex-bits.
84
+ *
85
+ * Performs the complete multi-bit quantization pipeline:
86
+ * 1. Normalize residual
87
+ * 2. Take absolute value
88
+ * 3. Find optimal scaling factor
89
+ * 4. Quantize to ex_bits
90
+ * 5. Handle negative dimensions by bit flipping
91
+ * 6. Pack codes into byte array
92
+ * 7. Compute factors for distance computation
93
+ *
94
+ * @param residual Input residual vector (data - centroid), length d
95
+ * @param d Dimensionality
96
+ * @param nb_bits Number of bits per dimension (2-9)
97
+ * @param ex_code Output packed ex-bit codes
98
+ * @param ex_factors Output ex-bits factors
99
+ * @param metric_type Distance metric (L2 or Inner Product)
100
+ * @param centroid Optional centroid vector (needed for IP metric)
101
+ */
102
+ void quantize_ex_bits(
103
+ const float* residual,
104
+ size_t d,
105
+ size_t nb_bits,
106
+ uint8_t* ex_code,
107
+ rabitq_utils::ExtraBitsFactors& ex_factors,
108
+ MetricType metric_type,
109
+ const float* centroid = nullptr);
110
+
111
+ } // namespace rabitq_multibit
112
+ } // namespace faiss
@@ -1009,16 +1009,13 @@ void train_Uniform(
1009
1009
  } else if (rs == ScalarQuantizer::RS_quantiles) {
1010
1010
  std::vector<float> x_copy(n);
1011
1011
  memcpy(x_copy.data(), x, n * sizeof(*x));
1012
- // TODO just do a quickselect
1013
- std::sort(x_copy.begin(), x_copy.end());
1014
- int o = int(rs_arg * n);
1015
- if (o < 0) {
1016
- o = 0;
1017
- }
1018
- if (o > n - o) {
1019
- o = n / 2;
1020
- }
1012
+ int temp = int(rs_arg * n);
1013
+ int o = temp < 0 ? 0 : (temp > n / 2 ? n / 2 : temp);
1014
+
1015
+ std::nth_element(x_copy.begin(), x_copy.begin() + o, x_copy.end());
1021
1016
  vmin = x_copy[o];
1017
+ std::nth_element(
1018
+ x_copy.begin(), x_copy.begin() + (n - 1 - o), x_copy.end());
1022
1019
  vmax = x_copy[n - 1 - o];
1023
1020
 
1024
1021
  } else if (rs == ScalarQuantizer::RS_optim) {
@@ -98,9 +98,7 @@ struct ScalarQuantizer : Quantizer {
98
98
  SQuantizer* select_quantizer() const;
99
99
 
100
100
  struct SQDistanceComputer : FlatCodesDistanceComputer {
101
- const float* q;
102
-
103
- SQDistanceComputer() : q(nullptr) {}
101
+ SQDistanceComputer() : FlatCodesDistanceComputer(nullptr) {}
104
102
 
105
103
  virtual float query_to_code(const uint8_t* code) const = 0;
106
104