faiss 0.5.0 → 0.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +2 -0
  4. data/ext/faiss/index.cpp +8 -0
  5. data/lib/faiss/version.rb +1 -1
  6. data/vendor/faiss/faiss/IVFlib.cpp +25 -49
  7. data/vendor/faiss/faiss/Index.cpp +11 -0
  8. data/vendor/faiss/faiss/Index.h +24 -1
  9. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +1 -0
  10. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  11. data/vendor/faiss/faiss/IndexFastScan.cpp +1 -1
  12. data/vendor/faiss/faiss/IndexFastScan.h +3 -8
  13. data/vendor/faiss/faiss/IndexFlat.cpp +374 -4
  14. data/vendor/faiss/faiss/IndexFlat.h +80 -0
  15. data/vendor/faiss/faiss/IndexHNSW.cpp +90 -1
  16. data/vendor/faiss/faiss/IndexHNSW.h +57 -1
  17. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +34 -149
  18. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +86 -2
  19. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +3 -1
  20. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +293 -115
  21. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +52 -16
  22. data/vendor/faiss/faiss/IndexPQ.cpp +4 -1
  23. data/vendor/faiss/faiss/IndexPreTransform.cpp +14 -0
  24. data/vendor/faiss/faiss/IndexPreTransform.h +9 -0
  25. data/vendor/faiss/faiss/IndexRaBitQ.cpp +96 -16
  26. data/vendor/faiss/faiss/IndexRaBitQ.h +5 -1
  27. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +238 -93
  28. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +35 -9
  29. data/vendor/faiss/faiss/IndexRefine.cpp +49 -0
  30. data/vendor/faiss/faiss/IndexRefine.h +17 -0
  31. data/vendor/faiss/faiss/clone_index.cpp +2 -0
  32. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +3 -1
  33. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +1 -1
  34. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +1 -1
  35. data/vendor/faiss/faiss/impl/DistanceComputer.h +74 -3
  36. data/vendor/faiss/faiss/impl/HNSW.cpp +294 -15
  37. data/vendor/faiss/faiss/impl/HNSW.h +31 -2
  38. data/vendor/faiss/faiss/impl/IDSelector.h +3 -3
  39. data/vendor/faiss/faiss/impl/Panorama.cpp +193 -0
  40. data/vendor/faiss/faiss/impl/Panorama.h +204 -0
  41. data/vendor/faiss/faiss/impl/RaBitQStats.cpp +29 -0
  42. data/vendor/faiss/faiss/impl/RaBitQStats.h +56 -0
  43. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +54 -6
  44. data/vendor/faiss/faiss/impl/RaBitQUtils.h +183 -6
  45. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +269 -84
  46. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +71 -4
  47. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +362 -0
  48. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +112 -0
  49. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +6 -9
  50. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +1 -3
  51. data/vendor/faiss/faiss/impl/index_read.cpp +156 -12
  52. data/vendor/faiss/faiss/impl/index_write.cpp +142 -19
  53. data/vendor/faiss/faiss/impl/platform_macros.h +12 -0
  54. data/vendor/faiss/faiss/impl/svs_io.cpp +86 -0
  55. data/vendor/faiss/faiss/impl/svs_io.h +67 -0
  56. data/vendor/faiss/faiss/index_factory.cpp +182 -15
  57. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +1 -1
  58. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  59. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +18 -109
  60. data/vendor/faiss/faiss/invlists/InvertedLists.h +2 -18
  61. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
  62. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  63. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +261 -0
  64. data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +117 -0
  65. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +66 -0
  66. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +245 -0
  67. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +137 -0
  68. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +39 -0
  69. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +42 -0
  70. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +149 -0
  71. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +58 -0
  72. data/vendor/faiss/faiss/utils/distances.cpp +0 -3
  73. data/vendor/faiss/faiss/utils/utils.cpp +4 -0
  74. metadata +18 -1
@@ -8,6 +8,7 @@
8
8
  #include <faiss/IndexRaBitQFastScan.h>
9
9
  #include <faiss/impl/FastScanDistancePostProcessing.h>
10
10
  #include <faiss/impl/RaBitQUtils.h>
11
+ #include <faiss/impl/RaBitQuantizerMultiBit.h>
11
12
  #include <faiss/impl/pq4_fast_scan.h>
12
13
  #include <faiss/utils/utils.h>
13
14
  #include <algorithm>
@@ -19,15 +20,35 @@ static inline size_t roundup(size_t a, size_t b) {
19
20
  return (a + b - 1) / b * b;
20
21
  }
21
22
 
