faiss 0.4.3 → 0.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +10 -0
  3. data/README.md +2 -0
  4. data/ext/faiss/index.cpp +33 -6
  5. data/ext/faiss/index_binary.cpp +17 -4
  6. data/ext/faiss/kmeans.cpp +6 -6
  7. data/lib/faiss/version.rb +1 -1
  8. data/vendor/faiss/faiss/AutoTune.cpp +2 -3
  9. data/vendor/faiss/faiss/AutoTune.h +1 -1
  10. data/vendor/faiss/faiss/Clustering.cpp +2 -2
  11. data/vendor/faiss/faiss/Clustering.h +2 -2
  12. data/vendor/faiss/faiss/IVFlib.cpp +26 -51
  13. data/vendor/faiss/faiss/IVFlib.h +1 -1
  14. data/vendor/faiss/faiss/Index.cpp +11 -0
  15. data/vendor/faiss/faiss/Index.h +34 -11
  16. data/vendor/faiss/faiss/Index2Layer.cpp +1 -1
  17. data/vendor/faiss/faiss/Index2Layer.h +2 -2
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +1 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +9 -4
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +5 -1
  21. data/vendor/faiss/faiss/IndexBinary.h +7 -7
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +1 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +8 -2
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +3 -3
  26. data/vendor/faiss/faiss/IndexBinaryHash.h +5 -5
  27. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +7 -6
  28. data/vendor/faiss/faiss/IndexFastScan.cpp +125 -49
  29. data/vendor/faiss/faiss/IndexFastScan.h +102 -7
  30. data/vendor/faiss/faiss/IndexFlat.cpp +374 -4
  31. data/vendor/faiss/faiss/IndexFlat.h +81 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +93 -2
  33. data/vendor/faiss/faiss/IndexHNSW.h +58 -2
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +14 -13
  35. data/vendor/faiss/faiss/IndexIDMap.h +6 -6
  36. data/vendor/faiss/faiss/IndexIVF.cpp +1 -1
  37. data/vendor/faiss/faiss/IndexIVF.h +5 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +1 -1
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +9 -3
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +3 -1
  41. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +176 -90
  42. data/vendor/faiss/faiss/IndexIVFFastScan.h +173 -18
  43. data/vendor/faiss/faiss/IndexIVFFlat.cpp +1 -0
  44. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +251 -0
  45. data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +64 -0
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -1
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +134 -2
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -1
  50. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +99 -8
  51. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +4 -1
  52. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +828 -0
  53. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +252 -0
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -1
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
  56. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -1
  57. data/vendor/faiss/faiss/IndexNSG.cpp +1 -1
  58. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +1 -1
  59. data/vendor/faiss/faiss/IndexPQ.cpp +4 -1
  60. data/vendor/faiss/faiss/IndexPQ.h +1 -1
  61. data/vendor/faiss/faiss/IndexPQFastScan.cpp +6 -2
  62. data/vendor/faiss/faiss/IndexPQFastScan.h +5 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +14 -0
  64. data/vendor/faiss/faiss/IndexPreTransform.h +9 -0
  65. data/vendor/faiss/faiss/IndexRaBitQ.cpp +96 -13
  66. data/vendor/faiss/faiss/IndexRaBitQ.h +11 -2
  67. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +731 -0
  68. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +175 -0
  69. data/vendor/faiss/faiss/IndexRefine.cpp +49 -0
  70. data/vendor/faiss/faiss/IndexRefine.h +17 -0
  71. data/vendor/faiss/faiss/IndexShards.cpp +1 -1
  72. data/vendor/faiss/faiss/MatrixStats.cpp +3 -3
  73. data/vendor/faiss/faiss/MetricType.h +1 -1
  74. data/vendor/faiss/faiss/VectorTransform.h +2 -2
  75. data/vendor/faiss/faiss/clone_index.cpp +5 -1
  76. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  77. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +3 -1
  78. data/vendor/faiss/faiss/gpu/GpuIndex.h +11 -11
  79. data/vendor/faiss/faiss/gpu/GpuIndexBinaryCagra.h +1 -1
  80. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +1 -1
  81. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +11 -7
  82. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +1 -1
  83. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -0
  84. data/vendor/faiss/faiss/gpu/test/TestGpuIcmEncoder.cpp +7 -0
  85. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +1 -1
  86. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +1 -1
  87. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -1
  88. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +2 -2
  89. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -1
  90. data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
  91. data/vendor/faiss/faiss/impl/DistanceComputer.h +77 -6
  92. data/vendor/faiss/faiss/impl/FastScanDistancePostProcessing.h +53 -0
  93. data/vendor/faiss/faiss/impl/HNSW.cpp +295 -16
  94. data/vendor/faiss/faiss/impl/HNSW.h +35 -6
  95. data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
  96. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  97. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +4 -4
  98. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +1 -1
  99. data/vendor/faiss/faiss/impl/LookupTableScaler.h +1 -1
  100. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  101. data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
  102. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  103. data/vendor/faiss/faiss/impl/Panorama.cpp +193 -0
  104. data/vendor/faiss/faiss/impl/Panorama.h +204 -0
  105. data/vendor/faiss/faiss/impl/PanoramaStats.cpp +33 -0
  106. data/vendor/faiss/faiss/impl/PanoramaStats.h +38 -0
  107. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +5 -5
  108. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +1 -1
  109. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  110. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -0
  111. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  112. data/vendor/faiss/faiss/impl/RaBitQStats.cpp +29 -0
  113. data/vendor/faiss/faiss/impl/RaBitQStats.h +56 -0
  114. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +294 -0
  115. data/vendor/faiss/faiss/impl/RaBitQUtils.h +330 -0
  116. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +304 -223
  117. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +72 -4
  118. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +362 -0
  119. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +112 -0
  120. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
  121. data/vendor/faiss/faiss/impl/ResultHandler.h +4 -4
  122. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +7 -10
  123. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +2 -4
  124. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +7 -4
  125. data/vendor/faiss/faiss/impl/index_read.cpp +238 -10
  126. data/vendor/faiss/faiss/impl/index_write.cpp +212 -19
  127. data/vendor/faiss/faiss/impl/io.cpp +2 -2
  128. data/vendor/faiss/faiss/impl/io.h +4 -4
  129. data/vendor/faiss/faiss/impl/kmeans1d.cpp +1 -1
  130. data/vendor/faiss/faiss/impl/kmeans1d.h +1 -1
  131. data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
  132. data/vendor/faiss/faiss/impl/mapped_io.cpp +2 -2
  133. data/vendor/faiss/faiss/impl/mapped_io.h +4 -3
  134. data/vendor/faiss/faiss/impl/maybe_owned_vector.h +8 -1
  135. data/vendor/faiss/faiss/impl/platform_macros.h +12 -0
  136. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +30 -4
  137. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +14 -8
  138. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +5 -6
  139. data/vendor/faiss/faiss/impl/simd_result_handlers.h +55 -11
  140. data/vendor/faiss/faiss/impl/svs_io.cpp +86 -0
  141. data/vendor/faiss/faiss/impl/svs_io.h +67 -0
  142. data/vendor/faiss/faiss/impl/zerocopy_io.h +1 -1
  143. data/vendor/faiss/faiss/index_factory.cpp +217 -8
  144. data/vendor/faiss/faiss/index_factory.h +1 -1
  145. data/vendor/faiss/faiss/index_io.h +1 -1
  146. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +1 -1
  147. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  148. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +115 -1
  149. data/vendor/faiss/faiss/invlists/InvertedLists.h +46 -0
  150. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
  151. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  152. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +261 -0
  153. data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +117 -0
  154. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +66 -0
  155. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +245 -0
  156. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +137 -0
  157. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +39 -0
  158. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +42 -0
  159. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +149 -0
  160. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +58 -0
  161. data/vendor/faiss/faiss/utils/AlignedTable.h +1 -1
  162. data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
  163. data/vendor/faiss/faiss/utils/Heap.h +3 -3
  164. data/vendor/faiss/faiss/utils/NeuralNet.cpp +1 -1
  165. data/vendor/faiss/faiss/utils/NeuralNet.h +3 -3
  166. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
  167. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
  168. data/vendor/faiss/faiss/utils/approx_topk/mode.h +1 -1
  169. data/vendor/faiss/faiss/utils/distances.cpp +0 -3
  170. data/vendor/faiss/faiss/utils/distances.h +2 -2
  171. data/vendor/faiss/faiss/utils/extra_distances-inl.h +3 -1
  172. data/vendor/faiss/faiss/utils/hamming-inl.h +2 -0
  173. data/vendor/faiss/faiss/utils/hamming.cpp +7 -6
  174. data/vendor/faiss/faiss/utils/hamming.h +1 -1
  175. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -2
  176. data/vendor/faiss/faiss/utils/partitioning.cpp +5 -5
  177. data/vendor/faiss/faiss/utils/partitioning.h +2 -2
  178. data/vendor/faiss/faiss/utils/rabitq_simd.h +222 -336
  179. data/vendor/faiss/faiss/utils/random.cpp +1 -1
  180. data/vendor/faiss/faiss/utils/simdlib_avx2.h +1 -1
  181. data/vendor/faiss/faiss/utils/simdlib_avx512.h +1 -1
  182. data/vendor/faiss/faiss/utils/simdlib_neon.h +2 -2
  183. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +1 -1
  184. data/vendor/faiss/faiss/utils/utils.cpp +9 -2
  185. data/vendor/faiss/faiss/utils/utils.h +2 -2
  186. metadata +29 -1
