faiss 0.4.2 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/ext/faiss/index.cpp +36 -10
  4. data/ext/faiss/index_binary.cpp +19 -6
  5. data/ext/faiss/kmeans.cpp +6 -6
  6. data/ext/faiss/numo.hpp +273 -123
  7. data/lib/faiss/version.rb +1 -1
  8. data/vendor/faiss/faiss/AutoTune.cpp +2 -3
  9. data/vendor/faiss/faiss/AutoTune.h +1 -1
  10. data/vendor/faiss/faiss/Clustering.cpp +2 -2
  11. data/vendor/faiss/faiss/Clustering.h +2 -2
  12. data/vendor/faiss/faiss/IVFlib.cpp +1 -2
  13. data/vendor/faiss/faiss/IVFlib.h +1 -1
  14. data/vendor/faiss/faiss/Index.h +10 -10
  15. data/vendor/faiss/faiss/Index2Layer.cpp +1 -1
  16. data/vendor/faiss/faiss/Index2Layer.h +2 -2
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +9 -4
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +5 -1
  19. data/vendor/faiss/faiss/IndexBinary.h +7 -7
  20. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +1 -1
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +3 -1
  22. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  23. data/vendor/faiss/faiss/IndexBinaryHash.cpp +3 -3
  24. data/vendor/faiss/faiss/IndexBinaryHash.h +5 -5
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +7 -6
  26. data/vendor/faiss/faiss/IndexFastScan.cpp +125 -49
  27. data/vendor/faiss/faiss/IndexFastScan.h +107 -7
  28. data/vendor/faiss/faiss/IndexFlat.h +1 -1
  29. data/vendor/faiss/faiss/IndexHNSW.cpp +3 -1
  30. data/vendor/faiss/faiss/IndexHNSW.h +1 -1
  31. data/vendor/faiss/faiss/IndexIDMap.cpp +14 -13
  32. data/vendor/faiss/faiss/IndexIDMap.h +6 -6
  33. data/vendor/faiss/faiss/IndexIVF.cpp +1 -1
  34. data/vendor/faiss/faiss/IndexIVF.h +5 -5
  35. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +1 -1
  36. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +9 -3
  37. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +3 -1
  38. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +176 -90
  39. data/vendor/faiss/faiss/IndexIVFFastScan.h +173 -18
  40. data/vendor/faiss/faiss/IndexIVFFlat.cpp +1 -0
  41. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +366 -0
  42. data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +64 -0
  43. data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -1
  44. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
  45. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +134 -2
  46. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -1
  47. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +13 -6
  48. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +1 -0
  49. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +650 -0
  50. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +216 -0
  51. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -1
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
  53. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -1
  54. data/vendor/faiss/faiss/IndexNSG.cpp +1 -1
  55. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +1 -1
  56. data/vendor/faiss/faiss/IndexPQ.h +1 -1
  57. data/vendor/faiss/faiss/IndexPQFastScan.cpp +6 -2
  58. data/vendor/faiss/faiss/IndexPQFastScan.h +5 -1
  59. data/vendor/faiss/faiss/IndexRaBitQ.cpp +13 -10
  60. data/vendor/faiss/faiss/IndexRaBitQ.h +7 -2
  61. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +586 -0
  62. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +149 -0
  63. data/vendor/faiss/faiss/IndexShards.cpp +1 -1
  64. data/vendor/faiss/faiss/MatrixStats.cpp +3 -3
  65. data/vendor/faiss/faiss/MetricType.h +1 -1
  66. data/vendor/faiss/faiss/VectorTransform.h +2 -2
  67. data/vendor/faiss/faiss/clone_index.cpp +3 -1
  68. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  69. data/vendor/faiss/faiss/gpu/GpuIndex.h +11 -11
  70. data/vendor/faiss/faiss/gpu/GpuIndexBinaryCagra.h +1 -1
  71. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +1 -1
  72. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +10 -6
  73. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -0
  74. data/vendor/faiss/faiss/gpu/test/TestGpuIcmEncoder.cpp +7 -0
  75. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +1 -1
  76. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +1 -1
  77. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -1
  78. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +2 -2
  79. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -1
  80. data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
  81. data/vendor/faiss/faiss/impl/DistanceComputer.h +3 -3
  82. data/vendor/faiss/faiss/impl/FastScanDistancePostProcessing.h +53 -0
  83. data/vendor/faiss/faiss/impl/HNSW.cpp +1 -1
  84. data/vendor/faiss/faiss/impl/HNSW.h +4 -4
  85. data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
  86. data/vendor/faiss/faiss/impl/IDSelector.h +1 -1
  87. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +4 -4
  88. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +1 -1
  89. data/vendor/faiss/faiss/impl/LookupTableScaler.h +1 -1
  90. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  91. data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
  92. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  93. data/vendor/faiss/faiss/impl/PanoramaStats.cpp +33 -0
  94. data/vendor/faiss/faiss/impl/PanoramaStats.h +38 -0
  95. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +5 -5
  96. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +1 -1
  97. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  98. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -0
  99. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  100. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +246 -0
  101. data/vendor/faiss/faiss/impl/RaBitQUtils.h +153 -0
  102. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +54 -158
  103. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +2 -1
  104. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
  105. data/vendor/faiss/faiss/impl/ResultHandler.h +4 -4
  106. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1 -1
  107. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +1 -1
  108. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +7 -4
  109. data/vendor/faiss/faiss/impl/index_read.cpp +87 -3
  110. data/vendor/faiss/faiss/impl/index_write.cpp +73 -3
  111. data/vendor/faiss/faiss/impl/io.cpp +2 -2
  112. data/vendor/faiss/faiss/impl/io.h +4 -4
  113. data/vendor/faiss/faiss/impl/kmeans1d.cpp +1 -1
  114. data/vendor/faiss/faiss/impl/kmeans1d.h +1 -1
  115. data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
  116. data/vendor/faiss/faiss/impl/mapped_io.cpp +2 -2
  117. data/vendor/faiss/faiss/impl/mapped_io.h +4 -3
  118. data/vendor/faiss/faiss/impl/maybe_owned_vector.h +8 -1
  119. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +30 -4
  120. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +14 -8
  121. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +5 -6
  122. data/vendor/faiss/faiss/impl/simd_result_handlers.h +55 -11
  123. data/vendor/faiss/faiss/impl/zerocopy_io.h +1 -1
  124. data/vendor/faiss/faiss/index_factory.cpp +43 -1
  125. data/vendor/faiss/faiss/index_factory.h +1 -1
  126. data/vendor/faiss/faiss/index_io.h +1 -1
  127. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +205 -0
  128. data/vendor/faiss/faiss/invlists/InvertedLists.h +62 -0
  129. data/vendor/faiss/faiss/utils/AlignedTable.h +1 -1
  130. data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
  131. data/vendor/faiss/faiss/utils/Heap.h +3 -3
  132. data/vendor/faiss/faiss/utils/NeuralNet.cpp +1 -1
  133. data/vendor/faiss/faiss/utils/NeuralNet.h +3 -3
  134. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
  135. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
  136. data/vendor/faiss/faiss/utils/approx_topk/mode.h +1 -1
  137. data/vendor/faiss/faiss/utils/distances.h +2 -2
  138. data/vendor/faiss/faiss/utils/extra_distances-inl.h +3 -1
  139. data/vendor/faiss/faiss/utils/hamming-inl.h +2 -0
  140. data/vendor/faiss/faiss/utils/hamming.cpp +7 -6
  141. data/vendor/faiss/faiss/utils/hamming.h +1 -1
  142. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -2
  143. data/vendor/faiss/faiss/utils/partitioning.cpp +5 -5
  144. data/vendor/faiss/faiss/utils/partitioning.h +2 -2
  145. data/vendor/faiss/faiss/utils/rabitq_simd.h +222 -336
  146. data/vendor/faiss/faiss/utils/random.cpp +1 -1
  147. data/vendor/faiss/faiss/utils/simdlib_avx2.h +1 -1
  148. data/vendor/faiss/faiss/utils/simdlib_avx512.h +1 -1
  149. data/vendor/faiss/faiss/utils/simdlib_neon.h +2 -2
  150. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +1 -1
  151. data/vendor/faiss/faiss/utils/utils.cpp +5 -2
  152. data/vendor/faiss/faiss/utils/utils.h +2 -2
  153. metadata +14 -3
