faiss 0.2.3 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/LICENSE.txt +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/vendor/faiss/faiss/Clustering.cpp +32 -0
  6. data/vendor/faiss/faiss/Clustering.h +14 -0
  7. data/vendor/faiss/faiss/Index.h +1 -1
  8. data/vendor/faiss/faiss/Index2Layer.cpp +19 -92
  9. data/vendor/faiss/faiss/Index2Layer.h +2 -16
  10. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  11. data/vendor/faiss/faiss/{IndexResidual.h → IndexAdditiveQuantizer.h} +101 -58
  12. data/vendor/faiss/faiss/IndexFlat.cpp +22 -52
  13. data/vendor/faiss/faiss/IndexFlat.h +9 -15
  14. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  15. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  16. data/vendor/faiss/faiss/IndexIVF.cpp +79 -7
  17. data/vendor/faiss/faiss/IndexIVF.h +25 -7
  18. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  19. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  20. data/vendor/faiss/faiss/IndexIVFFlat.cpp +9 -12
  21. data/vendor/faiss/faiss/IndexIVFPQ.cpp +5 -4
  22. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
  23. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +60 -39
  24. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +21 -6
  25. data/vendor/faiss/faiss/IndexLSH.cpp +4 -30
  26. data/vendor/faiss/faiss/IndexLSH.h +2 -15
  27. data/vendor/faiss/faiss/IndexNNDescent.cpp +0 -2
  28. data/vendor/faiss/faiss/IndexNSG.cpp +0 -2
  29. data/vendor/faiss/faiss/IndexPQ.cpp +2 -51
  30. data/vendor/faiss/faiss/IndexPQ.h +2 -17
  31. data/vendor/faiss/faiss/IndexRefine.cpp +28 -0
  32. data/vendor/faiss/faiss/IndexRefine.h +10 -0
  33. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -28
  34. data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -16
  35. data/vendor/faiss/faiss/VectorTransform.cpp +2 -1
  36. data/vendor/faiss/faiss/VectorTransform.h +3 -0
  37. data/vendor/faiss/faiss/clone_index.cpp +3 -2
  38. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -2
  39. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  40. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +257 -24
  41. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +69 -9
  42. data/vendor/faiss/faiss/impl/HNSW.cpp +10 -5
  43. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +393 -210
  44. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +100 -28
  45. data/vendor/faiss/faiss/impl/NSG.cpp +0 -3
  46. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  47. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +357 -47
  48. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +65 -7
  49. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +12 -19
  50. data/vendor/faiss/faiss/impl/index_read.cpp +102 -19
  51. data/vendor/faiss/faiss/impl/index_write.cpp +66 -16
  52. data/vendor/faiss/faiss/impl/io.cpp +1 -1
  53. data/vendor/faiss/faiss/impl/io_macros.h +20 -0
  54. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  55. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  56. data/vendor/faiss/faiss/index_factory.cpp +585 -414
  57. data/vendor/faiss/faiss/index_factory.h +3 -0
  58. data/vendor/faiss/faiss/utils/distances.cpp +4 -2
  59. data/vendor/faiss/faiss/utils/distances.h +36 -3
  60. data/vendor/faiss/faiss/utils/distances_simd.cpp +50 -0
  61. data/vendor/faiss/faiss/utils/utils.h +1 -1
  62. metadata +12 -5
  63. data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
@@ -13,6 +13,8 @@
13
13
  #include <algorithm>
14
14
  #include <memory>
15
15
 
16
+ #include <faiss/IndexLSH.h>
17
+ #include <faiss/IndexPreTransform.h>
16
18
  #include <faiss/VectorTransform.h>
17
19
  #include <faiss/impl/AuxIndexStructures.h>
18
20
  #include <faiss/impl/FaissAssert.h>
