faiss 0.5.0 → 0.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +2 -0
  4. data/ext/faiss/index.cpp +8 -0
  5. data/lib/faiss/version.rb +1 -1
  6. data/vendor/faiss/faiss/IVFlib.cpp +25 -49
  7. data/vendor/faiss/faiss/Index.cpp +11 -0
  8. data/vendor/faiss/faiss/Index.h +24 -1
  9. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +1 -0
  10. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  11. data/vendor/faiss/faiss/IndexFastScan.cpp +1 -1
  12. data/vendor/faiss/faiss/IndexFastScan.h +3 -8
  13. data/vendor/faiss/faiss/IndexFlat.cpp +374 -4
  14. data/vendor/faiss/faiss/IndexFlat.h +80 -0
  15. data/vendor/faiss/faiss/IndexHNSW.cpp +90 -1
  16. data/vendor/faiss/faiss/IndexHNSW.h +57 -1
  17. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +34 -149
  18. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +86 -2
  19. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +3 -1
  20. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +293 -115
  21. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +52 -16
  22. data/vendor/faiss/faiss/IndexPQ.cpp +4 -1
  23. data/vendor/faiss/faiss/IndexPreTransform.cpp +14 -0
  24. data/vendor/faiss/faiss/IndexPreTransform.h +9 -0
  25. data/vendor/faiss/faiss/IndexRaBitQ.cpp +96 -16
  26. data/vendor/faiss/faiss/IndexRaBitQ.h +5 -1
  27. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +238 -93
  28. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +35 -9
  29. data/vendor/faiss/faiss/IndexRefine.cpp +49 -0
  30. data/vendor/faiss/faiss/IndexRefine.h +17 -0
  31. data/vendor/faiss/faiss/clone_index.cpp +2 -0
  32. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +3 -1
  33. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +1 -1
  34. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +1 -1
  35. data/vendor/faiss/faiss/impl/DistanceComputer.h +74 -3
  36. data/vendor/faiss/faiss/impl/HNSW.cpp +294 -15
  37. data/vendor/faiss/faiss/impl/HNSW.h +31 -2
  38. data/vendor/faiss/faiss/impl/IDSelector.h +3 -3
  39. data/vendor/faiss/faiss/impl/Panorama.cpp +193 -0
  40. data/vendor/faiss/faiss/impl/Panorama.h +204 -0
  41. data/vendor/faiss/faiss/impl/RaBitQStats.cpp +29 -0
  42. data/vendor/faiss/faiss/impl/RaBitQStats.h +56 -0
  43. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +54 -6
  44. data/vendor/faiss/faiss/impl/RaBitQUtils.h +183 -6
  45. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +269 -84
  46. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +71 -4
  47. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +362 -0
  48. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +112 -0
  49. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +6 -9
  50. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +1 -3
  51. data/vendor/faiss/faiss/impl/index_read.cpp +156 -12
  52. data/vendor/faiss/faiss/impl/index_write.cpp +142 -19
  53. data/vendor/faiss/faiss/impl/platform_macros.h +12 -0
  54. data/vendor/faiss/faiss/impl/svs_io.cpp +86 -0
  55. data/vendor/faiss/faiss/impl/svs_io.h +67 -0
  56. data/vendor/faiss/faiss/index_factory.cpp +182 -15
  57. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +1 -1
  58. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  59. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +18 -109
  60. data/vendor/faiss/faiss/invlists/InvertedLists.h +2 -18
  61. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
  62. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  63. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +261 -0
  64. data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +117 -0
  65. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +66 -0
  66. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +245 -0
  67. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +137 -0
  68. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +39 -0
  69. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +42 -0
  70. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +149 -0
  71. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +58 -0
  72. data/vendor/faiss/faiss/utils/distances.cpp +0 -3
  73. data/vendor/faiss/faiss/utils/utils.cpp +4 -0
  74. metadata +18 -1
@@ -13,6 +13,7 @@
13
13
  #include <faiss/impl/FaissAssert.h>
14
14
  #include <faiss/impl/FastScanDistancePostProcessing.h>
