faiss 0.4.3 → 0.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +10 -0
  3. data/README.md +2 -0
  4. data/ext/faiss/index.cpp +33 -6
  5. data/ext/faiss/index_binary.cpp +17 -4
  6. data/ext/faiss/kmeans.cpp +6 -6
  7. data/lib/faiss/version.rb +1 -1
  8. data/vendor/faiss/faiss/AutoTune.cpp +2 -3
  9. data/vendor/faiss/faiss/AutoTune.h +1 -1
  10. data/vendor/faiss/faiss/Clustering.cpp +2 -2
  11. data/vendor/faiss/faiss/Clustering.h +2 -2
  12. data/vendor/faiss/faiss/IVFlib.cpp +26 -51
  13. data/vendor/faiss/faiss/IVFlib.h +1 -1
  14. data/vendor/faiss/faiss/Index.cpp +11 -0
  15. data/vendor/faiss/faiss/Index.h +34 -11
  16. data/vendor/faiss/faiss/Index2Layer.cpp +1 -1
  17. data/vendor/faiss/faiss/Index2Layer.h +2 -2
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +1 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +9 -4
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +5 -1
  21. data/vendor/faiss/faiss/IndexBinary.h +7 -7
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +1 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +8 -2
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +3 -3
  26. data/vendor/faiss/faiss/IndexBinaryHash.h +5 -5
  27. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +7 -6
  28. data/vendor/faiss/faiss/IndexFastScan.cpp +125 -49
  29. data/vendor/faiss/faiss/IndexFastScan.h +102 -7
  30. data/vendor/faiss/faiss/IndexFlat.cpp +374 -4
  31. data/vendor/faiss/faiss/IndexFlat.h +81 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +93 -2
  33. data/vendor/faiss/faiss/IndexHNSW.h +58 -2
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +14 -13
  35. data/vendor/faiss/faiss/IndexIDMap.h +6 -6
  36. data/vendor/faiss/faiss/IndexIVF.cpp +1 -1
  37. data/vendor/faiss/faiss/IndexIVF.h +5 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +1 -1
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +9 -3
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +3 -1
  41. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +176 -90
  42. data/vendor/faiss/faiss/IndexIVFFastScan.h +173 -18
  43. data/vendor/faiss/faiss/IndexIVFFlat.cpp +1 -0
  44. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +251 -0
  45. data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +64 -0
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -1
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +134 -2
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -1
  50. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +99 -8
  51. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +4 -1
  52. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +828 -0
  53. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +252 -0
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -1
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
  56. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -1
  57. data/vendor/faiss/faiss/IndexNSG.cpp +1 -1
  58. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +1 -1
  59. data/vendor/faiss/faiss/IndexPQ.cpp +4 -1
  60. data/vendor/faiss/faiss/IndexPQ.h +1 -1
  61. data/vendor/faiss/faiss/IndexPQFastScan.cpp +6 -2
  62. data/vendor/faiss/faiss/IndexPQFastScan.h +5 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +14 -0
  64. data/vendor/faiss/faiss/IndexPreTransform.h +9 -0
  65. data/vendor/faiss/faiss/IndexRaBitQ.cpp +96 -13
  66. data/vendor/faiss/faiss/IndexRaBitQ.h +11 -2
  67. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +731 -0
  68. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +175 -0
  69. data/vendor/faiss/faiss/IndexRefine.cpp +49 -0
  70. data/vendor/faiss/faiss/IndexRefine.h +17 -0
  71. data/vendor/faiss/faiss/IndexShards.cpp +1 -1
  72. data/vendor/faiss/faiss/MatrixStats.cpp +3 -3
  73. data/vendor/faiss/faiss/MetricType.h +1 -1
  74. data/vendor/faiss/faiss/VectorTransform.h +2 -2
  75. data/vendor/faiss/faiss/clone_index.cpp +5 -1
  76. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  77. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +3 -1
  78. data/vendor/faiss/faiss/gpu/GpuIndex.h +11 -11
  79. data/vendor/faiss/faiss/gpu/GpuIndexBinaryCagra.h +1 -1
  80. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +1 -1
  81. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +11 -7
  82. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +1 -1
  83. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -0
  84. data/vendor/faiss/faiss/gpu/test/TestGpuIcmEncoder.cpp +7 -0
  85. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +1 -1
  86. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +1 -1
  87. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -1
  88. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +2 -2
  89. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -1
  90. data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
  91. data/vendor/faiss/faiss/impl/DistanceComputer.h +77 -6
  92. data/vendor/faiss/faiss/impl/FastScanDistancePostProcessing.h +53 -0
  93. data/vendor/faiss/faiss/impl/HNSW.cpp +295 -16
  94. data/vendor/faiss/faiss/impl/HNSW.h +35 -6
  95. data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
  96. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  97. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +4 -4
  98. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +1 -1
  99. data/vendor/faiss/faiss/impl/LookupTableScaler.h +1 -1
  100. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  101. data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
  102. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  103. data/vendor/faiss/faiss/impl/Panorama.cpp +193 -0
  104. data/vendor/faiss/faiss/impl/Panorama.h +204 -0
  105. data/vendor/faiss/faiss/impl/PanoramaStats.cpp +33 -0
  106. data/vendor/faiss/faiss/impl/PanoramaStats.h +38 -0
  107. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +5 -5
  108. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +1 -1
  109. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  110. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -0
  111. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  112. data/vendor/faiss/faiss/impl/RaBitQStats.cpp +29 -0
  113. data/vendor/faiss/faiss/impl/RaBitQStats.h +56 -0
  114. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +294 -0
  115. data/vendor/faiss/faiss/impl/RaBitQUtils.h +330 -0
  116. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +304 -223
  117. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +72 -4
  118. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +362 -0
  119. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +112 -0
  120. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
  121. data/vendor/faiss/faiss/impl/ResultHandler.h +4 -4
  122. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +7 -10
  123. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +2 -4
  124. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +7 -4
  125. data/vendor/faiss/faiss/impl/index_read.cpp +238 -10
  126. data/vendor/faiss/faiss/impl/index_write.cpp +212 -19
  127. data/vendor/faiss/faiss/impl/io.cpp +2 -2
  128. data/vendor/faiss/faiss/impl/io.h +4 -4
  129. data/vendor/faiss/faiss/impl/kmeans1d.cpp +1 -1
  130. data/vendor/faiss/faiss/impl/kmeans1d.h +1 -1
  131. data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
  132. data/vendor/faiss/faiss/impl/mapped_io.cpp +2 -2
  133. data/vendor/faiss/faiss/impl/mapped_io.h +4 -3
  134. data/vendor/faiss/faiss/impl/maybe_owned_vector.h +8 -1
  135. data/vendor/faiss/faiss/impl/platform_macros.h +12 -0
  136. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +30 -4
  137. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +14 -8
  138. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +5 -6
  139. data/vendor/faiss/faiss/impl/simd_result_handlers.h +55 -11
  140. data/vendor/faiss/faiss/impl/svs_io.cpp +86 -0
  141. data/vendor/faiss/faiss/impl/svs_io.h +67 -0
  142. data/vendor/faiss/faiss/impl/zerocopy_io.h +1 -1
  143. data/vendor/faiss/faiss/index_factory.cpp +217 -8
  144. data/vendor/faiss/faiss/index_factory.h +1 -1
  145. data/vendor/faiss/faiss/index_io.h +1 -1
  146. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +1 -1
  147. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  148. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +115 -1
  149. data/vendor/faiss/faiss/invlists/InvertedLists.h +46 -0
  150. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
  151. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  152. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +261 -0
  153. data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +117 -0
  154. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +66 -0
  155. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +245 -0
  156. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +137 -0
  157. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +39 -0
  158. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +42 -0
  159. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +149 -0
  160. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +58 -0
  161. data/vendor/faiss/faiss/utils/AlignedTable.h +1 -1
  162. data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
  163. data/vendor/faiss/faiss/utils/Heap.h +3 -3
  164. data/vendor/faiss/faiss/utils/NeuralNet.cpp +1 -1
  165. data/vendor/faiss/faiss/utils/NeuralNet.h +3 -3
  166. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
  167. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
  168. data/vendor/faiss/faiss/utils/approx_topk/mode.h +1 -1
  169. data/vendor/faiss/faiss/utils/distances.cpp +0 -3
  170. data/vendor/faiss/faiss/utils/distances.h +2 -2
  171. data/vendor/faiss/faiss/utils/extra_distances-inl.h +3 -1
  172. data/vendor/faiss/faiss/utils/hamming-inl.h +2 -0
  173. data/vendor/faiss/faiss/utils/hamming.cpp +7 -6
  174. data/vendor/faiss/faiss/utils/hamming.h +1 -1
  175. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -2
  176. data/vendor/faiss/faiss/utils/partitioning.cpp +5 -5
  177. data/vendor/faiss/faiss/utils/partitioning.h +2 -2
  178. data/vendor/faiss/faiss/utils/rabitq_simd.h +222 -336
  179. data/vendor/faiss/faiss/utils/random.cpp +1 -1
  180. data/vendor/faiss/faiss/utils/simdlib_avx2.h +1 -1
  181. data/vendor/faiss/faiss/utils/simdlib_avx512.h +1 -1
  182. data/vendor/faiss/faiss/utils/simdlib_neon.h +2 -2
  183. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +1 -1
  184. data/vendor/faiss/faiss/utils/utils.cpp +9 -2
  185. data/vendor/faiss/faiss/utils/utils.h +2 -2
  186. metadata +29 -1
