faiss 0.2.3 → 0.2.4

Sign up to get free protection for your applications and to get access to all the features.
Files changed (63) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/LICENSE.txt +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/vendor/faiss/faiss/Clustering.cpp +32 -0
  6. data/vendor/faiss/faiss/Clustering.h +14 -0
  7. data/vendor/faiss/faiss/Index.h +1 -1
  8. data/vendor/faiss/faiss/Index2Layer.cpp +19 -92
  9. data/vendor/faiss/faiss/Index2Layer.h +2 -16
  10. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  11. data/vendor/faiss/faiss/{IndexResidual.h → IndexAdditiveQuantizer.h} +101 -58
  12. data/vendor/faiss/faiss/IndexFlat.cpp +22 -52
  13. data/vendor/faiss/faiss/IndexFlat.h +9 -15
  14. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  15. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  16. data/vendor/faiss/faiss/IndexIVF.cpp +79 -7
  17. data/vendor/faiss/faiss/IndexIVF.h +25 -7
  18. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  19. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  20. data/vendor/faiss/faiss/IndexIVFFlat.cpp +9 -12
  21. data/vendor/faiss/faiss/IndexIVFPQ.cpp +5 -4
  22. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
  23. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +60 -39
  24. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +21 -6
  25. data/vendor/faiss/faiss/IndexLSH.cpp +4 -30
  26. data/vendor/faiss/faiss/IndexLSH.h +2 -15
  27. data/vendor/faiss/faiss/IndexNNDescent.cpp +0 -2
  28. data/vendor/faiss/faiss/IndexNSG.cpp +0 -2
  29. data/vendor/faiss/faiss/IndexPQ.cpp +2 -51
  30. data/vendor/faiss/faiss/IndexPQ.h +2 -17
  31. data/vendor/faiss/faiss/IndexRefine.cpp +28 -0
  32. data/vendor/faiss/faiss/IndexRefine.h +10 -0
  33. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -28
  34. data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -16
  35. data/vendor/faiss/faiss/VectorTransform.cpp +2 -1
  36. data/vendor/faiss/faiss/VectorTransform.h +3 -0
  37. data/vendor/faiss/faiss/clone_index.cpp +3 -2
  38. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -2
  39. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  40. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +257 -24
  41. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +69 -9
  42. data/vendor/faiss/faiss/impl/HNSW.cpp +10 -5
  43. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +393 -210
  44. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +100 -28
  45. data/vendor/faiss/faiss/impl/NSG.cpp +0 -3
  46. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  47. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +357 -47
  48. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +65 -7
  49. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +12 -19
  50. data/vendor/faiss/faiss/impl/index_read.cpp +102 -19
  51. data/vendor/faiss/faiss/impl/index_write.cpp +66 -16
  52. data/vendor/faiss/faiss/impl/io.cpp +1 -1
  53. data/vendor/faiss/faiss/impl/io_macros.h +20 -0
  54. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  55. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  56. data/vendor/faiss/faiss/index_factory.cpp +585 -414
  57. data/vendor/faiss/faiss/index_factory.h +3 -0
  58. data/vendor/faiss/faiss/utils/distances.cpp +4 -2
  59. data/vendor/faiss/faiss/utils/distances.h +36 -3
  60. data/vendor/faiss/faiss/utils/distances_simd.cpp +50 -0
  61. data/vendor/faiss/faiss/utils/utils.h +1 -1
  62. metadata +12 -5
  63. data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
@@ -13,6 +13,8 @@
13
13
  #include <algorithm>
14
14
  #include <memory>
15
15
 
16
+ #include <faiss/IndexLSH.h>
17
+ #include <faiss/IndexPreTransform.h>
16
18
  #include <faiss/VectorTransform.h>
17
19
  #include <faiss/impl/AuxIndexStructures.h>
18
20
  #include <faiss/impl/FaissAssert.h>
