faiss 0.4.3 → 0.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +10 -0
  3. data/README.md +2 -0
  4. data/ext/faiss/index.cpp +33 -6
  5. data/ext/faiss/index_binary.cpp +17 -4
  6. data/ext/faiss/kmeans.cpp +6 -6
  7. data/lib/faiss/version.rb +1 -1
  8. data/vendor/faiss/faiss/AutoTune.cpp +2 -3
  9. data/vendor/faiss/faiss/AutoTune.h +1 -1
  10. data/vendor/faiss/faiss/Clustering.cpp +2 -2
  11. data/vendor/faiss/faiss/Clustering.h +2 -2
  12. data/vendor/faiss/faiss/IVFlib.cpp +26 -51
  13. data/vendor/faiss/faiss/IVFlib.h +1 -1
  14. data/vendor/faiss/faiss/Index.cpp +11 -0
  15. data/vendor/faiss/faiss/Index.h +34 -11
  16. data/vendor/faiss/faiss/Index2Layer.cpp +1 -1
  17. data/vendor/faiss/faiss/Index2Layer.h +2 -2
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +1 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +9 -4
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +5 -1
  21. data/vendor/faiss/faiss/IndexBinary.h +7 -7
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +1 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +8 -2
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +3 -3
  26. data/vendor/faiss/faiss/IndexBinaryHash.h +5 -5
  27. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +7 -6
  28. data/vendor/faiss/faiss/IndexFastScan.cpp +125 -49
  29. data/vendor/faiss/faiss/IndexFastScan.h +102 -7
  30. data/vendor/faiss/faiss/IndexFlat.cpp +374 -4
  31. data/vendor/faiss/faiss/IndexFlat.h +81 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +93 -2
  33. data/vendor/faiss/faiss/IndexHNSW.h +58 -2
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +14 -13
  35. data/vendor/faiss/faiss/IndexIDMap.h +6 -6
  36. data/vendor/faiss/faiss/IndexIVF.cpp +1 -1
  37. data/vendor/faiss/faiss/IndexIVF.h +5 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +1 -1
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +9 -3
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +3 -1
  41. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +176 -90
  42. data/vendor/faiss/faiss/IndexIVFFastScan.h +173 -18
  43. data/vendor/faiss/faiss/IndexIVFFlat.cpp +1 -0
  44. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +251 -0
  45. data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +64 -0
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -1
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +134 -2
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -1
  50. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +99 -8
  51. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +4 -1
  52. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +828 -0
  53. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +252 -0
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -1
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
  56. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -1
  57. data/vendor/faiss/faiss/IndexNSG.cpp +1 -1
  58. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +1 -1
  59. data/vendor/faiss/faiss/IndexPQ.cpp +4 -1
  60. data/vendor/faiss/faiss/IndexPQ.h +1 -1
  61. data/vendor/faiss/faiss/IndexPQFastScan.cpp +6 -2
  62. data/vendor/faiss/faiss/IndexPQFastScan.h +5 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +14 -0
  64. data/vendor/faiss/faiss/IndexPreTransform.h +9 -0
  65. data/vendor/faiss/faiss/IndexRaBitQ.cpp +96 -13
  66. data/vendor/faiss/faiss/IndexRaBitQ.h +11 -2
  67. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +731 -0
  68. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +175 -0
  69. data/vendor/faiss/faiss/IndexRefine.cpp +49 -0
  70. data/vendor/faiss/faiss/IndexRefine.h +17 -0
  71. data/vendor/faiss/faiss/IndexShards.cpp +1 -1
  72. data/vendor/faiss/faiss/MatrixStats.cpp +3 -3
  73. data/vendor/faiss/faiss/MetricType.h +1 -1
  74. data/vendor/faiss/faiss/VectorTransform.h +2 -2
  75. data/vendor/faiss/faiss/clone_index.cpp +5 -1
  76. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  77. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +3 -1
  78. data/vendor/faiss/faiss/gpu/GpuIndex.h +11 -11
  79. data/vendor/faiss/faiss/gpu/GpuIndexBinaryCagra.h +1 -1
  80. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +1 -1
  81. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +11 -7
  82. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +1 -1
  83. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -0
  84. data/vendor/faiss/faiss/gpu/test/TestGpuIcmEncoder.cpp +7 -0
  85. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +1 -1
  86. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +1 -1
  87. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -1
  88. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +2 -2
  89. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -1
  90. data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
  91. data/vendor/faiss/faiss/impl/DistanceComputer.h +77 -6
  92. data/vendor/faiss/faiss/impl/FastScanDistancePostProcessing.h +53 -0
  93. data/vendor/faiss/faiss/impl/HNSW.cpp +295 -16
  94. data/vendor/faiss/faiss/impl/HNSW.h +35 -6
  95. data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
  96. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  97. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +4 -4
  98. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +1 -1
  99. data/vendor/faiss/faiss/impl/LookupTableScaler.h +1 -1
  100. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  101. data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
  102. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  103. data/vendor/faiss/faiss/impl/Panorama.cpp +193 -0
  104. data/vendor/faiss/faiss/impl/Panorama.h +204 -0
  105. data/vendor/faiss/faiss/impl/PanoramaStats.cpp +33 -0
  106. data/vendor/faiss/faiss/impl/PanoramaStats.h +38 -0
  107. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +5 -5
  108. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +1 -1
  109. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  110. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -0
  111. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  112. data/vendor/faiss/faiss/impl/RaBitQStats.cpp +29 -0
  113. data/vendor/faiss/faiss/impl/RaBitQStats.h +56 -0
  114. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +294 -0
  115. data/vendor/faiss/faiss/impl/RaBitQUtils.h +330 -0
  116. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +304 -223
  117. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +72 -4
  118. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +362 -0
  119. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +112 -0
  120. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
  121. data/vendor/faiss/faiss/impl/ResultHandler.h +4 -4
  122. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +7 -10
  123. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +2 -4
  124. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +7 -4
  125. data/vendor/faiss/faiss/impl/index_read.cpp +238 -10
  126. data/vendor/faiss/faiss/impl/index_write.cpp +212 -19
  127. data/vendor/faiss/faiss/impl/io.cpp +2 -2
  128. data/vendor/faiss/faiss/impl/io.h +4 -4
  129. data/vendor/faiss/faiss/impl/kmeans1d.cpp +1 -1
  130. data/vendor/faiss/faiss/impl/kmeans1d.h +1 -1
  131. data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
  132. data/vendor/faiss/faiss/impl/mapped_io.cpp +2 -2
  133. data/vendor/faiss/faiss/impl/mapped_io.h +4 -3
  134. data/vendor/faiss/faiss/impl/maybe_owned_vector.h +8 -1
  135. data/vendor/faiss/faiss/impl/platform_macros.h +12 -0
  136. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +30 -4
  137. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +14 -8
  138. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +5 -6
  139. data/vendor/faiss/faiss/impl/simd_result_handlers.h +55 -11
  140. data/vendor/faiss/faiss/impl/svs_io.cpp +86 -0
  141. data/vendor/faiss/faiss/impl/svs_io.h +67 -0
  142. data/vendor/faiss/faiss/impl/zerocopy_io.h +1 -1
  143. data/vendor/faiss/faiss/index_factory.cpp +217 -8
  144. data/vendor/faiss/faiss/index_factory.h +1 -1
  145. data/vendor/faiss/faiss/index_io.h +1 -1
  146. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +1 -1
  147. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  148. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +115 -1
  149. data/vendor/faiss/faiss/invlists/InvertedLists.h +46 -0
  150. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
  151. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  152. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +261 -0
  153. data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +117 -0
  154. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +66 -0
  155. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +245 -0
  156. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +137 -0
  157. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +39 -0
  158. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +42 -0
  159. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +149 -0
  160. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +58 -0
  161. data/vendor/faiss/faiss/utils/AlignedTable.h +1 -1
  162. data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
  163. data/vendor/faiss/faiss/utils/Heap.h +3 -3
  164. data/vendor/faiss/faiss/utils/NeuralNet.cpp +1 -1
  165. data/vendor/faiss/faiss/utils/NeuralNet.h +3 -3
  166. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
  167. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
  168. data/vendor/faiss/faiss/utils/approx_topk/mode.h +1 -1
  169. data/vendor/faiss/faiss/utils/distances.cpp +0 -3
  170. data/vendor/faiss/faiss/utils/distances.h +2 -2
  171. data/vendor/faiss/faiss/utils/extra_distances-inl.h +3 -1
  172. data/vendor/faiss/faiss/utils/hamming-inl.h +2 -0
  173. data/vendor/faiss/faiss/utils/hamming.cpp +7 -6
  174. data/vendor/faiss/faiss/utils/hamming.h +1 -1
  175. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -2
  176. data/vendor/faiss/faiss/utils/partitioning.cpp +5 -5
  177. data/vendor/faiss/faiss/utils/partitioning.h +2 -2
  178. data/vendor/faiss/faiss/utils/rabitq_simd.h +222 -336
  179. data/vendor/faiss/faiss/utils/random.cpp +1 -1
  180. data/vendor/faiss/faiss/utils/simdlib_avx2.h +1 -1
  181. data/vendor/faiss/faiss/utils/simdlib_avx512.h +1 -1
  182. data/vendor/faiss/faiss/utils/simdlib_neon.h +2 -2
  183. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +1 -1
  184. data/vendor/faiss/faiss/utils/utils.cpp +9 -2
  185. data/vendor/faiss/faiss/utils/utils.h +2 -2
  186. metadata +29 -1
