faiss 0.5.0 → 0.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +2 -0
  4. data/ext/faiss/index.cpp +8 -0
  5. data/lib/faiss/version.rb +1 -1
  6. data/vendor/faiss/faiss/IVFlib.cpp +25 -49
  7. data/vendor/faiss/faiss/Index.cpp +11 -0
  8. data/vendor/faiss/faiss/Index.h +24 -1
  9. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +1 -0
  10. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  11. data/vendor/faiss/faiss/IndexFastScan.cpp +1 -1
  12. data/vendor/faiss/faiss/IndexFastScan.h +3 -8
  13. data/vendor/faiss/faiss/IndexFlat.cpp +374 -4
  14. data/vendor/faiss/faiss/IndexFlat.h +80 -0
  15. data/vendor/faiss/faiss/IndexHNSW.cpp +90 -1
  16. data/vendor/faiss/faiss/IndexHNSW.h +57 -1
  17. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +34 -149
  18. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +86 -2
  19. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +3 -1
  20. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +293 -115
  21. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +52 -16
  22. data/vendor/faiss/faiss/IndexPQ.cpp +4 -1
  23. data/vendor/faiss/faiss/IndexPreTransform.cpp +14 -0
  24. data/vendor/faiss/faiss/IndexPreTransform.h +9 -0
  25. data/vendor/faiss/faiss/IndexRaBitQ.cpp +96 -16
  26. data/vendor/faiss/faiss/IndexRaBitQ.h +5 -1
  27. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +238 -93
  28. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +35 -9
  29. data/vendor/faiss/faiss/IndexRefine.cpp +49 -0
  30. data/vendor/faiss/faiss/IndexRefine.h +17 -0
  31. data/vendor/faiss/faiss/clone_index.cpp +2 -0
  32. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +3 -1
  33. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +1 -1
  34. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +1 -1
  35. data/vendor/faiss/faiss/impl/DistanceComputer.h +74 -3
  36. data/vendor/faiss/faiss/impl/HNSW.cpp +294 -15
  37. data/vendor/faiss/faiss/impl/HNSW.h +31 -2
  38. data/vendor/faiss/faiss/impl/IDSelector.h +3 -3
  39. data/vendor/faiss/faiss/impl/Panorama.cpp +193 -0
  40. data/vendor/faiss/faiss/impl/Panorama.h +204 -0
  41. data/vendor/faiss/faiss/impl/RaBitQStats.cpp +29 -0
  42. data/vendor/faiss/faiss/impl/RaBitQStats.h +56 -0
  43. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +54 -6
  44. data/vendor/faiss/faiss/impl/RaBitQUtils.h +183 -6
  45. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +269 -84
  46. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +71 -4
  47. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +362 -0
  48. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +112 -0
  49. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +6 -9
  50. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +1 -3
  51. data/vendor/faiss/faiss/impl/index_read.cpp +156 -12
  52. data/vendor/faiss/faiss/impl/index_write.cpp +142 -19
  53. data/vendor/faiss/faiss/impl/platform_macros.h +12 -0
  54. data/vendor/faiss/faiss/impl/svs_io.cpp +86 -0
  55. data/vendor/faiss/faiss/impl/svs_io.h +67 -0
  56. data/vendor/faiss/faiss/index_factory.cpp +182 -15
  57. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +1 -1
  58. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  59. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +18 -109
  60. data/vendor/faiss/faiss/invlists/InvertedLists.h +2 -18
  61. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
  62. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  63. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +261 -0
  64. data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +117 -0
  65. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +66 -0
  66. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +245 -0
  67. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +137 -0
  68. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +39 -0
  69. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +42 -0
  70. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +149 -0
  71. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +58 -0
  72. data/vendor/faiss/faiss/utils/distances.cpp +0 -3
  73. data/vendor/faiss/faiss/utils/utils.cpp +4 -0
  74. metadata +18 -1
@@ -9,6 +9,7 @@
9
9
 
10
10
  #include <faiss/impl/FaissAssert.h>
11
11
  #include <faiss/impl/RaBitQUtils.h>
12
+ #include <faiss/impl/RaBitQuantizerMultiBit.h>
12
13
  #include <faiss/utils/distances.h>
13
14
  #include <faiss/utils/rabitq_simd.h>
14
15
  #include <algorithm>