@@ -31,7 +33,6 @@ IndexIVFSpectralHash::IndexIVFSpectralHash(
31
33
  nbit(nbit),
32
34
  period(period),
33
35
  threshold_type(Thresh_global) {
34
- FAISS_THROW_IF_NOT(code_size % 4 == 0);
35
36
  RandomRotationMatrix* rr = new RandomRotationMatrix(d, nbit);
36
37
  rr->init(1234);
37
38
  vt = rr;
@@ -151,8 +152,8 @@ void binarize_with_freq(
151
152
  memset(codes, 0, (nbit + 7) / 8);
152
153
  for (size_t i = 0; i < nbit; i++) {
153
154
  float xf = (x[i] - c[i]);
154
- int xi = int(floor(xf * freq));
155
- int bit = xi & 1;
155
+ int64_t xi = int64_t(floor(xf * freq));
156
+ int64_t bit = xi & 1;
156
157
  codes[i >> 3] |= bit << (i & 7);
157
158
  }
158
159
  }
@@ -167,35 +168,33 @@ void IndexIVFSpectralHash::encode_vectors(
167
168
  bool include_listnos) const {
168
169
  FAISS_THROW_IF_NOT(is_trained);
169
170
  float freq = 2.0 / period;
170
-
171
- FAISS_THROW_IF_NOT_MSG(!include_listnos, "listnos encoding not supported");
171
+ size_t coarse_size = include_listnos ? coarse_code_size() : 0;
172
172
 
173
173
  // transform with vt
174
174
  std::unique_ptr<float[]> x(vt->apply(n, x_in));
175
175
 
176
- #pragma omp parallel
177
- {
178
- std::vector<float> zero(nbit);
176
+ std::vector<float> zero(nbit);
179
177
 
180
- // each thread takes care of a subset of lists
181
178
  #pragma omp for
182
- for (idx_t i = 0; i < n; i++) {
183
- int64_t list_no = list_nos[i];
184
-
185
- if (list_no >= 0) {
186
- const float* c;
187
- if (threshold_type == Thresh_global) {
188
- c = zero.data();
189
- } else {
190
- c = trained.data() + list_no * nbit;
191
- }
192
- binarize_with_freq(
193
- nbit,
194
- freq,
195
- x.get() + i * nbit,
196
- c,
197
- codes + i * code_size);
179
+ for (idx_t i = 0; i < n; i++) {
180
+ int64_t list_no = list_nos[i];
181
+ uint8_t* code = codes + i * (code_size + coarse_size);
182
+
183
+ if (list_no >= 0) {
184
+ if (coarse_size) {
185
+ encode_listno(list_no, code);
186
+ }
187
+ const float* c;
188
+
189
+ if (threshold_type == Thresh_global) {
190
+ c = zero.data();
191
+ } else {
192
+ c = trained.data() + list_no * nbit;
198
193
  }
194
+ binarize_with_freq(
195
+ nbit, freq, x.get() + i * nbit, c, code + coarse_size);
196
+ } else {
197
+ memset(code, 0, code_size + coarse_size);
199
198
  }
200
199
  }
201
200
  }
@@ -206,9 +205,7 @@ template <class HammingComputer>
206
205
  struct IVFScanner : InvertedListScanner {
207
206
  // copied from index structure
208
207
  const IndexIVFSpectralHash* index;
209
- size_t code_size;
210
208
  size_t nbit;
211
- bool store_pairs;
212
209
 
213
210
  float period, freq;
214
211
  std::vector<float> q;
@@ -220,15 +217,16 @@ struct IVFScanner : InvertedListScanner {
220
217
 
221
218
  IVFScanner(const IndexIVFSpectralHash* index, bool store_pairs)
222
219
  : index(index),
223
- code_size(index->code_size),
224
220
  nbit(index->nbit),
225
- store_pairs(store_pairs),
226
221
  period(index->period),
227
222
  freq(2.0 / index->period),
228
223
  q(nbit),
229
224
  zero(nbit),
230
- qcode(code_size),
231
- hc(qcode.data(), code_size) {}
225
+ qcode(index->code_size),
226
+ hc(qcode.data(), index->code_size) {
227
+ this->store_pairs = store_pairs;
228
+ this->code_size = index->code_size;
229
+ }
232
230
 
233
231
  void set_query(const float* query) override {
234
232
  FAISS_THROW_IF_NOT(query);
@@ -241,8 +239,6 @@ struct IVFScanner : InvertedListScanner {
241
239
  }
242
240
  }
243
241
 
244
- idx_t list_no;
245
-
246
242
  void set_list(idx_t list_no, float /*coarse_dis*/) override {
247
243
  this->list_no = list_no;
248
244
  if (index->threshold_type != IndexIVFSpectralHash::Thresh_global) {
@@ -310,13 +306,38 @@ InvertedListScanner* IndexIVFSpectralHash::get_InvertedListScanner(
310
306
  HANDLE_CODE_SIZE(64);
311
307
  #undef HANDLE_CODE_SIZE
312
308
  default:
313
- if (code_size % 4 == 0) {
314
- return new IVFScanner<HammingComputerDefault>(
315
- this, store_pairs);
316
- } else {
317
- FAISS_THROW_MSG("not supported");
318
- }
309
+ return new IVFScanner<HammingComputerDefault>(this, store_pairs);
310
+ }
311
+ }
312
+
313
+ void IndexIVFSpectralHash::replace_vt(VectorTransform* vt_in, bool own) {
314
+ FAISS_THROW_IF_NOT(vt_in->d_out == nbit);
315
+ FAISS_THROW_IF_NOT(vt_in->d_in == d);
316
+ if (own_fields) {
317
+ delete vt;
319
318
  }
319
+ vt = vt_in;
320
+ threshold_type = Thresh_global;
321
+ is_trained = quantizer->is_trained && quantizer->ntotal == nlist &&
322
+ vt->is_trained;
323
+ own_fields = own;
324
+ }
325
+
326
+ /*
327
+ Check that the encoder is a single vector transform followed by a LSH
328
+ that just does thresholding.
329
+ If this is not the case, the linear transform + threhsolds of the IndexLSH
330
+ should be merged into the VectorTransform (which is feasible).
331
+ */
332
+
333
+ void IndexIVFSpectralHash::replace_vt(IndexPreTransform* encoder, bool own) {
334
+ FAISS_THROW_IF_NOT(encoder->chain.size() == 1);
335
+ auto sub_index = dynamic_cast<IndexLSH*>(encoder->index);
336
+ FAISS_THROW_IF_NOT_MSG(sub_index, "final index should be LSH");
337
+ FAISS_THROW_IF_NOT(sub_index->nbits == nbit);
338
+ FAISS_THROW_IF_NOT(!sub_index->rotate_data);
339
+ FAISS_THROW_IF_NOT(!sub_index->train_thresholds);
340
+ replace_vt(encoder->chain[0], own);
320
341
  }
321
342
 
322
343
  } // namespace faiss
@@ -17,6 +17,7 @@
17
17
  namespace faiss {
18
18
 
19
19
  struct VectorTransform;
20
+ struct IndexPreTransform;
20
21
 
21
22
  /** Inverted list that stores binary codes of size nbit. Before the
22
23
  * binary conversion, the dimension of the vectors is transformed from
@@ -25,23 +26,29 @@ struct VectorTransform;
25
26
  * Each coordinate is subtracted from a value determined by
26
27
  * threshold_type, and split into intervals of size period. Half of
27
28
  * the interval is a 0 bit, the other half a 1.
29
+ *
28
30
  */
29
31
  struct IndexIVFSpectralHash : IndexIVF {
30
- VectorTransform* vt; // transformation from d to nbit dim
32
+ /// transformation from d to nbit dim
33
+ VectorTransform* vt;
34
+ /// own the vt
31
35
  bool own_fields;
32
36
 
37
+ /// nb of bits of the binary signature
33
38
  int nbit;
39
+ /// interval size for 0s and 1s
34
40
  float period;
35
41
 
36
42
  enum ThresholdType {
37
- Thresh_global,
38
- Thresh_centroid,
39
- Thresh_centroid_half,
40
- Thresh_median
43
+ Thresh_global, ///< global threshold at 0
44
+ Thresh_centroid, ///< compare to centroid
45
+ Thresh_centroid_half, ///< central interval around centroid
46
+ Thresh_median ///< median of training set
41
47
  };
42
48
  ThresholdType threshold_type;
43
49
 
44
- // size nlist * nbit or 0 if Thresh_global
50
+ /// Trained threshold.
51
+ /// size nlist * nbit or 0 if Thresh_global
45
52
  std::vector<float> trained;
46
53
 
47
54
  IndexIVFSpectralHash(
@@ -65,6 +72,14 @@ struct IndexIVFSpectralHash : IndexIVF {
65
72
  InvertedListScanner* get_InvertedListScanner(
66
73
  bool store_pairs) const override;
67
74
 
75
+ /** replace the vector transform for an empty (and possibly untrained) index
76
+ */
77
+ void replace_vt(VectorTransform* vt, bool own = false);
78
+
79
+ /** convenience function to get the VT from an index constucted by an
80
+ * index_factory (should end in "LSH") */
81
+ void replace_vt(IndexPreTransform* index, bool own = false);
82
+
68
83
  ~IndexIVFSpectralHash() override;
69
84
  };
70
85
 
@@ -5,8 +5,6 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  #include <faiss/IndexLSH.h>
11
9
 
12
10
  #include <cstdio>
@@ -25,15 +23,13 @@ namespace faiss {
25
23
  ***************************************************************/
26
24
 
27
25
  IndexLSH::IndexLSH(idx_t d, int nbits, bool rotate_data, bool train_thresholds)
28
- : Index(d),
26
+ : IndexFlatCodes((nbits + 7) / 8, d),
29
27
  nbits(nbits),
30
28
  rotate_data(rotate_data),
31
29
  train_thresholds(train_thresholds),
32
30
  rrot(d, nbits) {
33
31
  is_trained = !train_thresholds;
34
32
 
35
- bytes_per_vec = (nbits + 7) / 8;
36
-
37
33
  if (rotate_data) {
38
34
  rrot.init(5);
39
35
  } else {
@@ -41,11 +37,7 @@ IndexLSH::IndexLSH(idx_t d, int nbits, bool rotate_data, bool train_thresholds)
41
37
  }
42
38
  }
43
39
 
44
- IndexLSH::IndexLSH()
45
- : nbits(0),
46
- bytes_per_vec(0),
47
- rotate_data(false),
48
- train_thresholds(false) {}
40
+ IndexLSH::IndexLSH() : nbits(0), rotate_data(false), train_thresholds(false) {}
49
41
 
50
42
  const float* IndexLSH::apply_preprocess(idx_t n, const float* x) const {
51
43
  float* xt = nullptr;
@@ -106,15 +98,6 @@ void IndexLSH::train(idx_t n, const float* x) {
106
98
  is_trained = true;
107
99
  }
108
100
 
109
- void IndexLSH::add(idx_t n, const float* x) {
110
- FAISS_THROW_IF_NOT(is_trained);
111
- codes.resize((ntotal + n) * bytes_per_vec);
112
-
113
- sa_encode(n, x, &codes[ntotal * bytes_per_vec]);
114
-
115
- ntotal += n;
116
- }
117
-
118
101
  void IndexLSH::search(
119
102
  idx_t n,
120
103
  const float* x,
@@ -127,7 +110,7 @@ void IndexLSH::search(
127
110
  const float* xt = apply_preprocess(n, x);
128
111
  ScopeDeleter<float> del(xt == x ? nullptr : xt);
129
112
 
130
- uint8_t* qcodes = new uint8_t[n * bytes_per_vec];
113
+ uint8_t* qcodes = new uint8_t[n * code_size];
131
114
  ScopeDeleter<uint8_t> del2(qcodes);
132
115
 
133
116
  fvecs2bitvecs(xt, qcodes, nbits, n);
@@ -137,7 +120,7 @@ void IndexLSH::search(
137
120
 
138
121
  int_maxheap_array_t res = {size_t(n), size_t(k), labels, idistances};
139
122
 
140
- hammings_knn_hc(&res, qcodes, codes.data(), ntotal, bytes_per_vec, true);
123
+ hammings_knn_hc(&res, qcodes, codes.data(), ntotal, code_size, true);
141
124
 
142
125
  // convert distances to floats
143
126
  for (int i = 0; i < k * n; i++)
@@ -158,15 +141,6 @@ void IndexLSH::transfer_thresholds(LinearTransform* vt) {
158
141
  thresholds.clear();
159
142
  }
160
143
 
161
- void IndexLSH::reset() {
162
- codes.clear();
163
- ntotal = 0;
164
- }
165
-
166
- size_t IndexLSH::sa_code_size() const {
167
- return bytes_per_vec;
168
- }
169
-
170
144
  void IndexLSH::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
171
145
  FAISS_THROW_IF_NOT(is_trained);
172
146
  const float* xt = apply_preprocess(n, x);
@@ -12,17 +12,14 @@
12
12
 
13
13
  #include <vector>
14
14
 
15
- #include <faiss/Index.h>
15
+ #include <faiss/IndexFlatCodes.h>
16
16
  #include <faiss/VectorTransform.h>
17
17
 
18
18
  namespace faiss {
19
19
 
20
20
  /** The sign of each vector component is put in a binary signature */
21
- struct IndexLSH : Index {
22
- typedef unsigned char uint8_t;
23
-
21
+ struct IndexLSH : IndexFlatCodes {
24
22
  int nbits; ///< nb of bits per vector
25
- int bytes_per_vec; ///< nb of 8-bits per encoded vector
26
23
  bool rotate_data; ///< whether to apply a random rotation to input
27
24
  bool train_thresholds; ///< whether we train thresholds or use 0
28
25
 
@@ -30,9 +27,6 @@ struct IndexLSH : Index {
30
27
 
31
28
  std::vector<float> thresholds; ///< thresholds to compare with
32
29
 
33
- /// encoded dataset
34
- std::vector<uint8_t> codes;
35
-
36
30
  IndexLSH(
37
31
  idx_t d,
38
32
  int nbits,
@@ -50,8 +44,6 @@ struct IndexLSH : Index {
50
44
 
51
45
  void train(idx_t n, const float* x) override;
52
46
 
53
- void add(idx_t n, const float* x) override;
54
-
55
47
  void search(
56
48
  idx_t n,
57
49
  const float* x,
@@ -59,8 +51,6 @@ struct IndexLSH : Index {
59
51
  float* distances,
60
52
  idx_t* labels) const override;
61
53
 
62
- void reset() override;
63
-
64
54
  /// transfer the thresholds to a pre-processing stage (and unset
65
55
  /// train_thresholds)
66
56
  void transfer_thresholds(LinearTransform* vt);
@@ -72,9 +62,6 @@ struct IndexLSH : Index {
72
62
  /* standalone codec interface.
73
63
  *
74
64
  * The vectors are decoded to +/- 1 (not 0, 1) */
75
-
76
- size_t sa_code_size() const override;
77
-
78
65
  void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
79
66
 
80
67
  void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
@@ -167,9 +167,7 @@ void IndexNNDescent::search(
167
167
  float* simi = distances + i * k;
168
168
  dis->set_query(x + i * d);
169
169
 
170
- maxheap_heapify(k, simi, idxi);
171
170
  nndescent.search(*dis, k, idxi, simi, vt);
172
- maxheap_reorder(k, simi, idxi);
173
171
  }
174
172
  }
175
173
  InterruptCallback::check();
@@ -104,9 +104,7 @@ void IndexNSG::search(
104
104
  float* simi = distances + i * k;
105
105
  dis->set_query(x + i * d);
106
106
 
107
- maxheap_heapify(k, simi, idxi);
108
107
  nsg.search(*dis, k, idxi, simi, vt);
109
- maxheap_reorder(k, simi, idxi);
110
108
 
111
109
  vt.advance();
112
110
  }
@@ -28,12 +28,13 @@ namespace faiss {
28
28
  ********************************************************/
29
29
 
30
30
  IndexPQ::IndexPQ(int d, size_t M, size_t nbits, MetricType metric)
31
- : Index(d, metric), pq(d, M, nbits) {
31
+ : IndexFlatCodes(0, d, metric), pq(d, M, nbits) {
32
32
  is_trained = false;
33
33
  do_polysemous_training = false;
34
34
  polysemous_ht = nbits * M + 1;
35
35
  search_type = ST_PQ;
36
36
  encode_signs = false;
37
+ code_size = pq.code_size;
37
38
  }
38
39
 
39
40
  IndexPQ::IndexPQ() {
@@ -69,53 +70,6 @@ void IndexPQ::train(idx_t n, const float* x) {
69
70
  is_trained = true;
70
71
  }
71
72
 
72
- void IndexPQ::add(idx_t n, const float* x) {
73
- FAISS_THROW_IF_NOT(is_trained);
74
- codes.resize((n + ntotal) * pq.code_size);
75
- pq.compute_codes(x, &codes[ntotal * pq.code_size], n);
76
- ntotal += n;
77
- }
78
-
79
- size_t IndexPQ::remove_ids(const IDSelector& sel) {
80
- idx_t j = 0;
81
- for (idx_t i = 0; i < ntotal; i++) {
82
- if (sel.is_member(i)) {
83
- // should be removed
84
- } else {
85
- if (i > j) {
86
- memmove(&codes[pq.code_size * j],
87
- &codes[pq.code_size * i],
88
- pq.code_size);
89
- }
90
- j++;
91
- }
92
- }
93
- size_t nremove = ntotal - j;
94
- if (nremove > 0) {
95
- ntotal = j;
96
- codes.resize(ntotal * pq.code_size);
97
- }
98
- return nremove;
99
- }
100
-
101
- void IndexPQ::reset() {
102
- codes.clear();
103
- ntotal = 0;
104
- }
105
-
106
- void IndexPQ::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
107
- FAISS_THROW_IF_NOT(ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
108
- for (idx_t i = 0; i < ni; i++) {
109
- const uint8_t* code = &codes[(i0 + i) * pq.code_size];
110
- pq.decode(code, recons + i * d);
111
- }
112
- }
113
-
114
- void IndexPQ::reconstruct(idx_t key, float* recons) const {
115
- FAISS_THROW_IF_NOT(key >= 0 && key < ntotal);
116
- pq.decode(&codes[key * pq.code_size], recons);
117
- }
118
-
119
73
  namespace {
120
74
 
121
75
  template <class PQDecoder>
@@ -457,9 +411,6 @@ void IndexPQ::search_core_polysemous(
457
411
  }
458
412
 
459
413
  /* The standalone codec interface (just remaps to the PQ functions) */
460
- size_t IndexPQ::sa_code_size() const {
461
- return pq.code_size;
462
- }
463
414
 
464
415
  void IndexPQ::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
465
416
  pq.compute_codes(x, bytes, n);
@@ -12,7 +12,7 @@
12
12
 
13
13
  #include <vector>
14
14
 
15
- #include <faiss/Index.h>
15
+ #include <faiss/IndexFlatCodes.h>
16
16
  #include <faiss/impl/PolysemousTraining.h>
17
17
  #include <faiss/impl/ProductQuantizer.h>
18
18
  #include <faiss/impl/platform_macros.h>
@@ -21,13 +21,10 @@ namespace faiss {
21
21
 
22
22
  /** Index based on a product quantizer. Stored vectors are
23
23
  * approximated by PQ codes. */
24
- struct IndexPQ : Index {
24
+ struct IndexPQ : IndexFlatCodes {
25
25
  /// The product quantizer used to encode the vectors
26
26
  ProductQuantizer pq;
27
27
 
28
- /// Codes. Size ntotal * pq.code_size
29
- std::vector<uint8_t> codes;
30
-
31
28
  /** Constructor.
32
29
  *
33
30
  * @param d dimensionality of the input vectors
@@ -43,8 +40,6 @@ struct IndexPQ : Index {
43
40
 
44
41
  void train(idx_t n, const float* x) override;
45
42
 
46
- void add(idx_t n, const float* x) override;
47
-
48
43
  void search(
49
44
  idx_t n,
50
45
  const float* x,
@@ -52,17 +47,7 @@ struct IndexPQ : Index {
52
47
  float* distances,
53
48
  idx_t* labels) const override;
54
49
 
55
- void reset() override;
56
-
57
- void reconstruct_n(idx_t i0, idx_t ni, float* recons) const override;
58
-
59
- void reconstruct(idx_t key, float* recons) const override;
60
-
61
- size_t remove_ids(const IDSelector& sel) override;
62
-
63
50
  /* The standalone codec interface */
64
- size_t sa_code_size() const override;
65
-
66
51
  void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
67
52
 
68
53
  void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
@@ -155,6 +155,34 @@ void IndexRefine::reconstruct(idx_t key, float* recons) const {
155
155
  refine_index->reconstruct(key, recons);
156
156
  }
157
157
 
158
+ size_t IndexRefine::sa_code_size() const {
159
+ return base_index->sa_code_size() + refine_index->sa_code_size();
160
+ }
161
+
162
+ void IndexRefine::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
163
+ size_t cs1 = base_index->sa_code_size(), cs2 = refine_index->sa_code_size();
164
+ std::unique_ptr<uint8_t[]> tmp1(new uint8_t[n * cs1]);
165
+ base_index->sa_encode(n, x, tmp1.get());
166
+ std::unique_ptr<uint8_t[]> tmp2(new uint8_t[n * cs2]);
167
+ refine_index->sa_encode(n, x, tmp2.get());
168
+ for (size_t i = 0; i < n; i++) {
169
+ uint8_t* b = bytes + i * (cs1 + cs2);
170
+ memcpy(b, tmp1.get() + cs1 * i, cs1);
171
+ memcpy(b + cs1, tmp2.get() + cs2 * i, cs2);
172
+ }
173
+ }
174
+
175
+ void IndexRefine::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
176
+ size_t cs1 = base_index->sa_code_size(), cs2 = refine_index->sa_code_size();
177
+ std::unique_ptr<uint8_t[]> tmp2(
178
+ new uint8_t[n * refine_index->sa_code_size()]);
179
+ for (size_t i = 0; i < n; i++) {
180
+ memcpy(tmp2.get() + i * cs2, bytes + i * (cs1 + cs2), cs2);
181
+ }
182
+
183
+ refine_index->sa_decode(n, tmp2.get(), x);
184
+ }
185
+
158
186
  IndexRefine::~IndexRefine() {
159
187
  if (own_fields)
160
188
  delete base_index;
@@ -49,6 +49,16 @@ struct IndexRefine : Index {
49
49
  // reconstruct is routed to the refine_index
50
50
  void reconstruct(idx_t key, float* recons) const override;
51
51
 
52
+ /* standalone codec interface: the base_index codes are interleaved with the
53
+ * refine_index ones */
54
+ size_t sa_code_size() const override;
55
+
56
+ void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
57
+
58
+ /// The sa_decode decodes from the index_refine, which is assumed to be more
59
+ /// accurate
60
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
61
+
52
62
  ~IndexRefine() override;
53
63
  };
54
64
 
@@ -29,7 +29,7 @@ IndexScalarQuantizer::IndexScalarQuantizer(
29
29
  int d,
30
30
  ScalarQuantizer::QuantizerType qtype,
31
31
  MetricType metric)
32
- : Index(d, metric), sq(d, qtype) {
32
+ : IndexFlatCodes(0, d, metric), sq(d, qtype) {
33
33
  is_trained = qtype == ScalarQuantizer::QT_fp16 ||
34
34
  qtype == ScalarQuantizer::QT_8bit_direct;
35
35
  code_size = sq.code_size;
@@ -43,13 +43,6 @@ void IndexScalarQuantizer::train(idx_t n, const float* x) {
43
43
  is_trained = true;
44
44
  }
45
45
 
46
- void IndexScalarQuantizer::add(idx_t n, const float* x) {
47
- FAISS_THROW_IF_NOT(is_trained);
48
- codes.resize((n + ntotal) * code_size);
49
- sq.compute_codes(x, &codes[ntotal * code_size], n);
50
- ntotal += n;
51
- }
52
-
53
46
  void IndexScalarQuantizer::search(
54
47
  idx_t n,
55
48
  const float* x,
@@ -67,6 +60,7 @@ void IndexScalarQuantizer::search(
67
60
  InvertedListScanner* scanner =
68
61
  sq.select_InvertedListScanner(metric_type, nullptr, true);
69
62
  ScopeDeleter1<InvertedListScanner> del(scanner);
63
+ scanner->list_no = 0; // directly the list number
70
64
 
71
65
  #pragma omp for
72
66
  for (idx_t i = 0; i < n; i++) {
@@ -99,27 +93,7 @@ DistanceComputer* IndexScalarQuantizer::get_distance_computer() const {
99
93
  return dc;
100
94
  }
101
95
 
102
- void IndexScalarQuantizer::reset() {
103
- codes.clear();
104
- ntotal = 0;
105
- }
106
-
107
- void IndexScalarQuantizer::reconstruct_n(idx_t i0, idx_t ni, float* recons)
108
- const {
109
- std::unique_ptr<ScalarQuantizer::Quantizer> squant(sq.select_quantizer());
110
- for (size_t i = 0; i < ni; i++) {
111
- squant->decode_vector(&codes[(i + i0) * code_size], recons + i * d);
112
- }
113
- }
114
-
115
- void IndexScalarQuantizer::reconstruct(idx_t key, float* recons) const {
116
- reconstruct_n(key, 1, recons);
117
- }
118
-
119
96
  /* Codec interface */
120
- size_t IndexScalarQuantizer::sa_code_size() const {
121
- return sq.code_size;
122
- }
123
97
 
124
98
  void IndexScalarQuantizer::sa_encode(idx_t n, const float* x, uint8_t* bytes)
125
99
  const {
@@ -13,6 +13,7 @@
13
13
  #include <stdint.h>
14
14
  #include <vector>
15
15
 
16
+ #include <faiss/IndexFlatCodes.h>
16
17
  #include <faiss/IndexIVF.h>
17
18
  #include <faiss/impl/ScalarQuantizer.h>
18
19
 
@@ -24,15 +25,10 @@ namespace faiss {
24
25
  * (default).
25
26
  */
26
27
 
27
- struct IndexScalarQuantizer : Index {
28
+ struct IndexScalarQuantizer : IndexFlatCodes {
28
29
  /// Used to encode the vectors
29
30
  ScalarQuantizer sq;
30
31
 
31
- /// Codes. Size ntotal * pq.code_size
32
- std::vector<uint8_t> codes;
33
-
34
- size_t code_size;
35
-
36
32
  /** Constructor.
37
33
  *
38
34
  * @param d dimensionality of the input vectors
@@ -48,8 +44,6 @@ struct IndexScalarQuantizer : Index {
48
44
 
49
45
  void train(idx_t n, const float* x) override;
50
46
 
51
- void add(idx_t n, const float* x) override;
52
-
53
47
  void search(
54
48
  idx_t n,
55
49
  const float* x,
@@ -57,17 +51,9 @@ struct IndexScalarQuantizer : Index {
57
51
  float* distances,
58
52
  idx_t* labels) const override;
59
53
 
60
- void reset() override;
61
-
62
- void reconstruct_n(idx_t i0, idx_t ni, float* recons) const override;
63
-
64
- void reconstruct(idx_t key, float* recons) const override;
65
-
66
54
  DistanceComputer* get_distance_computer() const override;
67
55
 
68
56
  /* standalone codec interface */
69
- size_t sa_code_size() const override;
70
-
71
57
  void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
72
58
 
73
59
  void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;