faiss 0.4.2 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/ext/faiss/index.cpp +36 -10
  4. data/ext/faiss/index_binary.cpp +19 -6
  5. data/ext/faiss/kmeans.cpp +6 -6
  6. data/ext/faiss/numo.hpp +273 -123
  7. data/lib/faiss/version.rb +1 -1
  8. data/vendor/faiss/faiss/AutoTune.cpp +2 -3
  9. data/vendor/faiss/faiss/AutoTune.h +1 -1
  10. data/vendor/faiss/faiss/Clustering.cpp +2 -2
  11. data/vendor/faiss/faiss/Clustering.h +2 -2
  12. data/vendor/faiss/faiss/IVFlib.cpp +1 -2
  13. data/vendor/faiss/faiss/IVFlib.h +1 -1
  14. data/vendor/faiss/faiss/Index.h +10 -10
  15. data/vendor/faiss/faiss/Index2Layer.cpp +1 -1
  16. data/vendor/faiss/faiss/Index2Layer.h +2 -2
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +9 -4
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +5 -1
  19. data/vendor/faiss/faiss/IndexBinary.h +7 -7
  20. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +1 -1
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +3 -1
  22. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  23. data/vendor/faiss/faiss/IndexBinaryHash.cpp +3 -3
  24. data/vendor/faiss/faiss/IndexBinaryHash.h +5 -5
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +7 -6
  26. data/vendor/faiss/faiss/IndexFastScan.cpp +125 -49
  27. data/vendor/faiss/faiss/IndexFastScan.h +107 -7
  28. data/vendor/faiss/faiss/IndexFlat.h +1 -1
  29. data/vendor/faiss/faiss/IndexHNSW.cpp +3 -1
  30. data/vendor/faiss/faiss/IndexHNSW.h +1 -1
  31. data/vendor/faiss/faiss/IndexIDMap.cpp +14 -13
  32. data/vendor/faiss/faiss/IndexIDMap.h +6 -6
  33. data/vendor/faiss/faiss/IndexIVF.cpp +1 -1
  34. data/vendor/faiss/faiss/IndexIVF.h +5 -5
  35. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +1 -1
  36. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +9 -3
  37. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +3 -1
  38. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +176 -90
  39. data/vendor/faiss/faiss/IndexIVFFastScan.h +173 -18
  40. data/vendor/faiss/faiss/IndexIVFFlat.cpp +1 -0
  41. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +366 -0
  42. data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +64 -0
  43. data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -1
  44. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
  45. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +134 -2
  46. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -1
  47. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +13 -6
  48. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +1 -0
  49. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +650 -0
  50. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +216 -0
  51. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -1
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
  53. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -1
  54. data/vendor/faiss/faiss/IndexNSG.cpp +1 -1
  55. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +1 -1
  56. data/vendor/faiss/faiss/IndexPQ.h +1 -1
  57. data/vendor/faiss/faiss/IndexPQFastScan.cpp +6 -2
  58. data/vendor/faiss/faiss/IndexPQFastScan.h +5 -1
  59. data/vendor/faiss/faiss/IndexRaBitQ.cpp +13 -10
  60. data/vendor/faiss/faiss/IndexRaBitQ.h +7 -2
  61. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +586 -0
  62. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +149 -0
  63. data/vendor/faiss/faiss/IndexShards.cpp +1 -1
  64. data/vendor/faiss/faiss/MatrixStats.cpp +3 -3
  65. data/vendor/faiss/faiss/MetricType.h +1 -1
  66. data/vendor/faiss/faiss/VectorTransform.h +2 -2
  67. data/vendor/faiss/faiss/clone_index.cpp +3 -1
  68. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  69. data/vendor/faiss/faiss/gpu/GpuIndex.h +11 -11
  70. data/vendor/faiss/faiss/gpu/GpuIndexBinaryCagra.h +1 -1
  71. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +1 -1
  72. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +10 -6
  73. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -0
  74. data/vendor/faiss/faiss/gpu/test/TestGpuIcmEncoder.cpp +7 -0
  75. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +1 -1
  76. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +1 -1
  77. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -1
  78. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +2 -2
  79. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -1
  80. data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
  81. data/vendor/faiss/faiss/impl/DistanceComputer.h +3 -3
  82. data/vendor/faiss/faiss/impl/FastScanDistancePostProcessing.h +53 -0
  83. data/vendor/faiss/faiss/impl/HNSW.cpp +1 -1
  84. data/vendor/faiss/faiss/impl/HNSW.h +4 -4
  85. data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
  86. data/vendor/faiss/faiss/impl/IDSelector.h +1 -1
  87. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +4 -4
  88. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +1 -1
  89. data/vendor/faiss/faiss/impl/LookupTableScaler.h +1 -1
  90. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  91. data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
  92. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  93. data/vendor/faiss/faiss/impl/PanoramaStats.cpp +33 -0
  94. data/vendor/faiss/faiss/impl/PanoramaStats.h +38 -0
  95. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +5 -5
  96. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +1 -1
  97. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  98. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -0
  99. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  100. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +246 -0
  101. data/vendor/faiss/faiss/impl/RaBitQUtils.h +153 -0
  102. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +54 -158
  103. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +2 -1
  104. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
  105. data/vendor/faiss/faiss/impl/ResultHandler.h +4 -4
  106. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1 -1
  107. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +1 -1
  108. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +7 -4
  109. data/vendor/faiss/faiss/impl/index_read.cpp +87 -3
  110. data/vendor/faiss/faiss/impl/index_write.cpp +73 -3
  111. data/vendor/faiss/faiss/impl/io.cpp +2 -2
  112. data/vendor/faiss/faiss/impl/io.h +4 -4
  113. data/vendor/faiss/faiss/impl/kmeans1d.cpp +1 -1
  114. data/vendor/faiss/faiss/impl/kmeans1d.h +1 -1
  115. data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
  116. data/vendor/faiss/faiss/impl/mapped_io.cpp +2 -2
  117. data/vendor/faiss/faiss/impl/mapped_io.h +4 -3
  118. data/vendor/faiss/faiss/impl/maybe_owned_vector.h +8 -1
  119. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +30 -4
  120. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +14 -8
  121. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +5 -6
  122. data/vendor/faiss/faiss/impl/simd_result_handlers.h +55 -11
  123. data/vendor/faiss/faiss/impl/zerocopy_io.h +1 -1
  124. data/vendor/faiss/faiss/index_factory.cpp +43 -1
  125. data/vendor/faiss/faiss/index_factory.h +1 -1
  126. data/vendor/faiss/faiss/index_io.h +1 -1
  127. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +205 -0
  128. data/vendor/faiss/faiss/invlists/InvertedLists.h +62 -0
  129. data/vendor/faiss/faiss/utils/AlignedTable.h +1 -1
  130. data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
  131. data/vendor/faiss/faiss/utils/Heap.h +3 -3
  132. data/vendor/faiss/faiss/utils/NeuralNet.cpp +1 -1
  133. data/vendor/faiss/faiss/utils/NeuralNet.h +3 -3
  134. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
  135. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
  136. data/vendor/faiss/faiss/utils/approx_topk/mode.h +1 -1
  137. data/vendor/faiss/faiss/utils/distances.h +2 -2
  138. data/vendor/faiss/faiss/utils/extra_distances-inl.h +3 -1
  139. data/vendor/faiss/faiss/utils/hamming-inl.h +2 -0
  140. data/vendor/faiss/faiss/utils/hamming.cpp +7 -6
  141. data/vendor/faiss/faiss/utils/hamming.h +1 -1
  142. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -2
  143. data/vendor/faiss/faiss/utils/partitioning.cpp +5 -5
  144. data/vendor/faiss/faiss/utils/partitioning.h +2 -2
  145. data/vendor/faiss/faiss/utils/rabitq_simd.h +222 -336
  146. data/vendor/faiss/faiss/utils/random.cpp +1 -1
  147. data/vendor/faiss/faiss/utils/simdlib_avx2.h +1 -1
  148. data/vendor/faiss/faiss/utils/simdlib_avx512.h +1 -1
  149. data/vendor/faiss/faiss/utils/simdlib_neon.h +2 -2
  150. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +1 -1
  151. data/vendor/faiss/faiss/utils/utils.cpp +5 -2
  152. data/vendor/faiss/faiss/utils/utils.h +2 -2
  153. metadata +14 -3