15
15
  #include <faiss/impl/RaBitQUtils.h>
16
+ #include <faiss/impl/RaBitQuantizerMultiBit.h>
16
17
  #include <faiss/impl/pq4_fast_scan.h>
17
18
  #include <faiss/impl/simd_result_handlers.h>
18
19
  #include <faiss/invlists/BlockInvertedLists.h>
@@ -22,8 +23,10 @@
22
23
  namespace faiss {
23
24
 
24
25
  // Import shared utilities from RaBitQUtils
25
- using rabitq_utils::FactorsData;
26
+ using rabitq_utils::ExtraBitsFactors;
26
27
  using rabitq_utils::QueryFactorsData;
28
+ using rabitq_utils::SignBitFactors;
29
+ using rabitq_utils::SignBitFactorsWithError;
27
30
 
28
31
  inline size_t roundup(size_t a, size_t b) {
29
32
  return (a + b - 1) / b * b;
@@ -41,9 +44,10 @@ IndexIVFRaBitQFastScan::IndexIVFRaBitQFastScan(
41
44
  size_t nlist,
42
45
  MetricType metric,
43
46
  int bbs,
44
- bool own_invlists)
47
+ bool own_invlists,
48
+ uint8_t nb_bits)
45
49
  : IndexIVFFastScan(quantizer, d, nlist, 0, metric, own_invlists),
46
- rabitq(d, metric) {
50
+ rabitq(d, metric, nb_bits) {
47
51
  FAISS_THROW_IF_NOT_MSG(d > 0, "Dimension must be positive");
48
52
  FAISS_THROW_IF_NOT_MSG(
49
53
  metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT,
@@ -66,9 +70,9 @@ IndexIVFRaBitQFastScan::IndexIVFRaBitQFastScan(
66
70
  this->ksub = (1 << nbits_fastscan);
67
71
  this->M2 = roundup(M_fastscan, 2);
68
72
 
69
- // Override code_size to include space for factors after bit patterns
73
+ // Compute code_size: bit_pattern + per-vector storage (factors/ex-codes)
70
74
  const size_t bit_pattern_size = (d + 7) / 8;
71
- this->code_size = bit_pattern_size + sizeof(FactorsData);
75
+ this->code_size = bit_pattern_size + compute_per_vector_storage_size();
72
76
 
73
77
  is_trained = false;
74
78
 
@@ -76,7 +80,7 @@ IndexIVFRaBitQFastScan::IndexIVFRaBitQFastScan(
76
80
  replace_invlists(new BlockInvertedLists(nlist, get_CodePacker()), true);
77
81
  }
78
82
 
79
- factors_storage.clear();
83
+ flat_storage.clear();
80
84
  }
81
85
 
82
86
  // Constructor that converts an existing IndexIVFRaBitQ to FastScan format
@@ -92,20 +96,35 @@ IndexIVFRaBitQFastScan::IndexIVFRaBitQFastScan(
92
96
  false),
93
97
  rabitq(orig.rabitq) {}
94
98
 
99
+ size_t IndexIVFRaBitQFastScan::compute_per_vector_storage_size() const {
100
+ const size_t ex_bits = rabitq.nb_bits - 1;
101
+
102
+ if (ex_bits == 0) {
103
+ // 1-bit: only SignBitFactors (8 bytes)
104
+ return sizeof(SignBitFactors);
105
+ } else {
106
+ // Multi-bit: SignBitFactorsWithError + ExtraBitsFactors + ex-codes
107
+ return sizeof(SignBitFactorsWithError) + sizeof(ExtraBitsFactors) +
108
+ (d * ex_bits + 7) / 8;
109
+ }
110
+ }
111
+
95
112
  void IndexIVFRaBitQFastScan::preprocess_code_metadata(
96
113
  idx_t n,
97
114
  const uint8_t* flat_codes,
98
115
  idx_t start_global_idx) {
99
- // Extract and store factors from codes for use during search
100
- const size_t bit_pattern_size = (d + 7) / 8;
101
- factors_storage.resize(start_global_idx + n);
116
+ // Unified approach: always use flat_storage for both 1-bit and multi-bit
117
+ const size_t storage_size = compute_per_vector_storage_size();
118
+ flat_storage.resize((start_global_idx + n) * storage_size);
102
119
 
120
+ // Copy factors data directly to flat storage (no reordering needed)
121
+ const size_t bit_pattern_size = (d + 7) / 8;
103
122
  for (idx_t i = 0; i < n; i++) {
104
123
  const uint8_t* code = flat_codes + i * code_size;
105
- const uint8_t* factors_ptr = code + bit_pattern_size;
106
- const FactorsData& embedded_factors =
107
- *reinterpret_cast<const FactorsData*>(factors_ptr);
108
- factors_storage[start_global_idx + i] = embedded_factors;
124
+ const uint8_t* source_factors_ptr = code + bit_pattern_size;
125
+ uint8_t* storage =
126
+ flat_storage.data() + (start_global_idx + i) * storage_size;
127
+ memcpy(storage, source_factors_ptr, storage_size);
109
128
  }
110
129
  }
111
130
 
@@ -143,7 +162,7 @@ void IndexIVFRaBitQFastScan::encode_vectors(
143
162
  size_t total_code_size = code_size + coarse_size;
144
163
  memset(codes, 0, total_code_size * n);
145
164
 
146
- const size_t bit_pattern_size = (d + 7) / 8;
165
+ const size_t ex_bits = rabitq.nb_bits - 1;
147
166
 
148
167
  #pragma omp parallel if (n > 1000)
149
168
  {
@@ -161,16 +180,61 @@ void IndexIVFRaBitQFastScan::encode_vectors(
161
180
  // Reconstruct centroid for residual computation
162
181
  quantizer->reconstruct(list_no, centroid.data());
163
182
 
164
- // Encode vector to FastScan format (bit pattern only)
165
- encode_vector_to_fastscan(xi, centroid.data(), fastscan_code);
183
+ const size_t bit_pattern_size = (d + 7) / 8;
166
184
 
167
- // Compute and embed factors after the bit pattern
168
- // Pass original vector and centroid (same as old add_with_ids)
169
- FactorsData factors = rabitq_utils::compute_vector_factors(
170
- xi, d, centroid.data(), rabitq.metric_type);
185
+ // Pack sign bits directly into FastScan format (inline)
186
+ for (size_t j = 0; j < d; j++) {
187
+ const float or_minus_c = xi[j] - centroid[j];
188
+ if (or_minus_c > 0.0f) {
189
+ rabitq_utils::set_bit_fastscan(fastscan_code, j);
190
+ }
191
+ }
192
+
193
+ // Compute factors (with or without f_error depending on mode)
194
+ SignBitFactorsWithError factors =
195
+ rabitq_utils::compute_vector_factors(
196
+ xi,
197
+ d,
198
+ centroid.data(),
199
+ rabitq.metric_type,
200
+ ex_bits > 0);
201
+
202
+ if (ex_bits == 0) {
203
+ // 1-bit: store only SignBitFactors (8 bytes)
204
+ memcpy(fastscan_code + bit_pattern_size,
205
+ &factors,
206
+ sizeof(SignBitFactors));
207
+ } else {
208
+ // Multi-bit: store full SignBitFactorsWithError (12 bytes)
209
+ memcpy(fastscan_code + bit_pattern_size,
210
+ &factors,
211
+ sizeof(SignBitFactorsWithError));
212
+
213
+ // Compute residual (needed for quantize_ex_bits)
214
+ std::vector<float> residual(d);
215
+ for (size_t j = 0; j < d; j++) {
216
+ residual[j] = xi[j] - centroid[j];
217
+ }
171
218
 
172
- uint8_t* factors_ptr = fastscan_code + bit_pattern_size;
173
- *reinterpret_cast<FactorsData*>(factors_ptr) = factors;
219
+ // Quantize ex-bits
220
+ const size_t ex_code_size = (d * ex_bits + 7) / 8;
221
+ uint8_t* ex_code = fastscan_code + bit_pattern_size +
222
+ sizeof(SignBitFactorsWithError);
223
+ ExtraBitsFactors ex_factors_temp;
224
+
225
+ rabitq_multibit::quantize_ex_bits(
226
+ residual.data(),
227
+ d,
228
+ rabitq.nb_bits,
229
+ ex_code,
230
+ ex_factors_temp,
231
+ rabitq.metric_type,
232
+ centroid.data());
233
+
234
+ memcpy(ex_code + ex_code_size,
235
+ &ex_factors_temp,
236
+ sizeof(ExtraBitsFactors));
237
+ }
174
238
 
175
239
  // Include coarse codes if requested
176
240
  if (include_listnos) {
@@ -181,24 +245,6 @@ void IndexIVFRaBitQFastScan::encode_vectors(
181
245
  }
182
246
  }
183
247
 
184
- void IndexIVFRaBitQFastScan::encode_vector_to_fastscan(
185
- const float* xi,
186
- const float* centroid,
187
- uint8_t* fastscan_code) const {
188
- memset(fastscan_code, 0, code_size);
189
-
190
- for (size_t j = 0; j < d; j++) {
191
- const float x_val = xi[j];
192
- const float centroid_val = (centroid != nullptr) ? centroid[j] : 0.0f;
193
- const float or_minus_c = x_val - centroid_val;
194
- const bool xb = (or_minus_c > 0.0f);
195
-
196
- if (xb) {
197
- rabitq_utils::set_bit_fastscan(fastscan_code, j);
198
- }
199
- }
200
- }
201
-
202
248
  bool IndexIVFRaBitQFastScan::lookup_table_is_3d() const {
203
249
  return true;
204
250
  }
@@ -231,6 +277,11 @@ void IndexIVFRaBitQFastScan::compute_residual_LUT(
231
277
  query_factors.qr_norm_L2sqr = fvec_norm_L2sqr(original_query, d);
232
278
  }
233
279
 
280
+ const size_t ex_bits = rabitq.nb_bits - 1;
281
+ if (ex_bits > 0) {
282
+ query_factors.rotated_q = rotated_q;
283
+ }
284
+
234
285
  if (centered) {
235
286
  const float max_code_value = (1 << qb) - 1;
236
287
 
@@ -352,7 +403,7 @@ void IndexIVFRaBitQFastScan::compute_LUT(
352
403
  x + i * d);
353
404
 
354
405
  // Store query factors using compact indexing (ij directly)
355
- if (context.query_factors) {
406
+ if (context.query_factors != nullptr) {
356
407
  context.query_factors[ij] = query_factors_data;
357
408
  }
358
409
 
@@ -367,52 +418,56 @@ void IndexIVFRaBitQFastScan::reconstruct_from_offset(
367
418
  int64_t list_no,
368
419
  int64_t offset,
369
420
  float* recons) const {
370
- // Unpack codes from packed format
371
- size_t coarse_size = coarse_code_size();
421
+ // Get centroid for this list
422
+ std::vector<float> centroid(d);
423
+ quantizer->reconstruct(list_no, centroid.data());
424
+
425
+ // Unpack bit pattern from packed format
372
426
  const size_t bit_pattern_size = (d + 7) / 8;
373
- std::vector<uint8_t> code(
374
- coarse_size + bit_pattern_size + sizeof(FactorsData), 0);
427
+ std::vector<uint8_t> fastscan_code(bit_pattern_size, 0);
375
428
 
376
- encode_listno(list_no, code.data());
377
429
  InvertedLists::ScopedCodes list_codes(invlists, list_no);
378
-
379
- // Unpack the bit pattern from packed format to FastScan layout
380
- uint8_t* fastscan_code = code.data() + coarse_size;
381
430
  for (size_t m = 0; m < M; m++) {
382
431
  uint8_t c =
383
432
  pq4_get_packed_element(list_codes.get(), bbs, M2, offset, m);
384
433
 
385
- // Write the 4-bit code value to FastScan format
386
- // Each byte stores two 4-bit codes (lower and upper nibbles)
387
434
  size_t byte_idx = m / 2;
388
435
  if (m % 2 == 0) {
389
- // Even m: write to lower 4 bits
390
436
  fastscan_code[byte_idx] =
391
437
  (fastscan_code[byte_idx] & 0xF0) | (c & 0x0F);
392
438
  } else {
393
- // Odd m: write to upper 4 bits
394
439
  fastscan_code[byte_idx] =
395
440
  (fastscan_code[byte_idx] & 0x0F) | ((c & 0x0F) << 4);
396
441
  }
397
442
  }
398
443
 
399
- // Get the global index to retrieve factors
400
- // Need to look up the ID from inverted lists
444
+ // Get dp_multiplier directly from flat_storage
401
445
  InvertedLists::ScopedIds list_ids(invlists, list_no);
402
446
  idx_t global_id = list_ids[offset];
403
447
 
404
- // Get factors from factors_storage using global ID
405
- if (global_id >= 0 &&
406
- static_cast<size_t>(global_id) < factors_storage.size()) {
407
- const FactorsData& factors = factors_storage[global_id];
408
-
409
- // Embed factors into the unpacked code
410
- uint8_t* factors_ptr = code.data() + coarse_size + bit_pattern_size;
411
- *reinterpret_cast<FactorsData*>(factors_ptr) = factors;
448
+ float dp_multiplier = 1.0f;
449
+ if (global_id >= 0) {
450
+ const size_t storage_size = compute_per_vector_storage_size();
451
+ const size_t storage_capacity = flat_storage.size() / storage_size;
452
+
453
+ if (static_cast<size_t>(global_id) < storage_capacity) {
454
+ const uint8_t* base_ptr =
455
+ flat_storage.data() + global_id * storage_size;
456
+ const auto& base_factors =
457
+ *reinterpret_cast<const SignBitFactors*>(base_ptr);
458
+ dp_multiplier = base_factors.dp_multiplier;
459
+ }
412
460
  }
413
461
 
414
- // Now use sa_decode which expects unpacked codes with embedded factors
415
- sa_decode(1, code.data(), recons);
462
+ // Decode residual directly using dp_multiplier
463
+ std::vector<float> residual(d);
464
+ decode_fastscan_to_residual(
465
+ fastscan_code.data(), residual.data(), dp_multiplier);
466
+
467
+ // Reconstruct: x = centroid + residual
468
+ for (size_t j = 0; j < d; j++) {
469
+ recons[j] = centroid[j] + residual[j];
470
+ }
416
471
  }
417
472
 
418
473
  void IndexIVFRaBitQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
@@ -426,6 +481,7 @@ void IndexIVFRaBitQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
426
481
  size_t total_code_size = code_size + coarse_size;
427
482
  std::vector<float> centroid(d);
428
483
  std::vector<float> residual(d);
484
+ const size_t bit_pattern_size = (d + 7) / 8;
429
485
 
430
486
  #pragma omp parallel for if (n > 1000)
431
487
  for (idx_t i = 0; i < n; i++) {
@@ -439,7 +495,12 @@ void IndexIVFRaBitQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
439
495
 
440
496
  const uint8_t* fastscan_code = code_i + coarse_size;
441
497
 
442
- decode_fastscan_to_residual(fastscan_code, residual.data());
498
+ const uint8_t* factors_ptr = fastscan_code + bit_pattern_size;
499
+ const auto& base_factors =
500
+ *reinterpret_cast<const SignBitFactors*>(factors_ptr);
501
+
502
+ decode_fastscan_to_residual(
503
+ fastscan_code, residual.data(), base_factors.dp_multiplier);
443
504
 
444
505
  for (size_t j = 0; j < d; j++) {
445
506
  x_i[j] = centroid[j] + residual[j];
@@ -452,23 +513,17 @@ void IndexIVFRaBitQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
452
513
 
453
514
  void IndexIVFRaBitQFastScan::decode_fastscan_to_residual(
454
515
  const uint8_t* fastscan_code,
455
- float* residual) const {
516
+ float* residual,
517
+ float dp_multiplier) const {
456
518
  memset(residual, 0, sizeof(float) * d);
457
519
 
458
520
  const float inv_d_sqrt = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
459
- const size_t bit_pattern_size = (d + 7) / 8;
460
-
461
- // Extract factors directly from embedded codes
462
- const uint8_t* factors_ptr = fastscan_code + bit_pattern_size;
463
- const FactorsData& fac = *reinterpret_cast<const FactorsData*>(factors_ptr);
464
521
 
465
522
  for (size_t j = 0; j < d; j++) {
466
- // Use RaBitQUtils for consistent bit extraction
467
523
  bool bit_value = rabitq_utils::extract_bit_fastscan(fastscan_code, j);
468
524
 
469
525
  float bit_as_float = bit_value ? 1.0f : 0.0f;
470
- residual[j] =
471
- (bit_as_float - 0.5f) * fac.dp_multiplier * 2 * inv_d_sqrt;
526
+ residual[j] = (bit_as_float - 0.5f) * dp_multiplier * 2 * inv_d_sqrt;
472
527
  }
473
528
  }
474
529
 
@@ -483,12 +538,15 @@ SIMDResultHandlerToFloat* IndexIVFRaBitQFastScan::make_knn_handler(
483
538
  const IDSelector* /* sel */,
484
539
  const FastScanDistancePostProcessing& context,
485
540
  const float* /* normalizers */) const {
541
+ const size_t ex_bits = rabitq.nb_bits - 1;
542
+ const bool is_multibit = ex_bits > 0;
543
+
486
544
  if (is_max) {
487
545
  return new IVFRaBitQHeapHandler<CMax<uint16_t, int64_t>>(
488
- this, n, k, distances, labels, &context);
546
+ this, n, k, distances, labels, &context, is_multibit);
489
547
  } else {
490
548
  return new IVFRaBitQHeapHandler<CMin<uint16_t, int64_t>>(
491
- this, n, k, distances, labels, &context);
549
+ this, n, k, distances, labels, &context, is_multibit);
492
550
  }
493
551
  }
494
552
 
@@ -503,7 +561,8 @@ IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::IVFRaBitQHeapHandler(
503
561
  size_t k_val,
504
562
  float* distances,
505
563
  int64_t* labels,
506
- const FastScanDistancePostProcessing* ctx)
564
+ const FastScanDistancePostProcessing* ctx,
565
+ bool multibit)
507
566
  : simd_result_handlers::ResultHandlerCompare<C, true>(
508
567
  nq_val,
509
568
  0,
@@ -513,7 +572,8 @@ IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::IVFRaBitQHeapHandler(
513
572
  heap_labels(labels),
514
573
  nq(nq_val),
515
574
  k(k_val),
516
- context(ctx) {
575
+ context(ctx),
576
+ is_multibit(multibit) {
517
577
  current_list_no = 0;
518
578
  probe_indices.clear();
519
579
 
@@ -572,8 +632,15 @@ void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::handle(
572
632
  }
573
633
 
574
634
  size_t max_positions = std::min<size_t>(32, this->ntotal - idx_base);
635
+
636
+ // Stats tracking for two-stage search
637
+ // n_1bit_evaluations: candidates evaluated using 1-bit lower bound
638
+ // n_multibit_evaluations: candidates requiring full multi-bit distance
639
+ size_t local_1bit_evaluations = 0;
640
+ size_t local_multibit_evaluations = 0;
641
+
575
642
  // Process each candidate vector in the SIMD batch
576
- for (int j = 0; j < static_cast<int>(max_positions); j++) {
643
+ for (size_t j = 0; j < max_positions; j++) {
577
644
  const int64_t result_id = this->adjust_id(b, j);
578
645
 
579
646
  if (result_id < 0) {
@@ -582,39 +649,81 @@ void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::handle(
582
649
 
583
650
  const float normalized_distance = d32tab[j] * one_a + bias;
584
651
 
585
- // Get database factors using global index (factors are stored by global
586
- // index)
587
- const auto& db_factors = index->factors_storage[result_id];
588
- float adjusted_distance;
589
-
590
- // Distance computation depends on quantization mode
591
- if (index->centered) {
592
- int64_t int_dot = ((1 << index->qb) - 1) * index->d;
593
- int_dot -= 2 * static_cast<int64_t>(normalized_distance);
594
-
595
- adjusted_distance = query_factors.qr_to_c_L2sqr +
596
- db_factors.or_minus_c_l2sqr -
597
- 2 * db_factors.dp_multiplier * int_dot *
598
- query_factors.int_dot_scale;
599
-
652
+ // Get database factors from flat_storage
653
+ const size_t storage_size = index->compute_per_vector_storage_size();
654
+ const uint8_t* base_ptr =
655
+ index->flat_storage.data() + result_id * storage_size;
656
+
657
+ if (is_multibit) {
658
+ // Track candidates actually considered for two-stage filtering
659
+ local_1bit_evaluations++;
660
+
661
+ // Multi-bit: use SignBitFactorsWithError and two-stage search
662
+ const SignBitFactorsWithError& full_factors =
663
+ *reinterpret_cast<const SignBitFactorsWithError*>(base_ptr);
664
+
665
+ // Compute 1-bit adjusted distance using shared helper
666
+ float dist_1bit = rabitq_utils::compute_1bit_adjusted_distance(
667
+ normalized_distance,
668
+ full_factors,
669
+ query_factors,
670
+ index->centered,
671
+ index->qb,
672
+ index->d);
673
+
674
+ // Compute lower bound using error bound
675
+ float lower_bound =
676
+ compute_lower_bound(dist_1bit, result_id, local_q, q);
677
+
678
+ // Adaptive filtering: decide whether to compute full distance
679
+ const bool is_similarity =
680
+ index->metric_type == MetricType::METRIC_INNER_PRODUCT;
681
+ bool should_refine = is_similarity
682
+ ? (lower_bound > heap_dis[0]) // IP: keep if better
683
+ : (lower_bound < heap_dis[0]); // L2: keep if better
684
+
685
+ if (should_refine) {
686
+ local_multibit_evaluations++;
687
+
688
+ // Compute local_offset: position within current inverted list
689
+ size_t local_offset = this->j0 + b * 32 + j;
690
+
691
+ // Compute full multi-bit distance
692
+ float dist_full = compute_full_multibit_distance(
693
+ result_id, local_q, q, local_offset);
694
+
695
+ // Update heap if this distance is better
696
+ if (Cfloat::cmp(heap_dis[0], dist_full)) {
697
+ heap_replace_top<Cfloat>(
698
+ k, heap_dis, heap_ids, dist_full, result_id);
699
+ }
700
+ }
600
701
  } else {
601
- float final_dot = normalized_distance - query_factors.c34;
602
- adjusted_distance = db_factors.or_minus_c_l2sqr +
603
- query_factors.qr_to_c_L2sqr -
604
- 2 * db_factors.dp_multiplier * final_dot;
605
- }
606
-
607
- // Convert L2 to inner product if needed
608
- if (query_factors.qr_norm_L2sqr != 0.0f) {
609
- adjusted_distance =
610
- -0.5f * (adjusted_distance - query_factors.qr_norm_L2sqr);
611
- }
612
-
613
- if (Cfloat::cmp(heap_dis[0], adjusted_distance)) {
614
- heap_replace_top<Cfloat>(
615
- k, heap_dis, heap_ids, adjusted_distance, result_id);
702
+ const auto& db_factors =
703
+ *reinterpret_cast<const SignBitFactors*>(base_ptr);
704
+
705
+ // Compute adjusted distance using shared helper
706
+ float adjusted_distance =
707
+ rabitq_utils::compute_1bit_adjusted_distance(
708
+ normalized_distance,
709
+ db_factors,
710
+ query_factors,
711
+ index->centered,
712
+ index->qb,
713
+ index->d);
714
+
715
+ if (Cfloat::cmp(heap_dis[0], adjusted_distance)) {
716
+ heap_replace_top<Cfloat>(
717
+ k, heap_dis, heap_ids, adjusted_distance, result_id);
718
+ }
616
719
  }
617
720
  }
721
+
722
+ // Update global stats atomically
723
+ #pragma omp atomic
724
+ rabitq_stats.n_1bit_evaluations += local_1bit_evaluations;
725
+ #pragma omp atomic
726
+ rabitq_stats.n_multibit_evaluations += local_multibit_evaluations;
618
727
  }
619
728
 
620
729
  template <class C>
@@ -641,10 +750,79 @@ void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::end() {
641
750
  }
642
751
  }
643
752
 
644
- // Explicit template instantiations
645
- template struct IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<
646
- CMin<uint16_t, int64_t>>;
647
- template struct IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<
648
- CMax<uint16_t, int64_t>>;
753
+ template <class C>
754
+ float IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::compute_lower_bound(
755
+ float dist_1bit,
756
+ size_t db_idx,
757
+ size_t local_q,
758
+ size_t global_q) const {
759
+ // Access f_error from SignBitFactorsWithError in flat storage
760
+ const size_t storage_size = index->compute_per_vector_storage_size();
761
+ const uint8_t* base_ptr =
762
+ index->flat_storage.data() + db_idx * storage_size;
763
+ const SignBitFactorsWithError& db_factors =
764
+ *reinterpret_cast<const SignBitFactorsWithError*>(base_ptr);
765
+ float f_error = db_factors.f_error;
766
+
767
+ // Get g_error from query factors
768
+ // Use local_q to access probe_indices (batch-local), global_q for storage
769
+ float g_error = 0.0f;
770
+ if (context && context->query_factors) {
771
+ size_t probe_rank = probe_indices[local_q];
772
+ size_t nprobe = context->nprobe > 0 ? context->nprobe : index->nprobe;
773
+ size_t storage_idx = global_q * nprobe + probe_rank;
774
+ g_error = context->query_factors[storage_idx].g_error;
775
+ }
776
+
777
+ // Compute error adjustment: f_error * g_error
778
+ float error_adjustment = f_error * g_error;
779
+
780
+ return dist_1bit - error_adjustment;
781
+ }
782
+
783
+ template <class C>
784
+ float IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::
785
+ compute_full_multibit_distance(
786
+ size_t db_idx,
787
+ size_t local_q,
788
+ size_t global_q,
789
+ size_t local_offset) const {
790
+ const size_t ex_bits = index->rabitq.nb_bits - 1;
791
+ const size_t dim = index->d;
792
+
793
+ const size_t storage_size = index->compute_per_vector_storage_size();
794
+ const uint8_t* base_ptr =
795
+ index->flat_storage.data() + db_idx * storage_size;
796
+
797
+ const size_t ex_code_size = (dim * ex_bits + 7) / 8;
798
+ const uint8_t* ex_code = base_ptr + sizeof(SignBitFactorsWithError);
799
+ const ExtraBitsFactors& ex_fac = *reinterpret_cast<const ExtraBitsFactors*>(
800
+ base_ptr + sizeof(SignBitFactorsWithError) + ex_code_size);
801
+
802
+ // Use local_q to access probe_indices (batch-local), global_q for storage
803
+ size_t probe_rank = probe_indices[local_q];
804
+ size_t nprobe = context->nprobe > 0 ? context->nprobe : index->nprobe;
805
+ size_t storage_idx = global_q * nprobe + probe_rank;
806
+ const auto& query_factors = context->query_factors[storage_idx];
807
+
808
+ size_t list_no = current_list_no;
809
+ InvertedLists::ScopedCodes list_codes(index->invlists, list_no);
810
+
811
+ std::vector<uint8_t> unpacked_code(index->code_size);
812
+ CodePackerPQ4 packer(index->M2, index->bbs);
813
+ packer.unpack_1(list_codes.get(), local_offset, unpacked_code.data());
814
+ const uint8_t* sign_bits = unpacked_code.data();
815
+
816
+ return rabitq_utils::compute_full_multibit_distance(
817
+ sign_bits,
818
+ ex_code,
819
+ ex_fac,
820
+ query_factors.rotated_q.data(),
821
+ query_factors.qr_to_c_L2sqr,
822
+ query_factors.qr_norm_L2sqr,
823
+ dim,
824
+ ex_bits,
825
+ index->metric_type);
826
+ }
649
827
 
650
828
  } // namespace faiss