@@ -31,7 +33,6 @@ IndexIVFSpectralHash::IndexIVFSpectralHash(
31
33
  nbit(nbit),
32
34
  period(period),
33
35
  threshold_type(Thresh_global) {
34
- FAISS_THROW_IF_NOT(code_size % 4 == 0);
35
36
  RandomRotationMatrix* rr = new RandomRotationMatrix(d, nbit);
36
37
  rr->init(1234);
37
38
  vt = rr;
@@ -151,8 +152,8 @@ void binarize_with_freq(
151
152
  memset(codes, 0, (nbit + 7) / 8);
152
153
  for (size_t i = 0; i < nbit; i++) {
153
154
  float xf = (x[i] - c[i]);
154
- int xi = int(floor(xf * freq));
155
- int bit = xi & 1;
155
+ int64_t xi = int64_t(floor(xf * freq));
156
+ int64_t bit = xi & 1;
156
157
  codes[i >> 3] |= bit << (i & 7);
157
158
  }
158
159
  }
@@ -167,35 +168,33 @@ void IndexIVFSpectralHash::encode_vectors(
167
168
  bool include_listnos) const {
168
169
  FAISS_THROW_IF_NOT(is_trained);
169
170
  float freq = 2.0 / period;
170
-
171
- FAISS_THROW_IF_NOT_MSG(!include_listnos, "listnos encoding not supported");
171
+ size_t coarse_size = include_listnos ? coarse_code_size() : 0;
172
172
 
173
173
  // transform with vt
174
174
  std::unique_ptr<float[]> x(vt->apply(n, x_in));
175
175
 
176
- #pragma omp parallel
177
- {
178
- std::vector<float> zero(nbit);
176
+ std::vector<float> zero(nbit);
179
177
 
180
- // each thread takes care of a subset of lists
181
178
  #pragma omp for
182
- for (idx_t i = 0; i < n; i++) {
183
- int64_t list_no = list_nos[i];
184
-
185
- if (list_no >= 0) {
186
- const float* c;
187
- if (threshold_type == Thresh_global) {
188
- c = zero.data();
189
- } else {
190
- c = trained.data() + list_no * nbit;
191
- }
192
- binarize_with_freq(
193
- nbit,
194
- freq,
195
- x.get() + i * nbit,
196
- c,
197
- codes + i * code_size);
179
+ for (idx_t i = 0; i < n; i++) {
180
+ int64_t list_no = list_nos[i];
181
+ uint8_t* code = codes + i * (code_size + coarse_size);
182
+
183
+ if (list_no >= 0) {
184
+ if (coarse_size) {
185
+ encode_listno(list_no, code);
186
+ }
187
+ const float* c;
188
+
189
+ if (threshold_type == Thresh_global) {
190
+ c = zero.data();
191
+ } else {
192
+ c = trained.data() + list_no * nbit;
198
193
  }
194
+ binarize_with_freq(
195
+ nbit, freq, x.get() + i * nbit, c, code + coarse_size);
196
+ } else {
197
+ memset(code, 0, code_size + coarse_size);
199
198
  }
200
199
  }
201
200
  }
@@ -206,9 +205,7 @@ template <class HammingComputer>
206
205
  struct IVFScanner : InvertedListScanner {
207
206
  // copied from index structure
208
207
  const IndexIVFSpectralHash* index;
209
- size_t code_size;
210
208
  size_t nbit;
211
- bool store_pairs;
212
209
 
213
210
  float period, freq;
214
211
  std::vector<float> q;
@@ -220,15 +217,16 @@ struct IVFScanner : InvertedListScanner {
220
217
 
221
218
  IVFScanner(const IndexIVFSpectralHash* index, bool store_pairs)
222
219
  : index(index),
223
- code_size(index->code_size),
224
220
  nbit(index->nbit),
225
- store_pairs(store_pairs),
226
221
  period(index->period),
227
222
  freq(2.0 / index->period),
228
223
  q(nbit),
229
224
  zero(nbit),
230
- qcode(code_size),
231
- hc(qcode.data(), code_size) {}
225
+ qcode(index->code_size),
226
+ hc(qcode.data(), index->code_size) {
227
+ this->store_pairs = store_pairs;
228
+ this->code_size = index->code_size;
229
+ }
232
230
 
233
231
  void set_query(const float* query) override {
234
232
  FAISS_THROW_IF_NOT(query);
@@ -241,8 +239,6 @@ struct IVFScanner : InvertedListScanner {
241
239
  }
242
240
  }
243
241
 
244
- idx_t list_no;
245
-
246
242
  void set_list(idx_t list_no, float /*coarse_dis*/) override {
247
243
  this->list_no = list_no;
248
244
  if (index->threshold_type != IndexIVFSpectralHash::Thresh_global) {
@@ -310,13 +306,38 @@ InvertedListScanner* IndexIVFSpectralHash::get_InvertedListScanner(
310
306
  HANDLE_CODE_SIZE(64);
311
307
  #undef HANDLE_CODE_SIZE
312
308
  default:
313
- if (code_size % 4 == 0) {
314
- return new IVFScanner<HammingComputerDefault>(
315
- this, store_pairs);
316
- } else {
317
- FAISS_THROW_MSG("not supported");
318
- }
309
+ return new IVFScanner<HammingComputerDefault>(this, store_pairs);
310
+ }
311
+ }
312
+
313
+ void IndexIVFSpectralHash::replace_vt(VectorTransform* vt_in, bool own) {
314
+ FAISS_THROW_IF_NOT(vt_in->d_out == nbit);
315
+ FAISS_THROW_IF_NOT(vt_in->d_in == d);
316
+ if (own_fields) {
317
+ delete vt;
319
318
  }
319
+ vt = vt_in;
320
+ threshold_type = Thresh_global;
321
+ is_trained = quantizer->is_trained && quantizer->ntotal == nlist &&
322
+ vt->is_trained;
323
+ own_fields = own;
324
+ }
325
+
326
+ /*
327
+ Check that the encoder is a single vector transform followed by a LSH
328
+ that just does thresholding.
329
+ If this is not the case, the linear transform + threhsolds of the IndexLSH
330
+ should be merged into the VectorTransform (which is feasible).
331
+ */
332
+
333
+ void IndexIVFSpectralHash::replace_vt(IndexPreTransform* encoder, bool own) {
334
+ FAISS_THROW_IF_NOT(encoder->chain.size() == 1);
335
+ auto sub_index = dynamic_cast<IndexLSH*>(encoder->index);
336
+ FAISS_THROW_IF_NOT_MSG(sub_index, "final index should be LSH");
337
+ FAISS_THROW_IF_NOT(sub_index->nbits == nbit);
338
+ FAISS_THROW_IF_NOT(!sub_index->rotate_data);
339
+ FAISS_THROW_IF_NOT(!sub_index->train_thresholds);
340
+ replace_vt(encoder->chain[0], own);
320
341
  }
321
342
 
322
343
  } // namespace faiss
@@ -17,6 +17,7 @@
17
17
  namespace faiss {
18
18
 
19
19
  struct VectorTransform;
20
+ struct IndexPreTransform;
20
21
 
21
22
  /** Inverted list that stores binary codes of size nbit. Before the
22
23
  * binary conversion, the dimension of the vectors is transformed from
@@ -25,23 +26,29 @@ struct VectorTransform;
25
26
  * Each coordinate is subtracted from a value determined by
26
27
  * threshold_type, and split into intervals of size period. Half of
27
28
  * the interval is a 0 bit, the other half a 1.
29
+ *
28
30
  */
29
31
  struct IndexIVFSpectralHash : IndexIVF {
30
- VectorTransform* vt; // transformation from d to nbit dim
32
+ /// transformation from d to nbit dim
33
+ VectorTransform* vt;
34
+ /// own the vt
31
35
  bool own_fields;
32
36
 
37
+ /// nb of bits of the binary signature
33
38
  int nbit;
39
+ /// interval size for 0s and 1s
34
40
  float period;
35
41
 
36
42
  enum ThresholdType {
37
- Thresh_global,
38
- Thresh_centroid,
39
- Thresh_centroid_half,
40
- Thresh_median
43
+ Thresh_global, ///< global threshold at 0
44
+ Thresh_centroid, ///< compare to centroid
45
+ Thresh_centroid_half, ///< central interval around centroid
46
+ Thresh_median ///< median of training set
41
47
  };
42
48
  ThresholdType threshold_type;
43
49
 
44
- // size nlist * nbit or 0 if Thresh_global
50
+ /// Trained threshold.
51
+ /// size nlist * nbit or 0 if Thresh_global
45
52
  std::vector<float> trained;
46
53
 
47
54
  IndexIVFSpectralHash(
@@ -65,6 +72,14 @@ struct IndexIVFSpectralHash : IndexIVF {
65
72
  InvertedListScanner* get_InvertedListScanner(
66
73
  bool store_pairs) const override;
67
74
 
75
+ /** replace the vector transform for an empty (and possibly untrained) index
76
+ */
77
+ void replace_vt(VectorTransform* vt, bool own = false);
78
+
79
+ /** convenience function to get the VT from an index constucted by an
80
+ * index_factory (should end in "LSH") */
81
+ void replace_vt(IndexPreTransform* index, bool own = false);
82
+
68
83
  ~IndexIVFSpectralHash() override;
69
84
  };
70
85
 
@@ -5,8 +5,6 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  #include <faiss/IndexLSH.h>
11
9
 
12
10
  #include <cstdio>
@@ -25,15 +23,13 @@ namespace faiss {
25
23
  ***************************************************************/
26
24
 
27
25
  IndexLSH::IndexLSH(idx_t d, int nbits, bool rotate_data, bool train_thresholds)
28
- : Index(d),
26
+ : IndexFlatCodes((nbits + 7) / 8, d),
29
27
  nbits(nbits),
30
28
  rotate_data(rotate_data),
31
29
  train_thresholds(train_thresholds),
32
30
  rrot(d, nbits) {
33
31
  is_trained = !train_thresholds;
34
32
 
35
- bytes_per_vec = (nbits + 7) / 8;
36
-
37
33
  if (rotate_data) {
38
34
  rrot.init(5);
39
35
  } else {
@@ -41,11 +37,7 @@ IndexLSH::IndexLSH(idx_t d, int nbits, bool rotate_data, bool train_thresholds)
41
37
  }
42
38
  }
43
39
 
44
- IndexLSH::IndexLSH()
45
- : nbits(0),
46
- bytes_per_vec(0),
47
- rotate_data(false),
48
- train_thresholds(false) {}
40
+ IndexLSH::IndexLSH() : nbits(0), rotate_data(false), train_thresholds(false) {}
49
41
 
50
42
  const float* IndexLSH::apply_preprocess(idx_t n, const float* x) const {
51
43
  float* xt = nullptr;
@@ -106,15 +98,6 @@ void IndexLSH::train(idx_t n, const float* x) {
106
98
  is_trained = true;
107
99
  }
108
100
 
109
- void IndexLSH::add(idx_t n, const float* x) {
110
- FAISS_THROW_IF_NOT(is_trained);
111
- codes.resize((ntotal + n) * bytes_per_vec);
112
-
113
- sa_encode(n, x, &codes[ntotal * bytes_per_vec]);
114
-
115
- ntotal += n;
116
- }
117
-
118
101
  void IndexLSH::search(
119
102
  idx_t n,
120
103
  const float* x,
@@ -127,7 +110,7 @@ void IndexLSH::search(
127
110
  const float* xt = apply_preprocess(n, x);
128
111
  ScopeDeleter<float> del(xt == x ? nullptr : xt);
129
112
 
130
- uint8_t* qcodes = new uint8_t[n * bytes_per_vec];
113
+ uint8_t* qcodes = new uint8_t[n * code_size];
131
114
  ScopeDeleter<uint8_t> del2(qcodes);
132
115
 
133
116
  fvecs2bitvecs(xt, qcodes, nbits, n);
@@ -137,7 +120,7 @@ void IndexLSH::search(
137
120
 
138
121
  int_maxheap_array_t res = {size_t(n), size_t(k), labels, idistances};
139
122
 
140
- hammings_knn_hc(&res, qcodes, codes.data(), ntotal, bytes_per_vec, true);
123
+ hammings_knn_hc(&res, qcodes, codes.data(), ntotal, code_size, true);
141
124
 
142
125
  // convert distances to floats
143
126
  for (int i = 0; i < k * n; i++)
@@ -158,15 +141,6 @@ void IndexLSH::transfer_thresholds(LinearTransform* vt) {
158
141
  thresholds.clear();
159
142
  }
160
143
 
161
- void IndexLSH::reset() {
162
- codes.clear();
163
- ntotal = 0;
164
- }
165
-
166
- size_t IndexLSH::sa_code_size() const {
167
- return bytes_per_vec;
168
- }
169
-
170
144
  void IndexLSH::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
171
145
  FAISS_THROW_IF_NOT(is_trained);
172
146
  const float* xt = apply_preprocess(n, x);
@@ -12,17 +12,14 @@
12
12
 
13
13
  #include <vector>
14
14
 
15
- #include <faiss/Index.h>
15
+ #include <faiss/IndexFlatCodes.h>
16
16
  #include <faiss/VectorTransform.h>
17
17
 
18
18
  namespace faiss {
19
19
 
20
20
  /** The sign of each vector component is put in a binary signature */
21
- struct IndexLSH : Index {
22
- typedef unsigned char uint8_t;
23
-
21
+ struct IndexLSH : IndexFlatCodes {
24
22
  int nbits; ///< nb of bits per vector
25
- int bytes_per_vec; ///< nb of 8-bits per encoded vector
26
23
  bool rotate_data; ///< whether to apply a random rotation to input
27
24
  bool train_thresholds; ///< whether we train thresholds or use 0
28
25
 
@@ -30,9 +27,6 @@ struct IndexLSH : Index {
30
27
 
31
28
  std::vector<float> thresholds; ///< thresholds to compare with
32
29
 
33
- /// encoded dataset
34
- std::vector<uint8_t> codes;
35
-
36
30
  IndexLSH(
37
31
  idx_t d,
38
32
  int nbits,
@@ -50,8 +44,6 @@ struct IndexLSH : Index {
50
44
 
51
45
  void train(idx_t n, const float* x) override;
52
46
 
53
- void add(idx_t n, const float* x) override;
54
-
55
47
  void search(
56
48
  idx_t n,
57
49
  const float* x,
@@ -59,8 +51,6 @@ struct IndexLSH : Index {
59
51
  float* distances,
60
52
  idx_t* labels) const override;
61
53
 
62
- void reset() override;
63
-
64
54
  /// transfer the thresholds to a pre-processing stage (and unset
65
55
  /// train_thresholds)
66
56
  void transfer_thresholds(LinearTransform* vt);
@@ -72,9 +62,6 @@ struct IndexLSH : Index {
72
62
  /* standalone codec interface.
73
63
  *
74
64
  * The vectors are decoded to +/- 1 (not 0, 1) */
75
-
76
- size_t sa_code_size() const override;
77
-
78
65
  void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
79
66
 
80
67
  void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
@@ -167,9 +167,7 @@ void IndexNNDescent::search(
167
167
  float* simi = distances + i * k;
168
168
  dis->set_query(x + i * d);
169
169
 
170
- maxheap_heapify(k, simi, idxi);
171
170
  nndescent.search(*dis, k, idxi, simi, vt);
172
- maxheap_reorder(k, simi, idxi);
173
171
  }
174
172
  }
175
173
  InterruptCallback::check();
@@ -104,9 +104,7 @@ void IndexNSG::search(
104
104
  float* simi = distances + i * k;
105
105
  dis->set_query(x + i * d);
106
106
 
107
- maxheap_heapify(k, simi, idxi);
108
107
  nsg.search(*dis, k, idxi, simi, vt);
109
- maxheap_reorder(k, simi, idxi);
110
108
 
111
109
  vt.advance();
112
110
  }
@@ -28,12 +28,13 @@ namespace faiss {
28
28
  ********************************************************/
29
29
 
30
30
  IndexPQ::IndexPQ(int d, size_t M, size_t nbits, MetricType metric)
31
- : Index(d, metric), pq(d, M, nbits) {
31
+ : IndexFlatCodes(0, d, metric), pq(d, M, nbits) {
32
32
  is_trained = false;
33
33
  do_polysemous_training = false;
34
34
  polysemous_ht = nbits * M + 1;
35
35
  search_type = ST_PQ;
36
36
  encode_signs = false;
37
+ code_size = pq.code_size;
37
38
  }
38
39
 
39
40
  IndexPQ::IndexPQ() {
@@ -69,53 +70,6 @@ void IndexPQ::train(idx_t n, const float* x) {
69
70
  is_trained = true;
70
71
  }
71
72
 
72
- void IndexPQ::add(idx_t n, const float* x) {
73
- FAISS_THROW_IF_NOT(is_trained);
74
- codes.resize((n + ntotal) * pq.code_size);
75
- pq.compute_codes(x, &codes[ntotal * pq.code_size], n);
76
- ntotal += n;
77
- }
78
-
79
- size_t IndexPQ::remove_ids(const IDSelector& sel) {
80
- idx_t j = 0;
81
- for (idx_t i = 0; i < ntotal; i++) {
82
- if (sel.is_member(i)) {
83
- // should be removed
84
- } else {
85
- if (i > j) {
86
- memmove(&codes[pq.code_size * j],
87
- &codes[pq.code_size * i],
88
- pq.code_size);
89
- }
90
- j++;
91
- }
92
- }
93
- size_t nremove = ntotal - j;
94
- if (nremove > 0) {
95
- ntotal = j;
96
- codes.resize(ntotal * pq.code_size);
97
- }
98
- return nremove;
99
- }
100
-
101
- void IndexPQ::reset() {
102
- codes.clear();
103
- ntotal = 0;
104
- }
105
-
106
- void IndexPQ::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
107
- FAISS_THROW_IF_NOT(ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
108
- for (idx_t i = 0; i < ni; i++) {
109
- const uint8_t* code = &codes[(i0 + i) * pq.code_size];
110
- pq.decode(code, recons + i * d);
111
- }
112
- }
113
-
114
- void IndexPQ::reconstruct(idx_t key, float* recons) const {
115
- FAISS_THROW_IF_NOT(key >= 0 && key < ntotal);
116
- pq.decode(&codes[key * pq.code_size], recons);
117
- }
118
-
119
73
  namespace {
120
74
 
121
75
  template <class PQDecoder>
@@ -457,9 +411,6 @@ void IndexPQ::search_core_polysemous(
457
411
  }
458
412
 
459
413
  /* The standalone codec interface (just remaps to the PQ functions) */
460
- size_t IndexPQ::sa_code_size() const {
461
- return pq.code_size;
462
- }
463
414
 
464
415
  void IndexPQ::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
465
416
  pq.compute_codes(x, bytes, n);
@@ -12,7 +12,7 @@
12
12
 
13
13
  #include <vector>
14
14
 
15
- #include <faiss/Index.h>
15
+ #include <faiss/IndexFlatCodes.h>
16
16
  #include <faiss/impl/PolysemousTraining.h>
17
17
  #include <faiss/impl/ProductQuantizer.h>
18
18
  #include <faiss/impl/platform_macros.h>
@@ -21,13 +21,10 @@ namespace faiss {
21
21
 
22
22
  /** Index based on a product quantizer. Stored vectors are
23
23
  * approximated by PQ codes. */
24
- struct IndexPQ : Index {
24
+ struct IndexPQ : IndexFlatCodes {
25
25
  /// The product quantizer used to encode the vectors
26
26
  ProductQuantizer pq;
27
27
 
28
- /// Codes. Size ntotal * pq.code_size
29
- std::vector<uint8_t> codes;
30
-
31
28
  /** Constructor.
32
29
  *
33
30
  * @param d dimensionality of the input vectors
@@ -43,8 +40,6 @@ struct IndexPQ : Index {
43
40
 
44
41
  void train(idx_t n, const float* x) override;
45
42
 
46
- void add(idx_t n, const float* x) override;
47
-
48
43
  void search(
49
44
  idx_t n,
50
45
  const float* x,
@@ -52,17 +47,7 @@ struct IndexPQ : Index {
52
47
  float* distances,
53
48
  idx_t* labels) const override;
54
49
 
55
- void reset() override;
56
-
57
- void reconstruct_n(idx_t i0, idx_t ni, float* recons) const override;
58
-
59
- void reconstruct(idx_t key, float* recons) const override;
60
-
61
- size_t remove_ids(const IDSelector& sel) override;
62
-
63
50
  /* The standalone codec interface */
64
- size_t sa_code_size() const override;
65
-
66
51
  void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
67
52
 
68
53
  void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
@@ -155,6 +155,34 @@ void IndexRefine::reconstruct(idx_t key, float* recons) const {
155
155
  refine_index->reconstruct(key, recons);
156
156
  }
157
157
 
158
+ size_t IndexRefine::sa_code_size() const {
159
+ return base_index->sa_code_size() + refine_index->sa_code_size();
160
+ }
161
+
162
+ void IndexRefine::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
163
+ size_t cs1 = base_index->sa_code_size(), cs2 = refine_index->sa_code_size();
164
+ std::unique_ptr<uint8_t[]> tmp1(new uint8_t[n * cs1]);
165
+ base_index->sa_encode(n, x, tmp1.get());
166
+ std::unique_ptr<uint8_t[]> tmp2(new uint8_t[n * cs2]);
167
+ refine_index->sa_encode(n, x, tmp2.get());
168
+ for (size_t i = 0; i < n; i++) {
169
+ uint8_t* b = bytes + i * (cs1 + cs2);
170
+ memcpy(b, tmp1.get() + cs1 * i, cs1);
171
+ memcpy(b + cs1, tmp2.get() + cs2 * i, cs2);
172
+ }
173
+ }
174
+
175
+ void IndexRefine::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
176
+ size_t cs1 = base_index->sa_code_size(), cs2 = refine_index->sa_code_size();
177
+ std::unique_ptr<uint8_t[]> tmp2(
178
+ new uint8_t[n * refine_index->sa_code_size()]);
179
+ for (size_t i = 0; i < n; i++) {
180
+ memcpy(tmp2.get() + i * cs2, bytes + i * (cs1 + cs2), cs2);
181
+ }
182
+
183
+ refine_index->sa_decode(n, tmp2.get(), x);
184
+ }
185
+
158
186
  IndexRefine::~IndexRefine() {
159
187
  if (own_fields)
160
188
  delete base_index;
@@ -49,6 +49,16 @@ struct IndexRefine : Index {
49
49
  // reconstruct is routed to the refine_index
50
50
  void reconstruct(idx_t key, float* recons) const override;
51
51
 
52
+ /* standalone codec interface: the base_index codes are interleaved with the
53
+ * refine_index ones */
54
+ size_t sa_code_size() const override;
55
+
56
+ void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
57
+
58
+ /// The sa_decode decodes from the index_refine, which is assumed to be more
59
+ /// accurate
60
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
61
+
52
62
  ~IndexRefine() override;
53
63
  };
54
64
 
@@ -29,7 +29,7 @@ IndexScalarQuantizer::IndexScalarQuantizer(
29
29
  int d,
30
30
  ScalarQuantizer::QuantizerType qtype,
31
31
  MetricType metric)
32
- : Index(d, metric), sq(d, qtype) {
32
+ : IndexFlatCodes(0, d, metric), sq(d, qtype) {
33
33
  is_trained = qtype == ScalarQuantizer::QT_fp16 ||
34
34
  qtype == ScalarQuantizer::QT_8bit_direct;
35
35
  code_size = sq.code_size;
@@ -43,13 +43,6 @@ void IndexScalarQuantizer::train(idx_t n, const float* x) {
43
43
  is_trained = true;
44
44
  }
45
45
 
46
- void IndexScalarQuantizer::add(idx_t n, const float* x) {
47
- FAISS_THROW_IF_NOT(is_trained);
48
- codes.resize((n + ntotal) * code_size);
49
- sq.compute_codes(x, &codes[ntotal * code_size], n);
50
- ntotal += n;
51
- }
52
-
53
46
  void IndexScalarQuantizer::search(
54
47
  idx_t n,
55
48
  const float* x,
@@ -67,6 +60,7 @@ void IndexScalarQuantizer::search(
67
60
  InvertedListScanner* scanner =
68
61
  sq.select_InvertedListScanner(metric_type, nullptr, true);
69
62
  ScopeDeleter1<InvertedListScanner> del(scanner);
63
+ scanner->list_no = 0; // directly the list number
70
64
 
71
65
  #pragma omp for
72
66
  for (idx_t i = 0; i < n; i++) {
@@ -99,27 +93,7 @@ DistanceComputer* IndexScalarQuantizer::get_distance_computer() const {
99
93
  return dc;
100
94
  }
101
95
 
102
- void IndexScalarQuantizer::reset() {
103
- codes.clear();
104
- ntotal = 0;
105
- }
106
-
107
- void IndexScalarQuantizer::reconstruct_n(idx_t i0, idx_t ni, float* recons)
108
- const {
109
- std::unique_ptr<ScalarQuantizer::Quantizer> squant(sq.select_quantizer());
110
- for (size_t i = 0; i < ni; i++) {
111
- squant->decode_vector(&codes[(i + i0) * code_size], recons + i * d);
112
- }
113
- }
114
-
115
- void IndexScalarQuantizer::reconstruct(idx_t key, float* recons) const {
116
- reconstruct_n(key, 1, recons);
117
- }
118
-
119
96
  /* Codec interface */
120
- size_t IndexScalarQuantizer::sa_code_size() const {
121
- return sq.code_size;
122
- }
123
97
 
124
98
  void IndexScalarQuantizer::sa_encode(idx_t n, const float* x, uint8_t* bytes)
125
99
  const {
@@ -13,6 +13,7 @@
13
13
  #include <stdint.h>
14
14
  #include <vector>
15
15
 
16
+ #include <faiss/IndexFlatCodes.h>
16
17
  #include <faiss/IndexIVF.h>
17
18
  #include <faiss/impl/ScalarQuantizer.h>
18
19
 
@@ -24,15 +25,10 @@ namespace faiss {
24
25
  * (default).
25
26
  */
26
27
 
27
- struct IndexScalarQuantizer : Index {
28
+ struct IndexScalarQuantizer : IndexFlatCodes {
28
29
  /// Used to encode the vectors
29
30
  ScalarQuantizer sq;
30
31
 
31
- /// Codes. Size ntotal * pq.code_size
32
- std::vector<uint8_t> codes;
33
-
34
- size_t code_size;
35
-
36
32
  /** Constructor.
37
33
  *
38
34
  * @param d dimensionality of the input vectors
@@ -48,8 +44,6 @@ struct IndexScalarQuantizer : Index {
48
44
 
49
45
  void train(idx_t n, const float* x) override;
50
46
 
51
- void add(idx_t n, const float* x) override;
52
-
53
47
  void search(
54
48
  idx_t n,
55
49
  const float* x,
@@ -57,17 +51,9 @@ struct IndexScalarQuantizer : Index {
57
51
  float* distances,
58
52
  idx_t* labels) const override;
59
53
 
60
- void reset() override;
61
-
62
- void reconstruct_n(idx_t i0, idx_t ni, float* recons) const override;
63
-
64
- void reconstruct(idx_t key, float* recons) const override;
65
-
66
54
  DistanceComputer* get_distance_computer() const override;
67
55
 
68
56
  /* standalone codec interface */
69
- size_t sa_code_size() const override;
70
-
71
57
  void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
72
58
 
73
59
  void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;