faiss 0.5.2 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +5 -6
  6. data/ext/faiss/index_binary.cpp +76 -17
  7. data/ext/faiss/{index.cpp → index_rb.cpp} +108 -35
  8. data/ext/faiss/kmeans.cpp +12 -9
  9. data/ext/faiss/numo.hpp +11 -9
  10. data/ext/faiss/pca_matrix.cpp +10 -8
  11. data/ext/faiss/product_quantizer.cpp +14 -12
  12. data/ext/faiss/{utils.cpp → utils_rb.cpp} +10 -3
  13. data/ext/faiss/{utils.h → utils_rb.h} +6 -0
  14. data/lib/faiss/version.rb +1 -1
  15. data/lib/faiss.rb +1 -1
  16. data/vendor/faiss/faiss/AutoTune.cpp +130 -11
  17. data/vendor/faiss/faiss/AutoTune.h +14 -1
  18. data/vendor/faiss/faiss/Clustering.cpp +59 -10
  19. data/vendor/faiss/faiss/Clustering.h +12 -0
  20. data/vendor/faiss/faiss/IVFlib.cpp +31 -28
  21. data/vendor/faiss/faiss/Index.cpp +20 -8
  22. data/vendor/faiss/faiss/Index.h +25 -3
  23. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +19 -24
  24. data/vendor/faiss/faiss/IndexBinary.cpp +1 -0
  25. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +9 -4
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +45 -11
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +35 -22
  28. data/vendor/faiss/faiss/IndexFastScan.h +10 -1
  29. data/vendor/faiss/faiss/IndexFlat.cpp +193 -136
  30. data/vendor/faiss/faiss/IndexFlat.h +16 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +46 -22
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +24 -50
  34. data/vendor/faiss/faiss/IndexHNSW.h +14 -12
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +1 -1
  36. data/vendor/faiss/faiss/IndexIVF.cpp +76 -49
  37. data/vendor/faiss/faiss/IndexIVF.h +14 -4
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +11 -8
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -2
  40. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +25 -14
  41. data/vendor/faiss/faiss/IndexIVFFastScan.h +26 -22
  42. data/vendor/faiss/faiss/IndexIVFFlat.cpp +10 -61
  43. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +39 -111
  44. data/vendor/faiss/faiss/IndexIVFPQ.cpp +89 -147
  45. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +37 -5
  46. data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -1
  47. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +42 -30
  48. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -2
  49. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +246 -97
  50. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +32 -29
  51. data/vendor/faiss/faiss/IndexLSH.cpp +8 -6
  52. data/vendor/faiss/faiss/IndexLattice.cpp +29 -24
  53. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -0
  54. data/vendor/faiss/faiss/IndexNSG.cpp +2 -1
  55. data/vendor/faiss/faiss/IndexNSG.h +0 -2
  56. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +1 -1
  57. data/vendor/faiss/faiss/IndexPQ.cpp +19 -10
  58. data/vendor/faiss/faiss/IndexRaBitQ.cpp +26 -13
  59. data/vendor/faiss/faiss/IndexRaBitQ.h +2 -2
  60. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +132 -78
  61. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +14 -12
  62. data/vendor/faiss/faiss/IndexRefine.cpp +0 -30
  63. data/vendor/faiss/faiss/IndexShards.cpp +3 -4
  64. data/vendor/faiss/faiss/MetricType.h +16 -0
  65. data/vendor/faiss/faiss/VectorTransform.cpp +120 -0
  66. data/vendor/faiss/faiss/VectorTransform.h +23 -0
  67. data/vendor/faiss/faiss/clone_index.cpp +7 -4
  68. data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +1 -1
  69. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  70. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +37 -11
  71. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -28
  72. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
  73. data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
  74. data/vendor/faiss/faiss/impl/CodePacker.cpp +4 -0
  75. data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
  76. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
  77. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
  78. data/vendor/faiss/faiss/impl/FaissAssert.h +60 -2
  79. data/vendor/faiss/faiss/impl/HNSW.cpp +25 -34
  80. data/vendor/faiss/faiss/impl/HNSW.h +8 -6
  81. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +34 -27
  82. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  83. data/vendor/faiss/faiss/impl/NSG.cpp +6 -5
  84. data/vendor/faiss/faiss/impl/NSG.h +17 -7
  85. data/vendor/faiss/faiss/impl/Panorama.cpp +53 -46
  86. data/vendor/faiss/faiss/impl/Panorama.h +22 -6
  87. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +16 -5
  88. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +70 -58
  89. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +92 -0
  90. data/vendor/faiss/faiss/impl/RaBitQUtils.h +93 -31
  91. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +12 -28
  92. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
  93. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
  94. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
  95. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +14 -9
  96. data/vendor/faiss/faiss/impl/ResultHandler.h +131 -50
  97. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +67 -2358
  98. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -2
  99. data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
  100. data/vendor/faiss/faiss/impl/VisitedTable.h +69 -0
  101. data/vendor/faiss/faiss/impl/expanded_scanners.h +158 -0
  102. data/vendor/faiss/faiss/impl/index_read.cpp +829 -471
  103. data/vendor/faiss/faiss/impl/index_read_utils.h +0 -1
  104. data/vendor/faiss/faiss/impl/index_write.cpp +17 -8
  105. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +47 -20
  106. data/vendor/faiss/faiss/impl/mapped_io.cpp +9 -2
  107. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +7 -2
  108. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +11 -3
  109. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +19 -13
  110. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +29 -21
  111. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.cpp} +42 -215
  112. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.cpp} +68 -107
  113. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +141 -0
  114. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +23 -0
  115. data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -144
  116. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +9 -6
  117. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
  118. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +136 -0
  119. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +280 -0
  120. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +164 -0
  121. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
  122. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +455 -0
  123. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +430 -0
  124. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +329 -0
  125. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +467 -0
  126. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +203 -0
  127. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +42 -0
  128. data/vendor/faiss/faiss/impl/simd_dispatch.h +139 -0
  129. data/vendor/faiss/faiss/impl/simd_result_handlers.h +18 -18
  130. data/vendor/faiss/faiss/index_factory.cpp +35 -16
  131. data/vendor/faiss/faiss/index_io.h +29 -3
  132. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +7 -4
  133. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
  134. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
  135. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
  136. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +2 -1
  137. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +9 -1
  138. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +9 -0
  139. data/vendor/faiss/faiss/utils/Heap.cpp +46 -0
  140. data/vendor/faiss/faiss/utils/Heap.h +21 -0
  141. data/vendor/faiss/faiss/utils/NeuralNet.cpp +10 -7
  142. data/vendor/faiss/faiss/utils/distances.cpp +141 -23
  143. data/vendor/faiss/faiss/utils/distances.h +98 -0
  144. data/vendor/faiss/faiss/utils/distances_dispatch.h +170 -0
  145. data/vendor/faiss/faiss/utils/distances_simd.cpp +74 -3511
  146. data/vendor/faiss/faiss/utils/extra_distances-inl.h +164 -157
  147. data/vendor/faiss/faiss/utils/extra_distances.cpp +52 -95
  148. data/vendor/faiss/faiss/utils/extra_distances.h +47 -1
  149. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -1
  150. data/vendor/faiss/faiss/utils/partitioning.cpp +1 -1
  151. data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
  152. data/vendor/faiss/faiss/utils/rabitq_simd.h +260 -0
  153. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +150 -0
  154. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +568 -0
  155. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +153 -0
  156. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1185 -0
  157. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1092 -0
  158. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +391 -0
  159. data/vendor/faiss/faiss/utils/simd_levels.cpp +322 -0
  160. data/vendor/faiss/faiss/utils/simd_levels.h +91 -0
  161. data/vendor/faiss/faiss/utils/simdlib_avx2.h +12 -1
  162. data/vendor/faiss/faiss/utils/simdlib_avx512.h +69 -0
  163. data/vendor/faiss/faiss/utils/simdlib_neon.h +6 -0
  164. data/vendor/faiss/faiss/utils/sorting.cpp +4 -4
  165. data/vendor/faiss/faiss/utils/utils.cpp +16 -9
  166. metadata +47 -18
  167. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
  168. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
  169. /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