@@ -20,15 +21,47 @@
20
21
  namespace faiss {
21
22
 
22
23
  // Import shared utilities from RaBitQUtils
23
- using rabitq_utils::FactorsData;
24
+ using rabitq_utils::ExtraBitsFactors;
24
25
  using rabitq_utils::QueryFactorsData;
25
-
26
- static size_t get_code_size(const size_t d) {
27
- return (d + 7) / 8 + sizeof(FactorsData);
26
+ using rabitq_utils::SignBitFactors;
27
+ using rabitq_utils::SignBitFactorsWithError;
28
+
29
+ RaBitQuantizer::RaBitQuantizer(size_t d, MetricType metric, size_t nb_bits)
30
+ : Quantizer(d, 0), // code_size will be set below
31
+ metric_type{metric},
32
+ nb_bits{nb_bits} {
33
+ // Validate nb_bits range
34
+ FAISS_THROW_IF_NOT(nb_bits >= 1 && nb_bits <= 9);
35
+
36
+ // Set code_size using compute_code_size
37
+ code_size = compute_code_size(d, nb_bits);
28
38
  }
29
39
 
30
- RaBitQuantizer::RaBitQuantizer(size_t d, MetricType metric)
31
- : Quantizer(d, get_code_size(d)), metric_type{metric} {}
40
+ size_t RaBitQuantizer::compute_code_size(size_t d, size_t num_bits) const {
41
+ // Validate inputs
42
+ FAISS_THROW_IF_NOT(num_bits >= 1 && num_bits <= 9);
43
+
44
+ size_t ex_bits = num_bits - 1;
45
+
46
+ // Base: 1-bit codes + base factors
47
+ // Layout for 1-bit: [binary_code: (d+7)/8 bytes][SignBitFactors: 8 bytes]
48
+ // base_factors = or_minus_c_l2sqr (4) + dp_multiplier (4)
49
+ // Layout for multi-bit: [binary_code: (d+7)/8
50
+ // bytes][SignBitFactorsWithError: 12 bytes]
51
+ // factors = or_minus_c_l2sqr (4) + dp_multiplier (4) + f_error (4)
52
+ size_t base_size = (d + 7) / 8 +
53
+ (ex_bits == 0 ? sizeof(SignBitFactors)
54
+ : sizeof(SignBitFactorsWithError));
55
+
56
+ // Extra: ex-bit codes + ex factors (only if ex_bits > 0)
57
+ // Layout: [ex_code: (d*ex_bits+7)/8 bytes][ex_factors: 8 bytes]
58
+ size_t ex_size = 0;
59
+ if (ex_bits > 0) {
60
+ ex_size = (d * ex_bits + 7) / 8 + sizeof(ExtraBitsFactors);
61
+ }
62
+
63
+ return base_size + ex_size;
64
+ }
32
65
 
33
66
  void RaBitQuantizer::train(size_t n, const float* x) {
34
67
  // does nothing
@@ -54,23 +87,49 @@ void RaBitQuantizer::compute_codes_core(
54
87
  return;
55
88
  }
56
89
 
57
- // compute codes
90
+ const size_t ex_bits = nb_bits - 1;
91
+
92
+ // Compute codes
58
93
  #pragma omp parallel for if (n > 1000)
59
94
  for (int64_t i = 0; i < n; i++) {
60
- // the code
95
+ // Pointer to this vector's code
61
96
  uint8_t* code = codes + i * code_size;
62
- FactorsData* fac = reinterpret_cast<FactorsData*>(code + (d + 7) / 8);
63
97
 
64
- // cleanup it
65
- if (code != nullptr) {
66
- memset(code, 0, code_size);
67
- }
98
+ // Clear code memory
99
+ memset(code, 0, code_size);
68
100
 
69
101
  const float* x_row = x + i * d;
70
102
 
103
+ // Pointer arithmetic for code layout:
104
+ // For 1-bit: [binary_code: (d+7)/8 bytes][SignBitFactors: 8 bytes]
105
+ // For multi-bit: [binary_code: (d+7)/8 bytes][SignBitFactorsWithError:
106
+ // 12 bytes]
107
+ // [ex_code: (d*ex_bits+7)/8 bytes][ex_factors: 8 bytes]
108
+ uint8_t* binary_code = code;
109
+
110
+ // Step 1: Compute 1-bit quantization and base factors
111
+ // Store residual for potential ex-bits quantization
112
+ std::vector<float> residual(d);
113
+
71
114
  // Use shared utilities for computing factors
72
- *fac = rabitq_utils::compute_vector_factors(
73
- x_row, d, centroid_in, metric_type);
115
+ SignBitFactorsWithError factors_data =
116
+ rabitq_utils::compute_vector_factors(
117
+ x_row, d, centroid_in, metric_type, ex_bits > 0);
118
+
119
+ // Write appropriate factors based on nb_bits
120
+ if (ex_bits == 0) {
121
+ // For 1-bit: write only SignBitFactors (8 bytes)
122
+ SignBitFactors* base_factors =
123
+ reinterpret_cast<SignBitFactors*>(code + (d + 7) / 8);
124
+ base_factors->or_minus_c_l2sqr = factors_data.or_minus_c_l2sqr;
125
+ base_factors->dp_multiplier = factors_data.dp_multiplier;
126
+ } else {
127
+ // For multi-bit: write full SignBitFactorsWithError (12 bytes)
128
+ SignBitFactorsWithError* full_factors =
129
+ reinterpret_cast<SignBitFactorsWithError*>(
130
+ code + (d + 7) / 8);
131
+ *full_factors = factors_data;
132
+ }
74
133
 
75
134
  // Pack bits into standard RaBitQ format
76
135
  for (size_t j = 0; j < d; j++) {
@@ -78,13 +137,35 @@ void RaBitQuantizer::compute_codes_core(
78
137
  const float centroid_val =
79
138
  (centroid_in == nullptr) ? 0.0f : centroid_in[j];
80
139
  const float or_minus_c = x_val - centroid_val;
140
+ residual[j] = or_minus_c;
141
+
81
142
  const bool xb = (or_minus_c > 0.0f);
82
143
 
83
- // store the output data
84
- if (code != nullptr && xb) {
85
- rabitq_utils::set_bit_standard(code, j);
144
+ // Store the 1-bit sign code
145
+ if (xb) {
146
+ rabitq_utils::set_bit_standard(binary_code, j);
86
147
  }
87
148
  }
149
+
150
+ // Step 2: Compute ex-bits quantization (if nb_bits > 1)
151
+ if (ex_bits > 0) {
152
+ // Pointer to ex-bit code section
153
+ uint8_t* ex_code =
154
+ code + (d + 7) / 8 + sizeof(SignBitFactorsWithError);
155
+ // Pointer to ex-factors section
156
+ ExtraBitsFactors* ex_factors = reinterpret_cast<ExtraBitsFactors*>(
157
+ ex_code + (d * ex_bits + 7) / 8);
158
+
159
+ // Quantize residual to ex-bits (pass centroid for IP metric)
160
+ rabitq_multibit::quantize_ex_bits(
161
+ residual.data(),
162
+ d,
163
+ nb_bits,
164
+ ex_code,
165
+ *ex_factors,
166
+ metric_type,
167
+ centroid_in);
168
+ }
88
169
  }
89
170
  }
90
171
 
@@ -101,6 +182,7 @@ void RaBitQuantizer::decode_core(
101
182
  FAISS_ASSERT(x != nullptr);
102
183
 
103
184
  const float inv_d_sqrt = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
185
+ const size_t ex_bits = nb_bits - 1;
104
186
 
105
187
  #pragma omp parallel for if (n > 1000)
106
188
  for (int64_t i = 0; i < n; i++) {
@@ -108,10 +190,19 @@ void RaBitQuantizer::decode_core(
108
190
 
109
191
  // split the code into parts
110
192
  const uint8_t* binary_data = code;
111
- const FactorsData* fac =
112
- reinterpret_cast<const FactorsData*>(code + (d + 7) / 8);
113
193
 
194
+ // Cast to appropriate type based on nb_bits
195
+ // For 1-bit: use SignBitFactors (8 bytes)
196
+ // For multi-bit: use SignBitFactorsWithError (12 bytes, but only first
197
+ // 8 bytes used for decode)
198
+ const SignBitFactors* fac = (ex_bits == 0)
199
+ ? reinterpret_cast<const SignBitFactors*>(code + (d + 7) / 8)
200
+ : reinterpret_cast<const SignBitFactorsWithError*>(
201
+ code + (d + 7) / 8);
202
+
203
+ // this is the baseline code
114
204
  //
205
+ // compute <q,o> using floats
115
206
  for (size_t j = 0; j < d; j++) {
116
207
  // extract i-th bit
117
208
  const uint8_t masker = (1 << (j % 8));
@@ -124,51 +215,69 @@ void RaBitQuantizer::decode_core(
124
215
  }
125
216
  }
126
217
 
127
- struct RaBitDistanceComputer : FlatCodesDistanceComputer {
128
- // dimensionality
129
- size_t d = 0;
130
- // a centroid to use
131
- const float* centroid = nullptr;
218
+ // Implementation of RaBitQDistanceComputer (declared in header)
132
219
 
133
- // the metric
134
- MetricType metric_type = MetricType::METRIC_L2;
220
+ float RaBitQDistanceComputer::lower_bound_distance(const uint8_t* code) {
221
+ FAISS_ASSERT(code != nullptr);
135
222
 
136
- RaBitDistanceComputer();
223
+ // Compute estimated distance using 1-bit codes
224
+ float est_distance = distance_to_code_1bit(code);
137
225
 
138
- float symmetric_dis(idx_t i, idx_t j) override;
139
- };
226
+ // Extract f_error from the code
227
+ size_t size = (d + 7) / 8;
228
+ const SignBitFactorsWithError* base_fac =
229
+ reinterpret_cast<const SignBitFactorsWithError*>(code + size);
230
+ float f_error = base_fac->f_error;
140
231
 
141
- RaBitDistanceComputer::RaBitDistanceComputer() = default;
232
+ // Compute proper lower bound using RaBitQ error formula:
233
+ // lower_bound = est_distance - f_error * g_error
234
+ // This guarantees: lower_bound ≤ true_distance
235
+ float lower_bound = est_distance - (f_error * g_error);
142
236
 
143
- float RaBitDistanceComputer::symmetric_dis(idx_t i, idx_t j) {
144
- FAISS_THROW_MSG("Not implemented");
237
+ // Distance cannot be negative
238
+ return std::max(0.0f, lower_bound);
145
239
  }
146
240
 
147
- struct RaBitDistanceComputerNotQ : RaBitDistanceComputer {
241
+ namespace {
242
+
243
+ struct RaBitQDistanceComputerNotQ : RaBitQDistanceComputer {
148
244
  // the rotated query (qr - c)
149
245
  std::vector<float> rotated_q;
150
246
  // some additional numbers for the query
151
247
  QueryFactorsData query_fac;
152
248
 
153
- RaBitDistanceComputerNotQ();
249
+ RaBitQDistanceComputerNotQ();
154
250
 
155
- float distance_to_code(const uint8_t* code) override;
251
+ // Compute distance using only 1-bit codes (fast)
252
+ float distance_to_code_1bit(const uint8_t* code) override;
253
+
254
+ // Compute full distance using 1-bit + ex-bits (accurate)
255
+ float distance_to_code_full(const uint8_t* code) override;
156
256
 
157
257
  void set_query(const float* x) override;
158
258
  };
159
259
 
160
- RaBitDistanceComputerNotQ::RaBitDistanceComputerNotQ() = default;
260
+ RaBitQDistanceComputerNotQ::RaBitQDistanceComputerNotQ() = default;
161
261
 
162
- float RaBitDistanceComputerNotQ::distance_to_code(const uint8_t* code) {
262
+ float RaBitQDistanceComputerNotQ::distance_to_code_1bit(const uint8_t* code) {
163
263
  FAISS_ASSERT(code != nullptr);
164
264
  FAISS_ASSERT(
165
265
  (metric_type == MetricType::METRIC_L2 ||
166
266
  metric_type == MetricType::METRIC_INNER_PRODUCT));
267
+ FAISS_ASSERT(rotated_q.size() == d);
167
268
 
168
269
  // split the code into parts
169
270
  const uint8_t* binary_data = code;
170
- const FactorsData* fac =
171
- reinterpret_cast<const FactorsData*>(code + (d + 7) / 8);
271
+
272
+ // Cast to appropriate type based on nb_bits
273
+ // For 1-bit: use SignBitFactors (8 bytes)
274
+ // For multi-bit: use SignBitFactorsWithError (12 bytes) which includes
275
+ // f_error
276
+ size_t ex_bits = nb_bits - 1;
277
+ const SignBitFactors* base_fac = (ex_bits == 0)
278
+ ? reinterpret_cast<const SignBitFactors*>(code + (d + 7) / 8)
279
+ : reinterpret_cast<const SignBitFactorsWithError*>(
280
+ code + (d + 7) / 8);
172
281
 
173
282
  // this is the baseline code
174
283
  //
@@ -177,48 +286,70 @@ float RaBitDistanceComputerNotQ::distance_to_code(const uint8_t* code) {
177
286
  // It was a willful decision (after the discussion) to not to pre-cache
178
287
  // the sum of all bits, just in order to reduce the overhead per vector.
179
288
  uint64_t sum_q = 0;
180
- for (size_t i = 0; i < d; i++) {
181
- // extract i-th bit
182
- const uint8_t masker = (1 << (i % 8));
183
- const bool b_bit = ((binary_data[i / 8] & masker) == masker);
184
289
 
290
+ for (size_t i = 0; i < d; i++) {
291
+ // Extract i-th bit
292
+ bool bit = rabitq_utils::extract_bit_standard(binary_data, i);
185
293
  // accumulate dp
186
- dot_qo += (b_bit) ? rotated_q[i] : 0;
294
+ dot_qo += bit ? rotated_q[i] : 0;
187
295
  // accumulate sum-of-bits
188
- sum_q += (b_bit) ? 1 : 0;
296
+ sum_q += bit ? 1 : 0;
189
297
  }
190
298
 
191
- float final_dot = 0;
192
- // dot-product itself
193
- final_dot += query_fac.c1 * dot_qo;
194
- // normalizer coefficients
195
- final_dot += query_fac.c2 * sum_q;
196
- // normalizer coefficients
197
- final_dot -= query_fac.c34;
198
-
199
- // this is ||or - c||^2 - (IP ? ||or||^2 : 0)
200
- const float or_c_l2sqr = fac->or_minus_c_l2sqr;
299
+ // Apply query factors
300
+ float final_dot =
301
+ query_fac.c1 * dot_qo + query_fac.c2 * sum_q - query_fac.c34;
201
302
 
202
303
  // pre_dist = ||or - c||^2 + ||qr - c||^2 -
203
304
  // 2 * ||or - c|| * ||qr - c|| * <q,o> - (IP ? ||or||^2 : 0)
204
- const float pre_dist = or_c_l2sqr + query_fac.qr_to_c_L2sqr -
205
- 2 * fac->dp_multiplier * final_dot;
305
+ float pre_dist = base_fac->or_minus_c_l2sqr + query_fac.qr_to_c_L2sqr -
306
+ 2 * base_fac->dp_multiplier * final_dot;
206
307
 
207
308
  if (metric_type == MetricType::METRIC_L2) {
208
309
  // ||or - q||^ 2
209
310
  return pre_dist;
210
311
  } else {
211
312
  // metric == MetricType::METRIC_INNER_PRODUCT
313
+ return -0.5f * (pre_dist - query_fac.qr_norm_L2sqr);
314
+ }
315
+ }
212
316
 
213
- // this is ||q||^2
214
- const float query_norm_sqr = query_fac.qr_norm_L2sqr;
317
+ float RaBitQDistanceComputerNotQ::distance_to_code_full(const uint8_t* code) {
318
+ FAISS_ASSERT(code != nullptr);
319
+ FAISS_ASSERT(
320
+ (metric_type == MetricType::METRIC_L2 ||
321
+ metric_type == MetricType::METRIC_INNER_PRODUCT));
322
+ FAISS_ASSERT(rotated_q.size() == d);
215
323
 
216
- // 2 * (or, q) = (||or - q||^2 - ||q||^2 - ||or||^2)
217
- return -0.5f * (pre_dist - query_norm_sqr);
324
+ size_t ex_bits = nb_bits - 1;
325
+
326
+ if (ex_bits == 0) {
327
+ // No ex-bits, just return 1-bit distance
328
+ return distance_to_code_1bit(code);
218
329
  }
330
+
331
+ // Extract pointers to code sections
332
+ const uint8_t* binary_data = code;
333
+ size_t offset = (d + 7) / 8 + sizeof(SignBitFactorsWithError);
334
+ const uint8_t* ex_code = code + offset;
335
+ const ExtraBitsFactors* ex_fac = reinterpret_cast<const ExtraBitsFactors*>(
336
+ ex_code + (d * ex_bits + 7) / 8);
337
+
338
+ // Call shared utility directly with rotated_q pointer
339
+ return rabitq_utils::compute_full_multibit_distance(
340
+ binary_data,
341
+ ex_code,
342
+ *ex_fac,
343
+ rotated_q.data(),
344
+ query_fac.qr_to_c_L2sqr,
345
+ query_fac.qr_norm_L2sqr,
346
+ d,
347
+ ex_bits,
348
+ metric_type);
219
349
  }
220
350
 
221
- void RaBitDistanceComputerNotQ::set_query(const float* x) {
351
+ void RaBitQDistanceComputerNotQ::set_query(const float* x) {
352
+ q = x;
222
353
  FAISS_ASSERT(x != nullptr);
223
354
  FAISS_ASSERT(
224
355
  (metric_type == MetricType::METRIC_L2 ||
@@ -237,6 +368,10 @@ void RaBitDistanceComputerNotQ::set_query(const float* x) {
237
368
  rotated_q[i] = x[i] - ((centroid == nullptr) ? 0 : centroid[i]);
238
369
  }
239
370
 
371
+ // Compute g_error (query norm for lower bound computation)
372
+ // g_error = ||qr - c|| (L2 norm of rotated query)
373
+ g_error = std::sqrt(query_fac.qr_to_c_L2sqr);
374
+
240
375
  // compute some numbers
241
376
  const float inv_d = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
242
377
 
@@ -257,8 +392,10 @@ void RaBitDistanceComputerNotQ::set_query(const float* x) {
257
392
  }
258
393
 
259
394
  //
260
- struct RaBitDistanceComputerQ : RaBitDistanceComputer {
395
+ struct RaBitQDistanceComputerQ : RaBitQDistanceComputer {
261
396
  // the rotated and quantized query (qr - c)
397
+ std::vector<float> rotated_q;
398
+ // the rotated and quantized query (qr - c) for fast 1-bit computation
262
399
  std::vector<uint8_t> rotated_qq;
263
400
  // we're using the proposed relayout-ed scheme from 3.3 that allows
264
401
  // using popcounts for computing the distance.
@@ -272,16 +409,20 @@ struct RaBitDistanceComputerQ : RaBitDistanceComputer {
272
409
  // the smallest value divisible by 8 that is not smaller than dim
273
410
  size_t popcount_aligned_dim = 0;
274
411
 
275
- RaBitDistanceComputerQ();
412
+ RaBitQDistanceComputerQ();
276
413
 
277
- float distance_to_code(const uint8_t* code) override;
414
+ // Compute distance using only 1-bit codes (fast)
415
+ float distance_to_code_1bit(const uint8_t* code) override;
416
+
417
+ // Compute full distance using 1-bit + ex-bits (accurate)
418
+ float distance_to_code_full(const uint8_t* code) override;
278
419
 
279
420
  void set_query(const float* x) override;
280
421
  };
281
422
 
282
- RaBitDistanceComputerQ::RaBitDistanceComputerQ() = default;
423
+ RaBitQDistanceComputerQ::RaBitQDistanceComputerQ() = default;
283
424
 
284
- float RaBitDistanceComputerQ::distance_to_code(const uint8_t* code) {
425
+ float RaBitQDistanceComputerQ::distance_to_code_1bit(const uint8_t* code) {
285
426
  FAISS_ASSERT(code != nullptr);
286
427
  FAISS_ASSERT(
287
428
  (metric_type == MetricType::METRIC_L2 ||
@@ -290,21 +431,28 @@ float RaBitDistanceComputerQ::distance_to_code(const uint8_t* code) {
290
431
  // split the code into parts
291
432
  size_t size = (d + 7) / 8;
292
433
  const uint8_t* binary_data = code;
293
- const FactorsData* fac = reinterpret_cast<const FactorsData*>(code + size);
434
+
435
+ // Cast to appropriate type based on nb_bits
436
+ // For 1-bit: use SignBitFactors (8 bytes)
437
+ // For multi-bit: use SignBitFactorsWithError (12 bytes) which includes
438
+ // f_error
439
+ size_t ex_bits = nb_bits - 1;
440
+ const SignBitFactors* base_fac = (ex_bits == 0)
441
+ ? reinterpret_cast<const SignBitFactors*>(code + size)
442
+ : reinterpret_cast<const SignBitFactorsWithError*>(code + size);
294
443
 
295
444
  // this is ||or - c||^2 - (IP ? ||or||^2 : 0)
296
445
  float final_dot = 0;
297
446
  if (centered) {
298
447
  int64_t int_dot = ((1 << qb) - 1) * d;
448
+ // See RaBitDistanceComputerNotQ::distance_to_code() for baseline code.
299
449
  int_dot -= 2 *
300
450
  rabitq::bitwise_xor_dot_product(
301
451
  rearranged_rotated_qq.data(), binary_data, size, qb);
302
452
  final_dot += int_dot * query_fac.int_dot_scale;
303
453
  } else {
304
- // See RaBitDistanceComputerNotQ::distance_to_code() for baseline code.
305
454
  auto dot_qo = rabitq::bitwise_and_dot_product(
306
455
  rearranged_rotated_qq.data(), binary_data, size, qb);
307
-
308
456
  // It was a willful decision (after the discussion) to not to pre-cache
309
457
  // the sum of all bits, just in order to reduce the overhead per vector.
310
458
  // process 64-bit popcounts
@@ -317,32 +465,60 @@ float RaBitDistanceComputerQ::distance_to_code(const uint8_t* code) {
317
465
  final_dot -= query_fac.c34;
318
466
  }
319
467
 
320
- // this is ||or - c||^2 - (IP ? ||or||^2 : 0)
321
- const float or_c_l2sqr = fac->or_minus_c_l2sqr;
322
-
323
468
  // pre_dist = ||or - c||^2 + ||qr - c||^2 -
324
469
  // 2 * ||or - c|| * ||qr - c|| * <q,o> - (IP ? ||or||^2 : 0)
325
- const float pre_dist = or_c_l2sqr + query_fac.qr_to_c_L2sqr -
326
- 2 * fac->dp_multiplier * final_dot;
470
+ const float pre_dist = base_fac->or_minus_c_l2sqr +
471
+ query_fac.qr_to_c_L2sqr - 2 * base_fac->dp_multiplier * final_dot;
327
472
 
328
473
  if (metric_type == MetricType::METRIC_L2) {
329
474
  // ||or - q||^ 2
330
475
  return pre_dist;
331
476
  } else {
332
477
  // metric == MetricType::METRIC_INNER_PRODUCT
478
+ // 2 * (or, q) = (||or - q||^2 - ||q||^2 - ||or||^2)
479
+ return -0.5f * (pre_dist - query_fac.qr_norm_L2sqr);
480
+ }
481
+ }
482
+
483
+ float RaBitQDistanceComputerQ::distance_to_code_full(const uint8_t* code) {
484
+ FAISS_ASSERT(code != nullptr);
485
+ FAISS_ASSERT(
486
+ (metric_type == MetricType::METRIC_L2 ||
487
+ metric_type == MetricType::METRIC_INNER_PRODUCT));
488
+ FAISS_ASSERT(rotated_q.size() == d);
333
489
 
334
- // this is ||q||^2
335
- const float query_norm_sqr = query_fac.qr_norm_L2sqr;
490
+ size_t ex_bits = nb_bits - 1;
336
491
 
337
- // 2 * (or, q) = (||or - q||^2 - ||q||^2 - ||or||^2)
338
- return -0.5f * (pre_dist - query_norm_sqr);
492
+ if (ex_bits == 0) {
493
+ // No ex-bits, just return 1-bit distance
494
+ return distance_to_code_1bit(code);
339
495
  }
496
+
497
+ // Extract pointers to code sections
498
+ const uint8_t* binary_data = code;
499
+ size_t offset = (d + 7) / 8 + sizeof(SignBitFactorsWithError);
500
+ const uint8_t* ex_code = code + offset;
501
+ const ExtraBitsFactors* ex_fac = reinterpret_cast<const ExtraBitsFactors*>(
502
+ ex_code + (d * ex_bits + 7) / 8);
503
+
504
+ // Call shared utility directly with rotated_q pointer
505
+ return rabitq_utils::compute_full_multibit_distance(
506
+ binary_data,
507
+ ex_code,
508
+ *ex_fac,
509
+ rotated_q.data(),
510
+ query_fac.qr_to_c_L2sqr,
511
+ query_fac.qr_norm_L2sqr,
512
+ d,
513
+ ex_bits,
514
+ metric_type);
340
515
  }
341
516
 
342
517
  // Use shared constant from RaBitQUtils
343
518
  using rabitq_utils::Z_MAX_BY_QB;
344
519
 
345
- void RaBitDistanceComputerQ::set_query(const float* x) {
520
+ void RaBitQDistanceComputerQ::set_query(const float* x) {
521
+ q = x;
346
522
  FAISS_ASSERT(x != nullptr);
347
523
  FAISS_ASSERT(
348
524
  (metric_type == MetricType::METRIC_L2 ||
@@ -351,10 +527,15 @@ void RaBitDistanceComputerQ::set_query(const float* x) {
351
527
  FAISS_THROW_IF_NOT(qb > 0);
352
528
 
353
529
  // Use shared utilities for core query factor computation
354
- std::vector<float> rotated_q;
530
+ // rotated_q is populated directly by compute_query_factors as an output
531
+ // parameter
355
532
  query_fac = rabitq_utils::compute_query_factors(
356
533
  x, d, centroid, qb, centered, metric_type, rotated_q, rotated_qq);
357
534
 
535
+ // Compute g_error (query norm for lower bound computation)
536
+ // g_error = ||qr - c|| (L2 norm of rotated query)
537
+ g_error = std::sqrt(query_fac.qr_to_c_L2sqr);
538
+
358
539
  // Rearrange the query vector for SIMD operations (RaBitQuantizer-specific)
359
540
  popcount_aligned_dim = ((d + 7) / 8) * 8;
360
541
  size_t offset = (d + 7) / 8;
@@ -371,24 +552,28 @@ void RaBitDistanceComputerQ::set_query(const float* x) {
371
552
  }
372
553
  }
373
554
 
555
+ } // anonymous namespace
556
+
374
557
  FlatCodesDistanceComputer* RaBitQuantizer::get_distance_computer(
375
558
  uint8_t qb,
376
559
  const float* centroid_in,
377
560
  bool centered) const {
378
561
  if (qb == 0) {
379
- auto dc = std::make_unique<RaBitDistanceComputerNotQ>();
562
+ auto dc = std::make_unique<RaBitQDistanceComputerNotQ>();
380
563
  dc->metric_type = metric_type;
381
564
  dc->d = d;
382
565
  dc->centroid = centroid_in;
566
+ dc->nb_bits = nb_bits;
383
567
 
384
568
  return dc.release();
385
569
  } else {
386
- auto dc = std::make_unique<RaBitDistanceComputerQ>();
570
+ auto dc = std::make_unique<RaBitQDistanceComputerQ>();
387
571
  dc->metric_type = metric_type;
388
572
  dc->d = d;
389
573
  dc->centroid = centroid_in;
390
574
  dc->qb = qb;
391
575
  dc->centered = centered;
576
+ dc->nb_bits = nb_bits;
392
577
 
393
578
  return dc.release();
394
579
  }
@@ -37,11 +37,28 @@ struct RaBitQuantizer : Quantizer {
37
37
  // possible. Thus, a quantizer has to introduce a metric.
38
38
  MetricType metric_type = MetricType::METRIC_L2;
39
39
 
40
- RaBitQuantizer(size_t d = 0, MetricType metric = MetricType::METRIC_L2);
40
+ // Number of bits per dimension (1-9). Default is 1 for backward
41
+ // compatibility.
42
+ // - nb_bits = 1: standard 1-bit RaBitQ (sign bits only)
43
+ // - nb_bits = 2-9: multi-bit RaBitQ (1 sign bit + ex_bits extra bits)
44
+ size_t nb_bits = 1;
45
+
46
+ RaBitQuantizer(
47
+ size_t d = 0,
48
+ MetricType metric = MetricType::METRIC_L2,
49
+ size_t nb_bits = 1);
50
+
51
+ // Compute code size based on dimensionality and number of bits
52
+ // Returns: size in bytes for one encoded vector
53
+ // - nb_bits=1: (d+7)/8 + 8 bytes (1-bit codes + base factors)
54
+ // - nb_bits>1: (d+7)/8 + 8 + d*ex_bits/8 + 8 bytes
55
+ // (1-bit codes + base factors + ex-bit codes + ex factors)
56
+ size_t compute_code_size(size_t d, size_t num_bits) const;
41
57
 
42
58
  void train(size_t n, const float* x) override;
43
59
 
44
- // every vector is expected to take (d + 7) / 8 + sizeof(FactorsData) bytes,
60
+ // every vector is expected to take (d + 7) / 8 + sizeof(SignBitFactors)
61
+ // bytes,
45
62
  void compute_codes(const float* x, uint8_t* codes, size_t n) const override;
46
63
 
47
64
  void compute_codes_core(
@@ -71,9 +88,59 @@ struct RaBitQuantizer : Quantizer {
71
88
  // specify qb = 0 to get an DC that does not quantize a query
72
89
  // specify qb > 0 to have SQ qb-bits query
73
90
  FlatCodesDistanceComputer* get_distance_computer(
74
- uint8_t qb,
75
- const float* centroid_in = nullptr,
91
+ uint8_t qb = 0,
92
+ const float* centroid = nullptr,
76
93
  bool centered = false) const;
77
94
  };
78
95
 
96
+ // RaBitQDistanceComputer: Base class for RaBitQ distance computers
97
+ //
98
+ // This intermediate class exists to provide a unified interface for
99
+ // two-stage multi-bit search. While most Faiss quantizers extend
100
+ // FlatCodesDistanceComputer directly, RaBitQ requires this additional
101
+ // abstraction layer due to its unique split encoding strategy
102
+ // (1 sign bit + magnitude bits) which enables:
103
+ //
104
+ // 1. distance_to_code_1bit() - Fast 1-bit filtering using only sign bits
105
+ // 2. distance_to_code_full() - Accurate multi-bit refinement using all bits
106
+ // 3. lower_bound_distance() - Error-bounded adaptive filtering
107
+ // (based on 1-bit estimator)
108
+ //
109
+ // These three methods implement RaBitQ's two-stage search pattern and are
110
+ // shared between the quantized (Q) and non-quantized (NotQ) query variants.
111
+ // The intermediate class allows two-stage search code to work with both
112
+ // variants via a single dynamic_cast.
113
+ struct RaBitQDistanceComputer : FlatCodesDistanceComputer {
114
+ size_t d = 0;
115
+ const float* centroid = nullptr;
116
+ MetricType metric_type = MetricType::METRIC_L2;
117
+ size_t nb_bits = 1;
118
+
119
+ // Query norm for lower bound computation (g_error in rabitq-library)
120
+ // This is the L2 norm of the rotated query: ||query - centroid||
121
+ float g_error = 0.0f;
122
+
123
+ float symmetric_dis(idx_t /*i*/, idx_t /*j*/) override {
124
+ // Not used for RaBitQ
125
+ FAISS_THROW_MSG("Not implemented");
126
+ }
127
+
128
+ // Compute 1-bit distance estimate (fast)
129
+ virtual float distance_to_code_1bit(const uint8_t* code) = 0;
130
+
131
+ // Compute full multi-bit distance (accurate)
132
+ virtual float distance_to_code_full(const uint8_t* code) = 0;
133
+
134
+ // Compute lower bound of distance using error bounds
135
+ // Guarantees: actual_distance >= lower_bound_distance
136
+ // Used for adaptive filtering in two-stage search
137
+ virtual float lower_bound_distance(const uint8_t* code);
138
+
139
+ // Override from FlatCodesDistanceComputer
140
+ // Delegates to distance_to_code_full() for multi-bit distance computation
141
+ float distance_to_code(const uint8_t* code) final {
142
+ return distance_to_code_full(code);
143
+ }
144
+ };
145
+
79
146
  } // namespace faiss