faiss 0.3.4 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/ext/faiss/ext.cpp +2 -3
  4. data/ext/faiss/index.cpp +13 -14
  5. data/ext/faiss/index_binary.cpp +2 -0
  6. data/ext/faiss/kmeans.cpp +2 -0
  7. data/ext/faiss/pca_matrix.cpp +2 -0
  8. data/ext/faiss/product_quantizer.cpp +2 -0
  9. data/ext/faiss/utils.cpp +3 -0
  10. data/lib/faiss/version.rb +1 -1
  11. data/vendor/faiss/faiss/AutoTune.cpp +11 -8
  12. data/vendor/faiss/faiss/Clustering.cpp +0 -16
  13. data/vendor/faiss/faiss/IVFlib.cpp +213 -0
  14. data/vendor/faiss/faiss/IVFlib.h +42 -0
  15. data/vendor/faiss/faiss/Index.h +1 -1
  16. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -7
  17. data/vendor/faiss/faiss/IndexBinaryFlat.h +2 -1
  18. data/vendor/faiss/faiss/IndexFlatCodes.cpp +1 -1
  19. data/vendor/faiss/faiss/IndexFlatCodes.h +4 -2
  20. data/vendor/faiss/faiss/IndexHNSW.cpp +13 -20
  21. data/vendor/faiss/faiss/IndexHNSW.h +1 -1
  22. data/vendor/faiss/faiss/IndexIVF.cpp +20 -3
  23. data/vendor/faiss/faiss/IndexIVF.h +5 -2
  24. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -1
  25. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +2 -1
  26. data/vendor/faiss/faiss/IndexIVFFlat.cpp +2 -1
  27. data/vendor/faiss/faiss/IndexIVFFlat.h +2 -1
  28. data/vendor/faiss/faiss/IndexIVFPQ.cpp +2 -1
  29. data/vendor/faiss/faiss/IndexIVFPQ.h +2 -1
  30. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +277 -0
  31. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +70 -0
  32. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +2 -1
  33. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  34. data/vendor/faiss/faiss/IndexRaBitQ.cpp +148 -0
  35. data/vendor/faiss/faiss/IndexRaBitQ.h +65 -0
  36. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -1
  37. data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -1
  38. data/vendor/faiss/faiss/clone_index.cpp +38 -3
  39. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +19 -0
  40. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +4 -11
  41. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +2 -1
  42. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +13 -3
  43. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +1 -1
  44. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +1 -1
  45. data/vendor/faiss/faiss/gpu/test/TestGpuIcmEncoder.cpp +112 -0
  46. data/vendor/faiss/faiss/impl/HNSW.cpp +35 -13
  47. data/vendor/faiss/faiss/impl/HNSW.h +5 -4
  48. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  49. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +519 -0
  50. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +78 -0
  51. data/vendor/faiss/faiss/impl/ResultHandler.h +2 -2
  52. data/vendor/faiss/faiss/impl/code_distance/code_distance-sve.h +3 -4
  53. data/vendor/faiss/faiss/impl/index_read.cpp +220 -25
  54. data/vendor/faiss/faiss/impl/index_write.cpp +29 -0
  55. data/vendor/faiss/faiss/impl/io.h +2 -2
  56. data/vendor/faiss/faiss/impl/io_macros.h +2 -0
  57. data/vendor/faiss/faiss/impl/mapped_io.cpp +313 -0
  58. data/vendor/faiss/faiss/impl/mapped_io.h +51 -0
  59. data/vendor/faiss/faiss/impl/maybe_owned_vector.h +316 -0
  60. data/vendor/faiss/faiss/impl/platform_macros.h +7 -3
  61. data/vendor/faiss/faiss/impl/simd_result_handlers.h +1 -1
  62. data/vendor/faiss/faiss/impl/zerocopy_io.cpp +67 -0
  63. data/vendor/faiss/faiss/impl/zerocopy_io.h +32 -0
  64. data/vendor/faiss/faiss/index_factory.cpp +16 -5
  65. data/vendor/faiss/faiss/index_io.h +4 -0
  66. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +3 -3
  67. data/vendor/faiss/faiss/invlists/InvertedLists.h +5 -3
  68. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +3 -3
  69. data/vendor/faiss/faiss/python/python_callbacks.cpp +24 -0
  70. data/vendor/faiss/faiss/python/python_callbacks.h +22 -0
  71. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +30 -12
  72. data/vendor/faiss/faiss/utils/hamming.cpp +45 -21
  73. data/vendor/faiss/faiss/utils/hamming.h +7 -3
  74. data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +1 -1
  75. data/vendor/faiss/faiss/utils/utils.cpp +4 -4
  76. data/vendor/faiss/faiss/utils/utils.h +3 -3
  77. metadata +16 -4
