faiss 0.3.4 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.cpp +11 -8
  5. data/vendor/faiss/faiss/Clustering.cpp +0 -16
  6. data/vendor/faiss/faiss/IVFlib.cpp +213 -0
  7. data/vendor/faiss/faiss/IVFlib.h +42 -0
  8. data/vendor/faiss/faiss/Index.h +1 -1
  9. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -7
  10. data/vendor/faiss/faiss/IndexBinaryFlat.h +2 -1
  11. data/vendor/faiss/faiss/IndexFlatCodes.cpp +1 -1
  12. data/vendor/faiss/faiss/IndexFlatCodes.h +4 -2
  13. data/vendor/faiss/faiss/IndexHNSW.cpp +13 -20
  14. data/vendor/faiss/faiss/IndexHNSW.h +1 -1
  15. data/vendor/faiss/faiss/IndexIVF.cpp +20 -3
  16. data/vendor/faiss/faiss/IndexIVF.h +5 -2
  17. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -1
  18. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +2 -1
  19. data/vendor/faiss/faiss/IndexIVFFlat.cpp +2 -1
  20. data/vendor/faiss/faiss/IndexIVFFlat.h +2 -1
  21. data/vendor/faiss/faiss/IndexIVFPQ.cpp +2 -1
  22. data/vendor/faiss/faiss/IndexIVFPQ.h +2 -1
  23. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +277 -0
  24. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +70 -0
  25. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +2 -1
  26. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  27. data/vendor/faiss/faiss/IndexRaBitQ.cpp +148 -0
  28. data/vendor/faiss/faiss/IndexRaBitQ.h +65 -0
  29. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -1
  30. data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -1
  31. data/vendor/faiss/faiss/clone_index.cpp +38 -3
  32. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +19 -0
  33. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +4 -11
  34. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +2 -1
  35. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +13 -3
  36. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +1 -1
  37. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +1 -1
  38. data/vendor/faiss/faiss/gpu/test/TestGpuIcmEncoder.cpp +112 -0
  39. data/vendor/faiss/faiss/impl/HNSW.cpp +35 -13
  40. data/vendor/faiss/faiss/impl/HNSW.h +5 -4
  41. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  42. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +519 -0
  43. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +78 -0
  44. data/vendor/faiss/faiss/impl/ResultHandler.h +2 -2
  45. data/vendor/faiss/faiss/impl/code_distance/code_distance-sve.h +3 -4
  46. data/vendor/faiss/faiss/impl/index_read.cpp +220 -25
  47. data/vendor/faiss/faiss/impl/index_write.cpp +29 -0
  48. data/vendor/faiss/faiss/impl/io.h +2 -2
  49. data/vendor/faiss/faiss/impl/io_macros.h +2 -0
  50. data/vendor/faiss/faiss/impl/mapped_io.cpp +313 -0
  51. data/vendor/faiss/faiss/impl/mapped_io.h +51 -0
  52. data/vendor/faiss/faiss/impl/maybe_owned_vector.h +316 -0
  53. data/vendor/faiss/faiss/impl/platform_macros.h +7 -3
  54. data/vendor/faiss/faiss/impl/simd_result_handlers.h +1 -1
  55. data/vendor/faiss/faiss/impl/zerocopy_io.cpp +67 -0
  56. data/vendor/faiss/faiss/impl/zerocopy_io.h +32 -0
  57. data/vendor/faiss/faiss/index_factory.cpp +16 -5
  58. data/vendor/faiss/faiss/index_io.h +4 -0
  59. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +3 -3
  60. data/vendor/faiss/faiss/invlists/InvertedLists.h +5 -3
  61. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +3 -3
  62. data/vendor/faiss/faiss/python/python_callbacks.cpp +24 -0
  63. data/vendor/faiss/faiss/python/python_callbacks.h +22 -0
  64. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +30 -12
  65. data/vendor/faiss/faiss/utils/hamming.cpp +45 -21
  66. data/vendor/faiss/faiss/utils/hamming.h +7 -3
  67. data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +1 -1
  68. data/vendor/faiss/faiss/utils/utils.cpp +4 -4
  69. data/vendor/faiss/faiss/utils/utils.h +3 -3
  70. metadata +16 -4