@@ -0,0 +1,650 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/IndexIVFRaBitQFastScan.h>
9
+
10
+ #include <algorithm>
11
+ #include <cstdio>
12
+
13
+ #include <faiss/impl/FaissAssert.h>
14
+ #include <faiss/impl/FastScanDistancePostProcessing.h>
15
+ #include <faiss/impl/RaBitQUtils.h>
16
+ #include <faiss/impl/pq4_fast_scan.h>
17
+ #include <faiss/impl/simd_result_handlers.h>
18
+ #include <faiss/invlists/BlockInvertedLists.h>
19
+ #include <faiss/utils/distances.h>
20
+ #include <faiss/utils/utils.h>
21
+
22
+ namespace faiss {
23
+
24
+ // Import shared utilities from RaBitQUtils
25
+ using rabitq_utils::FactorsData;
26
+ using rabitq_utils::QueryFactorsData;
27
+
28
+ inline size_t roundup(size_t a, size_t b) {
29
+ return (a + b - 1) / b * b;
30
+ }
31
+
32
+ /*********************************************************
33
+ * IndexIVFRaBitQFastScan implementation
34
+ *********************************************************/
35
+
36
+ IndexIVFRaBitQFastScan::IndexIVFRaBitQFastScan() = default;
37
+
38
+ IndexIVFRaBitQFastScan::IndexIVFRaBitQFastScan(
39
+ Index* quantizer,
40
+ size_t d,
41
+ size_t nlist,
42
+ MetricType metric,
43
+ int bbs,
44
+ bool own_invlists)
45
+ : IndexIVFFastScan(quantizer, d, nlist, 0, metric, own_invlists),
46
+ rabitq(d, metric) {
47
+ FAISS_THROW_IF_NOT_MSG(d > 0, "Dimension must be positive");
48
+ FAISS_THROW_IF_NOT_MSG(
49
+ metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT,
50
+ "RaBitQ only supports L2 and Inner Product metrics");
51
+ FAISS_THROW_IF_NOT_MSG(bbs % 32 == 0, "Batch size must be multiple of 32");
52
+ FAISS_THROW_IF_NOT_MSG(quantizer != nullptr, "Quantizer cannot be null");
53
+
54
+ by_residual = true;
55
+ qb = 8; // RaBitQ quantization bits
56
+ centered = false;
57
+
58
+ // FastScan-specific parameters: 4 bits per sub-quantizer
59
+ const size_t M_fastscan = (d + 3) / 4;
60
+ constexpr size_t nbits_fastscan = 4;
61
+
62
+ this->bbs = bbs;
63
+ this->fine_quantizer = &rabitq;
64
+ this->M = M_fastscan;
65
+ this->nbits = nbits_fastscan;
66
+ this->ksub = (1 << nbits_fastscan);
67
+ this->M2 = roundup(M_fastscan, 2);
68
+
69
+ // Override code_size to include space for factors after bit patterns
70
+ const size_t bit_pattern_size = (d + 7) / 8;
71
+ this->code_size = bit_pattern_size + sizeof(FactorsData);
72
+
73
+ is_trained = false;
74
+
75
+ if (own_invlists) {
76
+ replace_invlists(new BlockInvertedLists(nlist, get_CodePacker()), true);
77
+ }
78
+
79
+ factors_storage.clear();
80
+ }
81
+
82
+ // Constructor that converts an existing IndexIVFRaBitQ to FastScan format
83
+ IndexIVFRaBitQFastScan::IndexIVFRaBitQFastScan(
84
+ const IndexIVFRaBitQ& orig,
85
+ int /* bbs */)
86
+ : IndexIVFFastScan(
87
+ orig.quantizer,
88
+ orig.d,
89
+ orig.nlist,
90
+ 0,
91
+ orig.metric_type,
92
+ false),
93
+ rabitq(orig.rabitq) {}
94
+
95
+ void IndexIVFRaBitQFastScan::preprocess_code_metadata(
96
+ idx_t n,
97
+ const uint8_t* flat_codes,
98
+ idx_t start_global_idx) {
99
+ // Extract and store factors from codes for use during search
100
+ const size_t bit_pattern_size = (d + 7) / 8;
101
+ factors_storage.resize(start_global_idx + n);
102
+
103
+ for (idx_t i = 0; i < n; i++) {
104
+ const uint8_t* code = flat_codes + i * code_size;
105
+ const uint8_t* factors_ptr = code + bit_pattern_size;
106
+ const FactorsData& embedded_factors =
107
+ *reinterpret_cast<const FactorsData*>(factors_ptr);
108
+ factors_storage[start_global_idx + i] = embedded_factors;
109
+ }
110
+ }
111
+
112
+ size_t IndexIVFRaBitQFastScan::code_packing_stride() const {
113
+ // Use code_size as stride to skip embedded factor data during packing
114
+ return code_size;
115
+ }
116
+
117
+ void IndexIVFRaBitQFastScan::train_encoder(
118
+ idx_t n,
119
+ const float* x,
120
+ const idx_t* assign) {
121
+ FAISS_THROW_IF_NOT(n > 0);
122
+ FAISS_THROW_IF_NOT(x != nullptr);
123
+ FAISS_THROW_IF_NOT(assign != nullptr || !by_residual);
124
+
125
+ rabitq.train(n, x);
126
+ is_trained = true;
127
+ init_code_packer();
128
+ }
129
+
130
+ void IndexIVFRaBitQFastScan::encode_vectors(
131
+ idx_t n,
132
+ const float* x,
133
+ const idx_t* list_nos,
134
+ uint8_t* codes,
135
+ bool include_listnos) const {
136
+ FAISS_THROW_IF_NOT(n > 0);
137
+ FAISS_THROW_IF_NOT(x != nullptr);
138
+ FAISS_THROW_IF_NOT(list_nos != nullptr);
139
+ FAISS_THROW_IF_NOT(codes != nullptr);
140
+ FAISS_THROW_IF_NOT(is_trained);
141
+
142
+ size_t coarse_size = include_listnos ? coarse_code_size() : 0;
143
+ size_t total_code_size = code_size + coarse_size;
144
+ memset(codes, 0, total_code_size * n);
145
+
146
+ const size_t bit_pattern_size = (d + 7) / 8;
147
+
148
+ #pragma omp parallel if (n > 1000)
149
+ {
150
+ std::vector<float> centroid(d);
151
+
152
+ #pragma omp for
153
+ for (idx_t i = 0; i < n; i++) {
154
+ int64_t list_no = list_nos[i];
155
+
156
+ if (list_no >= 0) {
157
+ const float* xi = x + i * d;
158
+ uint8_t* code_out = codes + i * total_code_size;
159
+ uint8_t* fastscan_code = code_out + coarse_size;
160
+
161
+ // Reconstruct centroid for residual computation
162
+ quantizer->reconstruct(list_no, centroid.data());
163
+
164
+ // Encode vector to FastScan format (bit pattern only)
165
+ encode_vector_to_fastscan(xi, centroid.data(), fastscan_code);
166
+
167
+ // Compute and embed factors after the bit pattern
168
+ // Pass original vector and centroid (same as old add_with_ids)
169
+ FactorsData factors = rabitq_utils::compute_vector_factors(
170
+ xi, d, centroid.data(), rabitq.metric_type);
171
+
172
+ uint8_t* factors_ptr = fastscan_code + bit_pattern_size;
173
+ *reinterpret_cast<FactorsData*>(factors_ptr) = factors;
174
+
175
+ // Include coarse codes if requested
176
+ if (include_listnos) {
177
+ encode_listno(list_no, code_out);
178
+ }
179
+ }
180
+ }
181
+ }
182
+ }
183
+
184
+ void IndexIVFRaBitQFastScan::encode_vector_to_fastscan(
185
+ const float* xi,
186
+ const float* centroid,
187
+ uint8_t* fastscan_code) const {
188
+ memset(fastscan_code, 0, code_size);
189
+
190
+ for (size_t j = 0; j < d; j++) {
191
+ const float x_val = xi[j];
192
+ const float centroid_val = (centroid != nullptr) ? centroid[j] : 0.0f;
193
+ const float or_minus_c = x_val - centroid_val;
194
+ const bool xb = (or_minus_c > 0.0f);
195
+
196
+ if (xb) {
197
+ rabitq_utils::set_bit_fastscan(fastscan_code, j);
198
+ }
199
+ }
200
+ }
201
+
202
+ bool IndexIVFRaBitQFastScan::lookup_table_is_3d() const {
203
+ return true;
204
+ }
205
+
206
+ // Computes lookup table for residual vectors in RaBitQ FastScan format
207
+ void IndexIVFRaBitQFastScan::compute_residual_LUT(
208
+ const float* residual,
209
+ QueryFactorsData& query_factors,
210
+ float* lut_out,
211
+ const float* original_query) const {
212
+ FAISS_THROW_IF_NOT(qb > 0 && qb <= 8);
213
+
214
+ std::vector<float> rotated_q(d);
215
+ std::vector<uint8_t> rotated_qq(d);
216
+
217
+ // Use RaBitQUtils to compute query factors - eliminates code duplication
218
+ query_factors = rabitq_utils::compute_query_factors(
219
+ residual,
220
+ d,
221
+ nullptr,
222
+ qb,
223
+ centered,
224
+ metric_type,
225
+ rotated_q,
226
+ rotated_qq);
227
+
228
+ // Override query norm for inner product if original query is provided
229
+ if (metric_type == MetricType::METRIC_INNER_PRODUCT &&
230
+ original_query != nullptr) {
231
+ query_factors.qr_norm_L2sqr = fvec_norm_L2sqr(original_query, d);
232
+ }
233
+
234
+ if (centered) {
235
+ const float max_code_value = (1 << qb) - 1;
236
+
237
+ for (size_t m = 0; m < M; m++) {
238
+ const size_t dim_start = m * 4;
239
+
240
+ for (int code_val = 0; code_val < 16; code_val++) {
241
+ float xor_contribution = 0.0f;
242
+
243
+ for (size_t dim_offset = 0; dim_offset < 4; dim_offset++) {
244
+ const size_t dim_idx = dim_start + dim_offset;
245
+
246
+ if (dim_idx < d) {
247
+ const bool db_bit = (code_val >> dim_offset) & 1;
248
+ const float query_value = rotated_qq[dim_idx];
249
+
250
+ xor_contribution += db_bit
251
+ ? (max_code_value - query_value)
252
+ : query_value;
253
+ }
254
+ }
255
+
256
+ lut_out[m * 16 + code_val] = xor_contribution;
257
+ }
258
+ }
259
+ } else {
260
+ for (size_t m = 0; m < M; m++) {
261
+ const size_t dim_start = m * 4;
262
+
263
+ for (int code_val = 0; code_val < 16; code_val++) {
264
+ float inner_product = 0.0f;
265
+ int popcount = 0;
266
+
267
+ for (size_t dim_offset = 0; dim_offset < 4; dim_offset++) {
268
+ const size_t dim_idx = dim_start + dim_offset;
269
+
270
+ if (dim_idx < d && ((code_val >> dim_offset) & 1)) {
271
+ inner_product += rotated_qq[dim_idx];
272
+ popcount++;
273
+ }
274
+ }
275
+ lut_out[m * 16 + code_val] = query_factors.c1 * inner_product +
276
+ query_factors.c2 * popcount;
277
+ }
278
+ }
279
+ }
280
+ }
281
+
282
+ void IndexIVFRaBitQFastScan::search_preassigned(
283
+ idx_t n,
284
+ const float* x,
285
+ idx_t k,
286
+ const idx_t* assign,
287
+ const float* centroid_dis,
288
+ float* distances,
289
+ idx_t* labels,
290
+ bool store_pairs,
291
+ const IVFSearchParameters* params,
292
+ IndexIVFStats* stats) const {
293
+ FAISS_THROW_IF_NOT(is_trained);
294
+ FAISS_THROW_IF_NOT(k > 0);
295
+ FAISS_THROW_IF_NOT_MSG(
296
+ !store_pairs, "store_pairs not supported for RaBitQFastScan");
297
+ FAISS_THROW_IF_NOT_MSG(!stats, "stats not supported for this index");
298
+
299
+ size_t nprobe = this->nprobe;
300
+ if (params) {
301
+ FAISS_THROW_IF_NOT(params->max_codes == 0);
302
+ nprobe = params->nprobe;
303
+ }
304
+
305
+ std::vector<QueryFactorsData> query_factors_storage(n * nprobe);
306
+ FastScanDistancePostProcessing context;
307
+ context.query_factors = query_factors_storage.data();
308
+ context.nprobe = nprobe;
309
+
310
+ const CoarseQuantized cq = {nprobe, centroid_dis, assign};
311
+ search_dispatch_implem(n, x, k, distances, labels, cq, context, params);
312
+ }
313
+
314
+ void IndexIVFRaBitQFastScan::compute_LUT(
315
+ size_t n,
316
+ const float* x,
317
+ const CoarseQuantized& cq,
318
+ AlignedTable<float>& dis_tables,
319
+ AlignedTable<float>& biases,
320
+ const FastScanDistancePostProcessing& context) const {
321
+ FAISS_THROW_IF_NOT(is_trained);
322
+ FAISS_THROW_IF_NOT(by_residual);
323
+
324
+ size_t nprobe = cq.nprobe;
325
+
326
+ size_t dim12 = 16 * M;
327
+
328
+ dis_tables.resize(n * nprobe * dim12);
329
+ biases.resize(n * nprobe);
330
+
331
+ if (n * nprobe > 0) {
332
+ memset(biases.get(), 0, sizeof(float) * n * nprobe);
333
+ }
334
+ std::unique_ptr<float[]> xrel(new float[n * nprobe * d]);
335
+
336
+ #pragma omp parallel for if (n * nprobe > 1000)
337
+ for (idx_t ij = 0; ij < n * nprobe; ij++) {
338
+ idx_t i = ij / nprobe;
339
+ float* xij = &xrel[ij * d];
340
+ idx_t cij = cq.ids[ij];
341
+
342
+ if (cij >= 0) {
343
+ quantizer->compute_residual(x + i * d, xij, cij);
344
+
345
+ // Create QueryFactorsData for this query-list combination
346
+ QueryFactorsData query_factors_data;
347
+
348
+ compute_residual_LUT(
349
+ xij,
350
+ query_factors_data,
351
+ dis_tables.get() + ij * dim12,
352
+ x + i * d);
353
+
354
+ // Store query factors using compact indexing (ij directly)
355
+ if (context.query_factors) {
356
+ context.query_factors[ij] = query_factors_data;
357
+ }
358
+
359
+ } else {
360
+ memset(xij, -1, sizeof(float) * d);
361
+ memset(dis_tables.get() + ij * dim12, -1, sizeof(float) * dim12);
362
+ }
363
+ }
364
+ }
365
+
366
+ void IndexIVFRaBitQFastScan::reconstruct_from_offset(
367
+ int64_t list_no,
368
+ int64_t offset,
369
+ float* recons) const {
370
+ // Unpack codes from packed format
371
+ size_t coarse_size = coarse_code_size();
372
+ const size_t bit_pattern_size = (d + 7) / 8;
373
+ std::vector<uint8_t> code(
374
+ coarse_size + bit_pattern_size + sizeof(FactorsData), 0);
375
+
376
+ encode_listno(list_no, code.data());
377
+ InvertedLists::ScopedCodes list_codes(invlists, list_no);
378
+
379
+ // Unpack the bit pattern from packed format to FastScan layout
380
+ uint8_t* fastscan_code = code.data() + coarse_size;
381
+ for (size_t m = 0; m < M; m++) {
382
+ uint8_t c =
383
+ pq4_get_packed_element(list_codes.get(), bbs, M2, offset, m);
384
+
385
+ // Write the 4-bit code value to FastScan format
386
+ // Each byte stores two 4-bit codes (lower and upper nibbles)
387
+ size_t byte_idx = m / 2;
388
+ if (m % 2 == 0) {
389
+ // Even m: write to lower 4 bits
390
+ fastscan_code[byte_idx] =
391
+ (fastscan_code[byte_idx] & 0xF0) | (c & 0x0F);
392
+ } else {
393
+ // Odd m: write to upper 4 bits
394
+ fastscan_code[byte_idx] =
395
+ (fastscan_code[byte_idx] & 0x0F) | ((c & 0x0F) << 4);
396
+ }
397
+ }
398
+
399
+ // Get the global index to retrieve factors
400
+ // Need to look up the ID from inverted lists
401
+ InvertedLists::ScopedIds list_ids(invlists, list_no);
402
+ idx_t global_id = list_ids[offset];
403
+
404
+ // Get factors from factors_storage using global ID
405
+ if (global_id >= 0 &&
406
+ static_cast<size_t>(global_id) < factors_storage.size()) {
407
+ const FactorsData& factors = factors_storage[global_id];
408
+
409
+ // Embed factors into the unpacked code
410
+ uint8_t* factors_ptr = code.data() + coarse_size + bit_pattern_size;
411
+ *reinterpret_cast<FactorsData*>(factors_ptr) = factors;
412
+ }
413
+
414
+ // Now use sa_decode which expects unpacked codes with embedded factors
415
+ sa_decode(1, code.data(), recons);
416
+ }
417
+
418
+ void IndexIVFRaBitQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
419
+ const {
420
+ FAISS_THROW_IF_NOT(is_trained);
421
+ FAISS_THROW_IF_NOT(n > 0);
422
+ FAISS_THROW_IF_NOT(bytes != nullptr);
423
+ FAISS_THROW_IF_NOT(x != nullptr);
424
+
425
+ size_t coarse_size = coarse_code_size();
426
+ size_t total_code_size = code_size + coarse_size;
427
+ std::vector<float> centroid(d);
428
+ std::vector<float> residual(d);
429
+
430
+ #pragma omp parallel for if (n > 1000)
431
+ for (idx_t i = 0; i < n; i++) {
432
+ const uint8_t* code_i = bytes + i * total_code_size;
433
+ float* x_i = x + i * d;
434
+
435
+ idx_t list_no = decode_listno(code_i);
436
+
437
+ if (list_no >= 0 && list_no < nlist) {
438
+ quantizer->reconstruct(list_no, centroid.data());
439
+
440
+ const uint8_t* fastscan_code = code_i + coarse_size;
441
+
442
+ decode_fastscan_to_residual(fastscan_code, residual.data());
443
+
444
+ for (size_t j = 0; j < d; j++) {
445
+ x_i[j] = centroid[j] + residual[j];
446
+ }
447
+ } else {
448
+ memset(x_i, 0, sizeof(float) * d);
449
+ }
450
+ }
451
+ }
452
+
453
+ void IndexIVFRaBitQFastScan::decode_fastscan_to_residual(
454
+ const uint8_t* fastscan_code,
455
+ float* residual) const {
456
+ memset(residual, 0, sizeof(float) * d);
457
+
458
+ const float inv_d_sqrt = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
459
+ const size_t bit_pattern_size = (d + 7) / 8;
460
+
461
+ // Extract factors directly from embedded codes
462
+ const uint8_t* factors_ptr = fastscan_code + bit_pattern_size;
463
+ const FactorsData& fac = *reinterpret_cast<const FactorsData*>(factors_ptr);
464
+
465
+ for (size_t j = 0; j < d; j++) {
466
+ // Use RaBitQUtils for consistent bit extraction
467
+ bool bit_value = rabitq_utils::extract_bit_fastscan(fastscan_code, j);
468
+
469
+ float bit_as_float = bit_value ? 1.0f : 0.0f;
470
+ residual[j] =
471
+ (bit_as_float - 0.5f) * fac.dp_multiplier * 2 * inv_d_sqrt;
472
+ }
473
+ }
474
+
475
+ // Implementation of virtual make_knn_handler method
476
+ SIMDResultHandlerToFloat* IndexIVFRaBitQFastScan::make_knn_handler(
477
+ bool is_max,
478
+ int /* impl */,
479
+ idx_t n,
480
+ idx_t k,
481
+ float* distances,
482
+ idx_t* labels,
483
+ const IDSelector* /* sel */,
484
+ const FastScanDistancePostProcessing& context,
485
+ const float* /* normalizers */) const {
486
+ if (is_max) {
487
+ return new IVFRaBitQHeapHandler<CMax<uint16_t, int64_t>>(
488
+ this, n, k, distances, labels, &context);
489
+ } else {
490
+ return new IVFRaBitQHeapHandler<CMin<uint16_t, int64_t>>(
491
+ this, n, k, distances, labels, &context);
492
+ }
493
+ }
494
+
495
+ /*********************************************************
496
+ * IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler implementation
497
+ *********************************************************/
498
+
499
+ template <class C>
500
+ IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::IVFRaBitQHeapHandler(
501
+ const IndexIVFRaBitQFastScan* idx,
502
+ size_t nq_val,
503
+ size_t k_val,
504
+ float* distances,
505
+ int64_t* labels,
506
+ const FastScanDistancePostProcessing* ctx)
507
+ : simd_result_handlers::ResultHandlerCompare<C, true>(
508
+ nq_val,
509
+ 0,
510
+ nullptr),
511
+ index(idx),
512
+ heap_distances(distances),
513
+ heap_labels(labels),
514
+ nq(nq_val),
515
+ k(k_val),
516
+ context(ctx) {
517
+ current_list_no = 0;
518
+ probe_indices.clear();
519
+
520
+ // Initialize heaps in constructor (standard pattern from HeapHandler)
521
+ for (int64_t q = 0; q < static_cast<int64_t>(nq); q++) {
522
+ float* heap_dis = heap_distances + q * k;
523
+ int64_t* heap_ids = heap_labels + q * k;
524
+ heap_heapify<Cfloat>(k, heap_dis, heap_ids);
525
+ }
526
+ }
527
+
528
+ template <class C>
529
+ void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::handle(
530
+ size_t q,
531
+ size_t b,
532
+ simd16uint16 d0,
533
+ simd16uint16 d1) {
534
+ // Store the original local query index before adjust_with_origin changes it
535
+ size_t local_q = q;
536
+ this->adjust_with_origin(q, d0, d1);
537
+
538
+ ALIGNED(32) uint16_t d32tab[32];
539
+ d0.store(d32tab);
540
+ d1.store(d32tab + 16);
541
+
542
+ float* const heap_dis = heap_distances + q * k;
543
+ int64_t* const heap_ids = heap_labels + q * k;
544
+
545
+ FAISS_THROW_IF_NOT_FMT(
546
+ !probe_indices.empty() && local_q < probe_indices.size(),
547
+ "set_list_context() must be called before handle() - probe_indices size: %zu, local_q: %zu, global_q: %zu",
548
+ probe_indices.size(),
549
+ local_q,
550
+ q);
551
+
552
+ // Access query factors directly from array via ProcessingContext
553
+ if (!context || !context->query_factors) {
554
+ FAISS_THROW_MSG(
555
+ "Query factors not available: FastScanDistancePostProcessing with query_factors required");
556
+ }
557
+
558
+ // Use probe_rank from probe_indices for compact storage indexing
559
+ size_t probe_rank = probe_indices[local_q];
560
+ size_t nprobe = context->nprobe > 0 ? context->nprobe : index->nprobe;
561
+ size_t storage_idx = q * nprobe + probe_rank;
562
+
563
+ const auto& query_factors = context->query_factors[storage_idx];
564
+
565
+ const float one_a =
566
+ this->normalizers ? (1.0f / this->normalizers[2 * q]) : 1.0f;
567
+ const float bias = this->normalizers ? this->normalizers[2 * q + 1] : 0.0f;
568
+
569
+ uint64_t idx_base = this->j0 + b * 32;
570
+ if (idx_base >= this->ntotal) {
571
+ return;
572
+ }
573
+
574
+ size_t max_positions = std::min<size_t>(32, this->ntotal - idx_base);
575
+ // Process each candidate vector in the SIMD batch
576
+ for (int j = 0; j < static_cast<int>(max_positions); j++) {
577
+ const int64_t result_id = this->adjust_id(b, j);
578
+
579
+ if (result_id < 0) {
580
+ continue;
581
+ }
582
+
583
+ const float normalized_distance = d32tab[j] * one_a + bias;
584
+
585
+ // Get database factors using global index (factors are stored by global
586
+ // index)
587
+ const auto& db_factors = index->factors_storage[result_id];
588
+ float adjusted_distance;
589
+
590
+ // Distance computation depends on quantization mode
591
+ if (index->centered) {
592
+ int64_t int_dot = ((1 << index->qb) - 1) * index->d;
593
+ int_dot -= 2 * static_cast<int64_t>(normalized_distance);
594
+
595
+ adjusted_distance = query_factors.qr_to_c_L2sqr +
596
+ db_factors.or_minus_c_l2sqr -
597
+ 2 * db_factors.dp_multiplier * int_dot *
598
+ query_factors.int_dot_scale;
599
+
600
+ } else {
601
+ float final_dot = normalized_distance - query_factors.c34;
602
+ adjusted_distance = db_factors.or_minus_c_l2sqr +
603
+ query_factors.qr_to_c_L2sqr -
604
+ 2 * db_factors.dp_multiplier * final_dot;
605
+ }
606
+
607
+ // Convert L2 to inner product if needed
608
+ if (query_factors.qr_norm_L2sqr != 0.0f) {
609
+ adjusted_distance =
610
+ -0.5f * (adjusted_distance - query_factors.qr_norm_L2sqr);
611
+ }
612
+
613
+ if (Cfloat::cmp(heap_dis[0], adjusted_distance)) {
614
+ heap_replace_top<Cfloat>(
615
+ k, heap_dis, heap_ids, adjusted_distance, result_id);
616
+ }
617
+ }
618
+ }
619
+
620
+ template <class C>
621
+ void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::set_list_context(
622
+ size_t list_no,
623
+ const std::vector<int>& probe_map) {
624
+ current_list_no = list_no;
625
+ probe_indices = probe_map;
626
+ }
627
+
628
+ template <class C>
629
+ void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::begin(
630
+ const float* norms) {
631
+ this->normalizers = norms;
632
+ }
633
+
634
+ template <class C>
635
+ void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::end() {
636
+ #pragma omp parallel for
637
+ for (int64_t q = 0; q < static_cast<int64_t>(nq); q++) {
638
+ float* heap_dis = heap_distances + q * k;
639
+ int64_t* heap_ids = heap_labels + q * k;
640
+ heap_reorder<Cfloat>(k, heap_dis, heap_ids);
641
+ }
642
+ }
643
+
644
+ // Explicit template instantiations
645
+ template struct IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<
646
+ CMin<uint16_t, int64_t>>;
647
+ template struct IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<
648
+ CMax<uint16_t, int64_t>>;
649
+
650
+ } // namespace faiss