@@ -8,8 +8,11 @@
8
8
  #include <faiss/IndexIVFRaBitQFastScan.h>
9
9
 
10
10
  #include <algorithm>
11
+ #include <array>
11
12
  #include <cstdio>
13
+ #include <memory>
12
14
 
15
+ #include <faiss/impl/CodePackerRaBitQ.h>
13
16
  #include <faiss/impl/FaissAssert.h>
14
17
  #include <faiss/impl/FastScanDistancePostProcessing.h>
15
18
  #include <faiss/impl/RaBitQUtils.h>
@@ -79,8 +82,6 @@ IndexIVFRaBitQFastScan::IndexIVFRaBitQFastScan(
79
82
  if (own_invlists) {
80
83
  replace_invlists(new BlockInvertedLists(nlist, get_CodePacker()), true);
81
84
  }
82
-
83
- flat_storage.clear();
84
85
  }
85
86
 
86
87
  // Constructor that converts an existing IndexIVFRaBitQ to FastScan format
@@ -97,41 +98,52 @@ IndexIVFRaBitQFastScan::IndexIVFRaBitQFastScan(
97
98
  rabitq(orig.rabitq) {}
98
99
 
99
100
  size_t IndexIVFRaBitQFastScan::compute_per_vector_storage_size() const {
100
- const size_t ex_bits = rabitq.nb_bits - 1;
101
+ return rabitq_utils::compute_per_vector_storage_size(rabitq.nb_bits, d);
102
+ }
101
103
 
102
- if (ex_bits == 0) {
103
- // 1-bit: only SignBitFactors (8 bytes)
104
- return sizeof(SignBitFactors);
105
- } else {
106
- // Multi-bit: SignBitFactorsWithError + ExtraBitsFactors + ex-codes
107
- return sizeof(SignBitFactorsWithError) + sizeof(ExtraBitsFactors) +
108
- (d * ex_bits + 7) / 8;
109
- }
104
+ size_t IndexIVFRaBitQFastScan::code_packing_stride() const {
105
+ // Use code_size as stride to skip embedded factor data during packing
106
+ return code_size;
110
107
  }
111
108
 
112
- void IndexIVFRaBitQFastScan::preprocess_code_metadata(
113
- idx_t n,
114
- const uint8_t* flat_codes,
115
- idx_t start_global_idx) {
116
- // Unified approach: always use flat_storage for both 1-bit and multi-bit
117
- const size_t storage_size = compute_per_vector_storage_size();
118
- flat_storage.resize((start_global_idx + n) * storage_size);
109
+ CodePacker* IndexIVFRaBitQFastScan::get_CodePacker() const {
110
+ return new CodePackerRaBitQ(M2, bbs, compute_per_vector_storage_size());
111
+ }
112
+
113
+ /*********************************************************
114
+ * postprocess_packed_codes: write auxiliary data into blocks
115
+ *********************************************************/
116
+
117
+ void IndexIVFRaBitQFastScan::postprocess_packed_codes(
118
+ idx_t list_no,
119
+ size_t list_offset,
120
+ size_t n_added,
121
+ const uint8_t* flat_codes) {
122
+ auto* bil = dynamic_cast<BlockInvertedLists*>(invlists);
123
+ FAISS_THROW_IF_NOT(bil);
119
124
 
120
- // Copy factors data directly to flat storage (no reordering needed)
125
+ uint8_t* block_data = bil->codes[list_no].data();
126
+ const size_t storage_size = compute_per_vector_storage_size();
121
127
  const size_t bit_pattern_size = (d + 7) / 8;
122
- for (idx_t i = 0; i < n; i++) {
123
- const uint8_t* code = flat_codes + i * code_size;
124
- const uint8_t* source_factors_ptr = code + bit_pattern_size;
125
- uint8_t* storage =
126
- flat_storage.data() + (start_global_idx + i) * storage_size;
127
- memcpy(storage, source_factors_ptr, storage_size);
128
+ const size_t packed_block_size = ((M2 + 1) / 2) * bbs;
129
+ const size_t full_block_size = get_block_stride();
130
+
131
+ for (size_t i = 0; i < n_added; i++) {
132
+ const uint8_t* src = flat_codes + i * code_size + bit_pattern_size;
133
+ uint8_t* dst = rabitq_utils::get_block_aux_ptr(
134
+ block_data,
135
+ list_offset + i,
136
+ bbs,
137
+ packed_block_size,
138
+ full_block_size,
139
+ storage_size);
140
+ memcpy(dst, src, storage_size);
128
141
  }
129
142
  }
130
143
 
131
- size_t IndexIVFRaBitQFastScan::code_packing_stride() const {
132
- // Use code_size as stride to skip embedded factor data during packing
133
- return code_size;
134
- }
144
+ /*********************************************************
145
+ * train_encoder
146
+ *********************************************************/
135
147
 
136
148
  void IndexIVFRaBitQFastScan::train_encoder(
137
149
  idx_t n,
@@ -271,10 +283,11 @@ void IndexIVFRaBitQFastScan::compute_residual_LUT(
271
283
  rotated_q,
272
284
  rotated_qq);
273
285
 
274
- // Override query norm for inner product if original query is provided
275
286
  if (metric_type == MetricType::METRIC_INNER_PRODUCT &&
276
287
  original_query != nullptr) {
277
288
  query_factors.qr_norm_L2sqr = fvec_norm_L2sqr(original_query, d);
289
+ query_factors.q_dot_c = query_factors.qr_norm_L2sqr -
290
+ fvec_inner_product(original_query, residual, d);
278
291
  }
279
292
 
280
293
  const size_t ex_bits = rabitq.nb_bits - 1;
@@ -441,23 +454,22 @@ void IndexIVFRaBitQFastScan::reconstruct_from_offset(
441
454
  }
442
455
  }
443
456
 
444
- // Get dp_multiplier directly from flat_storage
445
- InvertedLists::ScopedIds list_ids(invlists, list_no);
446
- idx_t global_id = list_ids[offset];
447
-
448
- float dp_multiplier = 1.0f;
449
- if (global_id >= 0) {
450
- const size_t storage_size = compute_per_vector_storage_size();
451
- const size_t storage_capacity = flat_storage.size() / storage_size;
452
-
453
- if (static_cast<size_t>(global_id) < storage_capacity) {
454
- const uint8_t* base_ptr =
455
- flat_storage.data() + global_id * storage_size;
456
- const auto& base_factors =
457
- *reinterpret_cast<const SignBitFactors*>(base_ptr);
458
- dp_multiplier = base_factors.dp_multiplier;
459
- }
460
- }
457
+ const size_t storage_size = compute_per_vector_storage_size();
458
+ const size_t packed_block_size = ((M2 + 1) / 2) * bbs;
459
+ const size_t full_block_size = get_block_stride();
460
+
461
+ InvertedLists::ScopedCodes list_block_codes(invlists, list_no);
462
+ const uint8_t* aux_ptr = rabitq_utils::get_block_aux_ptr(
463
+ list_block_codes.get(),
464
+ offset,
465
+ bbs,
466
+ packed_block_size,
467
+ full_block_size,
468
+ storage_size);
469
+
470
+ const auto& base_factors =
471
+ *reinterpret_cast<const SignBitFactors*>(aux_ptr);
472
+ const float dp_multiplier = base_factors.dp_multiplier;
461
473
 
462
474
  // Decode residual directly using dp_multiplier
463
475
  std::vector<float> residual(d);
@@ -573,7 +585,11 @@ IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::IVFRaBitQHeapHandler(
573
585
  nq(nq_val),
574
586
  k(k_val),
575
587
  context(ctx),
576
- is_multibit(multibit) {
588
+ is_multibit(multibit),
589
+ storage_size(idx->compute_per_vector_storage_size()),
590
+ packed_block_size(((idx->M2 + 1) / 2) * idx->bbs),
591
+ full_block_size(idx->get_block_stride()),
592
+ packer(idx->get_CodePacker()) {
577
593
  current_list_no = 0;
578
594
  probe_indices.clear();
579
595
 
@@ -649,10 +665,13 @@ void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::handle(
649
665
 
650
666
  const float normalized_distance = d32tab[j] * one_a + bias;
651
667
 
652
- // Get database factors from flat_storage
653
- const size_t storage_size = index->compute_per_vector_storage_size();
654
- const uint8_t* base_ptr =
655
- index->flat_storage.data() + result_id * storage_size;
668
+ const uint8_t* base_ptr = rabitq_utils::get_block_aux_ptr(
669
+ list_codes_ptr,
670
+ idx_base + j,
671
+ index->bbs,
672
+ packed_block_size,
673
+ full_block_size,
674
+ storage_size);
656
675
 
657
676
  if (is_multibit) {
658
677
  // Track candidates actually considered for two-stage filtering
@@ -671,17 +690,18 @@ void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::handle(
671
690
  index->qb,
672
691
  index->d);
673
692
 
674
- // Compute lower bound using error bound
675
- float lower_bound =
676
- compute_lower_bound(dist_1bit, result_id, local_q, q);
677
-
678
693
  // Adaptive filtering: decide whether to compute full distance
679
694
  const bool is_similarity =
680
695
  index->metric_type == MetricType::METRIC_INNER_PRODUCT;
681
- bool should_refine = is_similarity
682
- ? (lower_bound > heap_dis[0]) // IP: keep if better
683
- : (lower_bound < heap_dis[0]); // L2: keep if better
684
696
 
697
+ float g_error = query_factors.g_error;
698
+
699
+ bool should_refine = rabitq_utils::should_refine_candidate(
700
+ dist_1bit,
701
+ full_factors.f_error,
702
+ g_error,
703
+ heap_dis[0],
704
+ is_similarity);
685
705
  if (should_refine) {
686
706
  local_multibit_evaluations++;
687
707
 
@@ -696,6 +716,7 @@ void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::handle(
696
716
  if (Cfloat::cmp(heap_dis[0], dist_full)) {
697
717
  heap_replace_top<Cfloat>(
698
718
  k, heap_dis, heap_ids, dist_full, result_id);
719
+ nup++;
699
720
  }
700
721
  }
701
722
  } else {
@@ -715,6 +736,7 @@ void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::handle(
715
736
  if (Cfloat::cmp(heap_dis[0], adjusted_distance)) {
716
737
  heap_replace_top<Cfloat>(
717
738
  k, heap_dis, heap_ids, adjusted_distance, result_id);
739
+ nup++;
718
740
  }
719
741
  }
720
742
  }
@@ -732,6 +754,7 @@ void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::set_list_context(
732
754
  const std::vector<int>& probe_map) {
733
755
  current_list_no = list_no;
734
756
  probe_indices = probe_map;
757
+ list_codes_ptr = index->invlists->get_codes(list_no);
735
758
  }
736
759
 
737
760
  template <class C>
@@ -750,49 +773,23 @@ void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::end() {
750
773
  }
751
774
  }
752
775
 
753
- template <class C>
754
- float IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::compute_lower_bound(
755
- float dist_1bit,
756
- size_t db_idx,
757
- size_t local_q,
758
- size_t global_q) const {
759
- // Access f_error from SignBitFactorsWithError in flat storage
760
- const size_t storage_size = index->compute_per_vector_storage_size();
761
- const uint8_t* base_ptr =
762
- index->flat_storage.data() + db_idx * storage_size;
763
- const SignBitFactorsWithError& db_factors =
764
- *reinterpret_cast<const SignBitFactorsWithError*>(base_ptr);
765
- float f_error = db_factors.f_error;
766
-
767
- // Get g_error from query factors
768
- // Use local_q to access probe_indices (batch-local), global_q for storage
769
- float g_error = 0.0f;
770
- if (context && context->query_factors) {
771
- size_t probe_rank = probe_indices[local_q];
772
- size_t nprobe = context->nprobe > 0 ? context->nprobe : index->nprobe;
773
- size_t storage_idx = global_q * nprobe + probe_rank;
774
- g_error = context->query_factors[storage_idx].g_error;
775
- }
776
-
777
- // Compute error adjustment: f_error * g_error
778
- float error_adjustment = f_error * g_error;
779
-
780
- return dist_1bit - error_adjustment;
781
- }
782
-
783
776
  template <class C>
784
777
  float IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::
785
778
  compute_full_multibit_distance(
786
- size_t db_idx,
779
+ size_t /*db_idx*/,
787
780
  size_t local_q,
788
781
  size_t global_q,
789
782
  size_t local_offset) const {
790
783
  const size_t ex_bits = index->rabitq.nb_bits - 1;
791
784
  const size_t dim = index->d;
792
785
 
793
- const size_t storage_size = index->compute_per_vector_storage_size();
794
- const uint8_t* base_ptr =
795
- index->flat_storage.data() + db_idx * storage_size;
786
+ const uint8_t* base_ptr = rabitq_utils::get_block_aux_ptr(
787
+ list_codes_ptr,
788
+ local_offset,
789
+ index->bbs,
790
+ packed_block_size,
791
+ full_block_size,
792
+ storage_size);
796
793
 
797
794
  const size_t ex_code_size = (dim * ex_bits + 7) / 8;
798
795
  const uint8_t* ex_code = base_ptr + sizeof(SignBitFactorsWithError);
@@ -809,8 +806,7 @@ float IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::
809
806
  InvertedLists::ScopedCodes list_codes(index->invlists, list_no);
810
807
 
811
808
  std::vector<uint8_t> unpacked_code(index->code_size);
812
- CodePackerPQ4 packer(index->M2, index->bbs);
813
- packer.unpack_1(list_codes.get(), local_offset, unpacked_code.data());
809
+ packer->unpack_1(list_codes.get(), local_offset, unpacked_code.data());
814
810
  const uint8_t* sign_bits = unpacked_code.data();
815
811
 
816
812
  return rabitq_utils::compute_full_multibit_distance(
@@ -818,11 +814,164 @@ float IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::
818
814
  ex_code,
819
815
  ex_fac,
820
816
  query_factors.rotated_q.data(),
821
- query_factors.qr_to_c_L2sqr,
822
- query_factors.qr_norm_L2sqr,
817
+ (index->metric_type == MetricType::METRIC_INNER_PRODUCT)
818
+ ? query_factors.q_dot_c
819
+ : query_factors.qr_to_c_L2sqr,
823
820
  dim,
824
821
  ex_bits,
825
822
  index->metric_type);
826
823
  }
827
824
 
825
+ /*********************************************************
826
+ * IVFRaBitQFastScanScanner implementation
827
+ *********************************************************/
828
+
829
+ namespace {
830
+
831
+ /// Provides IVF scanner interface using FastScan's SIMD batch processing.
832
+ struct IVFRaBitQFastScanScanner : InvertedListScanner {
833
+ static constexpr int impl = 10;
834
+ static constexpr size_t nq = 1;
835
+
836
+ const IndexIVFRaBitQFastScan& index;
837
+
838
+ AlignedTable<uint8_t> dis_tables;
839
+ AlignedTable<uint16_t> biases;
840
+ /// [scale, offset] for converting uint16 to float
841
+ std::array<float, 2> normalizers{};
842
+
843
+ const float* xi = nullptr;
844
+
845
+ QueryFactorsData query_factors;
846
+ FastScanDistancePostProcessing context;
847
+
848
+ std::unique_ptr<FlatCodesDistanceComputer> dc;
849
+ std::vector<float> centroid;
850
+
851
+ IVFRaBitQFastScanScanner(
852
+ const IndexIVFRaBitQFastScan& index,
853
+ bool store_pairs,
854
+ const IDSelector* sel)
855
+ : InvertedListScanner(store_pairs, sel), index(index) {
856
+ this->keep_max = is_similarity_metric(index.metric_type);
857
+ }
858
+
859
+ void set_query(const float* query) override {
860
+ this->xi = query;
861
+ }
862
+
863
+ void set_list(idx_t list_no, float coarse_dis) override {
864
+ this->list_no = list_no;
865
+
866
+ IndexIVFFastScan::CoarseQuantized cq{
867
+ .nprobe = 1,
868
+ .dis = &coarse_dis,
869
+ .ids = &list_no,
870
+ };
871
+
872
+ // Set up context for use in scan_codes
873
+ context = FastScanDistancePostProcessing{};
874
+ context.query_factors = &query_factors;
875
+ context.nprobe = 1;
876
+
877
+ index.compute_LUT_uint8(
878
+ 1, xi, cq, dis_tables, biases, &normalizers[0], context);
879
+
880
+ // Set up distance computer for distance_to_code
881
+ centroid.resize(index.d);
882
+ index.quantizer->reconstruct(list_no, centroid.data());
883
+ dc.reset(index.rabitq.get_distance_computer(
884
+ index.qb, centroid.data(), index.centered));
885
+ dc->set_query(xi);
886
+ }
887
+
888
+ float distance_to_code(const uint8_t* code) const override {
889
+ FAISS_THROW_IF_NOT_MSG(
890
+ dc,
891
+ "set_query and set_list must be called before distance_to_code");
892
+ return dc->distance_to_code(code);
893
+ }
894
+
895
+ public:
896
+ size_t scan_codes(
897
+ size_t ntotal,
898
+ const uint8_t* codes,
899
+ const idx_t* ids,
900
+ float* distances,
901
+ idx_t* labels,
902
+ size_t k) const override {
903
+ // initialize the current iteration heap to the worst possible value of
904
+ // the prior loop
905
+ std::vector<float> curr_dists(k, distances[0]);
906
+ std::vector<idx_t> curr_labels(k, labels[0]);
907
+
908
+ std::unique_ptr<SIMDResultHandlerToFloat> handler(
909
+ index.make_knn_handler(
910
+ !keep_max,
911
+ impl,
912
+ nq,
913
+ k,
914
+ curr_dists.data(),
915
+ curr_labels.data(),
916
+ sel,
917
+ context,
918
+ &normalizers[0]));
919
+
920
+ int qmap1[1] = {0};
921
+ handler->q_map = qmap1;
922
+ handler->begin(&normalizers[0]);
923
+
924
+ const uint8_t* LUT = dis_tables.get();
925
+ handler->dbias = biases.get();
926
+ handler->ntotal = ntotal;
927
+ handler->id_map = ids;
928
+
929
+ // RaBitQ needs list context for factor lookup
930
+ std::vector<int> probe_map = {0};
931
+ handler->set_list_context(list_no, probe_map);
932
+
933
+ pq4_accumulate_loop(
934
+ 1,
935
+ roundup(ntotal, index.bbs),
936
+ index.bbs,
937
+ static_cast<int>(index.M2),
938
+ codes,
939
+ LUT,
940
+ *handler,
941
+ nullptr,
942
+ index.get_block_stride());
943
+
944
+ // Combine results across iterations
945
+ handler->end();
946
+ if (keep_max) {
947
+ minheap_addn(
948
+ k,
949
+ distances,
950
+ labels,
951
+ curr_dists.data(),
952
+ curr_labels.data(),
953
+ k);
954
+ } else {
955
+ maxheap_addn(
956
+ k,
957
+ distances,
958
+ labels,
959
+ curr_dists.data(),
960
+ curr_labels.data(),
961
+ k);
962
+ }
963
+
964
+ return handler->num_updates();
965
+ }
966
+ };
967
+
968
+ } // anonymous namespace
969
+
970
+ InvertedListScanner* IndexIVFRaBitQFastScan::get_InvertedListScanner(
971
+ bool store_pairs,
972
+ const IDSelector* sel,
973
+ const IVFSearchParameters*) const {
974
+ return new IVFRaBitQFastScanScanner(*this, store_pairs, sel);
975
+ }
976
+
828
977
  } // namespace faiss
@@ -7,6 +7,7 @@
7
7
 
8
8
  #pragma once
9
9
 
10
+ #include <memory>
10
11
  #include <vector>
11
12
 
12
13
  #include <faiss/IndexIVFFastScan.h>
@@ -55,17 +56,6 @@ struct IndexIVFRaBitQFastScan : IndexIVFFastScan {
55
56
  /// Use zero-centered scalar quantizer for queries
56
57
  bool centered = false;
57
58
 
58
- /// Per-vector auxiliary data (1-bit codes stored separately in `codes`)
59
- ///
60
- /// 1-bit codes (sign bits) are stored in the inherited `codes` array from
61
- /// IndexFastScan in packed FastScan format for SIMD processing.
62
- ///
63
- /// This flat_storage holds per-vector factors and refinement-bit codes:
64
- /// Layout for 1-bit: [SignBitFactors (8 bytes)]
65
- /// Layout for multi-bit: [SignBitFactorsWithError
66
- /// (12B)][ref_codes][ExtraBitsFactors (8B)]
67
- std::vector<uint8_t> flat_storage;
68
-
69
59
  // Constructors
70
60
 
71
61
  IndexIVFRaBitQFastScan();
@@ -94,16 +84,20 @@ struct IndexIVFRaBitQFastScan : IndexIVFFastScan {
94
84
  bool include_listnos = false) const override;
95
85
 
96
86
  protected:
97
- /// Extract and store RaBitQ factors from encoded vectors
98
- void preprocess_code_metadata(
99
- idx_t n,
100
- const uint8_t* flat_codes,
101
- idx_t start_global_idx) override;
102
-
103
87
  /// Return code_size as stride to skip embedded factor data during packing
104
88
  size_t code_packing_stride() const override;
105
89
 
106
90
  public:
91
+ /// Return CodePackerRaBitQ with enlarged block size
92
+ CodePacker* get_CodePacker() const override;
93
+
94
+ /// Write per-vector auxiliary data into block auxiliary region
95
+ void postprocess_packed_codes(
96
+ idx_t list_no,
97
+ size_t list_offset,
98
+ size_t n_added,
99
+ const uint8_t* flat_codes) override;
100
+
107
101
  /// Reconstruct a single vector from an inverted list
108
102
  void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
109
103
  const override;
@@ -111,7 +105,7 @@ struct IndexIVFRaBitQFastScan : IndexIVFFastScan {
111
105
  /// Override sa_decode to handle RaBitQ reconstruction
112
106
  void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
113
107
 
114
- /// Compute storage size per vector in flat_storage based on nb_bits
108
+ /// Compute per-vector auxiliary storage size based on nb_bits
115
109
  size_t compute_per_vector_storage_size() const;
116
110
 
117
111
  private:
@@ -166,6 +160,13 @@ struct IndexIVFRaBitQFastScan : IndexIVFFastScan {
166
160
  const FastScanDistancePostProcessing& context,
167
161
  const float* normalizers = nullptr) const override;
168
162
 
163
+ /// Get an InvertedListScanner for single-query scanning.
164
+ /// This provides compatibility with the standard IVF search interface
165
+ InvertedListScanner* get_InvertedListScanner(
166
+ bool store_pairs = false,
167
+ const IDSelector* sel = nullptr,
168
+ const IVFSearchParameters* params = nullptr) const override;
169
+
169
170
  /** SIMD result handler for IndexIVFRaBitQFastScan that applies
170
171
  * RaBitQ-specific distance corrections during batch processing.
171
172
  *
@@ -192,11 +193,19 @@ struct IndexIVFRaBitQFastScan : IndexIVFFastScan {
192
193
  int64_t* heap_labels; // [nq * k]
193
194
  const size_t nq, k;
194
195
  size_t current_list_no = 0;
196
+ const uint8_t* list_codes_ptr = nullptr; // raw block data for list
195
197
  std::vector<int>
196
198
  probe_indices; // probe index for each query in current batch
197
199
  const FastScanDistancePostProcessing*
198
200
  context; // Processing context with query factors
199
201
  const bool is_multibit; // Whether to use multi-bit two-stage search
202
+ size_t nup = 0; // Number of heap updates
203
+
204
+ // Cached block-layout constants (invariant for handler lifetime)
205
+ const size_t storage_size;
206
+ const size_t packed_block_size;
207
+ const size_t full_block_size;
208
+ std::unique_ptr<CodePacker> packer; // cached for unpack in hot path
200
209
 
201
210
  // Use float-based comparator for heap operations
202
211
  using Cfloat = typename std::conditional<
@@ -224,6 +233,10 @@ struct IndexIVFRaBitQFastScan : IndexIVFFastScan {
224
233
 
225
234
  void end() override;
226
235
 
236
+ size_t num_updates() override {
237
+ return nup;
238
+ }
239
+
227
240
  private:
228
241
  /// Compute full multi-bit distance for a candidate vector (multi-bit
229
242
  /// only)
@@ -232,20 +245,10 @@ struct IndexIVFRaBitQFastScan : IndexIVFFastScan {
232
245
  /// @param global_q Global query index (for storage indexing)
233
246
  /// @param local_offset Offset within the current inverted list
234
247
  float compute_full_multibit_distance(
235
- size_t db_idx,
248
+ size_t /*db_idx*/,
236
249
  size_t local_q,
237
250
  size_t global_q,
238
251
  size_t local_offset) const;
239
-
240
- /// Compute lower bound using 1-bit distance and error bound (multi-bit
241
- /// only)
242
- /// @param local_q Batch-local query index (for probe_indices access)
243
- /// @param global_q Global query index (for storage indexing)
244
- float compute_lower_bound(
245
- float dist_1bit,
246
- size_t db_idx,
247
- size_t local_q,
248
- size_t global_q) const;
249
252
  };
250
253
  };
251
254
 
@@ -86,12 +86,14 @@ void IndexLSH::train(idx_t n, const float* x) {
86
86
 
87
87
  for (idx_t i = 0; i < nbits; i++) {
88
88
  float* xi = transposed_x.get() + i * n;
89
- // std::nth_element
90
- std::sort(xi, xi + n);
91
- if (n % 2 == 1)
92
- thresholds[i] = xi[n / 2];
93
- else
94
- thresholds[i] = (xi[n / 2 - 1] + xi[n / 2]) / 2;
89
+ // Use nth_element (O(n)) instead of sort (O(n log n))
90
+ std::nth_element(xi, xi + n / 2, xi + n);
91
+ float median = xi[n / 2];
92
+ if (n % 2 == 0) {
93
+ std::nth_element(xi, xi + n / 2 - 1, xi + n);
94
+ median = (median + xi[n / 2 - 1]) / 2;
95
+ }
96
+ thresholds[i] = median;
95
97
  }
96
98
  }
97
99
  is_trained = true;