@@ -312,11 +312,14 @@ struct IndexIVF : Index, IndexIVFInterface {
312
312
 
313
313
  /** Get a scanner for this index (store_pairs means ignore labels)
314
314
  *
315
- * The default search implementation uses this to compute the distances
315
+ * The default search implementation uses this to compute the distances.
316
+ * Use sel instead of params->sel, because sel is initialized with
317
+ * params->sel, but may get overridden by IndexIVF's internal logic.
316
318
  */
317
319
  virtual InvertedListScanner* get_InvertedListScanner(
318
320
  bool store_pairs = false,
319
- const IDSelector* sel = nullptr) const;
321
+ const IDSelector* sel = nullptr,
322
+ const IVFSearchParameters* params = nullptr) const;
320
323
 
321
324
  /** reconstruct a vector. Works only if maintain_direct_map is set to 1 or 2
322
325
  */
@@ -253,7 +253,8 @@ struct AQInvertedListScannerLUT : AQInvertedListScanner {
253
253
 
254
254
  InvertedListScanner* IndexIVFAdditiveQuantizer::get_InvertedListScanner(
255
255
  bool store_pairs,
256
- const IDSelector* sel) const {
256
+ const IDSelector* sel,
257
+ const IVFSearchParameters*) const {
257
258
  FAISS_THROW_IF_NOT(!sel);
258
259
  if (metric_type == METRIC_INNER_PRODUCT) {
259
260
  if (aq->search_type == AdditiveQuantizer::ST_decompress) {
@@ -52,7 +52,8 @@ struct IndexIVFAdditiveQuantizer : IndexIVF {
52
52
 
53
53
  InvertedListScanner* get_InvertedListScanner(
54
54
  bool store_pairs,
55
- const IDSelector* sel) const override;
55
+ const IDSelector* sel,
56
+ const IVFSearchParameters* params) const override;
56
57
 
57
58
  void sa_decode(idx_t n, const uint8_t* codes, float* x) const override;
58
59
 
@@ -223,7 +223,8 @@ InvertedListScanner* get_InvertedListScanner1(
223
223
 
224
224
  InvertedListScanner* IndexIVFFlat::get_InvertedListScanner(
225
225
  bool store_pairs,
226
- const IDSelector* sel) const {
226
+ const IDSelector* sel,
227
+ const IVFSearchParameters*) const {
227
228
  if (sel) {
228
229
  return get_InvertedListScanner1<true>(this, store_pairs, sel);
229
230
  } else {
@@ -44,7 +44,8 @@ struct IndexIVFFlat : IndexIVF {
44
44
 
45
45
  InvertedListScanner* get_InvertedListScanner(
46
46
  bool store_pairs,
47
- const IDSelector* sel) const override;
47
+ const IDSelector* sel,
48
+ const IVFSearchParameters* params) const override;
48
49
 
49
50
  void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
50
51
  const override;
@@ -1321,7 +1321,8 @@ InvertedListScanner* get_InvertedListScanner2(
1321
1321
 
1322
1322
  InvertedListScanner* IndexIVFPQ::get_InvertedListScanner(
1323
1323
  bool store_pairs,
1324
- const IDSelector* sel) const {
1324
+ const IDSelector* sel,
1325
+ const IVFSearchParameters*) const {
1325
1326
  if (sel) {
1326
1327
  return get_InvertedListScanner2<true>(*this, store_pairs, sel);
1327
1328
  } else {
@@ -134,7 +134,8 @@ struct IndexIVFPQ : IndexIVF {
134
134
 
135
135
  InvertedListScanner* get_InvertedListScanner(
136
136
  bool store_pairs,
137
- const IDSelector* sel) const override;
137
+ const IDSelector* sel,
138
+ const IVFSearchParameters* params) const override;
138
139
 
139
140
  /// build precomputed table
140
141
  void precompute_table();
@@ -0,0 +1,277 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/IndexIVFRaBitQ.h>
9
+
10
+ #include <omp.h>
11
+
12
+ #include <cstddef>
13
+ #include <cstdint>
14
+ #include <memory>
15
+ #include <vector>
16
+
17
+ #include <faiss/impl/FaissAssert.h>
18
+ #include <faiss/impl/RaBitQuantizer.h>
19
+
20
+ namespace faiss {
21
+
22
+ IndexIVFRaBitQ::IndexIVFRaBitQ(
23
+ Index* quantizer,
24
+ const size_t d,
25
+ const size_t nlist,
26
+ MetricType metric)
27
+ : IndexIVF(quantizer, d, nlist, 0, metric), rabitq(d, metric) {
28
+ code_size = rabitq.code_size;
29
+ invlists->code_size = code_size;
30
+ is_trained = false;
31
+
32
+ by_residual = true;
33
+ }
34
+
35
+ IndexIVFRaBitQ::IndexIVFRaBitQ() {
36
+ by_residual = true;
37
+ }
38
+
39
+ void IndexIVFRaBitQ::train_encoder(
40
+ idx_t n,
41
+ const float* x,
42
+ const idx_t* assign) {
43
+ rabitq.train(n, x);
44
+ }
45
+
46
+ void IndexIVFRaBitQ::encode_vectors(
47
+ idx_t n,
48
+ const float* x,
49
+ const idx_t* list_nos,
50
+ uint8_t* codes,
51
+ bool include_listnos) const {
52
+ size_t coarse_size = include_listnos ? coarse_code_size() : 0;
53
+ memset(codes, 0, (code_size + coarse_size) * n);
54
+
55
+ #pragma omp parallel if (n > 1000)
56
+ {
57
+ std::vector<float> centroid(d);
58
+
59
+ #pragma omp for
60
+ for (idx_t i = 0; i < n; i++) {
61
+ int64_t list_no = list_nos[i];
62
+ if (list_no >= 0) {
63
+ const float* xi = x + i * d;
64
+ uint8_t* code = codes + i * (code_size + coarse_size);
65
+
66
+ // both by_residual and !by_residual lead to the same code
67
+ quantizer->reconstruct(list_no, centroid.data());
68
+ rabitq.compute_codes_core(
69
+ xi, code + coarse_size, 1, centroid.data());
70
+
71
+ if (coarse_size) {
72
+ encode_listno(list_no, code);
73
+ }
74
+ }
75
+ }
76
+ }
77
+ }
78
+
79
+ void IndexIVFRaBitQ::add_core(
80
+ idx_t n,
81
+ const float* x,
82
+ const idx_t* xids,
83
+ const idx_t* precomputed_idx,
84
+ void* inverted_list_context) {
85
+ FAISS_THROW_IF_NOT(is_trained);
86
+
87
+ DirectMapAdd dm_add(direct_map, n, xids);
88
+
89
+ #pragma omp parallel
90
+ {
91
+ std::vector<uint8_t> one_code(code_size);
92
+ std::vector<float> centroid(d);
93
+
94
+ int nt = omp_get_num_threads();
95
+ int rank = omp_get_thread_num();
96
+
97
+ // each thread takes care of a subset of lists
98
+ for (size_t i = 0; i < n; i++) {
99
+ int64_t list_no = precomputed_idx[i];
100
+ if (list_no >= 0 && list_no % nt == rank) {
101
+ int64_t id = xids ? xids[i] : ntotal + i;
102
+
103
+ const float* xi = x + i * d;
104
+
105
+ // both by_residual and !by_residual lead to the same code
106
+ quantizer->reconstruct(list_no, centroid.data());
107
+ rabitq.compute_codes_core(
108
+ xi, one_code.data(), 1, centroid.data());
109
+
110
+ size_t ofs = invlists->add_entry(
111
+ list_no, id, one_code.data(), inverted_list_context);
112
+
113
+ dm_add.add(i, list_no, ofs);
114
+
115
+ } else if (rank == 0 && list_no == -1) {
116
+ dm_add.add(i, -1, 0);
117
+ }
118
+ }
119
+ }
120
+
121
+ ntotal += n;
122
+ }
123
+
124
+ struct RaBitInvertedListScanner : InvertedListScanner {
125
+ const IndexIVFRaBitQ& ivf_rabitq;
126
+
127
+ std::vector<float> reconstructed_centroid;
128
+ std::vector<float> query_vector;
129
+
130
+ std::unique_ptr<FlatCodesDistanceComputer> dc;
131
+
132
+ uint8_t qb = 0;
133
+
134
+ RaBitInvertedListScanner(
135
+ const IndexIVFRaBitQ& ivf_rabitq_in,
136
+ bool store_pairs = false,
137
+ const IDSelector* sel = nullptr,
138
+ uint8_t qb_in = 0)
139
+ : InvertedListScanner(store_pairs, sel),
140
+ ivf_rabitq{ivf_rabitq_in},
141
+ qb{qb_in} {
142
+ keep_max = is_similarity_metric(ivf_rabitq.metric_type);
143
+ code_size = ivf_rabitq.code_size;
144
+ }
145
+
146
+ /// from now on we handle this query.
147
+ void set_query(const float* query_vector_in) override {
148
+ query_vector.assign(query_vector_in, query_vector_in + ivf_rabitq.d);
149
+
150
+ internal_try_setup_dc();
151
+ }
152
+
153
+ /// following codes come from this inverted list
154
+ void set_list(idx_t list_no, float coarse_dis) override {
155
+ this->list_no = list_no;
156
+
157
+ reconstructed_centroid.resize(ivf_rabitq.d);
158
+ ivf_rabitq.quantizer->reconstruct(
159
+ list_no, reconstructed_centroid.data());
160
+
161
+ internal_try_setup_dc();
162
+ }
163
+
164
+ /// compute a single query-to-code distance
165
+ float distance_to_code(const uint8_t* code) const override {
166
+ return dc->distance_to_code(code);
167
+ }
168
+
169
+ void internal_try_setup_dc() {
170
+ if (!query_vector.empty() && !reconstructed_centroid.empty()) {
171
+ // both query_vector and centroid are available!
172
+ // set up DistanceComputer
173
+ dc.reset(ivf_rabitq.rabitq.get_distance_computer(
174
+ qb, reconstructed_centroid.data()));
175
+
176
+ dc->set_query(query_vector.data());
177
+ }
178
+ }
179
+ };
180
+
181
+ InvertedListScanner* IndexIVFRaBitQ::get_InvertedListScanner(
182
+ bool store_pairs,
183
+ const IDSelector* sel,
184
+ const IVFSearchParameters* search_params_in) const {
185
+ uint8_t used_qb = qb;
186
+ if (auto params = dynamic_cast<const IVFRaBitQSearchParameters*>(
187
+ search_params_in)) {
188
+ used_qb = params->qb;
189
+ }
190
+
191
+ return new RaBitInvertedListScanner(*this, store_pairs, sel, used_qb);
192
+ }
193
+
194
+ void IndexIVFRaBitQ::reconstruct_from_offset(
195
+ int64_t list_no,
196
+ int64_t offset,
197
+ float* recons) const {
198
+ const uint8_t* code = invlists->get_single_code(list_no, offset);
199
+
200
+ std::vector<float> centroid(d);
201
+ quantizer->reconstruct(list_no, centroid.data());
202
+
203
+ rabitq.decode_core(code, recons, 1, centroid.data());
204
+ }
205
+
206
+ void IndexIVFRaBitQ::sa_decode(idx_t n, const uint8_t* codes, float* x) const {
207
+ size_t coarse_size = coarse_code_size();
208
+
209
+ #pragma omp parallel
210
+ {
211
+ std::vector<float> centroid(d);
212
+
213
+ #pragma omp for
214
+ for (idx_t i = 0; i < n; i++) {
215
+ const uint8_t* code = codes + i * (code_size + coarse_size);
216
+ int64_t list_no = decode_listno(code);
217
+ float* xi = x + i * d;
218
+
219
+ quantizer->reconstruct(list_no, centroid.data());
220
+ rabitq.decode_core(code + coarse_size, xi, 1, centroid.data());
221
+ }
222
+ }
223
+ }
224
+
225
+ struct IVFRaBitDistanceComputer : DistanceComputer {
226
+ const float* q = nullptr;
227
+ const IndexIVFRaBitQ* parent = nullptr;
228
+
229
+ void set_query(const float* x) override;
230
+
231
+ float operator()(idx_t i) override;
232
+
233
+ float symmetric_dis(idx_t i, idx_t j) override;
234
+ };
235
+
236
+ void IVFRaBitDistanceComputer::set_query(const float* x) {
237
+ q = x;
238
+ }
239
+
240
+ float IVFRaBitDistanceComputer::operator()(idx_t i) {
241
+ // find the appropriate list
242
+ idx_t lo = parent->direct_map.get(i);
243
+ uint64_t list_no = lo_listno(lo);
244
+ uint64_t offset = lo_offset(lo);
245
+
246
+ const uint8_t* code = parent->invlists->get_single_code(list_no, offset);
247
+
248
+ // ok, we know the appropriate cluster that we need
249
+ std::vector<float> centroid(parent->d);
250
+ parent->quantizer->reconstruct(list_no, centroid.data());
251
+
252
+ // compute the distance
253
+ float distance = 0;
254
+
255
+ std::unique_ptr<FlatCodesDistanceComputer> dc(
256
+ parent->rabitq.get_distance_computer(parent->qb, centroid.data()));
257
+ dc->set_query(q);
258
+ distance = dc->distance_to_code(code);
259
+
260
+ // deallocate
261
+ parent->invlists->release_codes(list_no, code);
262
+
263
+ // done
264
+ return distance;
265
+ }
266
+
267
+ float IVFRaBitDistanceComputer::symmetric_dis(idx_t i, idx_t j) {
268
+ FAISS_THROW_MSG("Not implemented");
269
+ }
270
+
271
+ DistanceComputer* IndexIVFRaBitQ::get_distance_computer() const {
272
+ IVFRaBitDistanceComputer* dc = new IVFRaBitDistanceComputer;
273
+ dc->parent = this;
274
+ return dc;
275
+ }
276
+
277
+ } // namespace faiss
@@ -0,0 +1,70 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <cstddef>
11
+ #include <cstdint>
12
+
13
+ #include <faiss/Index.h>
14
+ #include <faiss/IndexIVF.h>
15
+
16
+ #include <faiss/impl/RaBitQuantizer.h>
17
+
18
+ namespace faiss {
19
+
20
+ struct IVFRaBitQSearchParameters : IVFSearchParameters {
21
+ uint8_t qb = 0;
22
+ };
23
+
24
+ // * by_residual is true, just by design
25
+ struct IndexIVFRaBitQ : IndexIVF {
26
+ RaBitQuantizer rabitq;
27
+
28
+ // the default number of bits to quantize a query with.
29
+ // use '0' to disable quantization and use raw fp32 values.
30
+ uint8_t qb = 0;
31
+
32
+ IndexIVFRaBitQ(
33
+ Index* quantizer,
34
+ const size_t d,
35
+ const size_t nlist,
36
+ MetricType metric = METRIC_L2);
37
+
38
+ IndexIVFRaBitQ();
39
+
40
+ void train_encoder(idx_t n, const float* x, const idx_t* assign) override;
41
+
42
+ void encode_vectors(
43
+ idx_t n,
44
+ const float* x,
45
+ const idx_t* list_nos,
46
+ uint8_t* codes,
47
+ bool include_listnos = false) const override;
48
+
49
+ void add_core(
50
+ idx_t n,
51
+ const float* x,
52
+ const idx_t* xids,
53
+ const idx_t* precomputed_idx,
54
+ void* inverted_list_context = nullptr) override;
55
+
56
+ InvertedListScanner* get_InvertedListScanner(
57
+ bool store_pairs,
58
+ const IDSelector* sel,
59
+ const IVFSearchParameters* params) const override;
60
+
61
+ void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
62
+ const override;
63
+
64
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
65
+
66
+ // unfortunately
67
+ DistanceComputer* get_distance_computer() const override;
68
+ };
69
+
70
+ } // namespace faiss
@@ -301,7 +301,8 @@ struct BuildScanner {
301
301
 
302
302
  InvertedListScanner* IndexIVFSpectralHash::get_InvertedListScanner(
303
303
  bool store_pairs,
304
- const IDSelector* sel) const {
304
+ const IDSelector* sel,
305
+ const IVFSearchParameters*) const {
305
306
  FAISS_THROW_IF_NOT(!sel);
306
307
  BuildScanner bs;
307
308
  return dispatch_HammingComputer(code_size, bs, this, store_pairs);
@@ -71,7 +71,8 @@ struct IndexIVFSpectralHash : IndexIVF {
71
71
 
72
72
  InvertedListScanner* get_InvertedListScanner(
73
73
  bool store_pairs,
74
- const IDSelector* sel) const override;
74
+ const IDSelector* sel,
75
+ const IVFSearchParameters* params) const override;
75
76
 
76
77
  /** replace the vector transform for an empty (and possibly untrained) index
77
78
  */
@@ -0,0 +1,148 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/IndexRaBitQ.h>
9
+
10
+ #include <faiss/impl/FaissAssert.h>
11
+ #include <faiss/impl/ResultHandler.h>
12
+
13
+ namespace faiss {
14
+
15
+ IndexRaBitQ::IndexRaBitQ() = default;
16
+
17
+ IndexRaBitQ::IndexRaBitQ(idx_t d, MetricType metric)
18
+ : IndexFlatCodes(0, d, metric), rabitq(d, metric) {
19
+ code_size = rabitq.code_size;
20
+
21
+ is_trained = false;
22
+ }
23
+
24
+ void IndexRaBitQ::train(idx_t n, const float* x) {
25
+ // compute a centroid
26
+ std::vector<float> centroid(d, 0);
27
+ for (size_t i = 0; i < n; i++) {
28
+ for (size_t j = 0; j < d; j++) {
29
+ centroid[j] += x[i * d + j];
30
+ }
31
+ }
32
+
33
+ if (n != 0) {
34
+ for (size_t j = 0; j < d; j++) {
35
+ centroid[j] /= (float)n;
36
+ }
37
+ }
38
+
39
+ center = std::move(centroid);
40
+
41
+ //
42
+ rabitq.train(n, x);
43
+ is_trained = true;
44
+ }
45
+
46
+ void IndexRaBitQ::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
47
+ FAISS_THROW_IF_NOT(is_trained);
48
+ rabitq.compute_codes_core(x, bytes, n, center.data());
49
+ }
50
+
51
+ void IndexRaBitQ::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
52
+ FAISS_THROW_IF_NOT(is_trained);
53
+ rabitq.decode_core(bytes, x, n, center.data());
54
+ }
55
+
56
+ FlatCodesDistanceComputer* IndexRaBitQ::get_FlatCodesDistanceComputer() const {
57
+ FlatCodesDistanceComputer* dc =
58
+ rabitq.get_distance_computer(qb, center.data());
59
+ dc->code_size = rabitq.code_size;
60
+ dc->codes = codes.data();
61
+ return dc;
62
+ }
63
+
64
+ FlatCodesDistanceComputer* IndexRaBitQ::get_quantized_distance_computer(
65
+ const uint8_t qb) const {
66
+ FlatCodesDistanceComputer* dc =
67
+ rabitq.get_distance_computer(qb, center.data());
68
+ dc->code_size = rabitq.code_size;
69
+ dc->codes = codes.data();
70
+ return dc;
71
+ }
72
+
73
+ namespace {
74
+
75
+ struct Run_search_with_dc_res {
76
+ using T = void;
77
+
78
+ uint8_t qb = 0;
79
+
80
+ template <class BlockResultHandler>
81
+ void f(BlockResultHandler& res, const IndexRaBitQ* index, const float* xq) {
82
+ size_t ntotal = index->ntotal;
83
+ using SingleResultHandler =
84
+ typename BlockResultHandler::SingleResultHandler;
85
+ const int d = index->d;
86
+
87
+ #pragma omp parallel // if (res.nq > 100)
88
+ {
89
+ std::unique_ptr<FlatCodesDistanceComputer> dc(
90
+ index->get_quantized_distance_computer(qb));
91
+ SingleResultHandler resi(res);
92
+ #pragma omp for
93
+ for (int64_t q = 0; q < res.nq; q++) {
94
+ resi.begin(q);
95
+ dc->set_query(xq + d * q);
96
+ for (size_t i = 0; i < ntotal; i++) {
97
+ if (res.is_in_selection(i)) {
98
+ float dis = (*dc)(i);
99
+ resi.add_result(dis, i);
100
+ }
101
+ }
102
+ resi.end();
103
+ }
104
+ }
105
+ }
106
+ };
107
+
108
+ } // namespace
109
+
110
+ void IndexRaBitQ::search(
111
+ idx_t n,
112
+ const float* x,
113
+ idx_t k,
114
+ float* distances,
115
+ idx_t* labels,
116
+ const SearchParameters* params_in) const {
117
+ uint8_t used_qb = qb;
118
+ if (auto params = dynamic_cast<const RaBitQSearchParameters*>(params_in)) {
119
+ used_qb = params->qb;
120
+ }
121
+
122
+ const IDSelector* sel = (params_in != nullptr) ? params_in->sel : nullptr;
123
+ Run_search_with_dc_res r;
124
+ r.qb = used_qb;
125
+
126
+ dispatch_knn_ResultHandler(
127
+ n, distances, labels, k, metric_type, sel, r, this, x);
128
+ }
129
+
130
+ void IndexRaBitQ::range_search(
131
+ idx_t n,
132
+ const float* x,
133
+ float radius,
134
+ RangeSearchResult* result,
135
+ const SearchParameters* params_in) const {
136
+ uint8_t used_qb = qb;
137
+ if (auto params = dynamic_cast<const RaBitQSearchParameters*>(params_in)) {
138
+ used_qb = params->qb;
139
+ }
140
+
141
+ const IDSelector* sel = (params_in != nullptr) ? params_in->sel : nullptr;
142
+ Run_search_with_dc_res r;
143
+ r.qb = used_qb;
144
+
145
+ dispatch_range_ResultHandler(result, radius, metric_type, sel, r, this, x);
146
+ }
147
+
148
+ } // namespace faiss
@@ -0,0 +1,65 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <faiss/IndexFlatCodes.h>
11
+ #include <faiss/impl/RaBitQuantizer.h>
12
+
13
+ namespace faiss {
14
+
15
+ struct RaBitQSearchParameters : SearchParameters {
16
+ uint8_t qb = 0;
17
+ };
18
+
19
+ struct IndexRaBitQ : IndexFlatCodes {
20
+ RaBitQuantizer rabitq;
21
+
22
+ // center of all points
23
+ std::vector<float> center;
24
+
25
+ // the default number of bits to quantize a query with.
26
+ // use '0' to disable quantization and use raw fp32 values.
27
+ uint8_t qb = 0;
28
+
29
+ IndexRaBitQ();
30
+
31
+ IndexRaBitQ(idx_t d, MetricType metric = METRIC_L2);
32
+
33
+ void train(idx_t n, const float* x) override;
34
+
35
+ void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
36
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
37
+
38
+ // returns a quantized-to-qb bits DC if qb > 0
39
+ // returns a default fp32-based DC if qb == 0
40
+ FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const override;
41
+
42
+ // returns a quantized-to-qb bits DC if qb_in > 0
43
+ // returns a default fp32-based DC if qb_in == 0
44
+ FlatCodesDistanceComputer* get_quantized_distance_computer(
45
+ const uint8_t qb_in) const;
46
+
47
+ // Don't rely on sa_decode(), bcz it is good for IP, but not for L2.
48
+ // As a result, use get_FlatCodesDistanceComputer() for the search.
49
+ void search(
50
+ idx_t n,
51
+ const float* x,
52
+ idx_t k,
53
+ float* distances,
54
+ idx_t* labels,
55
+ const SearchParameters* params = nullptr) const override;
56
+
57
+ void range_search(
58
+ idx_t n,
59
+ const float* x,
60
+ float radius,
61
+ RangeSearchResult* result,
62
+ const SearchParameters* params = nullptr) const override;
63
+ };
64
+
65
+ } // namespace faiss
@@ -254,7 +254,8 @@ void IndexIVFScalarQuantizer::add_core(
254
254
 
255
255
  InvertedListScanner* IndexIVFScalarQuantizer::get_InvertedListScanner(
256
256
  bool store_pairs,
257
- const IDSelector* sel) const {
257
+ const IDSelector* sel,
258
+ const IVFSearchParameters*) const {
258
259
  return sq.select_InvertedListScanner(
259
260
  metric_type, quantizer, store_pairs, sel, by_residual);
260
261
  }
@@ -96,7 +96,8 @@ struct IndexIVFScalarQuantizer : IndexIVF {
96
96
 
97
97
  InvertedListScanner* get_InvertedListScanner(
98
98
  bool store_pairs,
99
- const IDSelector* sel) const override;
99
+ const IDSelector* sel,
100
+ const IVFSearchParameters* params) const override;
100
101
 
101
102
  void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
102
103
  const override;