@@ -0,0 +1,731 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/IndexRaBitQFastScan.h>
9
+ #include <faiss/impl/FastScanDistancePostProcessing.h>
10
+ #include <faiss/impl/RaBitQUtils.h>
11
+ #include <faiss/impl/RaBitQuantizerMultiBit.h>
12
+ #include <faiss/impl/pq4_fast_scan.h>
13
+ #include <faiss/utils/utils.h>
14
+ #include <algorithm>
15
+ #include <cmath>
16
+
17
+ namespace faiss {
18
+
19
+ static inline size_t roundup(size_t a, size_t b) {
20
+ return (a + b - 1) / b * b;
21
+ }
22
+
23
+ size_t IndexRaBitQFastScan::compute_per_vector_storage_size() const {
24
+ const size_t ex_bits = rabitq.nb_bits - 1;
25
+
26
+ if (ex_bits == 0) {
27
+ // 1-bit: only SignBitFactors
28
+ return sizeof(rabitq_utils::SignBitFactors);
29
+ } else {
30
+ // Multi-bit: SignBitFactorsWithError + ExtraBitsFactors +
31
+ // mag-codes
32
+ return sizeof(SignBitFactorsWithError) + sizeof(ExtraBitsFactors) +
33
+ (d * ex_bits + 7) / 8;
34
+ }
35
+ }
36
+
37
+ IndexRaBitQFastScan::IndexRaBitQFastScan() = default;
38
+
39
+ IndexRaBitQFastScan::IndexRaBitQFastScan(
40
+ idx_t d,
41
+ MetricType metric,
42
+ int bbs,
43
+ uint8_t nb_bits)
44
+ : rabitq(d, metric, nb_bits) {
45
+ // RaBitQ-specific validation
46
+ FAISS_THROW_IF_NOT_MSG(d > 0, "Dimension must be positive");
47
+ FAISS_THROW_IF_NOT_MSG(
48
+ metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT,
49
+ "RaBitQ FastScan only supports L2 and Inner Product metrics");
50
+ FAISS_THROW_IF_NOT_MSG(
51
+ nb_bits >= 1 && nb_bits <= 9, "nb_bits must be between 1 and 9");
52
+
53
+ // RaBitQ uses 1 bit per dimension packed into 4-bit FastScan sub-quantizers
54
+ // Each FastScan sub-quantizer handles 4 RaBitQ dimensions
55
+ const size_t M_fastscan = (d + 3) / 4;
56
+ constexpr size_t nbits_fastscan = 4;
57
+
58
+ // init_fastscan will validate bbs % 32 == 0 and nbits_fastscan == 4
59
+ init_fastscan(static_cast<int>(d), M_fastscan, nbits_fastscan, metric, bbs);
60
+
61
+ // Compute code_size directly using RaBitQuantizer
62
+ code_size = rabitq.compute_code_size(d, nb_bits);
63
+
64
+ // Set RaBitQ-specific parameters
65
+ qb = 8;
66
+ center.resize(d, 0.0f);
67
+
68
+ // Initialize empty flat storage
69
+ flat_storage.clear();
70
+ }
71
+
72
+ IndexRaBitQFastScan::IndexRaBitQFastScan(const IndexRaBitQ& orig, int bbs)
73
+ : rabitq(orig.rabitq) {
74
+ // RaBitQ-specific validation
75
+ FAISS_THROW_IF_NOT_MSG(orig.d > 0, "Dimension must be positive");
76
+ FAISS_THROW_IF_NOT_MSG(
77
+ orig.metric_type == METRIC_L2 ||
78
+ orig.metric_type == METRIC_INNER_PRODUCT,
79
+ "RaBitQ FastScan only supports L2 and Inner Product metrics");
80
+
81
+ // RaBitQ uses 1 bit per dimension packed into 4-bit FastScan sub-quantizers
82
+ // Each FastScan sub-quantizer handles 4 RaBitQ dimensions
83
+ const size_t M_fastscan = (orig.d + 3) / 4;
84
+ constexpr size_t nbits_fastscan = 4;
85
+
86
+ // Initialize FastScan base with the original index's parameters
87
+ init_fastscan(
88
+ static_cast<int>(orig.d),
89
+ M_fastscan,
90
+ nbits_fastscan,
91
+ orig.metric_type,
92
+ bbs);
93
+
94
+ code_size = rabitq.compute_code_size(d, rabitq.nb_bits);
95
+
96
+ // Copy properties from original index
97
+ ntotal = orig.ntotal;
98
+ ntotal2 = roundup(ntotal, bbs);
99
+ is_trained = orig.is_trained;
100
+ orig_codes = orig.codes.data();
101
+ qb = orig.qb;
102
+ centered = orig.centered;
103
+ center = orig.center;
104
+
105
+ // If the original index has data, extract factors and pack codes
106
+ if (ntotal > 0) {
107
+ // Compute per-vector storage size for flat storage
108
+ const size_t storage_size = compute_per_vector_storage_size();
109
+
110
+ // Allocate flat storage
111
+ flat_storage.resize(ntotal * storage_size);
112
+
113
+ // Copy factors directly from original codes
114
+ const size_t bit_pattern_size = (d + 7) / 8;
115
+ for (idx_t i = 0; i < ntotal; i++) {
116
+ const uint8_t* orig_code = orig.codes.data() + i * orig.code_size;
117
+ const uint8_t* source_factors_ptr = orig_code + bit_pattern_size;
118
+ uint8_t* storage = flat_storage.data() + i * storage_size;
119
+ memcpy(storage, source_factors_ptr, storage_size);
120
+ }
121
+
122
+ // Convert RaBitQ bit format to FastScan 4-bit sub-quantizer format
123
+ // This follows the same pattern as IndexPQFastScan constructor
124
+ AlignedTable<uint8_t> fastscan_codes(ntotal * code_size);
125
+ memset(fastscan_codes.get(), 0, ntotal * code_size);
126
+
127
+ // Convert from RaBitQ 1-bit-per-dimension to FastScan
128
+ // 4-bit-per-sub-quantizer
129
+ for (idx_t i = 0; i < ntotal; i++) {
130
+ const uint8_t* orig_code = orig.codes.data() + i * orig.code_size;
131
+ uint8_t* fs_code = fastscan_codes.get() + i * code_size;
132
+
133
+ // Convert each dimension's bit (same logic as compute_codes)
134
+ for (size_t j = 0; j < orig.d; j++) {
135
+ // Extract bit from original RaBitQ format
136
+ const size_t orig_byte_idx = j / 8;
137
+ const size_t orig_bit_offset = j % 8;
138
+ const bool bit_value =
139
+ (orig_code[orig_byte_idx] >> orig_bit_offset) & 1;
140
+
141
+ // Use RaBitQUtils for consistent bit setting
142
+ if (bit_value) {
143
+ rabitq_utils::set_bit_fastscan(fs_code, j);
144
+ }
145
+ }
146
+ }
147
+
148
+ // Pack the converted codes using pq4_pack_codes with custom stride
149
+ codes.resize(ntotal2 * M2 / 2);
150
+ pq4_pack_codes(
151
+ fastscan_codes.get(),
152
+ ntotal,
153
+ M,
154
+ ntotal2,
155
+ bbs,
156
+ M2,
157
+ codes.get(),
158
+ code_size);
159
+ }
160
+ }
161
+
162
+ void IndexRaBitQFastScan::train(idx_t n, const float* x) {
163
+ // compute a centroid
164
+ std::vector<float> centroid(d, 0);
165
+ for (int64_t i = 0; i < static_cast<int64_t>(n); i++) {
166
+ for (size_t j = 0; j < d; j++) {
167
+ centroid[j] += x[i * d + j];
168
+ }
169
+ }
170
+
171
+ if (n != 0) {
172
+ for (size_t j = 0; j < d; j++) {
173
+ centroid[j] /= (float)n;
174
+ }
175
+ }
176
+
177
+ center = std::move(centroid);
178
+
179
+ rabitq.train(n, x);
180
+ is_trained = true;
181
+ }
182
+
183
+ void IndexRaBitQFastScan::add(idx_t n, const float* x) {
184
+ FAISS_THROW_IF_NOT(is_trained);
185
+
186
+ // Handle blocking to avoid excessive allocations
187
+ constexpr idx_t bs = 65536;
188
+ if (n > bs) {
189
+ for (idx_t i0 = 0; i0 < n; i0 += bs) {
190
+ idx_t i1 = std::min(n, i0 + bs);
191
+ if (verbose) {
192
+ printf("IndexRaBitQFastScan::add %zd/%zd\n",
193
+ size_t(i1),
194
+ size_t(n));
195
+ }
196
+ add(i1 - i0, x + i0 * d);
197
+ }
198
+ return;
199
+ }
200
+ InterruptCallback::check();
201
+
202
+ // Create codes with embedded factors using our compute_codes
203
+ AlignedTable<uint8_t> tmp_codes(n * code_size);
204
+ compute_codes(tmp_codes.get(), n, x);
205
+
206
+ const size_t storage_size = compute_per_vector_storage_size();
207
+ flat_storage.resize((ntotal + n) * storage_size);
208
+
209
+ // Populate flat storage (no sign bits copying needed!)
210
+ const size_t bit_pattern_size = (d + 7) / 8;
211
+ for (idx_t i = 0; i < n; i++) {
212
+ const uint8_t* code = tmp_codes.get() + i * code_size;
213
+ const idx_t vec_idx = ntotal + i;
214
+
215
+ // Copy factors data directly to flat storage (no reordering needed)
216
+ const uint8_t* source_factors_ptr = code + bit_pattern_size;
217
+ uint8_t* storage = flat_storage.data() + vec_idx * storage_size;
218
+ memcpy(storage, source_factors_ptr, storage_size);
219
+ }
220
+
221
+ // Resize main storage (same logic as parent)
222
+ ntotal2 = roundup(ntotal + n, bbs);
223
+ size_t new_size = ntotal2 * M2 / 2; // assume nbits = 4
224
+ size_t old_size = codes.size();
225
+ if (new_size > old_size) {
226
+ codes.resize(new_size);
227
+ memset(codes.get() + old_size, 0, new_size - old_size);
228
+ }
229
+
230
+ // Use our custom packing function with correct stride
231
+ pq4_pack_codes_range(
232
+ tmp_codes.get(),
233
+ M, // Number of sub-quantizers (bit patterns only)
234
+ ntotal,
235
+ ntotal + n, // Range to pack
236
+ bbs,
237
+ M2, // Block parameters
238
+ codes.get(), // Output
239
+ code_size); // CUSTOM STRIDE: includes factor space
240
+
241
+ ntotal += n;
242
+ }
243
+
244
+ void IndexRaBitQFastScan::compute_codes(uint8_t* codes, idx_t n, const float* x)
245
+ const {
246
+ FAISS_ASSERT(codes != nullptr);
247
+ FAISS_ASSERT(x != nullptr);
248
+ FAISS_ASSERT(
249
+ (metric_type == MetricType::METRIC_L2 ||
250
+ metric_type == MetricType::METRIC_INNER_PRODUCT));
251
+ if (n == 0) {
252
+ return;
253
+ }
254
+
255
+ // Hoist loop-invariant computations
256
+ const float* centroid_data = center.data();
257
+ const size_t bit_pattern_size = (d + 7) / 8;
258
+ const size_t ex_bits = rabitq.nb_bits - 1;
259
+ const size_t ex_code_size = (d * ex_bits + 7) / 8;
260
+
261
+ memset(codes, 0, n * code_size);
262
+
263
+ #pragma omp parallel for if (n > 1000)
264
+ for (int64_t i = 0; i < n; i++) {
265
+ uint8_t* const code = codes + i * code_size;
266
+ const float* const x_row = x + i * d;
267
+
268
+ // Compute residual once, reuse for both sign bits and ex-bits
269
+ std::vector<float> residual(d);
270
+ for (size_t j = 0; j < d; j++) {
271
+ const float centroid_val = centroid_data ? centroid_data[j] : 0.0f;
272
+ residual[j] = x_row[j] - centroid_val;
273
+ }
274
+
275
+ // Pack sign bits directly into FastScan format using precomputed
276
+ // residual
277
+ for (size_t j = 0; j < d; j++) {
278
+ if (residual[j] > 0.0f) {
279
+ rabitq_utils::set_bit_fastscan(code, j);
280
+ }
281
+ }
282
+
283
+ SignBitFactorsWithError factors = rabitq_utils::compute_vector_factors(
284
+ x_row, d, centroid_data, metric_type, ex_bits > 0);
285
+
286
+ if (ex_bits == 0) {
287
+ // 1-bit: store only SignBitFactors (8 bytes)
288
+ memcpy(code + bit_pattern_size, &factors, sizeof(SignBitFactors));
289
+ } else {
290
+ // Multi-bit: store full SignBitFactorsWithError (12 bytes)
291
+ memcpy(code + bit_pattern_size,
292
+ &factors,
293
+ sizeof(SignBitFactorsWithError));
294
+
295
+ // Add mag-codes and ExtraBitsFactors using precomputed
296
+ // residual
297
+ uint8_t* ex_code =
298
+ code + bit_pattern_size + sizeof(SignBitFactorsWithError);
299
+ ExtraBitsFactors ex_factors_temp;
300
+
301
+ rabitq_multibit::quantize_ex_bits(
302
+ residual.data(),
303
+ d,
304
+ rabitq.nb_bits,
305
+ ex_code,
306
+ ex_factors_temp,
307
+ metric_type,
308
+ centroid_data);
309
+
310
+ memcpy(ex_code + ex_code_size,
311
+ &ex_factors_temp,
312
+ sizeof(ExtraBitsFactors));
313
+ }
314
+ }
315
+ }
316
+
317
+ void IndexRaBitQFastScan::compute_float_LUT(
318
+ float* lut,
319
+ idx_t n,
320
+ const float* x,
321
+ const FastScanDistancePostProcessing& context) const {
322
+ FAISS_THROW_IF_NOT(is_trained);
323
+
324
+ // Pre-allocate working buffers to avoid repeated allocations
325
+ std::vector<float> rotated_q(d);
326
+ std::vector<uint8_t> rotated_qq(d);
327
+
328
+ // Compute lookup tables for FastScan SIMD operations
329
+ // For each query vector, computes distance contributions for all
330
+ // possible 4-bit codes per sub-quantizer. Also computes and stores
331
+ // query factors for distance reconstruction.
332
+ for (idx_t i = 0; i < n; i++) {
333
+ const float* query = x + i * d;
334
+
335
+ // Compute query factors and store in array if available
336
+ rabitq_utils::QueryFactorsData query_factors_data =
337
+ rabitq_utils::compute_query_factors(
338
+ query,
339
+ d,
340
+ center.data(),
341
+ qb,
342
+ centered,
343
+ metric_type,
344
+ rotated_q,
345
+ rotated_qq);
346
+
347
+ // Store query factors in context array if provided
348
+ if (context.query_factors != nullptr) {
349
+ query_factors_data.rotated_q = rotated_q;
350
+ context.query_factors[i] = query_factors_data;
351
+ }
352
+
353
+ // Create lookup table storing distance contributions for all possible
354
+ // 4-bit codes per sub-quantizer for FastScan SIMD operations
355
+ float* query_lut = lut + i * M * 16;
356
+
357
+ if (centered) {
358
+ // For centered mode, we use the signed odd integer quantization
359
+ // scheme.
360
+ // Formula:
361
+ // int_dot = ((1 << qb) - 1) * d - 2 * xor_dot_product
362
+ // We precompute the XOR contribution for each
363
+ // sub-quantizer
364
+
365
+ const float max_code_value = (1 << qb) - 1;
366
+
367
+ for (size_t m = 0; m < M; m++) {
368
+ const size_t dim_start = m * 4;
369
+
370
+ for (int code_val = 0; code_val < 16; code_val++) {
371
+ float xor_contribution = 0.0f;
372
+
373
+ // Process 4 bits per sub-quantizer
374
+ for (size_t dim_offset = 0; dim_offset < 4; dim_offset++) {
375
+ const size_t dim_idx = dim_start + dim_offset;
376
+
377
+ if (dim_idx < d) {
378
+ const bool db_bit = (code_val >> dim_offset) & 1;
379
+ const float query_value = rotated_qq[dim_idx];
380
+
381
+ // XOR contribution:
382
+ // If db_bit == 0: XOR result = query_value
383
+ // If db_bit == 1: XOR result = (2^qb - 1) -
384
+ // query_value
385
+ xor_contribution += db_bit
386
+ ? (max_code_value - query_value)
387
+ : query_value;
388
+ }
389
+ }
390
+
391
+ // Store the XOR contribution (will be scaled by -2 *
392
+ // int_dot_scale during distance computation)
393
+ query_lut[m * 16 + code_val] = xor_contribution;
394
+ }
395
+ }
396
+
397
+ } else {
398
+ // For non-centered quantization, use traditional AND dot
399
+ // product Compute lookup table entries by processing popcount
400
+ // and inner product together
401
+ for (size_t m = 0; m < M; m++) {
402
+ const size_t dim_start = m * 4;
403
+
404
+ for (int code_val = 0; code_val < 16; code_val++) {
405
+ float inner_product = 0.0f;
406
+ int popcount = 0;
407
+
408
+ // Process 4 bits per sub-quantizer
409
+ for (size_t dim_offset = 0; dim_offset < 4; dim_offset++) {
410
+ const size_t dim_idx = dim_start + dim_offset;
411
+
412
+ if (dim_idx < d && ((code_val >> dim_offset) & 1)) {
413
+ inner_product += rotated_qq[dim_idx];
414
+ popcount++;
415
+ }
416
+ }
417
+
418
+ // Store pre-computed distance contribution
419
+ query_lut[m * 16 + code_val] =
420
+ query_factors_data.c1 * inner_product +
421
+ query_factors_data.c2 * popcount;
422
+ }
423
+ }
424
+ }
425
+ }
426
+ }
427
+
428
+ void IndexRaBitQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
429
+ const {
430
+ const float* centroid_in =
431
+ (center.data() == nullptr) ? nullptr : center.data();
432
+ const uint8_t* codes = bytes;
433
+ FAISS_ASSERT(codes != nullptr);
434
+ FAISS_ASSERT(x != nullptr);
435
+
436
+ const float inv_d_sqrt = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
437
+ const size_t bit_pattern_size = (d + 7) / 8;
438
+
439
+ #pragma omp parallel for if (n > 1000)
440
+ for (int64_t i = 0; i < n; i++) {
441
+ // Access code using correct FastScan format
442
+ const uint8_t* code = codes + i * code_size;
443
+
444
+ // Extract factors directly from embedded codes
445
+ const uint8_t* factors_ptr = code + bit_pattern_size;
446
+ const rabitq_utils::SignBitFactors* fac =
447
+ reinterpret_cast<const rabitq_utils::SignBitFactors*>(
448
+ factors_ptr);
449
+
450
+ for (size_t j = 0; j < d; j++) {
451
+ // Use RaBitQUtils for consistent bit extraction
452
+ bool bit_value = rabitq_utils::extract_bit_fastscan(code, j);
453
+ float bit = bit_value ? 1.0f : 0.0f;
454
+
455
+ // Compute the output using RaBitQ reconstruction formula
456
+ x[i * d + j] = (bit - 0.5f) * fac->dp_multiplier * 2 * inv_d_sqrt +
457
+ ((centroid_in == nullptr) ? 0 : centroid_in[j]);
458
+ }
459
+ }
460
+ }
461
+
462
+ void IndexRaBitQFastScan::search(
463
+ idx_t n,
464
+ const float* x,
465
+ idx_t k,
466
+ float* distances,
467
+ idx_t* labels,
468
+ const SearchParameters* params) const {
469
+ FAISS_THROW_IF_NOT_MSG(
470
+ !params, "search params not supported for this index");
471
+
472
+ // Create query factors array on stack - memory managed by caller
473
+ std::vector<rabitq_utils::QueryFactorsData> query_factors_storage(n);
474
+
475
+ // Use the faster search_dispatch_implem flow from IndexFastScan
476
+ // Pass the query factors array - factors will be computed during LUT
477
+ // computation
478
+ FastScanDistancePostProcessing context;
479
+ context.query_factors = query_factors_storage.data();
480
+ if (metric_type == METRIC_L2) {
481
+ search_dispatch_implem<true>(n, x, k, distances, labels, context);
482
+ } else {
483
+ search_dispatch_implem<false>(n, x, k, distances, labels, context);
484
+ }
485
+ }
486
+
487
+ // Template implementations for RaBitQHeapHandler
488
+ template <class C, bool with_id_map>
489
+ RaBitQHeapHandler<C, with_id_map>::RaBitQHeapHandler(
490
+ const IndexRaBitQFastScan* index,
491
+ size_t nq_val,
492
+ size_t k_val,
493
+ float* distances,
494
+ int64_t* labels,
495
+ const IDSelector* sel_in,
496
+ const FastScanDistancePostProcessing& ctx,
497
+ bool multi_bit)
498
+ : RHC(nq_val, index->ntotal, sel_in),
499
+ rabitq_index(index),
500
+ heap_distances(distances),
501
+ heap_labels(labels),
502
+ nq(nq_val),
503
+ k(k_val),
504
+ context(ctx),
505
+ is_multi_bit(multi_bit) {
506
+ // Initialize heaps for all queries in constructor
507
+ // This allows us to support direct normalizer assignment
508
+ #pragma omp parallel for if (nq > 100)
509
+ for (int64_t q = 0; q < static_cast<int64_t>(nq); q++) {
510
+ float* heap_dis = heap_distances + q * k;
511
+ int64_t* heap_ids = heap_labels + q * k;
512
+ heap_heapify<Cfloat>(k, heap_dis, heap_ids);
513
+ }
514
+ }
515
+
516
+ template <class C, bool with_id_map>
517
+ void RaBitQHeapHandler<C, with_id_map>::handle(
518
+ size_t q,
519
+ size_t b,
520
+ simd16uint16 d0,
521
+ simd16uint16 d1) {
522
+ ALIGNED(32) uint16_t d32tab[32];
523
+ d0.store(d32tab);
524
+ d1.store(d32tab + 16);
525
+
526
+ // Get heap pointers and query factors (computed once per batch)
527
+ float* const heap_dis = heap_distances + q * k;
528
+ int64_t* const heap_ids = heap_labels + q * k;
529
+
530
+ // Access query factors from query_factors pointer
531
+ rabitq_utils::QueryFactorsData query_factors_data = {};
532
+ if (context.query_factors != nullptr) {
533
+ query_factors_data = context.query_factors[q];
534
+ }
535
+
536
+ // Compute normalizers once per batch
537
+ const float one_a = normalizers ? (1.0f / normalizers[2 * q]) : 1.0f;
538
+ const float bias = normalizers ? normalizers[2 * q + 1] : 0.0f;
539
+
540
+ // Compute loop bounds to avoid redundant bounds checking
541
+ const size_t base_db_idx = this->j0 + b * 32;
542
+ const size_t max_vectors = (base_db_idx < rabitq_index->ntotal)
543
+ ? std::min<size_t>(32, rabitq_index->ntotal - base_db_idx)
544
+ : 0;
545
+
546
+ // Get storage size once
547
+ const size_t storage_size = rabitq_index->compute_per_vector_storage_size();
548
+
549
+ // Stats tracking for multi-bit two-stage search only
550
+ // n_1bit_evaluations: candidates evaluated using 1-bit lower bound
551
+ // n_multibit_evaluations: candidates requiring full multi-bit distance
552
+ size_t local_1bit_evaluations = 0;
553
+ size_t local_multibit_evaluations = 0;
554
+
555
+ // Process distances in batch
556
+ for (size_t i = 0; i < max_vectors; i++) {
557
+ const size_t db_idx = base_db_idx + i;
558
+
559
+ // Normalize distance from LUT lookup
560
+ const float normalized_distance = d32tab[i] * one_a + bias;
561
+
562
+ // Access factors from flat storage
563
+ const uint8_t* base_ptr =
564
+ rabitq_index->flat_storage.data() + db_idx * storage_size;
565
+
566
+ if (is_multi_bit) {
567
+ // Track candidates actually considered for two-stage filtering
568
+ local_1bit_evaluations++;
569
+
570
+ const SignBitFactorsWithError& full_factors =
571
+ *reinterpret_cast<const SignBitFactorsWithError*>(base_ptr);
572
+
573
+ float dist_1bit = rabitq_utils::compute_1bit_adjusted_distance(
574
+ normalized_distance,
575
+ full_factors,
576
+ query_factors_data,
577
+ rabitq_index->centered,
578
+ rabitq_index->qb,
579
+ rabitq_index->d);
580
+
581
+ float lower_bound = compute_lower_bound(dist_1bit, db_idx, q);
582
+
583
+ // Adaptive filtering: decide whether to compute full distance
584
+ const bool is_similarity = rabitq_index->metric_type ==
585
+ MetricType::METRIC_INNER_PRODUCT;
586
+ bool should_refine = is_similarity
587
+ ? (lower_bound > heap_dis[0]) // IP: keep if better
588
+ : (lower_bound < heap_dis[0]); // L2: keep if better
589
+
590
+ if (should_refine) {
591
+ local_multibit_evaluations++;
592
+ float dist_full = compute_full_multibit_distance(db_idx, q);
593
+
594
+ if (Cfloat::cmp(heap_dis[0], dist_full)) {
595
+ heap_replace_top<Cfloat>(
596
+ k, heap_dis, heap_ids, dist_full, db_idx);
597
+ }
598
+ }
599
+ } else {
600
+ const rabitq_utils::SignBitFactors& db_factors =
601
+ *reinterpret_cast<const rabitq_utils::SignBitFactors*>(
602
+ base_ptr);
603
+
604
+ float adjusted_distance =
605
+ rabitq_utils::compute_1bit_adjusted_distance(
606
+ normalized_distance,
607
+ db_factors,
608
+ query_factors_data,
609
+ rabitq_index->centered,
610
+ rabitq_index->qb,
611
+ rabitq_index->d);
612
+
613
+ // Add to heap if better than current worst
614
+ if (Cfloat::cmp(heap_dis[0], adjusted_distance)) {
615
+ heap_replace_top<Cfloat>(
616
+ k, heap_dis, heap_ids, adjusted_distance, db_idx);
617
+ }
618
+ }
619
+ }
620
+
621
+ // Update global stats atomically
622
+ #pragma omp atomic
623
+ rabitq_stats.n_1bit_evaluations += local_1bit_evaluations;
624
+ #pragma omp atomic
625
+ rabitq_stats.n_multibit_evaluations += local_multibit_evaluations;
626
+ }
627
+
628
+ template <class C, bool with_id_map>
629
+ void RaBitQHeapHandler<C, with_id_map>::begin(const float* norms) {
630
+ normalizers = norms;
631
+ // Heap initialization is now done in constructor
632
+ }
633
+
634
+ template <class C, bool with_id_map>
635
+ void RaBitQHeapHandler<C, with_id_map>::end() {
636
+ // Reorder final results
637
+ #pragma omp parallel for if (nq > 100)
638
+ for (int64_t q = 0; q < static_cast<int64_t>(nq); q++) {
639
+ float* heap_dis = heap_distances + q * k;
640
+ int64_t* heap_ids = heap_labels + q * k;
641
+ heap_reorder<Cfloat>(k, heap_dis, heap_ids);
642
+ }
643
+ }
644
+
645
+ template <class C, bool with_id_map>
646
+ float RaBitQHeapHandler<C, with_id_map>::compute_lower_bound(
647
+ float dist_1bit,
648
+ size_t db_idx,
649
+ size_t q) const {
650
+ // Access f_error directly from SignBitFactorsWithError in flat storage
651
+ const size_t storage_size = rabitq_index->compute_per_vector_storage_size();
652
+ const uint8_t* base_ptr =
653
+ rabitq_index->flat_storage.data() + db_idx * storage_size;
654
+ const SignBitFactorsWithError& db_factors =
655
+ *reinterpret_cast<const SignBitFactorsWithError*>(base_ptr);
656
+ float f_error = db_factors.f_error;
657
+
658
+ // Get g_error from query factors (query-dependent error term)
659
+ float g_error = 0.0f;
660
+ if (context.query_factors != nullptr) {
661
+ g_error = context.query_factors[q].g_error;
662
+ }
663
+
664
+ // Compute error adjustment: f_error * g_error
665
+ float error_adjustment = f_error * g_error;
666
+
667
+ return dist_1bit - error_adjustment;
668
+ }
669
+
670
+ template <class C, bool with_id_map>
671
+ float RaBitQHeapHandler<C, with_id_map>::compute_full_multibit_distance(
672
+ size_t db_idx,
673
+ size_t q) const {
674
+ const size_t ex_bits = rabitq_index->rabitq.nb_bits - 1;
675
+ const size_t dim = rabitq_index->d;
676
+
677
+ const size_t storage_size = rabitq_index->compute_per_vector_storage_size();
678
+ const uint8_t* base_ptr =
679
+ rabitq_index->flat_storage.data() + db_idx * storage_size;
680
+
681
+ const size_t ex_code_size = (dim * ex_bits + 7) / 8;
682
+ const uint8_t* ex_code = base_ptr + sizeof(SignBitFactorsWithError);
683
+ const ExtraBitsFactors& ex_fac = *reinterpret_cast<const ExtraBitsFactors*>(
684
+ base_ptr + sizeof(SignBitFactorsWithError) + ex_code_size);
685
+
686
+ // Get query factors reference (avoid copying)
687
+ const rabitq_utils::QueryFactorsData& query_factors =
688
+ context.query_factors[q];
689
+
690
+ // Get sign bits from FastScan packed format
691
+ std::vector<uint8_t> unpacked_code(rabitq_index->code_size);
692
+ CodePackerPQ4 packer(rabitq_index->M2, rabitq_index->bbs);
693
+ packer.unpack_1(rabitq_index->codes.get(), db_idx, unpacked_code.data());
694
+ const uint8_t* sign_bits = unpacked_code.data();
695
+
696
+ return rabitq_utils::compute_full_multibit_distance(
697
+ sign_bits,
698
+ ex_code,
699
+ ex_fac,
700
+ query_factors.rotated_q.data(),
701
+ query_factors.qr_to_c_L2sqr,
702
+ query_factors.qr_norm_L2sqr,
703
+ dim,
704
+ ex_bits,
705
+ rabitq_index->metric_type);
706
+ }
707
+
708
+ // Implementation of virtual make_knn_handler method
709
+ SIMDResultHandlerToFloat* IndexRaBitQFastScan::make_knn_handler(
710
+ bool is_max,
711
+ int /*impl*/,
712
+ idx_t n,
713
+ idx_t k,
714
+ size_t /*ntotal*/,
715
+ float* distances,
716
+ idx_t* labels,
717
+ const IDSelector* sel,
718
+ const FastScanDistancePostProcessing& context) const {
719
+ // Use runtime boolean for multi-bit mode
720
+ const bool multi_bit = rabitq.nb_bits > 1;
721
+
722
+ if (is_max) {
723
+ return new RaBitQHeapHandler<CMax<uint16_t, int>, false>(
724
+ this, n, k, distances, labels, sel, context, multi_bit);
725
+ } else {
726
+ return new RaBitQHeapHandler<CMin<uint16_t, int>, false>(
727
+ this, n, k, distances, labels, sel, context, multi_bit);
728
+ }
729
+ }
730
+
731
+ } // namespace faiss