23
+ size_t IndexRaBitQFastScan::compute_per_vector_storage_size() const {
24
+ const size_t ex_bits = rabitq.nb_bits - 1;
25
+
26
+ if (ex_bits == 0) {
27
+ // 1-bit: only SignBitFactors
28
+ return sizeof(rabitq_utils::SignBitFactors);
29
+ } else {
30
+ // Multi-bit: SignBitFactorsWithError + ExtraBitsFactors +
31
+ // mag-codes
32
+ return sizeof(SignBitFactorsWithError) + sizeof(ExtraBitsFactors) +
33
+ (d * ex_bits + 7) / 8;
34
+ }
35
+ }
36
+
22
37
  IndexRaBitQFastScan::IndexRaBitQFastScan() = default;
23
38
 
24
- IndexRaBitQFastScan::IndexRaBitQFastScan(idx_t d, MetricType metric, int bbs)
25
- : rabitq(d, metric) {
39
+ IndexRaBitQFastScan::IndexRaBitQFastScan(
40
+ idx_t d,
41
+ MetricType metric,
42
+ int bbs,
43
+ uint8_t nb_bits)
44
+ : rabitq(d, metric, nb_bits) {
26
45
  // RaBitQ-specific validation
27
46
  FAISS_THROW_IF_NOT_MSG(d > 0, "Dimension must be positive");
28
47
  FAISS_THROW_IF_NOT_MSG(
29
48
  metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT,
30
49
  "RaBitQ FastScan only supports L2 and Inner Product metrics");
50
+ FAISS_THROW_IF_NOT_MSG(
51
+ nb_bits >= 1 && nb_bits <= 9, "nb_bits must be between 1 and 9");
31
52
 
32
53
  // RaBitQ uses 1 bit per dimension packed into 4-bit FastScan sub-quantizers
33
54
  // Each FastScan sub-quantizer handles 4 RaBitQ dimensions
@@ -37,17 +58,15 @@ IndexRaBitQFastScan::IndexRaBitQFastScan(idx_t d, MetricType metric, int bbs)
37
58
  // init_fastscan will validate bbs % 32 == 0 and nbits_fastscan == 4
38
59
  init_fastscan(static_cast<int>(d), M_fastscan, nbits_fastscan, metric, bbs);
39
60
 
40
- // Override code_size to include space for factors after bit patterns
41
- // RaBitQ stores 1 bit per dimension, requiring (d + 7) / 8 bytes
42
- const size_t bit_pattern_size = (d + 7) / 8;
43
- code_size = bit_pattern_size + sizeof(FactorsData);
61
+ // Compute code_size directly using RaBitQuantizer
62
+ code_size = rabitq.compute_code_size(d, nb_bits);
44
63
 
45
64
  // Set RaBitQ-specific parameters
46
65
  qb = 8;
47
66
  center.resize(d, 0.0f);
48
67
 
49
- // Pre-allocate storage vectors for efficiency
50
- factors_storage.clear();
68
+ // Initialize empty flat storage
69
+ flat_storage.clear();
51
70
  }
52
71
 
53
72
  IndexRaBitQFastScan::IndexRaBitQFastScan(const IndexRaBitQ& orig, int bbs)
@@ -72,10 +91,7 @@ IndexRaBitQFastScan::IndexRaBitQFastScan(const IndexRaBitQ& orig, int bbs)
72
91
  orig.metric_type,
73
92
  bbs);
74
93
 
75
- // Override code_size to include space for factors after bit patterns
76
- // RaBitQ stores 1 bit per dimension, requiring (d + 7) / 8 bytes
77
- const size_t bit_pattern_size = (orig.d + 7) / 8;
78
- code_size = bit_pattern_size + sizeof(FactorsData);
94
+ code_size = rabitq.compute_code_size(d, rabitq.nb_bits);
79
95
 
80
96
  // Copy properties from original index
81
97
  ntotal = orig.ntotal;
@@ -88,23 +104,19 @@ IndexRaBitQFastScan::IndexRaBitQFastScan(const IndexRaBitQ& orig, int bbs)
88
104
 
89
105
  // If the original index has data, extract factors and pack codes