@@ -0,0 +1,252 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <vector>
11
+
12
+ #include <faiss/IndexIVFFastScan.h>
13
+ #include <faiss/IndexIVFRaBitQ.h>
14
+ #include <faiss/IndexRaBitQFastScan.h>
15
+ #include <faiss/impl/RaBitQStats.h>
16
+ #include <faiss/impl/RaBitQUtils.h>
17
+ #include <faiss/impl/RaBitQuantizer.h>
18
+ #include <faiss/impl/simd_result_handlers.h>
19
+ #include <faiss/utils/AlignedTable.h>
20
+ #include <faiss/utils/Heap.h>
21
+
22
+ namespace faiss {
23
+
24
+ // Forward declarations
25
+ struct FastScanDistancePostProcessing;
26
+
27
+ // Import shared utilities from RaBitQUtils
28
+ using rabitq_utils::QueryFactorsData;
29
+ using rabitq_utils::SignBitFactors;
30
+ using rabitq_utils::SignBitFactorsWithError;
31
+
32
+ /** Fast-scan version of IndexIVFRaBitQ that processes vectors in batches
33
+ * using SIMD operations. Combines the inverted file structure of IVF
34
+ * with RaBitQ's bit-level quantization and FastScan's batch processing.
35
+ *
36
+ * Key features:
37
+ * - Inherits from IndexIVFFastScan for IVF structure and search algorithms
38
+ * - Processes 32 database vectors at a time using SIMD
39
+ * - Separates factors from quantized bits for efficient processing
40
+ * - Supports both L2 and inner product metrics
41
+ * - Maintains compatibility with existing IVF search parameters
42
+ *
43
+ * Implementation details:
44
+ * - Batch size (bbs) is typically 32 for optimal SIMD performance
45
+ * - Factors are stored separately from packed codes for cache efficiency
46
+ * - Query factors are computed once per search and reused across lists
47
+ * - Uses specialized result handlers for RaBitQ distance corrections
48
+ */
49
+ struct IndexIVFRaBitQFastScan : IndexIVFFastScan {
50
+ RaBitQuantizer rabitq;
51
+
52
+ /// Default number of bits to quantize a query with
53
+ uint8_t qb = 8;
54
+
55
+ /// Use zero-centered scalar quantizer for queries
56
+ bool centered = false;
57
+
58
+ /// Per-vector auxiliary data (1-bit codes stored separately in `codes`)
59
+ ///
60
+ /// 1-bit codes (sign bits) are stored in the inherited `codes` array from
61
+ /// IndexFastScan in packed FastScan format for SIMD processing.
62
+ ///
63
+ /// This flat_storage holds per-vector factors and refinement-bit codes:
64
+ /// Layout for 1-bit: [SignBitFactors (8 bytes)]
65
+ /// Layout for multi-bit: [SignBitFactorsWithError
66
+ /// (12B)][ref_codes][ExtraBitsFactors (8B)]
67
+ std::vector<uint8_t> flat_storage;
68
+
69
+ // Constructors
70
+
71
+ IndexIVFRaBitQFastScan();
72
+
73
+ IndexIVFRaBitQFastScan(
74
+ Index* quantizer,
75
+ size_t d,
76
+ size_t nlist,
77
+ MetricType metric = METRIC_L2,
78
+ int bbs = 32,
79
+ bool own_invlists = true,
80
+ uint8_t nb_bits = 1);
81
+
82
+ /// Build from an existing IndexIVFRaBitQ
83
+ explicit IndexIVFRaBitQFastScan(const IndexIVFRaBitQ& orig, int bbs = 32);
84
+
85
+ // Required overrides
86
+
87
+ void train_encoder(idx_t n, const float* x, const idx_t* assign) override;
88
+
89
+ void encode_vectors(
90
+ idx_t n,
91
+ const float* x,
92
+ const idx_t* list_nos,
93
+ uint8_t* codes,
94
+ bool include_listnos = false) const override;
95
+
96
+ protected:
97
+ /// Extract and store RaBitQ factors from encoded vectors
98
+ void preprocess_code_metadata(
99
+ idx_t n,
100
+ const uint8_t* flat_codes,
101
+ idx_t start_global_idx) override;
102
+
103
+ /// Return code_size as stride to skip embedded factor data during packing
104
+ size_t code_packing_stride() const override;
105
+
106
+ public:
107
+ /// Reconstruct a single vector from an inverted list
108
+ void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
109
+ const override;
110
+
111
+ /// Override sa_decode to handle RaBitQ reconstruction
112
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
113
+
114
+ /// Compute storage size per vector in flat_storage based on nb_bits
115
+ size_t compute_per_vector_storage_size() const;
116
+
117
+ private:
118
+ /// Compute query factors and lookup table for a residual vector
119
+ /// (similar to IndexRaBitQFastScan::compute_float_LUT)
120
+ void compute_residual_LUT(
121
+ const float* residual,
122
+ QueryFactorsData& query_factors,
123
+ float* lut_out,
124
+ const float* original_query = nullptr) const;
125
+
126
+ /// Decode FastScan code to RaBitQ residual vector with explicit
127
+ /// dp_multiplier
128
+ void decode_fastscan_to_residual(
129
+ const uint8_t* fastscan_code,
130
+ float* residual,
131
+ float dp_multiplier) const;
132
+
133
+ public:
134
+ /// Implementation methods for IVFRaBitQFastScan specialization
135
+ bool lookup_table_is_3d() const override;
136
+
137
+ void compute_LUT(
138
+ size_t n,
139
+ const float* x,
140
+ const CoarseQuantized& cq,
141
+ AlignedTable<float>& dis_tables,
142
+ AlignedTable<float>& biases,
143
+ const FastScanDistancePostProcessing& context) const override;
144
+
145
+ void search_preassigned(
146
+ idx_t n,
147
+ const float* x,
148
+ idx_t k,
149
+ const idx_t* assign,
150
+ const float* centroid_dis,
151
+ float* distances,
152
+ idx_t* labels,
153
+ bool store_pairs,
154
+ const IVFSearchParameters* params = nullptr,
155
+ IndexIVFStats* stats = nullptr) const override;
156
+
157
+ /// Override to create RaBitQ-specific handlers
158
+ SIMDResultHandlerToFloat* make_knn_handler(
159
+ bool is_max,
160
+ int /* impl */,
161
+ idx_t n,
162
+ idx_t k,
163
+ float* distances,
164
+ idx_t* labels,
165
+ const IDSelector* sel,
166
+ const FastScanDistancePostProcessing& context,
167
+ const float* normalizers = nullptr) const override;
168
+
169
+ /** SIMD result handler for IndexIVFRaBitQFastScan that applies
170
+ * RaBitQ-specific distance corrections during batch processing.
171
+ *
172
+ * This handler processes batches of 32 distance computations from SIMD
173
+ * kernels, applies RaBitQ distance formula adjustments (factors and
174
+ * normalizers), and immediately updates result heaps. This eliminates the
175
+ * need for post-processing and provides significant performance benefits.
176
+ *
177
+ * Key optimizations:
178
+ * - Direct heap integration with no intermediate result storage
179
+ * - Batch-level computation of normalizers and query factors
180
+ * - Specialized handling for both centered and non-centered quantization
181
+ * modes
182
+ * - Efficient inner product metric corrections
183
+ * - Uses runtime boolean for multi-bit mode
184
+ *
185
+ * @tparam C Comparator type (CMin/CMax) for heap operations
186
+ */
187
+ template <class C>
188
+ struct IVFRaBitQHeapHandler
189
+ : simd_result_handlers::ResultHandlerCompare<C, true> {
190
+ const IndexIVFRaBitQFastScan* index;
191
+ float* heap_distances; // [nq * k]
192
+ int64_t* heap_labels; // [nq * k]
193
+ const size_t nq, k;
194
+ size_t current_list_no = 0;
195
+ std::vector<int>
196
+ probe_indices; // probe index for each query in current batch
197
+ const FastScanDistancePostProcessing*
198
+ context; // Processing context with query factors
199
+ const bool is_multibit; // Whether to use multi-bit two-stage search
200
+
201
+ // Use float-based comparator for heap operations
202
+ using Cfloat = typename std::conditional<
203
+ C::is_max,
204
+ CMax<float, int64_t>,
205
+ CMin<float, int64_t>>::type;
206
+
207
+ IVFRaBitQHeapHandler(
208
+ const IndexIVFRaBitQFastScan* idx,
209
+ size_t nq_val,
210
+ size_t k_val,
211
+ float* distances,
212
+ int64_t* labels,
213
+ const FastScanDistancePostProcessing* ctx = nullptr,
214
+ bool multibit = false);
215
+
216
+ void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1)
217
+ override;
218
+
219
+ /// Override base class virtual method to receive context information
220
+ void set_list_context(size_t list_no, const std::vector<int>& probe_map)
221
+ override;
222
+
223
+ void begin(const float* norms) override;
224
+
225
+ void end() override;
226
+
227
+ private:
228
+ /// Compute full multi-bit distance for a candidate vector (multi-bit
229
+ /// only)
230
+ /// @param db_idx Global database vector index
231
+ /// @param local_q Batch-local query index (for probe_indices access)
232
+ /// @param global_q Global query index (for storage indexing)
233
+ /// @param local_offset Offset within the current inverted list
234
+ float compute_full_multibit_distance(
235
+ size_t db_idx,
236
+ size_t local_q,
237
+ size_t global_q,
238
+ size_t local_offset) const;
239
+
240
+ /// Compute lower bound using 1-bit distance and error bound (multi-bit
241
+ /// only)
242
+ /// @param local_q Batch-local query index (for probe_indices access)
243
+ /// @param global_q Global query index (for storage indexing)
244
+ float compute_lower_bound(
245
+ float dist_1bit,
246
+ size_t db_idx,
247
+ size_t local_q,
248
+ size_t global_q) const;
249
+ };
250
+ };
251
+
252
+ } // namespace faiss
@@ -331,7 +331,7 @@ void IndexIVFSpectralHash::replace_vt(VectorTransform* vt_in, bool own) {
331
331
  /*
332
332
  Check that the encoder is a single vector transform followed by a LSH
333
333
  that just does thresholding.
334
- If this is not the case, the linear transform + threhsolds of the IndexLSH
334
+ If this is not the case, the linear transform + thresholds of the IndexLSH
335
335
  should be merged into the VectorTransform (which is feasible).
336
336
  */
337
337
 
@@ -79,7 +79,7 @@ struct IndexIVFSpectralHash : IndexIVF {
79
79
  */
80
80
  void replace_vt(VectorTransform* vt, bool own = false);
81
81
 
82
- /** convenience function to get the VT from an index constucted by an
82
+ /** convenience function to get the VT from an index constructed by an
83
83
  * index_factory (should end in "LSH") */
84
84
  void replace_vt(IndexPreTransform* index, bool own = false);
85
85
 
@@ -154,7 +154,7 @@ void IndexNNDescent::add(idx_t n, const float* x) {
154
154
 
155
155
  if (ntotal != 0) {
156
156
  fprintf(stderr,
157
- "WARNING NNDescent doest not support dynamic insertions,"
157
+ "WARNING NNDescent does not support dynamic insertions,"
158
158
  "multiple insertions would lead to re-building the index");
159
159
  }
160
160
 
@@ -261,7 +261,7 @@ void IndexNSG::check_knn_graph(const idx_t* knn_graph, idx_t n, int K) const {
261
261
  }
262
262
  FAISS_THROW_IF_NOT_MSG(
263
263
  total_count < n / 10,
264
- "There are too much invalid entries in the knn graph. "
264
+ "There are too many invalid entries in the knn graph. "
265
265
  "It may be an invalid knn graph.");
266
266
  }
267
267
 
@@ -29,7 +29,7 @@ struct IndexNeuralNetCodec : IndexFlatCodes {
29
29
  void sa_encode(idx_t n, const float* x, uint8_t* codes) const override;
30
30
  void sa_decode(idx_t n, const uint8_t* codes, float* x) const override;
31
31
 
32
- ~IndexNeuralNetCodec() {}
32
+ ~IndexNeuralNetCodec() override {}
33
33
  };
34
34
 
35
35
  struct IndexQINCo : IndexNeuralNetCodec {
@@ -81,6 +81,7 @@ struct PQDistanceComputer : FlatCodesDistanceComputer {
81
81
  const float* sdc;
82
82
  std::vector<float> precomputed_table;
83
83
  size_t ndis;
84
+ const float* q;
84
85
 
85
86
  float distance_to_code(const uint8_t* code) final {
86
87
  ndis++;
@@ -109,7 +110,8 @@ struct PQDistanceComputer : FlatCodesDistanceComputer {
109
110
  : FlatCodesDistanceComputer(
110
111
  storage.codes.data(),
111
112
  storage.code_size),
112
- pq(storage.pq) {
113
+ pq(storage.pq),
114
+ q(nullptr) {
113
115
  precomputed_table.resize(pq.M * pq.ksub);
114
116
  nb = storage.ntotal;
115
117
  d = storage.d;
@@ -123,6 +125,7 @@ struct PQDistanceComputer : FlatCodesDistanceComputer {
123
125
  }
124
126
 
125
127
  void set_query(const float* x) override {
128
+ q = x;
126
129
  if (metric == METRIC_L2) {
127
130
  pq.compute_distance_table(x, precomputed_table.data());
128
131
  } else {
@@ -164,7 +164,7 @@ struct MultiIndexQuantizer : Index {
164
164
  // block size used in MultiIndexQuantizer::search
165
165
  FAISS_API extern int multi_index_quantizer_search_bs;
166
166
 
167
- /** MultiIndexQuantizer where the PQ assignmnet is performed by sub-indexes
167
+ /** MultiIndexQuantizer where the PQ assignment is performed by sub-indexes
168
168
  */
169
169
  struct MultiIndexQuantizer2 : MultiIndexQuantizer {
170
170
  /// M Indexes on d / M dimensions
@@ -9,6 +9,7 @@
9
9
 
10
10
  #include <memory>
11
11
 
12
+ #include <faiss/impl/FastScanDistancePostProcessing.h>
12
13
  #include <faiss/impl/pq4_fast_scan.h>
13
14
  #include <faiss/utils/utils.h>
14
15
 
@@ -53,8 +54,11 @@ void IndexPQFastScan::compute_codes(uint8_t* codes, idx_t n, const float* x)
53
54
  pq.compute_codes(x, codes, n);
54
55
  }
55
56
 
56
- void IndexPQFastScan::compute_float_LUT(float* lut, idx_t n, const float* x)
57
- const {
57
+ void IndexPQFastScan::compute_float_LUT(
58
+ float* lut,
59
+ idx_t n,
60
+ const float* x,
61
+ const FastScanDistancePostProcessing&) const {
58
62
  if (metric_type == METRIC_L2) {
59
63
  pq.compute_distance_tables(n, x, lut);
60
64
  } else {
@@ -45,7 +45,11 @@ struct IndexPQFastScan : IndexFastScan {
45
45
 
46
46
  void compute_codes(uint8_t* codes, idx_t n, const float* x) const override;
47
47
 
48
- void compute_float_LUT(float* lut, idx_t n, const float* x) const override;
48
+ void compute_float_LUT(
49
+ float* lut,
50
+ idx_t n,
51
+ const float* x,
52
+ const FastScanDistancePostProcessing& context) const override;
49
53
 
50
54
  void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
51
55
  };
@@ -197,6 +197,20 @@ void IndexPreTransform::range_search(
197
197
  n, tv.x, radius, result, extract_index_search_params(params));
198
198
  }
199
199
 
200
+ void IndexPreTransform::search_subset(
201
+ idx_t n,
202
+ const float* x,
203
+ idx_t k_base,
204
+ const idx_t* base_labels,
205
+ idx_t k,
206
+ float* distances,
207
+ idx_t* labels) const {
208
+ FAISS_THROW_IF_NOT(k > 0);
209
+ FAISS_THROW_IF_NOT(is_trained);
210
+ TransformedVectors tv(x, apply_chain(n, x));
211
+ index->search_subset(n, tv.x, k_base, base_labels, k, distances, labels);
212
+ }
213
+
200
214
  void IndexPreTransform::reset() {
201
215
  index->reset();
202
216
  ntotal = 0;
@@ -57,6 +57,15 @@ struct IndexPreTransform : Index {
57
57
  idx_t* labels,
58
58
  const SearchParameters* params = nullptr) const override;
59
59
 
60
+ void search_subset(
61
+ idx_t n,
62
+ const float* x,
63
+ idx_t k_base,
64
+ const idx_t* base_labels,
65
+ idx_t k,
66
+ float* distances,
67
+ idx_t* labels) const override;
68
+
60
69
  /* range search, no attempt is done to change the radius */
61
70
  void range_search(
62
71
  idx_t n,
@@ -9,13 +9,18 @@
9
9
 
10
10
  #include <faiss/impl/FaissAssert.h>
11
11
  #include <faiss/impl/ResultHandler.h>
12
+ #include <memory>
12
13
 
13
14
  namespace faiss {
14
15
 
16
+ // Forward declaration from RaBitQuantizer.cpp
17
+ struct RaBitQDistanceComputer;
18
+
15
19
  IndexRaBitQ::IndexRaBitQ() = default;
16
20
 
17
- IndexRaBitQ::IndexRaBitQ(idx_t d, MetricType metric)
18
- : IndexFlatCodes(0, d, metric), rabitq(d, metric) {
21
+ IndexRaBitQ::IndexRaBitQ(idx_t d, MetricType metric, uint8_t nb_bits_in)
22
+ : IndexFlatCodes(0, d, metric), rabitq(d, metric, nb_bits_in) {
23
+ // Update code size based on nb_bits
19
24
  code_size = rabitq.code_size;
20
25
 
21
26
  is_trained = false;
@@ -55,16 +60,17 @@ void IndexRaBitQ::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
55
60
 
56
61
  FlatCodesDistanceComputer* IndexRaBitQ::get_FlatCodesDistanceComputer() const {
57
62
  FlatCodesDistanceComputer* dc =
58
- rabitq.get_distance_computer(qb, center.data());
63
+ rabitq.get_distance_computer(qb, center.data(), centered);
59
64
  dc->code_size = rabitq.code_size;
60
65
  dc->codes = codes.data();
61
66
  return dc;
62
67
  }
63
68
 
64
69
  FlatCodesDistanceComputer* IndexRaBitQ::get_quantized_distance_computer(
65
- const uint8_t qb) const {
70
+ const uint8_t qb,
71
+ bool centered) const {
66
72
  FlatCodesDistanceComputer* dc =
67
- rabitq.get_distance_computer(qb, center.data());
73
+ rabitq.get_distance_computer(qb, center.data(), centered);
68
74
  dc->code_size = rabitq.code_size;
69
75
  dc->codes = codes.data();
70
76
  return dc;
@@ -76,6 +82,8 @@ struct Run_search_with_dc_res {
76
82
  using T = void;
77
83
 
78
84
  uint8_t qb = 0;
85
+ bool centered = false;
86
+ uint8_t nb_bits = 1; // Number of bits per dimension
79
87
 
80
88
  template <class BlockResultHandler>
81
89
  void f(BlockResultHandler& res, const IndexRaBitQ* index, const float* xq) {
@@ -83,22 +91,87 @@ struct Run_search_with_dc_res {
83
91
  using SingleResultHandler =
84
92
  typename BlockResultHandler::SingleResultHandler;
85
93
  const int d = index->d;
94
+ size_t ex_bits = nb_bits - 1;
86
95
 
87
- #pragma omp parallel // if (res.nq > 100)
96
+ #pragma omp parallel
88
97
  {
89
- std::unique_ptr<FlatCodesDistanceComputer> dc(
90
- index->get_quantized_distance_computer(qb));
98
+ std::unique_ptr<FlatCodesDistanceComputer> dc_base(
99
+ index->get_quantized_distance_computer(qb, centered));
91
100
  SingleResultHandler resi(res);
92
101
  #pragma omp for
93
102
  for (int64_t q = 0; q < res.nq; q++) {
94
103
  resi.begin(q);
95
- dc->set_query(xq + d * q);
96
- for (size_t i = 0; i < ntotal; i++) {
97
- if (res.is_in_selection(i)) {
98
- float dis = (*dc)(i);
99
- resi.add_result(dis, i);
104
+ dc_base->set_query(xq + d * q);
105
+
106
+ // Stats tracking for multi-bit two-stage search only
107
+ // n_1bit_evaluations: candidates evaluated using 1-bit lower
108
+ // bound n_multibit_evaluations: candidates requiring full
109
+ // multi-bit distance
110
+ size_t local_1bit_evaluations = 0;
111
+ size_t local_multibit_evaluations = 0;
112
+
113
+ if (ex_bits == 0) {
114
+ // 1-bit: Standard single-stage search (no stats tracking)
115
+ for (size_t i = 0; i < ntotal; i++) {
116
+ if (res.is_in_selection(i)) {
117
+ float dis = (*dc_base)(i);
118
+ resi.add_result(dis, i);
119
+ }
120
+ }
121
+ } else {
122
+ // Multi-bit: Two-stage search with adaptive filtering
123
+ // Note: Even with query quantization (qb > 0), ex-bits
124
+ // distance computation uses the float query to maintain
125
+ // consistency with encoding-time factor computation. See
126
+ // RaBitQuantizer.cpp for details.
127
+ auto* dc = dynamic_cast<RaBitQDistanceComputer*>(
128
+ dc_base.get());
129
+ FAISS_THROW_IF_NOT_MSG(
130
+ dc != nullptr,
131
+ "Failed to cast to RaBitQDistanceComputer for two-stage search");
132
+
133
+ // Use appropriate comparison based on metric type
134
+ bool is_similarity =
135
+ is_similarity_metric(index->metric_type);
136
+
137
+ for (size_t i = 0; i < ntotal; i++) {
138
+ if (res.is_in_selection(i)) {
139
+ const uint8_t* code =
140
+ index->codes.data() + i * index->code_size;
141
+
142
+ local_1bit_evaluations++;
143
+
144
+ // Stage 1: Compute 1-bit lower bound
145
+ float lower_bound = dc->lower_bound_distance(code);
146
+
147
+ // Stage 2: Adaptive filtering using threshold
148
+ // For L2 (min-heap): filter if lower_bound <
149
+ // resi.threshold For IP (max-heap): filter if
150
+ // lower_bound > resi.threshold Note: Using
151
+ // resi.threshold directly (not cached) enables more
152
+ // aggressive filtering as the heap is updated
153
+ bool should_refine = is_similarity
154
+ ? (lower_bound > resi.threshold)
155
+ : (lower_bound < resi.threshold);
156
+
157
+ if (should_refine) {
158
+ local_multibit_evaluations++;
159
+ // Compute full multi-bit distance
160
+ float dist_full =
161
+ dc->distance_to_code_full(code);
162
+ resi.add_result(dist_full, i);
163
+ }
164
+ }
100
165
  }
101
166
  }
167
+
168
+ // Update global stats atomically
169
+ #pragma omp atomic
170
+ rabitq_stats.n_1bit_evaluations += local_1bit_evaluations;
171
+ #pragma omp atomic
172
+ rabitq_stats.n_multibit_evaluations +=
173
+ local_multibit_evaluations;
174
+
102
175
  resi.end();
103
176
  }
104
177
  }
@@ -114,15 +187,25 @@ void IndexRaBitQ::search(
114
187
  float* distances,
115
188
  idx_t* labels,
116
189
  const SearchParameters* params_in) const {
190
+ FAISS_THROW_IF_NOT(is_trained);
191
+
192
+ // Extract search parameters
117
193
  uint8_t used_qb = qb;
194
+ bool used_centered = centered;
118
195
  if (auto params = dynamic_cast<const RaBitQSearchParameters*>(params_in)) {
119
196
  used_qb = params->qb;
197
+ used_centered = params->centered;
120
198
  }
121
199
 
122
200
  const IDSelector* sel = (params_in != nullptr) ? params_in->sel : nullptr;
201
+
202
+ // Set up functor with all necessary parameters
123
203
  Run_search_with_dc_res r;
124
204
  r.qb = used_qb;
205
+ r.centered = used_centered;
206
+ r.nb_bits = rabitq.nb_bits; // Pass multi-bit info to functor
125
207
 
208
+ // Use Faiss framework for all cases (single-stage and two-stage)
126
209
  dispatch_knn_ResultHandler(
127
210
  n, distances, labels, k, metric_type, sel, r, this, x);
128
211
  }
@@ -8,12 +8,14 @@
8
8
  #pragma once
9
9
 
10
10
  #include <faiss/IndexFlatCodes.h>
11
+ #include <faiss/impl/RaBitQStats.h>
11
12
  #include <faiss/impl/RaBitQuantizer.h>
12
13
 
13
14
  namespace faiss {
14
15
 
15
16
  struct RaBitQSearchParameters : SearchParameters {
16
17
  uint8_t qb = 0;
18
+ bool centered = false;
17
19
  };
18
20
 
19
21
  struct IndexRaBitQ : IndexFlatCodes {
@@ -26,9 +28,15 @@ struct IndexRaBitQ : IndexFlatCodes {
26
28
  // use '0' to disable quantization and use raw fp32 values.
27
29
  uint8_t qb = 0;
28
30
 
31
+ // quantize the query with a zero-centered scalar quantizer.
32
+ bool centered = false;
33
+
29
34
  IndexRaBitQ();
30
35
 
31
- IndexRaBitQ(idx_t d, MetricType metric = METRIC_L2);
36
+ explicit IndexRaBitQ(
37
+ idx_t d,
38
+ MetricType metric = METRIC_L2,
39
+ uint8_t nb_bits = 1);
32
40
 
33
41
  void train(idx_t n, const float* x) override;
34
42
 
@@ -42,7 +50,8 @@ struct IndexRaBitQ : IndexFlatCodes {
42
50
  // returns a quantized-to-qb bits DC if qb_in > 0
43
51
  // returns a default fp32-based DC if qb_in == 0
44
52
  FlatCodesDistanceComputer* get_quantized_distance_computer(
45
- const uint8_t qb_in) const;
53
+ const uint8_t qb_in,
54
+ bool centered) const;
46
55
 
47
56
  // Don't rely on sa_decode(), bcz it is good for IP, but not for L2.
48
57
  // As a result, use get_FlatCodesDistanceComputer() for the search.