faiss 0.4.2 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/ext/faiss/index.cpp +36 -10
  4. data/ext/faiss/index_binary.cpp +19 -6
  5. data/ext/faiss/kmeans.cpp +6 -6
  6. data/ext/faiss/numo.hpp +273 -123
  7. data/lib/faiss/version.rb +1 -1
  8. data/vendor/faiss/faiss/AutoTune.cpp +2 -3
  9. data/vendor/faiss/faiss/AutoTune.h +1 -1
  10. data/vendor/faiss/faiss/Clustering.cpp +2 -2
  11. data/vendor/faiss/faiss/Clustering.h +2 -2
  12. data/vendor/faiss/faiss/IVFlib.cpp +1 -2
  13. data/vendor/faiss/faiss/IVFlib.h +1 -1
  14. data/vendor/faiss/faiss/Index.h +10 -10
  15. data/vendor/faiss/faiss/Index2Layer.cpp +1 -1
  16. data/vendor/faiss/faiss/Index2Layer.h +2 -2
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +9 -4
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +5 -1
  19. data/vendor/faiss/faiss/IndexBinary.h +7 -7
  20. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +1 -1
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +3 -1
  22. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  23. data/vendor/faiss/faiss/IndexBinaryHash.cpp +3 -3
  24. data/vendor/faiss/faiss/IndexBinaryHash.h +5 -5
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +7 -6
  26. data/vendor/faiss/faiss/IndexFastScan.cpp +125 -49
  27. data/vendor/faiss/faiss/IndexFastScan.h +107 -7
  28. data/vendor/faiss/faiss/IndexFlat.h +1 -1
  29. data/vendor/faiss/faiss/IndexHNSW.cpp +3 -1
  30. data/vendor/faiss/faiss/IndexHNSW.h +1 -1
  31. data/vendor/faiss/faiss/IndexIDMap.cpp +14 -13
  32. data/vendor/faiss/faiss/IndexIDMap.h +6 -6
  33. data/vendor/faiss/faiss/IndexIVF.cpp +1 -1
  34. data/vendor/faiss/faiss/IndexIVF.h +5 -5
  35. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +1 -1
  36. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +9 -3
  37. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +3 -1
  38. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +176 -90
  39. data/vendor/faiss/faiss/IndexIVFFastScan.h +173 -18
  40. data/vendor/faiss/faiss/IndexIVFFlat.cpp +1 -0
  41. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +366 -0
  42. data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +64 -0
  43. data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -1
  44. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
  45. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +134 -2
  46. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -1
  47. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +13 -6
  48. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +1 -0
  49. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +650 -0
  50. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +216 -0
  51. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -1
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
  53. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -1
  54. data/vendor/faiss/faiss/IndexNSG.cpp +1 -1
  55. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +1 -1
  56. data/vendor/faiss/faiss/IndexPQ.h +1 -1
  57. data/vendor/faiss/faiss/IndexPQFastScan.cpp +6 -2
  58. data/vendor/faiss/faiss/IndexPQFastScan.h +5 -1
  59. data/vendor/faiss/faiss/IndexRaBitQ.cpp +13 -10
  60. data/vendor/faiss/faiss/IndexRaBitQ.h +7 -2
  61. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +586 -0
  62. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +149 -0
  63. data/vendor/faiss/faiss/IndexShards.cpp +1 -1
  64. data/vendor/faiss/faiss/MatrixStats.cpp +3 -3
  65. data/vendor/faiss/faiss/MetricType.h +1 -1
  66. data/vendor/faiss/faiss/VectorTransform.h +2 -2
  67. data/vendor/faiss/faiss/clone_index.cpp +3 -1
  68. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  69. data/vendor/faiss/faiss/gpu/GpuIndex.h +11 -11
  70. data/vendor/faiss/faiss/gpu/GpuIndexBinaryCagra.h +1 -1
  71. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +1 -1
  72. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +10 -6
  73. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -0
  74. data/vendor/faiss/faiss/gpu/test/TestGpuIcmEncoder.cpp +7 -0
  75. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +1 -1
  76. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +1 -1
  77. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -1
  78. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +2 -2
  79. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -1
  80. data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
  81. data/vendor/faiss/faiss/impl/DistanceComputer.h +3 -3
  82. data/vendor/faiss/faiss/impl/FastScanDistancePostProcessing.h +53 -0
  83. data/vendor/faiss/faiss/impl/HNSW.cpp +1 -1
  84. data/vendor/faiss/faiss/impl/HNSW.h +4 -4
  85. data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
  86. data/vendor/faiss/faiss/impl/IDSelector.h +1 -1
  87. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +4 -4
  88. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +1 -1
  89. data/vendor/faiss/faiss/impl/LookupTableScaler.h +1 -1
  90. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  91. data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
  92. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  93. data/vendor/faiss/faiss/impl/PanoramaStats.cpp +33 -0
  94. data/vendor/faiss/faiss/impl/PanoramaStats.h +38 -0
  95. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +5 -5
  96. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +1 -1
  97. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  98. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -0
  99. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  100. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +246 -0
  101. data/vendor/faiss/faiss/impl/RaBitQUtils.h +153 -0
  102. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +54 -158
  103. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +2 -1
  104. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
  105. data/vendor/faiss/faiss/impl/ResultHandler.h +4 -4
  106. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1 -1
  107. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +1 -1
  108. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +7 -4
  109. data/vendor/faiss/faiss/impl/index_read.cpp +87 -3
  110. data/vendor/faiss/faiss/impl/index_write.cpp +73 -3
  111. data/vendor/faiss/faiss/impl/io.cpp +2 -2
  112. data/vendor/faiss/faiss/impl/io.h +4 -4
  113. data/vendor/faiss/faiss/impl/kmeans1d.cpp +1 -1
  114. data/vendor/faiss/faiss/impl/kmeans1d.h +1 -1
  115. data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
  116. data/vendor/faiss/faiss/impl/mapped_io.cpp +2 -2
  117. data/vendor/faiss/faiss/impl/mapped_io.h +4 -3
  118. data/vendor/faiss/faiss/impl/maybe_owned_vector.h +8 -1
  119. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +30 -4
  120. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +14 -8
  121. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +5 -6
  122. data/vendor/faiss/faiss/impl/simd_result_handlers.h +55 -11
  123. data/vendor/faiss/faiss/impl/zerocopy_io.h +1 -1
  124. data/vendor/faiss/faiss/index_factory.cpp +43 -1
  125. data/vendor/faiss/faiss/index_factory.h +1 -1
  126. data/vendor/faiss/faiss/index_io.h +1 -1
  127. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +205 -0
  128. data/vendor/faiss/faiss/invlists/InvertedLists.h +62 -0
  129. data/vendor/faiss/faiss/utils/AlignedTable.h +1 -1
  130. data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
  131. data/vendor/faiss/faiss/utils/Heap.h +3 -3
  132. data/vendor/faiss/faiss/utils/NeuralNet.cpp +1 -1
  133. data/vendor/faiss/faiss/utils/NeuralNet.h +3 -3
  134. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
  135. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
  136. data/vendor/faiss/faiss/utils/approx_topk/mode.h +1 -1
  137. data/vendor/faiss/faiss/utils/distances.h +2 -2
  138. data/vendor/faiss/faiss/utils/extra_distances-inl.h +3 -1
  139. data/vendor/faiss/faiss/utils/hamming-inl.h +2 -0
  140. data/vendor/faiss/faiss/utils/hamming.cpp +7 -6
  141. data/vendor/faiss/faiss/utils/hamming.h +1 -1
  142. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -2
  143. data/vendor/faiss/faiss/utils/partitioning.cpp +5 -5
  144. data/vendor/faiss/faiss/utils/partitioning.h +2 -2
  145. data/vendor/faiss/faiss/utils/rabitq_simd.h +222 -336
  146. data/vendor/faiss/faiss/utils/random.cpp +1 -1
  147. data/vendor/faiss/faiss/utils/simdlib_avx2.h +1 -1
  148. data/vendor/faiss/faiss/utils/simdlib_avx512.h +1 -1
  149. data/vendor/faiss/faiss/utils/simdlib_neon.h +2 -2
  150. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +1 -1
  151. data/vendor/faiss/faiss/utils/utils.cpp +5 -2
  152. data/vendor/faiss/faiss/utils/utils.h +2 -2
  153. metadata +14 -3