90
106
  if (ntotal > 0) {
91
- // Allocate space for factors
92
- factors_storage.resize(ntotal);
93
-
94
- // Extract factors from original codes for each vector
95
- const float* centroid_data = center.data();
107
+ // Compute per-vector storage size for flat storage
108
+ const size_t storage_size = compute_per_vector_storage_size();
96
109
 
97
- // Use the original RaBitQ quantizer to decode and compute factors
98
- std::vector<float> decoded_vectors(ntotal * orig.d);
99
- orig.sa_decode(ntotal, orig.codes.data(), decoded_vectors.data());
110
+ // Allocate flat storage
111
+ flat_storage.resize(ntotal * storage_size);
100
112
 
113
+ // Copy factors directly from original codes
114
+ const size_t bit_pattern_size = (d + 7) / 8;
101
115
  for (idx_t i = 0; i < ntotal; i++) {
102
- FactorsData& fac = factors_storage[i];
103
- const float* x_row = decoded_vectors.data() + i * orig.d;
104
-
105
- // Use shared utilities for computing factors
106
- fac = rabitq_utils::compute_vector_factors(
107
- x_row, orig.d, centroid_data, orig.metric_type);
116
+ const uint8_t* orig_code = orig.codes.data() + i * orig.code_size;
117
+ const uint8_t* source_factors_ptr = orig_code + bit_pattern_size;
118
+ uint8_t* storage = flat_storage.data() + i * storage_size;
119
+ memcpy(storage, source_factors_ptr, storage_size);
108
120
  }
109
121
 
110
122
  // Convert RaBitQ bit format to FastScan 4-bit sub-quantizer format
@@ -191,15 +203,19 @@ void IndexRaBitQFastScan::add(idx_t n, const float* x) {
191
203
  AlignedTable<uint8_t> tmp_codes(n * code_size);
192
204
  compute_codes(tmp_codes.get(), n, x);
193
205
 
194
- // Extract and store factors from embedded codes for handler access
206
+ const size_t storage_size = compute_per_vector_storage_size();
207
+ flat_storage.resize((ntotal + n) * storage_size);
208
+
209
+ // Populate flat storage (no sign bits copying needed!)
195
210
  const size_t bit_pattern_size = (d + 7) / 8;
196
- factors_storage.resize(ntotal + n);
197
211
  for (idx_t i = 0; i < n; i++) {
198
212
  const uint8_t* code = tmp_codes.get() + i * code_size;
199
- const uint8_t* factors_ptr = code + bit_pattern_size;
200
- const FactorsData& embedded_factors =
201
- *reinterpret_cast<const FactorsData*>(factors_ptr);
202
- factors_storage[ntotal + i] = embedded_factors;
213
+ const idx_t vec_idx = ntotal + i;
214
+
215
+ // Copy factors data directly to flat storage (no reordering needed)
216
+ const uint8_t* source_factors_ptr = code + bit_pattern_size;
217
+ uint8_t* storage = flat_storage.data() + vec_idx * storage_size;
218
+ memcpy(storage, source_factors_ptr, storage_size);
203
219
  }
204
220
 
205
221
  // Resize main storage (same logic as parent)
@@ -239,6 +255,8 @@ void IndexRaBitQFastScan::compute_codes(uint8_t* codes, idx_t n, const float* x)
239
255
  // Hoist loop-invariant computations
240
256
  const float* centroid_data = center.data();
241
257
  const size_t bit_pattern_size = (d + 7) / 8;
258
+ const size_t ex_bits = rabitq.nb_bits - 1;
259
+ const size_t ex_code_size = (d * ex_bits + 7) / 8;
242
260
 
243
261
  memset(codes, 0, n * code_size);
244
262
 
@@ -247,25 +265,52 @@ void IndexRaBitQFastScan::compute_codes(uint8_t* codes, idx_t n, const float* x)
247
265
  uint8_t* const code = codes + i * code_size;
248
266
  const float* const x_row = x + i * d;
249
267
 
250
- // Pack bits directly into FastScan format
268
+ // Compute residual once, reuse for both sign bits and ex-bits
269
+ std::vector<float> residual(d);
251
270
  for (size_t j = 0; j < d; j++) {
252
- const float x_val = x_row[j];
253
271
  const float centroid_val = centroid_data ? centroid_data[j] : 0.0f;
254
- const float or_minus_c = x_val - centroid_val;
255
- const bool xb = (or_minus_c > 0.0f);
272
+ residual[j] = x_row[j] - centroid_val;
273
+ }
256
274
 
257
- if (xb) {
275
+ // Pack sign bits directly into FastScan format using precomputed
276
+ // residual
277
+ for (size_t j = 0; j < d; j++) {
278
+ if (residual[j] > 0.0f) {
258
279
  rabitq_utils::set_bit_fastscan(code, j);
259
280
  }
260
281
  }
261
282
 
262
- // Calculate and append factors after the bit data
263
- FactorsData factors = rabitq_utils::compute_vector_factors(
264
- x_row, d, centroid_data, metric_type);
283
+ SignBitFactorsWithError factors = rabitq_utils::compute_vector_factors(
284
+ x_row, d, centroid_data, metric_type, ex_bits > 0);
265
285
 
266
- // Append factors at the end of the code
267
- uint8_t* factors_ptr = code + bit_pattern_size;
268
- *reinterpret_cast<FactorsData*>(factors_ptr) = factors;
286
+ if (ex_bits == 0) {
287
+ // 1-bit: store only SignBitFactors (8 bytes)
288
+ memcpy(code + bit_pattern_size, &factors, sizeof(SignBitFactors));
289
+ } else {
290
+ // Multi-bit: store full SignBitFactorsWithError (12 bytes)
291
+ memcpy(code + bit_pattern_size,
292
+ &factors,
293
+ sizeof(SignBitFactorsWithError));
294
+
295
+ // Add mag-codes and ExtraBitsFactors using precomputed
296
+ // residual
297
+ uint8_t* ex_code =
298
+ code + bit_pattern_size + sizeof(SignBitFactorsWithError);
299
+ ExtraBitsFactors ex_factors_temp;
300
+
301
+ rabitq_multibit::quantize_ex_bits(
302
+ residual.data(),
303
+ d,
304
+ rabitq.nb_bits,
305
+ ex_code,
306
+ ex_factors_temp,
307
+ metric_type,
308
+ centroid_data);
309
+
310
+ memcpy(ex_code + ex_code_size,
311
+ &ex_factors_temp,
312
+ sizeof(ExtraBitsFactors));
313
+ }
269
314
  }
270
315
  }
271
316
 
@@ -300,7 +345,8 @@ void IndexRaBitQFastScan::compute_float_LUT(
300
345
  rotated_qq);
301
346
 
302
347
  // Store query factors in context array if provided
303
- if (context.query_factors) {
348
+ if (context.query_factors != nullptr) {
349
+ query_factors_data.rotated_q = rotated_q;
304
350
  context.query_factors[i] = query_factors_data;
305
351
  }
306
352
 
@@ -397,8 +443,9 @@ void IndexRaBitQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
397
443
 
398
444
  // Extract factors directly from embedded codes
399
445
  const uint8_t* factors_ptr = code + bit_pattern_size;
400
- const FactorsData& fac =
401
- *reinterpret_cast<const FactorsData*>(factors_ptr);
446
+ const rabitq_utils::SignBitFactors* fac =
447
+ reinterpret_cast<const rabitq_utils::SignBitFactors*>(
448
+ factors_ptr);
402
449
 
403
450
  for (size_t j = 0; j < d; j++) {
404
451
  // Use RaBitQUtils for consistent bit extraction
@@ -406,7 +453,7 @@ void IndexRaBitQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
406
453
  float bit = bit_value ? 1.0f : 0.0f;
407
454
 
408
455
  // Compute the output using RaBitQ reconstruction formula
409
- x[i * d + j] = (bit - 0.5f) * fac.dp_multiplier * 2 * inv_d_sqrt +
456
+ x[i * d + j] = (bit - 0.5f) * fac->dp_multiplier * 2 * inv_d_sqrt +
410
457
  ((centroid_in == nullptr) ? 0 : centroid_in[j]);
411
458
  }
412
459
  }
@@ -446,14 +493,16 @@ RaBitQHeapHandler<C, with_id_map>::RaBitQHeapHandler(
446
493
  float* distances,
447
494
  int64_t* labels,
448
495
  const IDSelector* sel_in,
449
- const FastScanDistancePostProcessing& ctx)
496
+ const FastScanDistancePostProcessing& ctx,
497
+ bool multi_bit)
450
498
  : RHC(nq_val, index->ntotal, sel_in),