@@ -0,0 +1,828 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/IndexIVFRaBitQFastScan.h>
9
+
10
+ #include <algorithm>
11
+ #include <cstdio>
12
+
13
+ #include <faiss/impl/FaissAssert.h>
14
+ #include <faiss/impl/FastScanDistancePostProcessing.h>
15
+ #include <faiss/impl/RaBitQUtils.h>
16
+ #include <faiss/impl/RaBitQuantizerMultiBit.h>
17
+ #include <faiss/impl/pq4_fast_scan.h>
18
+ #include <faiss/impl/simd_result_handlers.h>
19
+ #include <faiss/invlists/BlockInvertedLists.h>
20
+ #include <faiss/utils/distances.h>
21
+ #include <faiss/utils/utils.h>
22
+
23
+ namespace faiss {
24
+
25
+ // Import shared utilities from RaBitQUtils
26
+ using rabitq_utils::ExtraBitsFactors;
27
+ using rabitq_utils::QueryFactorsData;
28
+ using rabitq_utils::SignBitFactors;
29
+ using rabitq_utils::SignBitFactorsWithError;
30
+
31
+ inline size_t roundup(size_t a, size_t b) {
32
+ return (a + b - 1) / b * b;
33
+ }
34
+
35
+ /*********************************************************
36
+ * IndexIVFRaBitQFastScan implementation
37
+ *********************************************************/
38
+
39
+ IndexIVFRaBitQFastScan::IndexIVFRaBitQFastScan() = default;
40
+
41
+ IndexIVFRaBitQFastScan::IndexIVFRaBitQFastScan(
42
+ Index* quantizer,
43
+ size_t d,
44
+ size_t nlist,
45
+ MetricType metric,
46
+ int bbs,
47
+ bool own_invlists,
48
+ uint8_t nb_bits)
49
+ : IndexIVFFastScan(quantizer, d, nlist, 0, metric, own_invlists),
50
+ rabitq(d, metric, nb_bits) {
51
+ FAISS_THROW_IF_NOT_MSG(d > 0, "Dimension must be positive");
52
+ FAISS_THROW_IF_NOT_MSG(
53
+ metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT,
54
+ "RaBitQ only supports L2 and Inner Product metrics");
55
+ FAISS_THROW_IF_NOT_MSG(bbs % 32 == 0, "Batch size must be multiple of 32");
56
+ FAISS_THROW_IF_NOT_MSG(quantizer != nullptr, "Quantizer cannot be null");
57
+
58
+ by_residual = true;
59
+ qb = 8; // RaBitQ quantization bits
60
+ centered = false;
61
+
62
+ // FastScan-specific parameters: 4 bits per sub-quantizer
63
+ const size_t M_fastscan = (d + 3) / 4;
64
+ constexpr size_t nbits_fastscan = 4;
65
+
66
+ this->bbs = bbs;
67
+ this->fine_quantizer = &rabitq;
68
+ this->M = M_fastscan;
69
+ this->nbits = nbits_fastscan;
70
+ this->ksub = (1 << nbits_fastscan);
71
+ this->M2 = roundup(M_fastscan, 2);
72
+
73
+ // Compute code_size: bit_pattern + per-vector storage (factors/ex-codes)
74
+ const size_t bit_pattern_size = (d + 7) / 8;
75
+ this->code_size = bit_pattern_size + compute_per_vector_storage_size();
76
+
77
+ is_trained = false;
78
+
79
+ if (own_invlists) {
80
+ replace_invlists(new BlockInvertedLists(nlist, get_CodePacker()), true);
81
+ }
82
+
83
+ flat_storage.clear();
84
+ }
85
+
86
+ // Constructor that converts an existing IndexIVFRaBitQ to FastScan format
87
+ IndexIVFRaBitQFastScan::IndexIVFRaBitQFastScan(
88
+ const IndexIVFRaBitQ& orig,
89
+ int /* bbs */)
90
+ : IndexIVFFastScan(
91
+ orig.quantizer,
92
+ orig.d,
93
+ orig.nlist,
94
+ 0,
95
+ orig.metric_type,
96
+ false),
97
+ rabitq(orig.rabitq) {}
98
+
99
+ size_t IndexIVFRaBitQFastScan::compute_per_vector_storage_size() const {
100
+ const size_t ex_bits = rabitq.nb_bits - 1;
101
+
102
+ if (ex_bits == 0) {
103
+ // 1-bit: only SignBitFactors (8 bytes)
104
+ return sizeof(SignBitFactors);
105
+ } else {
106
+ // Multi-bit: SignBitFactorsWithError + ExtraBitsFactors + ex-codes
107
+ return sizeof(SignBitFactorsWithError) + sizeof(ExtraBitsFactors) +
108
+ (d * ex_bits + 7) / 8;
109
+ }
110
+ }
111
+
112
+ void IndexIVFRaBitQFastScan::preprocess_code_metadata(
113
+ idx_t n,
114
+ const uint8_t* flat_codes,
115
+ idx_t start_global_idx) {
116
+ // Unified approach: always use flat_storage for both 1-bit and multi-bit
117
+ const size_t storage_size = compute_per_vector_storage_size();
118
+ flat_storage.resize((start_global_idx + n) * storage_size);
119
+
120
+ // Copy factors data directly to flat storage (no reordering needed)
121
+ const size_t bit_pattern_size = (d + 7) / 8;
122
+ for (idx_t i = 0; i < n; i++) {
123
+ const uint8_t* code = flat_codes + i * code_size;
124
+ const uint8_t* source_factors_ptr = code + bit_pattern_size;
125
+ uint8_t* storage =
126
+ flat_storage.data() + (start_global_idx + i) * storage_size;
127
+ memcpy(storage, source_factors_ptr, storage_size);
128
+ }
129
+ }
130
+
131
+ size_t IndexIVFRaBitQFastScan::code_packing_stride() const {
132
+ // Use code_size as stride to skip embedded factor data during packing
133
+ return code_size;
134
+ }
135
+
136
+ void IndexIVFRaBitQFastScan::train_encoder(
137
+ idx_t n,
138
+ const float* x,
139
+ const idx_t* assign) {
140
+ FAISS_THROW_IF_NOT(n > 0);
141
+ FAISS_THROW_IF_NOT(x != nullptr);
142
+ FAISS_THROW_IF_NOT(assign != nullptr || !by_residual);
143
+
144
+ rabitq.train(n, x);
145
+ is_trained = true;
146
+ init_code_packer();
147
+ }
148
+
149
+ void IndexIVFRaBitQFastScan::encode_vectors(
150
+ idx_t n,
151
+ const float* x,
152
+ const idx_t* list_nos,
153
+ uint8_t* codes,
154
+ bool include_listnos) const {
155
+ FAISS_THROW_IF_NOT(n > 0);
156
+ FAISS_THROW_IF_NOT(x != nullptr);
157
+ FAISS_THROW_IF_NOT(list_nos != nullptr);
158
+ FAISS_THROW_IF_NOT(codes != nullptr);
159
+ FAISS_THROW_IF_NOT(is_trained);
160
+
161
+ size_t coarse_size = include_listnos ? coarse_code_size() : 0;
162
+ size_t total_code_size = code_size + coarse_size;
163
+ memset(codes, 0, total_code_size * n);
164
+
165
+ const size_t ex_bits = rabitq.nb_bits - 1;
166
+
167
+ #pragma omp parallel if (n > 1000)
168
+ {
169
+ std::vector<float> centroid(d);
170
+
171
+ #pragma omp for
172
+ for (idx_t i = 0; i < n; i++) {
173
+ int64_t list_no = list_nos[i];
174
+
175
+ if (list_no >= 0) {
176
+ const float* xi = x + i * d;
177
+ uint8_t* code_out = codes + i * total_code_size;
178
+ uint8_t* fastscan_code = code_out + coarse_size;
179
+
180
+ // Reconstruct centroid for residual computation
181
+ quantizer->reconstruct(list_no, centroid.data());
182
+
183
+ const size_t bit_pattern_size = (d + 7) / 8;
184
+
185
+ // Pack sign bits directly into FastScan format (inline)
186
+ for (size_t j = 0; j < d; j++) {
187
+ const float or_minus_c = xi[j] - centroid[j];
188
+ if (or_minus_c > 0.0f) {
189
+ rabitq_utils::set_bit_fastscan(fastscan_code, j);
190
+ }
191
+ }
192
+
193
+ // Compute factors (with or without f_error depending on mode)
194
+ SignBitFactorsWithError factors =
195
+ rabitq_utils::compute_vector_factors(
196
+ xi,
197
+ d,
198
+ centroid.data(),
199
+ rabitq.metric_type,
200
+ ex_bits > 0);
201
+
202
+ if (ex_bits == 0) {
203
+ // 1-bit: store only SignBitFactors (8 bytes)
204
+ memcpy(fastscan_code + bit_pattern_size,
205
+ &factors,
206
+ sizeof(SignBitFactors));
207
+ } else {
208
+ // Multi-bit: store full SignBitFactorsWithError (12 bytes)
209
+ memcpy(fastscan_code + bit_pattern_size,
210
+ &factors,
211
+ sizeof(SignBitFactorsWithError));
212
+
213
+ // Compute residual (needed for quantize_ex_bits)
214
+ std::vector<float> residual(d);
215
+ for (size_t j = 0; j < d; j++) {
216
+ residual[j] = xi[j] - centroid[j];
217
+ }
218
+
219
+ // Quantize ex-bits
220
+ const size_t ex_code_size = (d * ex_bits + 7) / 8;
221
+ uint8_t* ex_code = fastscan_code + bit_pattern_size +
222
+ sizeof(SignBitFactorsWithError);
223
+ ExtraBitsFactors ex_factors_temp;
224
+
225
+ rabitq_multibit::quantize_ex_bits(
226
+ residual.data(),
227
+ d,
228
+ rabitq.nb_bits,
229
+ ex_code,
230
+ ex_factors_temp,
231
+ rabitq.metric_type,
232
+ centroid.data());
233
+
234
+ memcpy(ex_code + ex_code_size,
235
+ &ex_factors_temp,
236
+ sizeof(ExtraBitsFactors));
237
+ }
238
+
239
+ // Include coarse codes if requested
240
+ if (include_listnos) {
241
+ encode_listno(list_no, code_out);
242
+ }
243
+ }
244
+ }
245
+ }
246
+ }
247
+
248
+ bool IndexIVFRaBitQFastScan::lookup_table_is_3d() const {
249
+ return true;
250
+ }
251
+
252
+ // Computes lookup table for residual vectors in RaBitQ FastScan format
253
+ void IndexIVFRaBitQFastScan::compute_residual_LUT(
254
+ const float* residual,
255
+ QueryFactorsData& query_factors,
256
+ float* lut_out,
257
+ const float* original_query) const {
258
+ FAISS_THROW_IF_NOT(qb > 0 && qb <= 8);
259
+
260
+ std::vector<float> rotated_q(d);
261
+ std::vector<uint8_t> rotated_qq(d);
262
+
263
+ // Use RaBitQUtils to compute query factors - eliminates code duplication
264
+ query_factors = rabitq_utils::compute_query_factors(
265
+ residual,
266
+ d,
267
+ nullptr,
268
+ qb,
269
+ centered,
270
+ metric_type,
271
+ rotated_q,
272
+ rotated_qq);
273
+
274
+ // Override query norm for inner product if original query is provided
275
+ if (metric_type == MetricType::METRIC_INNER_PRODUCT &&
276
+ original_query != nullptr) {
277
+ query_factors.qr_norm_L2sqr = fvec_norm_L2sqr(original_query, d);
278
+ }
279
+
280
+ const size_t ex_bits = rabitq.nb_bits - 1;
281
+ if (ex_bits > 0) {
282
+ query_factors.rotated_q = rotated_q;
283
+ }
284
+
285
+ if (centered) {
286
+ const float max_code_value = (1 << qb) - 1;
287
+
288
+ for (size_t m = 0; m < M; m++) {
289
+ const size_t dim_start = m * 4;
290
+
291
+ for (int code_val = 0; code_val < 16; code_val++) {
292
+ float xor_contribution = 0.0f;
293
+
294
+ for (size_t dim_offset = 0; dim_offset < 4; dim_offset++) {
295
+ const size_t dim_idx = dim_start + dim_offset;
296
+
297
+ if (dim_idx < d) {
298
+ const bool db_bit = (code_val >> dim_offset) & 1;
299
+ const float query_value = rotated_qq[dim_idx];
300
+
301
+ xor_contribution += db_bit
302
+ ? (max_code_value - query_value)
303
+ : query_value;
304
+ }
305
+ }
306
+
307
+ lut_out[m * 16 + code_val] = xor_contribution;
308
+ }
309
+ }
310
+ } else {
311
+ for (size_t m = 0; m < M; m++) {
312
+ const size_t dim_start = m * 4;
313
+
314
+ for (int code_val = 0; code_val < 16; code_val++) {
315
+ float inner_product = 0.0f;
316
+ int popcount = 0;
317
+
318
+ for (size_t dim_offset = 0; dim_offset < 4; dim_offset++) {
319
+ const size_t dim_idx = dim_start + dim_offset;
320
+
321
+ if (dim_idx < d && ((code_val >> dim_offset) & 1)) {
322
+ inner_product += rotated_qq[dim_idx];
323
+ popcount++;
324
+ }
325
+ }
326
+ lut_out[m * 16 + code_val] = query_factors.c1 * inner_product +
327
+ query_factors.c2 * popcount;
328
+ }
329
+ }
330
+ }
331
+ }
332
+
333
+ void IndexIVFRaBitQFastScan::search_preassigned(
334
+ idx_t n,
335
+ const float* x,
336
+ idx_t k,
337
+ const idx_t* assign,
338
+ const float* centroid_dis,
339
+ float* distances,
340
+ idx_t* labels,
341
+ bool store_pairs,
342
+ const IVFSearchParameters* params,
343
+ IndexIVFStats* stats) const {
344
+ FAISS_THROW_IF_NOT(is_trained);
345
+ FAISS_THROW_IF_NOT(k > 0);
346
+ FAISS_THROW_IF_NOT_MSG(
347
+ !store_pairs, "store_pairs not supported for RaBitQFastScan");
348
+ FAISS_THROW_IF_NOT_MSG(!stats, "stats not supported for this index");
349
+
350
+ size_t nprobe = this->nprobe;
351
+ if (params) {
352
+ FAISS_THROW_IF_NOT(params->max_codes == 0);
353
+ nprobe = params->nprobe;
354
+ }
355
+
356
+ std::vector<QueryFactorsData> query_factors_storage(n * nprobe);
357
+ FastScanDistancePostProcessing context;
358
+ context.query_factors = query_factors_storage.data();
359
+ context.nprobe = nprobe;
360
+
361
+ const CoarseQuantized cq = {nprobe, centroid_dis, assign};
362
+ search_dispatch_implem(n, x, k, distances, labels, cq, context, params);
363
+ }
364
+
365
+ void IndexIVFRaBitQFastScan::compute_LUT(
366
+ size_t n,
367
+ const float* x,
368
+ const CoarseQuantized& cq,
369
+ AlignedTable<float>& dis_tables,
370
+ AlignedTable<float>& biases,
371
+ const FastScanDistancePostProcessing& context) const {
372
+ FAISS_THROW_IF_NOT(is_trained);
373
+ FAISS_THROW_IF_NOT(by_residual);
374
+
375
+ size_t nprobe = cq.nprobe;
376
+
377
+ size_t dim12 = 16 * M;
378
+
379
+ dis_tables.resize(n * nprobe * dim12);
380
+ biases.resize(n * nprobe);
381
+
382
+ if (n * nprobe > 0) {
383
+ memset(biases.get(), 0, sizeof(float) * n * nprobe);
384
+ }
385
+ std::unique_ptr<float[]> xrel(new float[n * nprobe * d]);
386
+
387
+ #pragma omp parallel for if (n * nprobe > 1000)
388
+ for (idx_t ij = 0; ij < n * nprobe; ij++) {
389
+ idx_t i = ij / nprobe;
390
+ float* xij = &xrel[ij * d];
391
+ idx_t cij = cq.ids[ij];
392
+
393
+ if (cij >= 0) {
394
+ quantizer->compute_residual(x + i * d, xij, cij);
395
+
396
+ // Create QueryFactorsData for this query-list combination
397
+ QueryFactorsData query_factors_data;
398
+
399
+ compute_residual_LUT(
400
+ xij,
401
+ query_factors_data,
402
+ dis_tables.get() + ij * dim12,
403
+ x + i * d);
404
+
405
+ // Store query factors using compact indexing (ij directly)
406
+ if (context.query_factors != nullptr) {
407
+ context.query_factors[ij] = query_factors_data;
408
+ }
409
+
410
+ } else {
411
+ memset(xij, -1, sizeof(float) * d);
412
+ memset(dis_tables.get() + ij * dim12, -1, sizeof(float) * dim12);
413
+ }
414
+ }
415
+ }
416
+
417
+ void IndexIVFRaBitQFastScan::reconstruct_from_offset(
418
+ int64_t list_no,
419
+ int64_t offset,
420
+ float* recons) const {
421
+ // Get centroid for this list
422
+ std::vector<float> centroid(d);
423
+ quantizer->reconstruct(list_no, centroid.data());
424
+
425
+ // Unpack bit pattern from packed format
426
+ const size_t bit_pattern_size = (d + 7) / 8;
427
+ std::vector<uint8_t> fastscan_code(bit_pattern_size, 0);
428
+
429
+ InvertedLists::ScopedCodes list_codes(invlists, list_no);
430
+ for (size_t m = 0; m < M; m++) {
431
+ uint8_t c =
432
+ pq4_get_packed_element(list_codes.get(), bbs, M2, offset, m);
433
+
434
+ size_t byte_idx = m / 2;
435
+ if (m % 2 == 0) {
436
+ fastscan_code[byte_idx] =
437
+ (fastscan_code[byte_idx] & 0xF0) | (c & 0x0F);
438
+ } else {
439
+ fastscan_code[byte_idx] =
440
+ (fastscan_code[byte_idx] & 0x0F) | ((c & 0x0F) << 4);
441
+ }
442
+ }
443
+
444
+ // Get dp_multiplier directly from flat_storage
445
+ InvertedLists::ScopedIds list_ids(invlists, list_no);
446
+ idx_t global_id = list_ids[offset];
447
+
448
+ float dp_multiplier = 1.0f;
449
+ if (global_id >= 0) {
450
+ const size_t storage_size = compute_per_vector_storage_size();
451
+ const size_t storage_capacity = flat_storage.size() / storage_size;
452
+
453
+ if (static_cast<size_t>(global_id) < storage_capacity) {
454
+ const uint8_t* base_ptr =
455
+ flat_storage.data() + global_id * storage_size;
456
+ const auto& base_factors =
457
+ *reinterpret_cast<const SignBitFactors*>(base_ptr);
458
+ dp_multiplier = base_factors.dp_multiplier;
459
+ }
460
+ }
461
+
462
+ // Decode residual directly using dp_multiplier
463
+ std::vector<float> residual(d);
464
+ decode_fastscan_to_residual(
465
+ fastscan_code.data(), residual.data(), dp_multiplier);
466
+
467
+ // Reconstruct: x = centroid + residual
468
+ for (size_t j = 0; j < d; j++) {
469
+ recons[j] = centroid[j] + residual[j];
470
+ }
471
+ }
472
+
473
+ void IndexIVFRaBitQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
474
+ const {
475
+ FAISS_THROW_IF_NOT(is_trained);
476
+ FAISS_THROW_IF_NOT(n > 0);
477
+ FAISS_THROW_IF_NOT(bytes != nullptr);
478
+ FAISS_THROW_IF_NOT(x != nullptr);
479
+
480
+ size_t coarse_size = coarse_code_size();
481
+ size_t total_code_size = code_size + coarse_size;
482
+ std::vector<float> centroid(d);
483
+ std::vector<float> residual(d);
484
+ const size_t bit_pattern_size = (d + 7) / 8;
485
+
486
+ #pragma omp parallel for if (n > 1000)
487
+ for (idx_t i = 0; i < n; i++) {
488
+ const uint8_t* code_i = bytes + i * total_code_size;
489
+ float* x_i = x + i * d;
490
+
491
+ idx_t list_no = decode_listno(code_i);
492
+
493
+ if (list_no >= 0 && list_no < nlist) {
494
+ quantizer->reconstruct(list_no, centroid.data());
495
+
496
+ const uint8_t* fastscan_code = code_i + coarse_size;
497
+
498
+ const uint8_t* factors_ptr = fastscan_code + bit_pattern_size;
499
+ const auto& base_factors =
500
+ *reinterpret_cast<const SignBitFactors*>(factors_ptr);
501
+
502
+ decode_fastscan_to_residual(
503
+ fastscan_code, residual.data(), base_factors.dp_multiplier);
504
+
505
+ for (size_t j = 0; j < d; j++) {
506
+ x_i[j] = centroid[j] + residual[j];
507
+ }
508
+ } else {
509
+ memset(x_i, 0, sizeof(float) * d);
510
+ }
511
+ }
512
+ }
513
+
514
+ void IndexIVFRaBitQFastScan::decode_fastscan_to_residual(
515
+ const uint8_t* fastscan_code,
516
+ float* residual,
517
+ float dp_multiplier) const {
518
+ memset(residual, 0, sizeof(float) * d);
519
+
520
+ const float inv_d_sqrt = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
521
+
522
+ for (size_t j = 0; j < d; j++) {
523
+ bool bit_value = rabitq_utils::extract_bit_fastscan(fastscan_code, j);
524
+
525
+ float bit_as_float = bit_value ? 1.0f : 0.0f;
526
+ residual[j] = (bit_as_float - 0.5f) * dp_multiplier * 2 * inv_d_sqrt;
527
+ }
528
+ }
529
+
530
+ // Implementation of virtual make_knn_handler method
531
+ SIMDResultHandlerToFloat* IndexIVFRaBitQFastScan::make_knn_handler(
532
+ bool is_max,
533
+ int /* impl */,
534
+ idx_t n,
535
+ idx_t k,
536
+ float* distances,
537
+ idx_t* labels,
538
+ const IDSelector* /* sel */,
539
+ const FastScanDistancePostProcessing& context,
540
+ const float* /* normalizers */) const {
541
+ const size_t ex_bits = rabitq.nb_bits - 1;
542
+ const bool is_multibit = ex_bits > 0;
543
+
544
+ if (is_max) {
545
+ return new IVFRaBitQHeapHandler<CMax<uint16_t, int64_t>>(
546
+ this, n, k, distances, labels, &context, is_multibit);
547
+ } else {
548
+ return new IVFRaBitQHeapHandler<CMin<uint16_t, int64_t>>(
549
+ this, n, k, distances, labels, &context, is_multibit);
550
+ }
551
+ }
552
+
553
+ /*********************************************************
554
+ * IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler implementation
555
+ *********************************************************/
556
+
557
+ template <class C>
558
+ IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::IVFRaBitQHeapHandler(
559
+ const IndexIVFRaBitQFastScan* idx,
560
+ size_t nq_val,
561
+ size_t k_val,
562
+ float* distances,
563
+ int64_t* labels,
564
+ const FastScanDistancePostProcessing* ctx,
565
+ bool multibit)
566
+ : simd_result_handlers::ResultHandlerCompare<C, true>(
567
+ nq_val,
568
+ 0,
569
+ nullptr),
570
+ index(idx),
571
+ heap_distances(distances),
572
+ heap_labels(labels),
573
+ nq(nq_val),
574
+ k(k_val),
575
+ context(ctx),
576
+ is_multibit(multibit) {
577
+ current_list_no = 0;
578
+ probe_indices.clear();
579
+
580
+ // Initialize heaps in constructor (standard pattern from HeapHandler)
581
+ for (int64_t q = 0; q < static_cast<int64_t>(nq); q++) {
582
+ float* heap_dis = heap_distances + q * k;
583
+ int64_t* heap_ids = heap_labels + q * k;
584
+ heap_heapify<Cfloat>(k, heap_dis, heap_ids);
585
+ }
586
+ }
587
+
588
+ template <class C>
589
+ void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::handle(
590
+ size_t q,
591
+ size_t b,
592
+ simd16uint16 d0,
593
+ simd16uint16 d1) {
594
+ // Store the original local query index before adjust_with_origin changes it
595
+ size_t local_q = q;
596
+ this->adjust_with_origin(q, d0, d1);
597
+
598
+ ALIGNED(32) uint16_t d32tab[32];
599
+ d0.store(d32tab);
600
+ d1.store(d32tab + 16);
601
+
602
+ float* const heap_dis = heap_distances + q * k;
603
+ int64_t* const heap_ids = heap_labels + q * k;
604
+
605
+ FAISS_THROW_IF_NOT_FMT(
606
+ !probe_indices.empty() && local_q < probe_indices.size(),
607
+ "set_list_context() must be called before handle() - probe_indices size: %zu, local_q: %zu, global_q: %zu",
608
+ probe_indices.size(),
609
+ local_q,
610
+ q);
611
+
612
+ // Access query factors directly from array via ProcessingContext
613
+ if (!context || !context->query_factors) {
614
+ FAISS_THROW_MSG(
615
+ "Query factors not available: FastScanDistancePostProcessing with query_factors required");
616
+ }
617
+
618
+ // Use probe_rank from probe_indices for compact storage indexing
619
+ size_t probe_rank = probe_indices[local_q];
620
+ size_t nprobe = context->nprobe > 0 ? context->nprobe : index->nprobe;
621
+ size_t storage_idx = q * nprobe + probe_rank;
622
+
623
+ const auto& query_factors = context->query_factors[storage_idx];
624
+
625
+ const float one_a =
626
+ this->normalizers ? (1.0f / this->normalizers[2 * q]) : 1.0f;
627
+ const float bias = this->normalizers ? this->normalizers[2 * q + 1] : 0.0f;
628
+
629
+ uint64_t idx_base = this->j0 + b * 32;
630
+ if (idx_base >= this->ntotal) {
631
+ return;
632
+ }
633
+
634
+ size_t max_positions = std::min<size_t>(32, this->ntotal - idx_base);
635
+
636
+ // Stats tracking for two-stage search
637
+ // n_1bit_evaluations: candidates evaluated using 1-bit lower bound
638
+ // n_multibit_evaluations: candidates requiring full multi-bit distance
639
+ size_t local_1bit_evaluations = 0;
640
+ size_t local_multibit_evaluations = 0;
641
+
642
+ // Process each candidate vector in the SIMD batch
643
+ for (size_t j = 0; j < max_positions; j++) {
644
+ const int64_t result_id = this->adjust_id(b, j);
645
+
646
+ if (result_id < 0) {
647
+ continue;
648
+ }
649
+
650
+ const float normalized_distance = d32tab[j] * one_a + bias;
651
+
652
+ // Get database factors from flat_storage
653
+ const size_t storage_size = index->compute_per_vector_storage_size();
654
+ const uint8_t* base_ptr =
655
+ index->flat_storage.data() + result_id * storage_size;
656
+
657
+ if (is_multibit) {
658
+ // Track candidates actually considered for two-stage filtering
659
+ local_1bit_evaluations++;
660
+
661
+ // Multi-bit: use SignBitFactorsWithError and two-stage search
662
+ const SignBitFactorsWithError& full_factors =
663
+ *reinterpret_cast<const SignBitFactorsWithError*>(base_ptr);
664
+
665
+ // Compute 1-bit adjusted distance using shared helper
666
+ float dist_1bit = rabitq_utils::compute_1bit_adjusted_distance(
667
+ normalized_distance,
668
+ full_factors,
669
+ query_factors,
670
+ index->centered,
671
+ index->qb,
672
+ index->d);
673
+
674
+ // Compute lower bound using error bound
675
+ float lower_bound =
676
+ compute_lower_bound(dist_1bit, result_id, local_q, q);
677
+
678
+ // Adaptive filtering: decide whether to compute full distance
679
+ const bool is_similarity =
680
+ index->metric_type == MetricType::METRIC_INNER_PRODUCT;
681
+ bool should_refine = is_similarity
682
+ ? (lower_bound > heap_dis[0]) // IP: keep if better
683
+ : (lower_bound < heap_dis[0]); // L2: keep if better
684
+
685
+ if (should_refine) {
686
+ local_multibit_evaluations++;
687
+
688
+ // Compute local_offset: position within current inverted list
689
+ size_t local_offset = this->j0 + b * 32 + j;
690
+
691
+ // Compute full multi-bit distance
692
+ float dist_full = compute_full_multibit_distance(
693
+ result_id, local_q, q, local_offset);
694
+
695
+ // Update heap if this distance is better
696
+ if (Cfloat::cmp(heap_dis[0], dist_full)) {
697
+ heap_replace_top<Cfloat>(
698
+ k, heap_dis, heap_ids, dist_full, result_id);
699
+ }
700
+ }
701
+ } else {
702
+ const auto& db_factors =
703
+ *reinterpret_cast<const SignBitFactors*>(base_ptr);
704
+
705
+ // Compute adjusted distance using shared helper
706
+ float adjusted_distance =
707
+ rabitq_utils::compute_1bit_adjusted_distance(
708
+ normalized_distance,
709
+ db_factors,
710
+ query_factors,
711
+ index->centered,
712
+ index->qb,
713
+ index->d);
714
+
715
+ if (Cfloat::cmp(heap_dis[0], adjusted_distance)) {
716
+ heap_replace_top<Cfloat>(
717
+ k, heap_dis, heap_ids, adjusted_distance, result_id);
718
+ }
719
+ }
720
+ }
721
+
722
+ // Update global stats atomically
723
+ #pragma omp atomic
724
+ rabitq_stats.n_1bit_evaluations += local_1bit_evaluations;
725
+ #pragma omp atomic
726
+ rabitq_stats.n_multibit_evaluations += local_multibit_evaluations;
727
+ }
728
+
729
+ template <class C>
730
+ void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::set_list_context(
731
+ size_t list_no,
732
+ const std::vector<int>& probe_map) {
733
+ current_list_no = list_no;
734
+ probe_indices = probe_map;
735
+ }
736
+
737
+ template <class C>
738
+ void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::begin(
739
+ const float* norms) {
740
+ this->normalizers = norms;
741
+ }
742
+
743
+ template <class C>
744
+ void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::end() {
745
+ #pragma omp parallel for
746
+ for (int64_t q = 0; q < static_cast<int64_t>(nq); q++) {
747
+ float* heap_dis = heap_distances + q * k;
748
+ int64_t* heap_ids = heap_labels + q * k;
749
+ heap_reorder<Cfloat>(k, heap_dis, heap_ids);
750
+ }
751
+ }
752
+
753
+ template <class C>
754
+ float IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::compute_lower_bound(
755
+ float dist_1bit,
756
+ size_t db_idx,
757
+ size_t local_q,
758
+ size_t global_q) const {
759
+ // Access f_error from SignBitFactorsWithError in flat storage
760
+ const size_t storage_size = index->compute_per_vector_storage_size();
761
+ const uint8_t* base_ptr =
762
+ index->flat_storage.data() + db_idx * storage_size;
763
+ const SignBitFactorsWithError& db_factors =
764
+ *reinterpret_cast<const SignBitFactorsWithError*>(base_ptr);
765
+ float f_error = db_factors.f_error;
766
+
767
+ // Get g_error from query factors
768
+ // Use local_q to access probe_indices (batch-local), global_q for storage
769
+ float g_error = 0.0f;
770
+ if (context && context->query_factors) {
771
+ size_t probe_rank = probe_indices[local_q];
772
+ size_t nprobe = context->nprobe > 0 ? context->nprobe : index->nprobe;
773
+ size_t storage_idx = global_q * nprobe + probe_rank;
774
+ g_error = context->query_factors[storage_idx].g_error;
775
+ }
776
+
777
+ // Compute error adjustment: f_error * g_error
778
+ float error_adjustment = f_error * g_error;
779
+
780
+ return dist_1bit - error_adjustment;
781
+ }
782
+
783
+ template <class C>
784
+ float IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::
785
+ compute_full_multibit_distance(
786
+ size_t db_idx,
787
+ size_t local_q,
788
+ size_t global_q,
789
+ size_t local_offset) const {
790
+ const size_t ex_bits = index->rabitq.nb_bits - 1;
791
+ const size_t dim = index->d;
792
+
793
+ const size_t storage_size = index->compute_per_vector_storage_size();
794
+ const uint8_t* base_ptr =
795
+ index->flat_storage.data() + db_idx * storage_size;
796
+
797
+ const size_t ex_code_size = (dim * ex_bits + 7) / 8;
798
+ const uint8_t* ex_code = base_ptr + sizeof(SignBitFactorsWithError);
799
+ const ExtraBitsFactors& ex_fac = *reinterpret_cast<const ExtraBitsFactors*>(
800
+ base_ptr + sizeof(SignBitFactorsWithError) + ex_code_size);
801
+
802
+ // Use local_q to access probe_indices (batch-local), global_q for storage
803
+ size_t probe_rank = probe_indices[local_q];
804
+ size_t nprobe = context->nprobe > 0 ? context->nprobe : index->nprobe;
805
+ size_t storage_idx = global_q * nprobe + probe_rank;
806
+ const auto& query_factors = context->query_factors[storage_idx];
807
+
808
+ size_t list_no = current_list_no;
809
+ InvertedLists::ScopedCodes list_codes(index->invlists, list_no);
810
+
811
+ std::vector<uint8_t> unpacked_code(index->code_size);
812
+ CodePackerPQ4 packer(index->M2, index->bbs);
813
+ packer.unpack_1(list_codes.get(), local_offset, unpacked_code.data());
814
+ const uint8_t* sign_bits = unpacked_code.data();
815
+
816
+ return rabitq_utils::compute_full_multibit_distance(
817
+ sign_bits,
818
+ ex_code,
819
+ ex_fac,
820
+ query_factors.rotated_q.data(),
821
+ query_factors.qr_to_c_L2sqr,
822
+ query_factors.qr_norm_L2sqr,
823
+ dim,
824
+ ex_bits,
825
+ index->metric_type);
826
+ }
827
+
828
+ } // namespace faiss