@@ -0,0 +1,586 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/IndexRaBitQFastScan.h>
9
+ #include <faiss/impl/FastScanDistancePostProcessing.h>
10
+ #include <faiss/impl/RaBitQUtils.h>
11
+ #include <faiss/impl/pq4_fast_scan.h>
12
+ #include <faiss/utils/utils.h>
13
+ #include <algorithm>
14
+ #include <cmath>
15
+
16
+ namespace faiss {
17
+
18
+ static inline size_t roundup(size_t a, size_t b) {
19
+ return (a + b - 1) / b * b;
20
+ }
21
+
22
+ IndexRaBitQFastScan::IndexRaBitQFastScan() = default;
23
+
24
+ IndexRaBitQFastScan::IndexRaBitQFastScan(idx_t d, MetricType metric, int bbs)
25
+ : rabitq(d, metric) {
26
+ // RaBitQ-specific validation
27
+ FAISS_THROW_IF_NOT_MSG(d > 0, "Dimension must be positive");
28
+ FAISS_THROW_IF_NOT_MSG(
29
+ metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT,
30
+ "RaBitQ FastScan only supports L2 and Inner Product metrics");
31
+
32
+ // RaBitQ uses 1 bit per dimension packed into 4-bit FastScan sub-quantizers
33
+ // Each FastScan sub-quantizer handles 4 RaBitQ dimensions
34
+ const size_t M_fastscan = (d + 3) / 4;
35
+ constexpr size_t nbits_fastscan = 4;
36
+
37
+ // init_fastscan will validate bbs % 32 == 0 and nbits_fastscan == 4
38
+ init_fastscan(static_cast<int>(d), M_fastscan, nbits_fastscan, metric, bbs);
39
+
40
+ // Override code_size to include space for factors after bit patterns
41
+ // RaBitQ stores 1 bit per dimension, requiring (d + 7) / 8 bytes
42
+ const size_t bit_pattern_size = (d + 7) / 8;
43
+ code_size = bit_pattern_size + sizeof(FactorsData);
44
+
45
+ // Set RaBitQ-specific parameters
46
+ qb = 8;
47
+ center.resize(d, 0.0f);
48
+
49
+ // Pre-allocate storage vectors for efficiency
50
+ factors_storage.clear();
51
+ }
52
+
53
+ IndexRaBitQFastScan::IndexRaBitQFastScan(const IndexRaBitQ& orig, int bbs)
54
+ : rabitq(orig.rabitq) {
55
+ // RaBitQ-specific validation
56
+ FAISS_THROW_IF_NOT_MSG(orig.d > 0, "Dimension must be positive");
57
+ FAISS_THROW_IF_NOT_MSG(
58
+ orig.metric_type == METRIC_L2 ||
59
+ orig.metric_type == METRIC_INNER_PRODUCT,
60
+ "RaBitQ FastScan only supports L2 and Inner Product metrics");
61
+
62
+ // RaBitQ uses 1 bit per dimension packed into 4-bit FastScan sub-quantizers
63
+ // Each FastScan sub-quantizer handles 4 RaBitQ dimensions
64
+ const size_t M_fastscan = (orig.d + 3) / 4;
65
+ constexpr size_t nbits_fastscan = 4;
66
+
67
+ // Initialize FastScan base with the original index's parameters
68
+ init_fastscan(
69
+ static_cast<int>(orig.d),
70
+ M_fastscan,
71
+ nbits_fastscan,
72
+ orig.metric_type,
73
+ bbs);
74
+
75
+ // Override code_size to include space for factors after bit patterns
76
+ // RaBitQ stores 1 bit per dimension, requiring (d + 7) / 8 bytes
77
+ const size_t bit_pattern_size = (orig.d + 7) / 8;
78
+ code_size = bit_pattern_size + sizeof(FactorsData);
79
+
80
+ // Copy properties from original index
81
+ ntotal = orig.ntotal;
82
+ ntotal2 = roundup(ntotal, bbs);
83
+ is_trained = orig.is_trained;
84
+ orig_codes = orig.codes.data();
85
+ qb = orig.qb;
86
+ centered = orig.centered;
87
+ center = orig.center;
88
+
89
+ // If the original index has data, extract factors and pack codes
90
+ if (ntotal > 0) {
91
+ // Allocate space for factors
92
+ factors_storage.resize(ntotal);
93
+
94
+ // Extract factors from original codes for each vector
95
+ const float* centroid_data = center.data();
96
+
97
+ // Use the original RaBitQ quantizer to decode and compute factors
98
+ std::vector<float> decoded_vectors(ntotal * orig.d);
99
+ orig.sa_decode(ntotal, orig.codes.data(), decoded_vectors.data());
100
+
101
+ for (idx_t i = 0; i < ntotal; i++) {
102
+ FactorsData& fac = factors_storage[i];
103
+ const float* x_row = decoded_vectors.data() + i * orig.d;
104
+
105
+ // Use shared utilities for computing factors
106
+ fac = rabitq_utils::compute_vector_factors(
107
+ x_row, orig.d, centroid_data, orig.metric_type);
108
+ }
109
+
110
+ // Convert RaBitQ bit format to FastScan 4-bit sub-quantizer format
111
+ // This follows the same pattern as IndexPQFastScan constructor
112
+ AlignedTable<uint8_t> fastscan_codes(ntotal * code_size);
113
+ memset(fastscan_codes.get(), 0, ntotal * code_size);
114
+
115
+ // Convert from RaBitQ 1-bit-per-dimension to FastScan
116
+ // 4-bit-per-sub-quantizer
117
+ for (idx_t i = 0; i < ntotal; i++) {
118
+ const uint8_t* orig_code = orig.codes.data() + i * orig.code_size;
119
+ uint8_t* fs_code = fastscan_codes.get() + i * code_size;
120
+
121
+ // Convert each dimension's bit (same logic as compute_codes)
122
+ for (size_t j = 0; j < orig.d; j++) {
123
+ // Extract bit from original RaBitQ format
124
+ const size_t orig_byte_idx = j / 8;
125
+ const size_t orig_bit_offset = j % 8;
126
+ const bool bit_value =
127
+ (orig_code[orig_byte_idx] >> orig_bit_offset) & 1;
128
+
129
+ // Use RaBitQUtils for consistent bit setting
130
+ if (bit_value) {
131
+ rabitq_utils::set_bit_fastscan(fs_code, j);
132
+ }
133
+ }
134
+ }
135
+
136
+ // Pack the converted codes using pq4_pack_codes with custom stride
137
+ codes.resize(ntotal2 * M2 / 2);
138
+ pq4_pack_codes(
139
+ fastscan_codes.get(),
140
+ ntotal,
141
+ M,
142
+ ntotal2,
143
+ bbs,
144
+ M2,
145
+ codes.get(),
146
+ code_size);
147
+ }
148
+ }
149
+
150
+ void IndexRaBitQFastScan::train(idx_t n, const float* x) {
151
+ // compute a centroid
152
+ std::vector<float> centroid(d, 0);
153
+ for (int64_t i = 0; i < static_cast<int64_t>(n); i++) {
154
+ for (size_t j = 0; j < d; j++) {
155
+ centroid[j] += x[i * d + j];
156
+ }
157
+ }
158
+
159
+ if (n != 0) {
160
+ for (size_t j = 0; j < d; j++) {
161
+ centroid[j] /= (float)n;
162
+ }
163
+ }
164
+
165
+ center = std::move(centroid);
166
+
167
+ rabitq.train(n, x);
168
+ is_trained = true;
169
+ }
170
+
171
+ void IndexRaBitQFastScan::add(idx_t n, const float* x) {
172
+ FAISS_THROW_IF_NOT(is_trained);
173
+
174
+ // Handle blocking to avoid excessive allocations
175
+ constexpr idx_t bs = 65536;
176
+ if (n > bs) {
177
+ for (idx_t i0 = 0; i0 < n; i0 += bs) {
178
+ idx_t i1 = std::min(n, i0 + bs);
179
+ if (verbose) {
180
+ printf("IndexRaBitQFastScan::add %zd/%zd\n",
181
+ size_t(i1),
182
+ size_t(n));
183
+ }
184
+ add(i1 - i0, x + i0 * d);
185
+ }
186
+ return;
187
+ }
188
+ InterruptCallback::check();
189
+
190
+ // Create codes with embedded factors using our compute_codes
191
+ AlignedTable<uint8_t> tmp_codes(n * code_size);
192
+ compute_codes(tmp_codes.get(), n, x);
193
+
194
+ // Extract and store factors from embedded codes for handler access
195
+ const size_t bit_pattern_size = (d + 7) / 8;
196
+ factors_storage.resize(ntotal + n);
197
+ for (idx_t i = 0; i < n; i++) {
198
+ const uint8_t* code = tmp_codes.get() + i * code_size;
199
+ const uint8_t* factors_ptr = code + bit_pattern_size;
200
+ const FactorsData& embedded_factors =
201
+ *reinterpret_cast<const FactorsData*>(factors_ptr);
202
+ factors_storage[ntotal + i] = embedded_factors;
203
+ }
204
+
205
+ // Resize main storage (same logic as parent)
206
+ ntotal2 = roundup(ntotal + n, bbs);
207
+ size_t new_size = ntotal2 * M2 / 2; // assume nbits = 4
208
+ size_t old_size = codes.size();
209
+ if (new_size > old_size) {
210
+ codes.resize(new_size);
211
+ memset(codes.get() + old_size, 0, new_size - old_size);
212
+ }
213
+
214
+ // Use our custom packing function with correct stride
215
+ pq4_pack_codes_range(
216
+ tmp_codes.get(),
217
+ M, // Number of sub-quantizers (bit patterns only)
218
+ ntotal,
219
+ ntotal + n, // Range to pack
220
+ bbs,
221
+ M2, // Block parameters
222
+ codes.get(), // Output
223
+ code_size); // CUSTOM STRIDE: includes factor space
224
+
225
+ ntotal += n;
226
+ }
227
+
228
+ void IndexRaBitQFastScan::compute_codes(uint8_t* codes, idx_t n, const float* x)
229
+ const {
230
+ FAISS_ASSERT(codes != nullptr);
231
+ FAISS_ASSERT(x != nullptr);
232
+ FAISS_ASSERT(
233
+ (metric_type == MetricType::METRIC_L2 ||
234
+ metric_type == MetricType::METRIC_INNER_PRODUCT));
235
+ if (n == 0) {
236
+ return;
237
+ }
238
+
239
+ // Hoist loop-invariant computations
240
+ const float* centroid_data = center.data();
241
+ const size_t bit_pattern_size = (d + 7) / 8;
242
+
243
+ memset(codes, 0, n * code_size);
244
+
245
+ #pragma omp parallel for if (n > 1000)
246
+ for (int64_t i = 0; i < n; i++) {
247
+ uint8_t* const code = codes + i * code_size;
248
+ const float* const x_row = x + i * d;
249
+
250
+ // Pack bits directly into FastScan format
251
+ for (size_t j = 0; j < d; j++) {
252
+ const float x_val = x_row[j];
253
+ const float centroid_val = centroid_data ? centroid_data[j] : 0.0f;
254
+ const float or_minus_c = x_val - centroid_val;
255
+ const bool xb = (or_minus_c > 0.0f);
256
+
257
+ if (xb) {
258
+ rabitq_utils::set_bit_fastscan(code, j);
259
+ }
260
+ }
261
+
262
+ // Calculate and append factors after the bit data
263
+ FactorsData factors = rabitq_utils::compute_vector_factors(
264
+ x_row, d, centroid_data, metric_type);
265
+
266
+ // Append factors at the end of the code
267
+ uint8_t* factors_ptr = code + bit_pattern_size;
268
+ *reinterpret_cast<FactorsData*>(factors_ptr) = factors;
269
+ }
270
+ }
271
+
272
+ void IndexRaBitQFastScan::compute_float_LUT(
273
+ float* lut,
274
+ idx_t n,
275
+ const float* x,
276
+ const FastScanDistancePostProcessing& context) const {
277
+ FAISS_THROW_IF_NOT(is_trained);
278
+
279
+ // Pre-allocate working buffers to avoid repeated allocations
280
+ std::vector<float> rotated_q(d);
281
+ std::vector<uint8_t> rotated_qq(d);
282
+
283
+ // Compute lookup tables for FastScan SIMD operations
284
+ // For each query vector, computes distance contributions for all
285
+ // possible 4-bit codes per sub-quantizer. Also computes and stores
286
+ // query factors for distance reconstruction.
287
+ for (idx_t i = 0; i < n; i++) {
288
+ const float* query = x + i * d;
289
+
290
+ // Compute query factors and store in array if available
291
+ rabitq_utils::QueryFactorsData query_factors_data =
292
+ rabitq_utils::compute_query_factors(
293
+ query,
294
+ d,
295
+ center.data(),
296
+ qb,
297
+ centered,
298
+ metric_type,
299
+ rotated_q,
300
+ rotated_qq);
301
+
302
+ // Store query factors in context array if provided
303
+ if (context.query_factors) {
304
+ context.query_factors[i] = query_factors_data;
305
+ }
306
+
307
+ // Create lookup table storing distance contributions for all possible
308
+ // 4-bit codes per sub-quantizer for FastScan SIMD operations
309
+ float* query_lut = lut + i * M * 16;
310
+
311
+ if (centered) {
312
+ // For centered mode, we use the signed odd integer quantization
313
+ // scheme.
314
+ // Formula:
315
+ // int_dot = ((1 << qb) - 1) * d - 2 * xor_dot_product
316
+ // We precompute the XOR contribution for each
317
+ // sub-quantizer
318
+
319
+ const float max_code_value = (1 << qb) - 1;
320
+
321
+ for (size_t m = 0; m < M; m++) {
322
+ const size_t dim_start = m * 4;
323
+
324
+ for (int code_val = 0; code_val < 16; code_val++) {
325
+ float xor_contribution = 0.0f;
326
+
327
+ // Process 4 bits per sub-quantizer
328
+ for (size_t dim_offset = 0; dim_offset < 4; dim_offset++) {
329
+ const size_t dim_idx = dim_start + dim_offset;
330
+
331
+ if (dim_idx < d) {
332
+ const bool db_bit = (code_val >> dim_offset) & 1;
333
+ const float query_value = rotated_qq[dim_idx];
334
+
335
+ // XOR contribution:
336
+ // If db_bit == 0: XOR result = query_value
337
+ // If db_bit == 1: XOR result = (2^qb - 1) -
338
+ // query_value
339
+ xor_contribution += db_bit
340
+ ? (max_code_value - query_value)
341
+ : query_value;
342
+ }
343
+ }
344
+
345
+ // Store the XOR contribution (will be scaled by -2 *
346
+ // int_dot_scale during distance computation)
347
+ query_lut[m * 16 + code_val] = xor_contribution;
348
+ }
349
+ }
350
+
351
+ } else {
352
+ // For non-centered quantization, use traditional AND dot
353
+ // product Compute lookup table entries by processing popcount
354
+ // and inner product together
355
+ for (size_t m = 0; m < M; m++) {
356
+ const size_t dim_start = m * 4;
357
+
358
+ for (int code_val = 0; code_val < 16; code_val++) {
359
+ float inner_product = 0.0f;
360
+ int popcount = 0;
361
+
362
+ // Process 4 bits per sub-quantizer
363
+ for (size_t dim_offset = 0; dim_offset < 4; dim_offset++) {
364
+ const size_t dim_idx = dim_start + dim_offset;
365
+
366
+ if (dim_idx < d && ((code_val >> dim_offset) & 1)) {
367
+ inner_product += rotated_qq[dim_idx];
368
+ popcount++;
369
+ }
370
+ }
371
+
372
+ // Store pre-computed distance contribution
373
+ query_lut[m * 16 + code_val] =
374
+ query_factors_data.c1 * inner_product +
375
+ query_factors_data.c2 * popcount;
376
+ }
377
+ }
378
+ }
379
+ }
380
+ }
381
+
382
+ void IndexRaBitQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
383
+ const {
384
+ const float* centroid_in =
385
+ (center.data() == nullptr) ? nullptr : center.data();
386
+ const uint8_t* codes = bytes;
387
+ FAISS_ASSERT(codes != nullptr);
388
+ FAISS_ASSERT(x != nullptr);
389
+
390
+ const float inv_d_sqrt = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
391
+ const size_t bit_pattern_size = (d + 7) / 8;
392
+
393
+ #pragma omp parallel for if (n > 1000)
394
+ for (int64_t i = 0; i < n; i++) {
395
+ // Access code using correct FastScan format
396
+ const uint8_t* code = codes + i * code_size;
397
+
398
+ // Extract factors directly from embedded codes
399
+ const uint8_t* factors_ptr = code + bit_pattern_size;
400
+ const FactorsData& fac =
401
+ *reinterpret_cast<const FactorsData*>(factors_ptr);
402
+
403
+ for (size_t j = 0; j < d; j++) {
404
+ // Use RaBitQUtils for consistent bit extraction
405
+ bool bit_value = rabitq_utils::extract_bit_fastscan(code, j);
406
+ float bit = bit_value ? 1.0f : 0.0f;
407
+
408
+ // Compute the output using RaBitQ reconstruction formula
409
+ x[i * d + j] = (bit - 0.5f) * fac.dp_multiplier * 2 * inv_d_sqrt +
410
+ ((centroid_in == nullptr) ? 0 : centroid_in[j]);
411
+ }
412
+ }
413
+ }
414
+
415
+ void IndexRaBitQFastScan::search(
416
+ idx_t n,
417
+ const float* x,
418
+ idx_t k,
419
+ float* distances,
420
+ idx_t* labels,
421
+ const SearchParameters* params) const {
422
+ FAISS_THROW_IF_NOT_MSG(
423
+ !params, "search params not supported for this index");
424
+
425
+ // Create query factors array on stack - memory managed by caller
426
+ std::vector<rabitq_utils::QueryFactorsData> query_factors_storage(n);
427
+
428
+ // Use the faster search_dispatch_implem flow from IndexFastScan
429
+ // Pass the query factors array - factors will be computed during LUT
430
+ // computation
431
+ FastScanDistancePostProcessing context;
432
+ context.query_factors = query_factors_storage.data();
433
+ if (metric_type == METRIC_L2) {
434
+ search_dispatch_implem<true>(n, x, k, distances, labels, context);
435
+ } else {
436
+ search_dispatch_implem<false>(n, x, k, distances, labels, context);
437
+ }
438
+ }
439
+
440
+ // Template implementations for RaBitQHeapHandler
441
+ template <class C, bool with_id_map>
442
+ RaBitQHeapHandler<C, with_id_map>::RaBitQHeapHandler(
443
+ const IndexRaBitQFastScan* index,
444
+ size_t nq_val,
445
+ size_t k_val,
446
+ float* distances,
447
+ int64_t* labels,
448
+ const IDSelector* sel_in,
449
+ const FastScanDistancePostProcessing& ctx)
450
+ : RHC(nq_val, index->ntotal, sel_in),
451
+ rabitq_index(index),
452
+ heap_distances(distances),
453
+ heap_labels(labels),
454
+ nq(nq_val),
455
+ k(k_val),
456
+ context(ctx) {
457
+ // Initialize heaps for all queries in constructor
458
+ // This allows us to support direct normalizer assignment
459
+ #pragma omp parallel for if (nq > 100)
460
+ for (int64_t q = 0; q < static_cast<int64_t>(nq); q++) {
461
+ float* heap_dis = heap_distances + q * k;
462
+ int64_t* heap_ids = heap_labels + q * k;
463
+ heap_heapify<Cfloat>(k, heap_dis, heap_ids);
464
+ }
465
+ }
466
+
467
+ template <class C, bool with_id_map>
468
+ void RaBitQHeapHandler<C, with_id_map>::handle(
469
+ size_t q,
470
+ size_t b,
471
+ simd16uint16 d0,
472
+ simd16uint16 d1) {
473
+ ALIGNED(32) uint16_t d32tab[32];
474
+ d0.store(d32tab);
475
+ d1.store(d32tab + 16);
476
+
477
+ // Get heap pointers and query factors (computed once per batch)
478
+ float* const heap_dis = heap_distances + q * k;
479
+ int64_t* const heap_ids = heap_labels + q * k;
480
+
481
+ // Access query factors from query_factors pointer
482
+ rabitq_utils::QueryFactorsData query_factors_data = {};
483
+ if (context.query_factors) {
484
+ query_factors_data = context.query_factors[q];
485
+ }
486
+
487
+ // Compute normalizers once per batch
488
+ const float one_a = normalizers ? (1.0f / normalizers[2 * q]) : 1.0f;
489
+ const float bias = normalizers ? normalizers[2 * q + 1] : 0.0f;
490
+
491
+ // Compute loop bounds to avoid redundant bounds checking
492
+ const size_t base_db_idx = this->j0 + b * 32;
493
+ const size_t max_vectors = (base_db_idx < rabitq_index->ntotal)
494
+ ? std::min<size_t>(32, rabitq_index->ntotal - base_db_idx)
495
+ : 0;
496
+
497
+ // Process distances in batch
498
+ for (size_t i = 0; i < max_vectors; i++) {
499
+ const size_t db_idx = base_db_idx + i;
500
+
501
+ // Normalize distance from LUT lookup
502
+ const float normalized_distance = d32tab[i] * one_a + bias;
503
+
504
+ // Access factors from storage (populated from embedded codes during
505
+ // add())
506
+ const auto& db_factors = rabitq_index->factors_storage[db_idx];
507
+
508
+ float adjusted_distance;
509
+
510
+ if (rabitq_index->centered) {
511
+ // For centered mode: normalized_distance contains the raw XOR
512
+ // contribution. Apply the signed odd integer quantization formula:
513
+ // int_dot = ((1 << qb) - 1) * d - 2 * xor_dot_product
514
+ int64_t int_dot = ((1 << rabitq_index->qb) - 1) * rabitq_index->d;
515
+ int_dot -= 2 * static_cast<int64_t>(normalized_distance);
516
+
517
+ adjusted_distance = query_factors_data.qr_to_c_L2sqr +
518
+ db_factors.or_minus_c_l2sqr -
519
+ 2 * db_factors.dp_multiplier * int_dot *
520
+ query_factors_data.int_dot_scale;
521
+ } else {
522
+ // For non-centered quantization: use traditional formula
523
+ float final_dot = normalized_distance - query_factors_data.c34;
524
+ adjusted_distance = db_factors.or_minus_c_l2sqr +
525
+ query_factors_data.qr_to_c_L2sqr -
526
+ 2 * db_factors.dp_multiplier * final_dot;
527
+ }
528
+
529
+ // Apply inner product correction if needed
530
+ if (query_factors_data.qr_norm_L2sqr != 0.0f) {
531
+ adjusted_distance = -0.5f *
532
+ (adjusted_distance - query_factors_data.qr_norm_L2sqr);
533
+ }
534
+
535
+ // Add to heap if better than current worst
536
+ if (Cfloat::cmp(heap_dis[0], adjusted_distance)) {
537
+ heap_replace_top<Cfloat>(
538
+ k, heap_dis, heap_ids, adjusted_distance, db_idx);
539
+ }
540
+ }
541
+ }
542
+
543
+ template <class C, bool with_id_map>
544
+ void RaBitQHeapHandler<C, with_id_map>::begin(const float* norms) {
545
+ normalizers = norms;
546
+ // Heap initialization is now done in constructor
547
+ }
548
+
549
+ template <class C, bool with_id_map>
550
+ void RaBitQHeapHandler<C, with_id_map>::end() {
551
+ // Reorder final results
552
+ #pragma omp parallel for if (nq > 100)
553
+ for (int64_t q = 0; q < static_cast<int64_t>(nq); q++) {
554
+ float* heap_dis = heap_distances + q * k;
555
+ int64_t* heap_ids = heap_labels + q * k;
556
+ heap_reorder<Cfloat>(k, heap_dis, heap_ids);
557
+ }
558
+ }
559
+
560
+ // Implementation of virtual make_knn_handler method
561
+ void* IndexRaBitQFastScan::make_knn_handler(
562
+ bool is_max,
563
+ int /*impl*/,
564
+ idx_t n,
565
+ idx_t k,
566
+ size_t /*ntotal*/,
567
+ float* distances,
568
+ idx_t* labels,
569
+ const IDSelector* sel,
570
+ const FastScanDistancePostProcessing& context) const {
571
+ if (is_max) {
572
+ return new RaBitQHeapHandler<CMax<uint16_t, int>, false>(
573
+ this, n, k, distances, labels, sel, context);
574
+ } else {
575
+ return new RaBitQHeapHandler<CMin<uint16_t, int>, false>(
576
+ this, n, k, distances, labels, sel, context);
577
+ }
578
+ }
579
+
580
+ // Explicit template instantiations for the required comparator types
581
+ template struct RaBitQHeapHandler<CMin<uint16_t, int>, false>;
582
+ template struct RaBitQHeapHandler<CMax<uint16_t, int>, false>;
583
+ template struct RaBitQHeapHandler<CMin<uint16_t, int>, true>;
584
+ template struct RaBitQHeapHandler<CMax<uint16_t, int>, true>;
585
+
586
+ } // namespace faiss