451
499
  rabitq_index(index),
452
500
  heap_distances(distances),
453
501
  heap_labels(labels),
454
502
  nq(nq_val),
455
503
  k(k_val),
456
- context(ctx) {
504
+ context(ctx),
505
+ is_multi_bit(multi_bit) {
457
506
  // Initialize heaps for all queries in constructor
458
507
  // This allows us to support direct normalizer assignment
459
508
  #pragma omp parallel for if (nq > 100)
@@ -480,7 +529,7 @@ void RaBitQHeapHandler<C, with_id_map>::handle(
480
529
 
481
530
  // Access query factors from query_factors pointer
482
531
  rabitq_utils::QueryFactorsData query_factors_data = {};
483
- if (context.query_factors) {
532
+ if (context.query_factors != nullptr) {
484
533
  query_factors_data = context.query_factors[q];
485
534
  }
486
535
 
@@ -494,6 +543,15 @@ void RaBitQHeapHandler<C, with_id_map>::handle(
494
543
  ? std::min<size_t>(32, rabitq_index->ntotal - base_db_idx)
495
544
  : 0;
496
545
 
546
+ // Get storage size once
547
+ const size_t storage_size = rabitq_index->compute_per_vector_storage_size();
548
+
549
+ // Stats tracking for multi-bit two-stage search only
550
+ // n_1bit_evaluations: candidates evaluated using 1-bit lower bound
551
+ // n_multibit_evaluations: candidates requiring full multi-bit distance
552
+ size_t local_1bit_evaluations = 0;
553
+ size_t local_multibit_evaluations = 0;
554
+
497
555
  // Process distances in batch
498
556
  for (size_t i = 0; i < max_vectors; i++) {
499
557
  const size_t db_idx = base_db_idx + i;
@@ -501,43 +559,70 @@ void RaBitQHeapHandler<C, with_id_map>::handle(
501
559
  // Normalize distance from LUT lookup
502
560
  const float normalized_distance = d32tab[i] * one_a + bias;
503
561
 
504
- // Access factors from storage (populated from embedded codes during
505
- // add())
506
- const auto& db_factors = rabitq_index->factors_storage[db_idx];
507
-
508
- float adjusted_distance;
509
-
510
- if (rabitq_index->centered) {
511
- // For centered mode: normalized_distance contains the raw XOR
512
- // contribution. Apply the signed odd integer quantization formula:
513
- // int_dot = ((1 << qb) - 1) * d - 2 * xor_dot_product
514
- int64_t int_dot = ((1 << rabitq_index->qb) - 1) * rabitq_index->d;
515
- int_dot -= 2 * static_cast<int64_t>(normalized_distance);
516
-
517
- adjusted_distance = query_factors_data.qr_to_c_L2sqr +
518
- db_factors.or_minus_c_l2sqr -
519
- 2 * db_factors.dp_multiplier * int_dot *
520
- query_factors_data.int_dot_scale;
562
+ // Access factors from flat storage
563
+ const uint8_t* base_ptr =
564
+ rabitq_index->flat_storage.data() + db_idx * storage_size;
565
+
566
+ if (is_multi_bit) {
567
+ // Track candidates actually considered for two-stage filtering
568
+ local_1bit_evaluations++;
569
+
570
+ const SignBitFactorsWithError& full_factors =
571
+ *reinterpret_cast<const SignBitFactorsWithError*>(base_ptr);
572
+
573
+ float dist_1bit = rabitq_utils::compute_1bit_adjusted_distance(
574
+ normalized_distance,
575
+ full_factors,
576
+ query_factors_data,
577
+ rabitq_index->centered,
578
+ rabitq_index->qb,
579
+ rabitq_index->d);
580
+
581
+ float lower_bound = compute_lower_bound(dist_1bit, db_idx, q);
582
+
583
+ // Adaptive filtering: decide whether to compute full distance
584
+ const bool is_similarity = rabitq_index->metric_type ==
585
+ MetricType::METRIC_INNER_PRODUCT;
586
+ bool should_refine = is_similarity
587
+ ? (lower_bound > heap_dis[0]) // IP: keep if better
588
+ : (lower_bound < heap_dis[0]); // L2: keep if better
589
+
590
+ if (should_refine) {
591
+ local_multibit_evaluations++;
592
+ float dist_full = compute_full_multibit_distance(db_idx, q);
593
+
594
+ if (Cfloat::cmp(heap_dis[0], dist_full)) {
595
+ heap_replace_top<Cfloat>(
596
+ k, heap_dis, heap_ids, dist_full, db_idx);
597
+ }
598
+ }
521
599
  } else {
522
- // For non-centered quantization: use traditional formula
523
- float final_dot = normalized_distance - query_factors_data.c34;
524
- adjusted_distance = db_factors.or_minus_c_l2sqr +
525
- query_factors_data.qr_to_c_L2sqr -
526
- 2 * db_factors.dp_multiplier * final_dot;
527
- }
528
-
529
- // Apply inner product correction if needed
530
- if (query_factors_data.qr_norm_L2sqr != 0.0f) {
531
- adjusted_distance = -0.5f *
532
- (adjusted_distance - query_factors_data.qr_norm_L2sqr);
533
- }
534
-
535
- // Add to heap if better than current worst
536
- if (Cfloat::cmp(heap_dis[0], adjusted_distance)) {
537
- heap_replace_top<Cfloat>(
538
- k, heap_dis, heap_ids, adjusted_distance, db_idx);
600
+ const rabitq_utils::SignBitFactors& db_factors =
601
+ *reinterpret_cast<const rabitq_utils::SignBitFactors*>(
602
+ base_ptr);
603
+
604
+ float adjusted_distance =
605
+ rabitq_utils::compute_1bit_adjusted_distance(
606
+ normalized_distance,
607
+ db_factors,
608
+ query_factors_data,
609
+ rabitq_index->centered,
610
+ rabitq_index->qb,
611
+ rabitq_index->d);
612
+
613
+ // Add to heap if better than current worst
614
+ if (Cfloat::cmp(heap_dis[0], adjusted_distance)) {
615
+ heap_replace_top<Cfloat>(
616
+ k, heap_dis, heap_ids, adjusted_distance, db_idx);
617
+ }
539
618
  }
540
619
  }
620
+
621
+ // Update global stats atomically
622
+ #pragma omp atomic
623
+ rabitq_stats.n_1bit_evaluations += local_1bit_evaluations;
624
+ #pragma omp atomic
625
+ rabitq_stats.n_multibit_evaluations += local_multibit_evaluations;
541
626
  }
542
627
 
543
628
  template <class C, bool with_id_map>
@@ -557,8 +642,71 @@ void RaBitQHeapHandler<C, with_id_map>::end() {
557
642
  }
558
643
  }
559
644
 
645
+ template <class C, bool with_id_map>
646
+ float RaBitQHeapHandler<C, with_id_map>::compute_lower_bound(
647
+ float dist_1bit,
648
+ size_t db_idx,
649
+ size_t q) const {
650
+ // Access f_error directly from SignBitFactorsWithError in flat storage
651
+ const size_t storage_size = rabitq_index->compute_per_vector_storage_size();
652
+ const uint8_t* base_ptr =
653
+ rabitq_index->flat_storage.data() + db_idx * storage_size;
654
+ const SignBitFactorsWithError& db_factors =
655
+ *reinterpret_cast<const SignBitFactorsWithError*>(base_ptr);
656
+ float f_error = db_factors.f_error;
657
+
658
+ // Get g_error from query factors (query-dependent error term)
659
+ float g_error = 0.0f;
660
+ if (context.query_factors != nullptr) {
661
+ g_error = context.query_factors[q].g_error;
662
+ }
663
+
664
+ // Compute error adjustment: f_error * g_error
665
+ float error_adjustment = f_error * g_error;
666
+
667
+ return dist_1bit - error_adjustment;
668
+ }
669
+
670
+ template <class C, bool with_id_map>
671
+ float RaBitQHeapHandler<C, with_id_map>::compute_full_multibit_distance(
672
+ size_t db_idx,
673
+ size_t q) const {
674
+ const size_t ex_bits = rabitq_index->rabitq.nb_bits - 1;
675
+ const size_t dim = rabitq_index->d;
676
+
677
+ const size_t storage_size = rabitq_index->compute_per_vector_storage_size();
678
+ const uint8_t* base_ptr =
679
+ rabitq_index->flat_storage.data() + db_idx * storage_size;
680
+
681
+ const size_t ex_code_size = (dim * ex_bits + 7) / 8;
682
+ const uint8_t* ex_code = base_ptr + sizeof(SignBitFactorsWithError);
683
+ const ExtraBitsFactors& ex_fac = *reinterpret_cast<const ExtraBitsFactors*>(
684
+ base_ptr + sizeof(SignBitFactorsWithError) + ex_code_size);
685
+
686
+ // Get query factors reference (avoid copying)
687
+ const rabitq_utils::QueryFactorsData& query_factors =
688
+ context.query_factors[q];
689
+
690
+ // Get sign bits from FastScan packed format
691
+ std::vector<uint8_t> unpacked_code(rabitq_index->code_size);
692
+ CodePackerPQ4 packer(rabitq_index->M2, rabitq_index->bbs);
693
+ packer.unpack_1(rabitq_index->codes.get(), db_idx, unpacked_code.data());
694
+ const uint8_t* sign_bits = unpacked_code.data();
695
+
696
+ return rabitq_utils::compute_full_multibit_distance(
697
+ sign_bits,
698
+ ex_code,
699
+ ex_fac,
700
+ query_factors.rotated_q.data(),
701
+ query_factors.qr_to_c_L2sqr,
702
+ query_factors.qr_norm_L2sqr,
703
+ dim,
704
+ ex_bits,
705
+ rabitq_index->metric_type);
706
+ }
707
+
560
708
  // Implementation of virtual make_knn_handler method
561
- void* IndexRaBitQFastScan::make_knn_handler(
709
+ SIMDResultHandlerToFloat* IndexRaBitQFastScan::make_knn_handler(
562
710
  bool is_max,
563
711
  int /*impl*/,
564
712
  idx_t n,
@@ -568,19 +716,16 @@ void* IndexRaBitQFastScan::make_knn_handler(
568
716
  idx_t* labels,
569
717
  const IDSelector* sel,
570
718
  const FastScanDistancePostProcessing& context) const {
719
+ // Use runtime boolean for multi-bit mode
720
+ const bool multi_bit = rabitq.nb_bits > 1;
721
+
571
722
  if (is_max) {
572
723
  return new RaBitQHeapHandler<CMax<uint16_t, int>, false>(
573
- this, n, k, distances, labels, sel, context);
724
+ this, n, k, distances, labels, sel, context, multi_bit);
574
725
  } else {
575
726
  return new RaBitQHeapHandler<CMin<uint16_t, int>, false>(
576
- this, n, k, distances, labels, sel, context);
727
+ this, n, k, distances, labels, sel, context, multi_bit);
577
728
  }
578
729
  }
579
730
 
580
- // Explicit template instantiations for the required comparator types
581
- template struct RaBitQHeapHandler<CMin<uint16_t, int>, false>;
582
- template struct RaBitQHeapHandler<CMax<uint16_t, int>, false>;
583
- template struct RaBitQHeapHandler<CMin<uint16_t, int>, true>;
584
- template struct RaBitQHeapHandler<CMax<uint16_t, int>, true>;
585
-
586
731
  } // namespace faiss
@@ -11,6 +11,7 @@
11
11
 
12
12
  #include <faiss/IndexFastScan.h>
13
13
  #include <faiss/IndexRaBitQ.h>
14
+ #include <faiss/impl/RaBitQStats.h>
14
15
  #include <faiss/impl/RaBitQUtils.h>
15
16
  #include <faiss/impl/RaBitQuantizer.h>
16
17
  #include <faiss/impl/simd_result_handlers.h>
@@ -20,8 +21,10 @@
20
21
  namespace faiss {
21
22
 
22
23
  // Import shared utilities from RaBitQUtils
23
- using rabitq_utils::FactorsData;
24
+ using rabitq_utils::ExtraBitsFactors;
24
25
  using rabitq_utils::QueryFactorsData;
26
+ using rabitq_utils::SignBitFactors;
27
+ using rabitq_utils::SignBitFactorsWithError;
25
28
 
26
29
  /** Fast-scan version of RaBitQ index that processes 32 database vectors at a
27
30
  * time using SIMD operations. Similar to IndexPQFastScan but adapted for
@@ -40,9 +43,16 @@ struct IndexRaBitQFastScan : IndexFastScan {
40
43
  /// Center of all points (same as IndexRaBitQ)
41
44
  std::vector<float> center;
42
45
 
43
- /// Extracted factors storage for batch processing
44
- /// Size: ntotal, stores factors separately from packed codes
45
- std::vector<FactorsData> factors_storage;
46
+ /// Per-vector auxiliary data (1-bit codes stored separately in `codes`)
47
+ ///
48
+ /// 1-bit codes (sign bits) are stored in the inherited `codes` array from
49
+ /// IndexFastScan in packed FastScan format for SIMD processing.
50
+ ///
51
+ /// This flat_storage holds per-vector factors and refinement-bit codes:
52
+ /// Layout for 1-bit: [SignBitFactors (8 bytes)]
53
+ /// Layout for multi-bit: [SignBitFactorsWithError
54
+ /// (12B)][ref_codes][ExtraBitsFactors (8B)]
55
+ std::vector<uint8_t> flat_storage;
46
56
 
47
57
  /// Default number of bits to quantize a query with
48
58
  uint8_t qb = 8;
@@ -55,7 +65,8 @@ struct IndexRaBitQFastScan : IndexFastScan {
55
65
  explicit IndexRaBitQFastScan(
56
66
  idx_t d,
57
67
  MetricType metric = METRIC_L2,
58
- int bbs = 32);
68
+ int bbs = 32,
69
+ uint8_t nb_bits = 1);
59
70
 
60
71
  /// build from an existing IndexRaBitQ
61
72
  explicit IndexRaBitQFastScan(const IndexRaBitQ& orig, int bbs = 32);
@@ -66,6 +77,9 @@ struct IndexRaBitQFastScan : IndexFastScan {
66
77
 
67
78
  void compute_codes(uint8_t* codes, idx_t n, const float* x) const override;
68
79
 
80
+ /// Compute storage size per vector in flat_storage
81
+ size_t compute_per_vector_storage_size() const;
82
+
69
83
  void compute_float_LUT(
70
84
  float* lut,
71
85
  idx_t n,
@@ -83,7 +97,7 @@ struct IndexRaBitQFastScan : IndexFastScan {
83
97
  const SearchParameters* params = nullptr) const override;
84
98
 
85
99
  /// Override to create RaBitQ-specific handlers
86
- void* make_knn_handler(
100
+ SIMDResultHandlerToFloat* make_knn_handler(
87
101
  bool is_max,
88
102
  int /*impl*/,
89
103
  idx_t n,
@@ -108,6 +122,8 @@ struct IndexRaBitQFastScan : IndexFastScan {
108
122
  * - Direct heap integration (no intermediate result storage)
109
123
  * - Batch-level computation of normalizers and query factors
110
124
  * - Preserves exact mathematical equivalence to original RaBitQ distances
125
+ * - Runtime boolean for multi-bit support
126
+ *
111
127
  * @tparam C Comparator type (CMin/CMax) for heap operations
112
128
  * @tparam with_id_map Whether to use id mapping (similar to HeapHandler)
113
129
  */
@@ -122,7 +138,8 @@ struct RaBitQHeapHandler
122
138
  int64_t* heap_labels; // [nq * k]
123
139
  const size_t nq, k;
124
140
  const FastScanDistancePostProcessing&
125
- context; // Processing context with query offset
141
+ context; // Processing context with query offset
142
+ const bool is_multi_bit; // Runtime flag for multi-bit mode
126
143
 
127
144
  // Use float-based comparator for heap operations
128
145
  using Cfloat = typename std::conditional<
@@ -137,13 +154,22 @@ struct RaBitQHeapHandler
137
154
  float* distances,
138
155
  int64_t* labels,
139
156
  const IDSelector* sel_in,
140
- const FastScanDistancePostProcessing& context);
157
+ const FastScanDistancePostProcessing& context,
158
+ bool multi_bit);
141
159
 
142
- void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final;
160
+ void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) override;
143
161
 
144
162
  void begin(const float* norms);
145
163
 
146
164
  void end();
165
+
166
+ private:
167
+ /// Compute full multi-bit distance for a candidate vector (multi-bit only)
168
+ float compute_full_multibit_distance(size_t db_idx, size_t q) const;
169
+
170
+ /// Compute lower bound using 1-bit distance and error bound (multi-bit
171
+ /// only)
172
+ float compute_lower_bound(float dist_1bit, size_t db_idx, size_t q) const;
147
173
  };
148
174
 
149
175
  } // namespace faiss
@@ -341,4 +341,53 @@ void IndexRefineFlat::search(
341
341
  }
342
342
  }