@@ -8,9 +8,7 @@
8
8
  #include <faiss/IndexHNSW.h>
9
9
 
10
10
  #include <omp.h>
11
- #include <cassert>
12
11
  #include <cinttypes>
13
- #include <cmath>
14
12
  #include <cstdio>
15
13
  #include <cstdlib>
16
14
  #include <cstring>
@@ -124,7 +122,7 @@ void hnsw_add_vertices(
124
122
  int i1 = n;
125
123
 
126
124
  for (int pt_level = hist.size() - 1;
127
- pt_level >= !index_hnsw.init_level0;
125
+ pt_level >= int(!index_hnsw.init_level0);
128
126
  pt_level--) {
129
127
  int i0 = i1 - hist[pt_level];
130
128
 
@@ -212,7 +210,9 @@ IndexHNSW::IndexHNSW(int d, int M, MetricType metric)
212
210
  : Index(d, metric), hnsw(M) {}
213
211
 
214
212
  IndexHNSW::IndexHNSW(Index* storage, int M)
215
- : Index(storage->d, storage->metric_type), hnsw(M), storage(storage) {}
213
+ : Index(storage->d, storage->metric_type), hnsw(M), storage(storage) {
214
+ metric_arg = storage->metric_arg;
215
+ }
216
216
 
217
217
  IndexHNSW::~IndexHNSW() {
218
218
  if (own_fields) {
@@ -237,19 +237,19 @@ void hnsw_search(
237
237
  idx_t n,
238
238
  const float* x,
239
239
  BlockResultHandler& bres,
240
- const SearchParameters* params_in) {
240
+ const SearchParameters* params) {
241
241
  FAISS_THROW_IF_NOT_MSG(
242
242
  index->storage,
243
243
  "No storage index, please use IndexHNSWFlat (or variants) "
244
244
  "instead of IndexHNSW directly");
245
- const SearchParametersHNSW* params = nullptr;
246
245
  const HNSW& hnsw = index->hnsw;
247
246
 
248
247
  int efSearch = hnsw.efSearch;
249
- if (params_in) {
250
- params = dynamic_cast<const SearchParametersHNSW*>(params_in);
251
- FAISS_THROW_IF_NOT_MSG(params, "params type invalid");
252
- efSearch = params->efSearch;
248
+ if (params) {
249
+ if (const SearchParametersHNSW* hnsw_params =
250
+ dynamic_cast<const SearchParametersHNSW*>(params)) {
251
+ efSearch = hnsw_params->efSearch;
252
+ }
253
253
  }
254
254
  size_t n1 = 0, n2 = 0, ndis = 0, nhops = 0;
255
255
 
@@ -294,13 +294,13 @@ void IndexHNSW::search(
294
294
  idx_t k,
295
295
  float* distances,
296
296
  idx_t* labels,
297
- const SearchParameters* params_in) const {
297
+ const SearchParameters* params) const {
298
298
  FAISS_THROW_IF_NOT(k > 0);
299
299
 
300
300
  using RH = HeapBlockResultHandler<HNSW::C>;
301
301
  RH bres(n, distances, labels, k);
302
302
 
303
- hnsw_search(this, n, x, bres, params_in);
303
+ hnsw_search(this, n, x, bres, params);
304
304
 
305
305
  if (is_similarity_metric(this->metric_type)) {
306
306
  // we need to revert the negated distances
@@ -408,17 +408,10 @@ void IndexHNSW::search_level_0(
408
408
  idx_t* labels,
409
409
  int nprobe,
410
410
  int search_type,
411
- const SearchParameters* params_in) const {
411
+ const SearchParameters* params) const {
412
412
  FAISS_THROW_IF_NOT(k > 0);
413
413
  FAISS_THROW_IF_NOT(nprobe > 0);
414
414
 
415
- const SearchParametersHNSW* params = nullptr;
416
-
417
- if (params_in) {
418
- params = dynamic_cast<const SearchParametersHNSW*>(params_in);
419
- FAISS_THROW_IF_NOT_MSG(params, "params type invalid");
420
- }
421
-
422
415
  storage_idx_t ntotal = hnsw.levels.size();
423
416
 
424
417
  using RH = HeapBlockResultHandler<HNSW::C>;
@@ -138,7 +138,7 @@ struct IndexHNSWPQ : IndexHNSW {
138
138
  void train(idx_t n, const float* x) override;
139
139
  };
140
140
 
141
- /** SQ index topped with with a HNSW structure to access elements
141
+ /** SQ index topped with a HNSW structure to access elements
142
142
  * more efficiently.
143
143
  */
144
144
  struct IndexHNSWSQ : IndexHNSW {
@@ -455,7 +455,7 @@ void IndexIVF::search_preassigned(
455
455
  #pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis, nheap)
456
456
  {
457
457
  std::unique_ptr<InvertedListScanner> scanner(
458
- get_InvertedListScanner(store_pairs, sel));
458
+ get_InvertedListScanner(store_pairs, sel, params));
459
459
 
460
460
  /*****************************************************
461
461
  * Depending on parallel_mode, there are two possible ways
@@ -796,7 +796,7 @@ void IndexIVF::range_search_preassigned(
796
796
  {
797
797
  RangeSearchPartialResult pres(result);
798
798
  std::unique_ptr<InvertedListScanner> scanner(
799
- get_InvertedListScanner(store_pairs, sel));
799
+ get_InvertedListScanner(store_pairs, sel, params));
800
800
  FAISS_THROW_IF_NOT(scanner.get());
801
801
  all_pres[omp_get_thread_num()] = &pres;
802
802
 
@@ -912,7 +912,8 @@ void IndexIVF::range_search_preassigned(
912
912
 
913
913
  InvertedListScanner* IndexIVF::get_InvertedListScanner(
914
914
  bool /*store_pairs*/,
915
- const IDSelector* /* sel */) const {
915
+ const IDSelector* /* sel */,
916
+ const IVFSearchParameters* /* params */) const {
916
917
  FAISS_THROW_MSG("get_InvertedListScanner not implemented");
917
918
  }
918
919
 
@@ -1290,6 +1291,14 @@ size_t InvertedListScanner::scan_codes(
1290
1291
 
1291
1292
  if (!keep_max) {
1292
1293
  for (size_t j = 0; j < list_size; j++) {
1294
+ if (sel != nullptr) {
1295
+ int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
1296
+ if (!sel->is_member(id)) {
1297
+ codes += code_size;
1298
+ continue;
1299
+ }
1300
+ }
1301
+
1293
1302
  float dis = distance_to_code(codes);
1294
1303
  if (dis < simi[0]) {
1295
1304
  int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
@@ -1300,6 +1309,14 @@ size_t InvertedListScanner::scan_codes(
1300
1309
  }
1301
1310
  } else {
1302
1311
  for (size_t j = 0; j < list_size; j++) {
1312
+ if (sel != nullptr) {
1313
+ int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
1314
+ if (!sel->is_member(id)) {
1315
+ codes += code_size;
1316
+ continue;
1317
+ }
1318
+ }
1319
+
1303
1320
  float dis = distance_to_code(codes);
1304
1321
  if (dis > simi[0]) {
1305
1322
  int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
@@ -312,11 +312,14 @@ struct IndexIVF : Index, IndexIVFInterface {
312
312
 
313
313
  /** Get a scanner for this index (store_pairs means ignore labels)
314
314
  *
315
- * The default search implementation uses this to compute the distances
315
+ * The default search implementation uses this to compute the distances.
316
+ * Use sel instead of params->sel, because sel is initialized with
317
+ * params->sel, but may get overridden by IndexIVF's internal logic.
316
318
  */
317
319
  virtual InvertedListScanner* get_InvertedListScanner(
318
320
  bool store_pairs = false,
319
- const IDSelector* sel = nullptr) const;
321
+ const IDSelector* sel = nullptr,
322
+ const IVFSearchParameters* params = nullptr) const;
320
323
 
321
324
  /** reconstruct a vector. Works only if maintain_direct_map is set to 1 or 2
322
325
  */
@@ -253,7 +253,8 @@ struct AQInvertedListScannerLUT : AQInvertedListScanner {
253
253
 
254
254
  InvertedListScanner* IndexIVFAdditiveQuantizer::get_InvertedListScanner(
255
255
  bool store_pairs,
256
- const IDSelector* sel) const {
256
+ const IDSelector* sel,
257
+ const IVFSearchParameters*) const {
257
258
  FAISS_THROW_IF_NOT(!sel);
258
259
  if (metric_type == METRIC_INNER_PRODUCT) {
259
260
  if (aq->search_type == AdditiveQuantizer::ST_decompress) {
@@ -52,7 +52,8 @@ struct IndexIVFAdditiveQuantizer : IndexIVF {
52
52
 
53
53
  InvertedListScanner* get_InvertedListScanner(
54
54
  bool store_pairs,
55
- const IDSelector* sel) const override;
55
+ const IDSelector* sel,
56
+ const IVFSearchParameters* params) const override;
56
57
 
57
58
  void sa_decode(idx_t n, const uint8_t* codes, float* x) const override;
58
59
 
@@ -223,7 +223,8 @@ InvertedListScanner* get_InvertedListScanner1(
223
223
 
224
224
  InvertedListScanner* IndexIVFFlat::get_InvertedListScanner(
225
225
  bool store_pairs,
226
- const IDSelector* sel) const {
226
+ const IDSelector* sel,
227
+ const IVFSearchParameters*) const {
227
228
  if (sel) {
228
229
  return get_InvertedListScanner1<true>(this, store_pairs, sel);
229
230
  } else {
@@ -44,7 +44,8 @@ struct IndexIVFFlat : IndexIVF {
44
44
 
45
45
  InvertedListScanner* get_InvertedListScanner(
46
46
  bool store_pairs,
47
- const IDSelector* sel) const override;
47
+ const IDSelector* sel,
48
+ const IVFSearchParameters* params) const override;
48
49
 
49
50
  void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
50
51
  const override;
@@ -1321,7 +1321,8 @@ InvertedListScanner* get_InvertedListScanner2(
1321
1321
 
1322
1322
  InvertedListScanner* IndexIVFPQ::get_InvertedListScanner(
1323
1323
  bool store_pairs,
1324
- const IDSelector* sel) const {
1324
+ const IDSelector* sel,
1325
+ const IVFSearchParameters*) const {
1325
1326
  if (sel) {
1326
1327
  return get_InvertedListScanner2<true>(*this, store_pairs, sel);
1327
1328
  } else {
@@ -134,7 +134,8 @@ struct IndexIVFPQ : IndexIVF {
134
134
 
135
135
  InvertedListScanner* get_InvertedListScanner(
136
136
  bool store_pairs,
137
- const IDSelector* sel) const override;
137
+ const IDSelector* sel,
138
+ const IVFSearchParameters* params) const override;
138
139
 
139
140
  /// build precomputed table
140
141
  void precompute_table();
@@ -0,0 +1,277 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/IndexIVFRaBitQ.h>
9
+
10
+ #include <omp.h>
11
+
12
+ #include <cstddef>
13
+ #include <cstdint>
14
+ #include <memory>
15
+ #include <vector>
16
+
17
+ #include <faiss/impl/FaissAssert.h>
18
+ #include <faiss/impl/RaBitQuantizer.h>
19
+
20
+ namespace faiss {
21
+
22
+ IndexIVFRaBitQ::IndexIVFRaBitQ(
23
+ Index* quantizer,
24
+ const size_t d,
25
+ const size_t nlist,
26
+ MetricType metric)
27
+ : IndexIVF(quantizer, d, nlist, 0, metric), rabitq(d, metric) {
28
+ code_size = rabitq.code_size;
29
+ invlists->code_size = code_size;
30
+ is_trained = false;
31
+
32
+ by_residual = true;
33
+ }
34
+
35
+ IndexIVFRaBitQ::IndexIVFRaBitQ() {
36
+ by_residual = true;
37
+ }
38
+
39
+ void IndexIVFRaBitQ::train_encoder(
40
+ idx_t n,
41
+ const float* x,
42
+ const idx_t* assign) {
43
+ rabitq.train(n, x);
44
+ }
45
+
46
+ void IndexIVFRaBitQ::encode_vectors(
47
+ idx_t n,
48
+ const float* x,
49
+ const idx_t* list_nos,
50
+ uint8_t* codes,
51
+ bool include_listnos) const {
52
+ size_t coarse_size = include_listnos ? coarse_code_size() : 0;
53
+ memset(codes, 0, (code_size + coarse_size) * n);
54
+
55
+ #pragma omp parallel if (n > 1000)
56
+ {
57
+ std::vector<float> centroid(d);
58
+
59
+ #pragma omp for
60
+ for (idx_t i = 0; i < n; i++) {
61
+ int64_t list_no = list_nos[i];
62
+ if (list_no >= 0) {
63
+ const float* xi = x + i * d;
64
+ uint8_t* code = codes + i * (code_size + coarse_size);
65
+
66
+ // both by_residual and !by_residual lead to the same code
67
+ quantizer->reconstruct(list_no, centroid.data());
68
+ rabitq.compute_codes_core(
69
+ xi, code + coarse_size, 1, centroid.data());
70
+
71
+ if (coarse_size) {
72
+ encode_listno(list_no, code);
73
+ }
74
+ }
75
+ }
76
+ }
77
+ }
78
+
79
+ void IndexIVFRaBitQ::add_core(
80
+ idx_t n,
81
+ const float* x,
82
+ const idx_t* xids,
83
+ const idx_t* precomputed_idx,
84
+ void* inverted_list_context) {
85
+ FAISS_THROW_IF_NOT(is_trained);
86
+
87
+ DirectMapAdd dm_add(direct_map, n, xids);
88
+
89
+ #pragma omp parallel
90
+ {
91
+ std::vector<uint8_t> one_code(code_size);
92
+ std::vector<float> centroid(d);
93
+
94
+ int nt = omp_get_num_threads();
95
+ int rank = omp_get_thread_num();
96
+
97
+ // each thread takes care of a subset of lists
98
+ for (size_t i = 0; i < n; i++) {
99
+ int64_t list_no = precomputed_idx[i];
100
+ if (list_no >= 0 && list_no % nt == rank) {
101
+ int64_t id = xids ? xids[i] : ntotal + i;
102
+
103
+ const float* xi = x + i * d;
104
+
105
+ // both by_residual and !by_residual lead to the same code
106
+ quantizer->reconstruct(list_no, centroid.data());
107
+ rabitq.compute_codes_core(
108
+ xi, one_code.data(), 1, centroid.data());
109
+
110
+ size_t ofs = invlists->add_entry(
111
+ list_no, id, one_code.data(), inverted_list_context);
112
+
113
+ dm_add.add(i, list_no, ofs);
114
+
115
+ } else if (rank == 0 && list_no == -1) {
116
+ dm_add.add(i, -1, 0);
117
+ }
118
+ }
119
+ }
120
+
121
+ ntotal += n;
122
+ }
123
+
124
+ struct RaBitInvertedListScanner : InvertedListScanner {
125
+ const IndexIVFRaBitQ& ivf_rabitq;
126
+
127
+ std::vector<float> reconstructed_centroid;
128
+ std::vector<float> query_vector;
129
+
130
+ std::unique_ptr<FlatCodesDistanceComputer> dc;
131
+
132
+ uint8_t qb = 0;
133
+
134
+ RaBitInvertedListScanner(
135
+ const IndexIVFRaBitQ& ivf_rabitq_in,
136
+ bool store_pairs = false,
137
+ const IDSelector* sel = nullptr,
138
+ uint8_t qb_in = 0)
139
+ : InvertedListScanner(store_pairs, sel),
140
+ ivf_rabitq{ivf_rabitq_in},
141
+ qb{qb_in} {
142
+ keep_max = is_similarity_metric(ivf_rabitq.metric_type);
143
+ code_size = ivf_rabitq.code_size;
144
+ }
145
+
146
+ /// from now on we handle this query.
147
+ void set_query(const float* query_vector_in) override {
148
+ query_vector.assign(query_vector_in, query_vector_in + ivf_rabitq.d);
149
+
150
+ internal_try_setup_dc();
151
+ }
152
+
153
+ /// following codes come from this inverted list
154
+ void set_list(idx_t list_no, float coarse_dis) override {
155
+ this->list_no = list_no;
156
+
157
+ reconstructed_centroid.resize(ivf_rabitq.d);
158
+ ivf_rabitq.quantizer->reconstruct(
159
+ list_no, reconstructed_centroid.data());
160
+
161
+ internal_try_setup_dc();
162
+ }
163
+
164
+ /// compute a single query-to-code distance
165
+ float distance_to_code(const uint8_t* code) const override {
166
+ return dc->distance_to_code(code);
167
+ }
168
+
169
+ void internal_try_setup_dc() {
170
+ if (!query_vector.empty() && !reconstructed_centroid.empty()) {
171
+ // both query_vector and centroid are available!
172
+ // set up DistanceComputer
173
+ dc.reset(ivf_rabitq.rabitq.get_distance_computer(
174
+ qb, reconstructed_centroid.data()));
175
+
176
+ dc->set_query(query_vector.data());
177
+ }
178
+ }
179
+ };
180
+
181
+ InvertedListScanner* IndexIVFRaBitQ::get_InvertedListScanner(
182
+ bool store_pairs,
183
+ const IDSelector* sel,
184
+ const IVFSearchParameters* search_params_in) const {
185
+ uint8_t used_qb = qb;
186
+ if (auto params = dynamic_cast<const IVFRaBitQSearchParameters*>(
187
+ search_params_in)) {
188
+ used_qb = params->qb;
189
+ }
190
+
191
+ return new RaBitInvertedListScanner(*this, store_pairs, sel, used_qb);
192
+ }
193
+
194
+ void IndexIVFRaBitQ::reconstruct_from_offset(
195
+ int64_t list_no,
196
+ int64_t offset,
197
+ float* recons) const {
198
+ const uint8_t* code = invlists->get_single_code(list_no, offset);
199
+
200
+ std::vector<float> centroid(d);
201
+ quantizer->reconstruct(list_no, centroid.data());
202
+
203
+ rabitq.decode_core(code, recons, 1, centroid.data());
204
+ }
205
+
206
+ void IndexIVFRaBitQ::sa_decode(idx_t n, const uint8_t* codes, float* x) const {
207
+ size_t coarse_size = coarse_code_size();
208
+
209
+ #pragma omp parallel
210
+ {
211
+ std::vector<float> centroid(d);
212
+
213
+ #pragma omp for
214
+ for (idx_t i = 0; i < n; i++) {
215
+ const uint8_t* code = codes + i * (code_size + coarse_size);
216
+ int64_t list_no = decode_listno(code);
217
+ float* xi = x + i * d;
218
+
219
+ quantizer->reconstruct(list_no, centroid.data());
220
+ rabitq.decode_core(code + coarse_size, xi, 1, centroid.data());
221
+ }
222
+ }
223
+ }
224
+
225
+ struct IVFRaBitDistanceComputer : DistanceComputer {
226
+ const float* q = nullptr;
227
+ const IndexIVFRaBitQ* parent = nullptr;
228
+
229
+ void set_query(const float* x) override;
230
+
231
+ float operator()(idx_t i) override;
232
+
233
+ float symmetric_dis(idx_t i, idx_t j) override;
234
+ };
235
+
236
+ void IVFRaBitDistanceComputer::set_query(const float* x) {
237
+ q = x;
238
+ }
239
+
240
+ float IVFRaBitDistanceComputer::operator()(idx_t i) {
241
+ // find the appropriate list
242
+ idx_t lo = parent->direct_map.get(i);
243
+ uint64_t list_no = lo_listno(lo);
244
+ uint64_t offset = lo_offset(lo);
245
+
246
+ const uint8_t* code = parent->invlists->get_single_code(list_no, offset);
247
+
248
+ // ok, we know the appropriate cluster that we need
249
+ std::vector<float> centroid(parent->d);
250
+ parent->quantizer->reconstruct(list_no, centroid.data());
251
+
252
+ // compute the distance
253
+ float distance = 0;
254
+
255
+ std::unique_ptr<FlatCodesDistanceComputer> dc(
256
+ parent->rabitq.get_distance_computer(parent->qb, centroid.data()));
257
+ dc->set_query(q);
258
+ distance = dc->distance_to_code(code);
259
+
260
+ // deallocate
261
+ parent->invlists->release_codes(list_no, code);
262
+
263
+ // done
264
+ return distance;
265
+ }
266
+
267
+ float IVFRaBitDistanceComputer::symmetric_dis(idx_t i, idx_t j) {
268
+ FAISS_THROW_MSG("Not implemented");
269
+ }
270
+
271
+ DistanceComputer* IndexIVFRaBitQ::get_distance_computer() const {
272
+ IVFRaBitDistanceComputer* dc = new IVFRaBitDistanceComputer;
273
+ dc->parent = this;
274
+ return dc;
275
+ }
276
+
277
+ } // namespace faiss
@@ -0,0 +1,70 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <cstddef>
11
+ #include <cstdint>
12
+
13
+ #include <faiss/Index.h>
14
+ #include <faiss/IndexIVF.h>
15
+
16
+ #include <faiss/impl/RaBitQuantizer.h>
17
+
18
+ namespace faiss {
19
+
20
+ struct IVFRaBitQSearchParameters : IVFSearchParameters {
21
+ uint8_t qb = 0;
22
+ };
23
+
24
+ // * by_residual is true, just by design
25
+ struct IndexIVFRaBitQ : IndexIVF {
26
+ RaBitQuantizer rabitq;
27
+
28
+ // the default number of bits to quantize a query with.
29
+ // use '0' to disable quantization and use raw fp32 values.
30
+ uint8_t qb = 0;
31
+
32
+ IndexIVFRaBitQ(
33
+ Index* quantizer,
34
+ const size_t d,
35
+ const size_t nlist,
36
+ MetricType metric = METRIC_L2);
37
+
38
+ IndexIVFRaBitQ();
39
+
40
+ void train_encoder(idx_t n, const float* x, const idx_t* assign) override;
41
+
42
+ void encode_vectors(
43
+ idx_t n,
44
+ const float* x,
45
+ const idx_t* list_nos,
46
+ uint8_t* codes,
47
+ bool include_listnos = false) const override;
48
+
49
+ void add_core(
50
+ idx_t n,
51
+ const float* x,
52
+ const idx_t* xids,
53
+ const idx_t* precomputed_idx,
54
+ void* inverted_list_context = nullptr) override;
55
+
56
+ InvertedListScanner* get_InvertedListScanner(
57
+ bool store_pairs,
58
+ const IDSelector* sel,
59
+ const IVFSearchParameters* params) const override;
60
+
61
+ void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
62
+ const override;
63
+
64
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
65
+
66
+ // unfortunately
67
+ DistanceComputer* get_distance_computer() const override;
68
+ };
69
+
70
+ } // namespace faiss
@@ -301,7 +301,8 @@ struct BuildScanner {
301
301
 
302
302
  InvertedListScanner* IndexIVFSpectralHash::get_InvertedListScanner(
303
303
  bool store_pairs,
304
- const IDSelector* sel) const {
304
+ const IDSelector* sel,
305
+ const IVFSearchParameters*) const {
305
306
  FAISS_THROW_IF_NOT(!sel);
306
307
  BuildScanner bs;
307
308
  return dispatch_HammingComputer(code_size, bs, this, store_pairs);
@@ -71,7 +71,8 @@ struct IndexIVFSpectralHash : IndexIVF {
71
71
 
72
72
  InvertedListScanner* get_InvertedListScanner(
73
73
  bool store_pairs,
74
- const IDSelector* sel) const override;
74
+ const IDSelector* sel,
75
+ const IVFSearchParameters* params) const override;
75
76
 
76
77
  /** replace the vector transform for an empty (and possibly untrained) index
77
78
  */