@@ -0,0 +1,216 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <vector>
11
+
12
+ #include <faiss/IndexIVFFastScan.h>
13
+ #include <faiss/IndexIVFRaBitQ.h>
14
+ #include <faiss/IndexRaBitQFastScan.h>
15
+ #include <faiss/impl/RaBitQUtils.h>
16
+ #include <faiss/impl/RaBitQuantizer.h>
17
+ #include <faiss/impl/simd_result_handlers.h>
18
+ #include <faiss/utils/AlignedTable.h>
19
+ #include <faiss/utils/Heap.h>
20
+
21
+ namespace faiss {
22
+
23
+ // Forward declarations
24
+ struct FastScanDistancePostProcessing;
25
+
26
+ // Import shared utilities from RaBitQUtils
27
+ using rabitq_utils::FactorsData;
28
+ using rabitq_utils::QueryFactorsData;
29
+
30
+ /** Fast-scan version of IndexIVFRaBitQ that processes vectors in batches
31
+ * using SIMD operations. Combines the inverted file structure of IVF
32
+ * with RaBitQ's bit-level quantization and FastScan's batch processing.
33
+ *
34
+ * Key features:
35
+ * - Inherits from IndexIVFFastScan for IVF structure and search algorithms
36
+ * - Processes 32 database vectors at a time using SIMD
37
+ * - Separates factors from quantized bits for efficient processing
38
+ * - Supports both L2 and inner product metrics
39
+ * - Maintains compatibility with existing IVF search parameters
40
+ *
41
+ * Implementation details:
42
+ * - Batch size (bbs) is typically 32 for optimal SIMD performance
43
+ * - Factors are stored separately from packed codes for cache efficiency
44
+ * - Query factors are computed once per search and reused across lists
45
+ * - Uses specialized result handlers for RaBitQ distance corrections
46
+ */
47
+ struct IndexIVFRaBitQFastScan : IndexIVFFastScan {
48
+ RaBitQuantizer rabitq;
49
+
50
+ /// Default number of bits to quantize a query with
51
+ uint8_t qb = 8;
52
+
53
+ /// Use zero-centered scalar quantizer for queries
54
+ bool centered = false;
55
+
56
+ /// Extracted factors storage for batch processing
57
+ /// Size: ntotal, stores factors separately from packed codes
58
+ std::vector<FactorsData> factors_storage;
59
+
60
+ // Constructors
61
+
62
+ IndexIVFRaBitQFastScan();
63
+
64
+ IndexIVFRaBitQFastScan(
65
+ Index* quantizer,
66
+ size_t d,
67
+ size_t nlist,
68
+ MetricType metric = METRIC_L2,
69
+ int bbs = 32,
70
+ bool own_invlists = true);
71
+
72
+ /// Build from an existing IndexIVFRaBitQ
73
+ explicit IndexIVFRaBitQFastScan(const IndexIVFRaBitQ& orig, int bbs = 32);
74
+
75
+ // Required overrides
76
+
77
+ void train_encoder(idx_t n, const float* x, const idx_t* assign) override;
78
+
79
+ void encode_vectors(
80
+ idx_t n,
81
+ const float* x,
82
+ const idx_t* list_nos,
83
+ uint8_t* codes,
84
+ bool include_listnos = false) const override;
85
+
86
+ protected:
87
+ /// Extract and store RaBitQ factors from encoded vectors
88
+ void preprocess_code_metadata(
89
+ idx_t n,
90
+ const uint8_t* flat_codes,
91
+ idx_t start_global_idx) override;
92
+
93
+ /// Return code_size as stride to skip embedded factor data during packing
94
+ size_t code_packing_stride() const override;
95
+
96
+ public:
97
+ /// Reconstruct a single vector from an inverted list
98
+ void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
99
+ const override;
100
+
101
+ /// Override sa_decode to handle RaBitQ reconstruction
102
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
103
+
104
+ private:
105
+ /// Encode a vector to FastScan format without computing factors
106
+ void encode_vector_to_fastscan(
107
+ const float* xi,
108
+ const float* centroid,
109
+ uint8_t* fastscan_code) const;
110
+
111
+ /// Compute query factors and lookup table for a residual vector
112
+ /// (similar to IndexRaBitQFastScan::compute_float_LUT)
113
+ void compute_residual_LUT(
114
+ const float* residual,
115
+ QueryFactorsData& query_factors,
116
+ float* lut_out,
117
+ const float* original_query = nullptr) const;
118
+
119
+ /// Decode FastScan code to RaBitQ residual vector
120
+ void decode_fastscan_to_residual(
121
+ const uint8_t* fastscan_code,
122
+ float* residual) const;
123
+
124
+ public:
125
+ /// Implementation methods for IVFRaBitQFastScan specialization
126
+ bool lookup_table_is_3d() const override;
127
+
128
+ void compute_LUT(
129
+ size_t n,
130
+ const float* x,
131
+ const CoarseQuantized& cq,
132
+ AlignedTable<float>& dis_tables,
133
+ AlignedTable<float>& biases,
134
+ const FastScanDistancePostProcessing& context) const override;
135
+
136
+ void search_preassigned(
137
+ idx_t n,
138
+ const float* x,
139
+ idx_t k,
140
+ const idx_t* assign,
141
+ const float* centroid_dis,
142
+ float* distances,
143
+ idx_t* labels,
144
+ bool store_pairs,
145
+ const IVFSearchParameters* params = nullptr,
146
+ IndexIVFStats* stats = nullptr) const override;
147
+
148
+ /// Override to create RaBitQ-specific handlers
149
+ SIMDResultHandlerToFloat* make_knn_handler(
150
+ bool is_max,
151
+ int /* impl */,
152
+ idx_t n,
153
+ idx_t k,
154
+ float* distances,
155
+ idx_t* labels,
156
+ const IDSelector* sel,
157
+ const FastScanDistancePostProcessing& context,
158
+ const float* normalizers = nullptr) const override;
159
+
160
+ /** SIMD result handler for IndexIVFRaBitQFastScan that applies
161
+ * RaBitQ-specific distance corrections during batch processing.
162
+ *
163
+ * This handler processes batches of 32 distance computations from SIMD
164
+ * kernels, applies RaBitQ distance formula adjustments (factors and
165
+ * normalizers), and immediately updates result heaps. This eliminates the
166
+ * need for post-processing and provides significant performance benefits.
167
+ *
168
+ * Key optimizations:
169
+ * - Direct heap integration with no intermediate result storage
170
+ * - Batch-level computation of normalizers and query factors
171
+ * - Specialized handling for both centered and non-centered quantization
172
+ * modes
173
+ * - Efficient inner product metric corrections
174
+ *
175
+ * @tparam C Comparator type (CMin/CMax) for heap operations
176
+ */
177
+ template <class C>
178
+ struct IVFRaBitQHeapHandler
179
+ : simd_result_handlers::ResultHandlerCompare<C, true> {
180
+ const IndexIVFRaBitQFastScan* index;
181
+ float* heap_distances; // [nq * k]
182
+ int64_t* heap_labels; // [nq * k]
183
+ const size_t nq, k;
184
+ size_t current_list_no = 0;
185
+ std::vector<int>
186
+ probe_indices; // probe index for each query in current batch
187
+ const FastScanDistancePostProcessing*
188
+ context; // Processing context with query factors
189
+
190
+ // Use float-based comparator for heap operations
191
+ using Cfloat = typename std::conditional<
192
+ C::is_max,
193
+ CMax<float, int64_t>,
194
+ CMin<float, int64_t>>::type;
195
+
196
+ IVFRaBitQHeapHandler(
197
+ const IndexIVFRaBitQFastScan* idx,
198
+ size_t nq_val,
199
+ size_t k_val,
200
+ float* distances,
201
+ int64_t* labels,
202
+ const FastScanDistancePostProcessing* ctx = nullptr);
203
+
204
+ void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final;
205
+
206
+ /// Override base class virtual method to receive context information
207
+ void set_list_context(size_t list_no, const std::vector<int>& probe_map)
208
+ override;
209
+
210
+ void begin(const float* norms) override;
211
+
212
+ void end() override;
213
+ };
214
+ };
215
+
216
+ } // namespace faiss
@@ -331,7 +331,7 @@ void IndexIVFSpectralHash::replace_vt(VectorTransform* vt_in, bool own) {
331
331
  /*
332
332
  Check that the encoder is a single vector transform followed by a LSH
333
333
  that just does thresholding.
334
- If this is not the case, the linear transform + threhsolds of the IndexLSH
334
+ If this is not the case, the linear transform + thresholds of the IndexLSH
335
335
  should be merged into the VectorTransform (which is feasible).
336
336
  */
337
337
 
@@ -79,7 +79,7 @@ struct IndexIVFSpectralHash : IndexIVF {
79
79
  */
80
80
  void replace_vt(VectorTransform* vt, bool own = false);
81
81
 
82
- /** convenience function to get the VT from an index constucted by an
82
+ /** convenience function to get the VT from an index constructed by an
83
83
  * index_factory (should end in "LSH") */
84
84
  void replace_vt(IndexPreTransform* index, bool own = false);
85
85
 
@@ -154,7 +154,7 @@ void IndexNNDescent::add(idx_t n, const float* x) {
154
154
 
155
155
  if (ntotal != 0) {
156
156
  fprintf(stderr,
157
- "WARNING NNDescent doest not support dynamic insertions,"
157
+ "WARNING NNDescent does not support dynamic insertions,"
158
158
  "multiple insertions would lead to re-building the index");
159
159
  }
160
160
 
@@ -261,7 +261,7 @@ void IndexNSG::check_knn_graph(const idx_t* knn_graph, idx_t n, int K) const {
261
261
  }
262
262
  FAISS_THROW_IF_NOT_MSG(
263
263
  total_count < n / 10,
264
- "There are too much invalid entries in the knn graph. "
264
+ "There are too many invalid entries in the knn graph. "
265
265
  "It may be an invalid knn graph.");
266
266
  }
267
267
 
@@ -29,7 +29,7 @@ struct IndexNeuralNetCodec : IndexFlatCodes {
29
29
  void sa_encode(idx_t n, const float* x, uint8_t* codes) const override;
30
30
  void sa_decode(idx_t n, const uint8_t* codes, float* x) const override;
31
31
 
32
- ~IndexNeuralNetCodec() {}
32
+ ~IndexNeuralNetCodec() override {}
33
33
  };
34
34
 
35
35
  struct IndexQINCo : IndexNeuralNetCodec {
@@ -164,7 +164,7 @@ struct MultiIndexQuantizer : Index {
164
164
  // block size used in MultiIndexQuantizer::search
165
165
  FAISS_API extern int multi_index_quantizer_search_bs;
166
166
 
167
- /** MultiIndexQuantizer where the PQ assignmnet is performed by sub-indexes
167
+ /** MultiIndexQuantizer where the PQ assignment is performed by sub-indexes
168
168
  */
169
169
  struct MultiIndexQuantizer2 : MultiIndexQuantizer {
170
170
  /// M Indexes on d / M dimensions
@@ -9,6 +9,7 @@
9
9
 
10
10
  #include <memory>
11
11
 
12
+ #include <faiss/impl/FastScanDistancePostProcessing.h>
12
13
  #include <faiss/impl/pq4_fast_scan.h>
13
14
  #include <faiss/utils/utils.h>
14
15
 
@@ -53,8 +54,11 @@ void IndexPQFastScan::compute_codes(uint8_t* codes, idx_t n, const float* x)
53
54
  pq.compute_codes(x, codes, n);
54
55
  }
55
56
 
56
- void IndexPQFastScan::compute_float_LUT(float* lut, idx_t n, const float* x)
57
- const {
57
+ void IndexPQFastScan::compute_float_LUT(
58
+ float* lut,
59
+ idx_t n,
60
+ const float* x,
61
+ const FastScanDistancePostProcessing&) const {
58
62
  if (metric_type == METRIC_L2) {
59
63
  pq.compute_distance_tables(n, x, lut);
60
64
  } else {
@@ -45,7 +45,11 @@ struct IndexPQFastScan : IndexFastScan {
45
45
 
46
46
  void compute_codes(uint8_t* codes, idx_t n, const float* x) const override;
47
47
 
48
- void compute_float_LUT(float* lut, idx_t n, const float* x) const override;
48
+ void compute_float_LUT(
49
+ float* lut,
50
+ idx_t n,
51
+ const float* x,
52
+ const FastScanDistancePostProcessing& context) const override;
49
53
 
50
54
  void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
51
55
  };
@@ -55,16 +55,17 @@ void IndexRaBitQ::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
55
55
 
56
56
  FlatCodesDistanceComputer* IndexRaBitQ::get_FlatCodesDistanceComputer() const {
57
57
  FlatCodesDistanceComputer* dc =
58
- rabitq.get_distance_computer(qb, center.data());
58
+ rabitq.get_distance_computer(qb, center.data(), centered);
59
59
  dc->code_size = rabitq.code_size;
60
60
  dc->codes = codes.data();
61
61
  return dc;
62
62
  }
63
63
 
64
64
  FlatCodesDistanceComputer* IndexRaBitQ::get_quantized_distance_computer(
65
- const uint8_t qb) const {
65
+ const uint8_t qb,
66
+ bool centered) const {
66
67
  FlatCodesDistanceComputer* dc =
67
- rabitq.get_distance_computer(qb, center.data());
68
+ rabitq.get_distance_computer(qb, center.data(), centered);
68
69
  dc->code_size = rabitq.code_size;
69
70
  dc->codes = codes.data();
70
71
  return dc;
@@ -76,6 +77,7 @@ struct Run_search_with_dc_res {
76
77
  using T = void;
77
78
 
78
79
  uint8_t qb = 0;
80
+ bool centered = false;
79
81
 
80
82
  template <class BlockResultHandler>
81
83
  void f(BlockResultHandler& res, const IndexRaBitQ* index, const float* xq) {
@@ -87,7 +89,7 @@ struct Run_search_with_dc_res {
87
89
  #pragma omp parallel // if (res.nq > 100)
88
90
  {
89
91
  std::unique_ptr<FlatCodesDistanceComputer> dc(
90
- index->get_quantized_distance_computer(qb));
92
+ index->get_quantized_distance_computer(qb, centered));
91
93
  SingleResultHandler resi(res);
92
94
  #pragma omp for
93
95
  for (int64_t q = 0; q < res.nq; q++) {
@@ -114,14 +116,15 @@ void IndexRaBitQ::search(
114
116
  float* distances,
115
117
  idx_t* labels,
116
118
  const SearchParameters* params_in) const {
117
- uint8_t used_qb = qb;
118
- if (auto params = dynamic_cast<const RaBitQSearchParameters*>(params_in)) {
119
- used_qb = params->qb;
120
- }
121
-
122
119
  const IDSelector* sel = (params_in != nullptr) ? params_in->sel : nullptr;
123
120
  Run_search_with_dc_res r;
124
- r.qb = used_qb;
121
+ if (auto params = dynamic_cast<const RaBitQSearchParameters*>(params_in)) {
122
+ r.qb = params->qb;
123
+ r.centered = params->centered;
124
+ } else {
125
+ r.qb = this->qb;
126
+ r.centered = this->centered;
127
+ }
125
128
 
126
129
  dispatch_knn_ResultHandler(
127
130
  n, distances, labels, k, metric_type, sel, r, this, x);
@@ -14,6 +14,7 @@ namespace faiss {
14
14
 
15
15
  struct RaBitQSearchParameters : SearchParameters {
16
16
  uint8_t qb = 0;
17
+ bool centered = false;
17
18
  };
18
19
 
19
20
  struct IndexRaBitQ : IndexFlatCodes {
@@ -26,9 +27,12 @@ struct IndexRaBitQ : IndexFlatCodes {
26
27
  // use '0' to disable quantization and use raw fp32 values.
27
28
  uint8_t qb = 0;
28
29
 
30
+ // quantize the query with a zero-centered scalar quantizer.
31
+ bool centered = false;
32
+
29
33
  IndexRaBitQ();
30
34
 
31
- IndexRaBitQ(idx_t d, MetricType metric = METRIC_L2);
35
+ explicit IndexRaBitQ(idx_t d, MetricType metric = METRIC_L2);
32
36
 
33
37
  void train(idx_t n, const float* x) override;
34
38
 
@@ -42,7 +46,8 @@ struct IndexRaBitQ : IndexFlatCodes {
42
46
  // returns a quantized-to-qb bits DC if qb_in > 0
43
47
  // returns a default fp32-based DC if qb_in == 0
44
48
  FlatCodesDistanceComputer* get_quantized_distance_computer(
45
- const uint8_t qb_in) const;
49
+ const uint8_t qb_in,
50
+ bool centered) const;
46
51
 
47
52
  // Don't rely on sa_decode(), bcz it is good for IP, but not for L2.
48
53
  // As a result, use get_FlatCodesDistanceComputer() for the search.