343
343
 
344
+ /***************************************************
345
+ * IndexRefinePanorama
346
+ ***************************************************/
347
+
348
+ void IndexRefinePanorama::search(
349
+ idx_t n,
350
+ const float* x,
351
+ idx_t k,
352
+ float* distances,
353
+ idx_t* labels,
354
+ const SearchParameters* params_in) const {
355
+ const IndexRefineSearchParameters* params = nullptr;
356
+ if (params_in) {
357
+ params = dynamic_cast<const IndexRefineSearchParameters*>(params_in);
358
+ FAISS_THROW_IF_NOT_MSG(
359
+ params, "IndexRefineFlat params have incorrect type");
360
+ }
361
+
362
+ idx_t k_base = (params != nullptr) ? idx_t(k * params->k_factor)
363
+ : idx_t(k * k_factor);
364
+ SearchParameters* base_index_params =
365
+ (params != nullptr) ? params->base_index_params : nullptr;
366
+
367
+ FAISS_THROW_IF_NOT(k_base >= k);
368
+
369
+ FAISS_THROW_IF_NOT(base_index);
370
+ FAISS_THROW_IF_NOT(refine_index);
371
+
372
+ FAISS_THROW_IF_NOT(k > 0);
373
+ FAISS_THROW_IF_NOT(is_trained);
374
+
375
+ std::unique_ptr<idx_t[]> del1;
376
+ std::unique_ptr<float[]> del2;
377
+ idx_t* base_labels = new idx_t[n * k_base];
378
+ float* base_distances = new float[n * k_base];
379
+ del1.reset(base_labels);
380
+ del2.reset(base_distances);
381
+
382
+ base_index->search(
383
+ n, x, k_base, base_distances, base_labels, base_index_params);
384
+
385
+ for (int i = 0; i < n * k_base; i++) {
386
+ assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
387
+ }
388
+
389
+ refine_index->search_subset(
390
+ n, x, k_base, base_labels, k, distances, labels);
391
+ }
392
+
344
393
  } // namespace faiss
@@ -95,4 +95,21 @@ struct IndexRefineFlat : IndexRefine {
95
95
  const SearchParameters* params = nullptr) const override;
96
96
  };
97
97
 
98
+ /** Version where the search calls search_subset, allowing for Panorama
99
+ * refinement. */
100
+ struct IndexRefinePanorama : IndexRefine {
101
+ explicit IndexRefinePanorama(Index* base_index, Index* refine_index)
102
+ : IndexRefine(base_index, refine_index) {}
103
+
104
+ IndexRefinePanorama() : IndexRefine() {}
105
+
106
+ void search(
107
+ idx_t n,
108
+ const float* x,
109
+ idx_t k,
110
+ float* distances,
111
+ idx_t* labels,
112
+ const SearchParameters* params = nullptr) const override;
113
+ };
114
+
98
115
  } // namespace faiss