faiss 0.3.0 → 0.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (171) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +0 -1
  12. data/vendor/faiss/faiss/Clustering.cpp +4 -18
  13. data/vendor/faiss/faiss/Clustering.h +31 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +22 -11
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +20 -5
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
  26. data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
  27. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  28. data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
  29. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  30. data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
  31. data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
  33. data/vendor/faiss/faiss/IndexHNSW.h +12 -48
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  35. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  36. data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
  37. data/vendor/faiss/faiss/IndexIVF.h +37 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  42. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
  43. data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  46. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  48. data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
  49. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  50. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  52. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  53. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  56. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  57. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
  58. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  59. data/vendor/faiss/faiss/IndexNSG.h +10 -10
  60. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  61. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  62. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  64. data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
  65. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  66. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  67. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
  68. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  69. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  70. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  71. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  72. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  73. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  74. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  75. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  76. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  77. data/vendor/faiss/faiss/clone_index.h +3 -0
  78. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
  79. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  82. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  83. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  84. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
  85. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  88. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  89. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  90. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
  91. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  92. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
  93. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  94. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  95. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
  96. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  97. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  98. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  99. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  100. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  101. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  102. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
  103. data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
  104. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  105. data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
  106. data/vendor/faiss/faiss/impl/HNSW.h +9 -8
  107. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  108. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
  109. data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
  110. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  111. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  112. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  113. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
  114. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  115. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  116. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  117. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  118. data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
  119. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
  121. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
  122. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  123. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  124. data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
  125. data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
  126. data/vendor/faiss/faiss/impl/io.cpp +10 -10
  127. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  128. data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
  129. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
  130. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  131. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  132. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
  133. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  134. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  135. data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
  136. data/vendor/faiss/faiss/index_factory.cpp +10 -7
  137. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  138. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
  139. data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
  140. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  141. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  142. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  143. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  144. data/vendor/faiss/faiss/utils/distances.cpp +128 -74
  145. data/vendor/faiss/faiss/utils/distances.h +81 -4
  146. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  147. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  148. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  149. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  150. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  151. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  152. data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
  153. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  154. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  155. data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
  156. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  157. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  158. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  159. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
  160. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
  161. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  162. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  163. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  164. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  165. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  166. data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
  167. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  168. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  169. data/vendor/faiss/faiss/utils/utils.cpp +112 -6
  170. data/vendor/faiss/faiss/utils/utils.h +57 -20
  171. metadata +10 -3
@@ -72,6 +72,8 @@ struct SearchParametersIVF : SearchParameters {
72
72
  size_t nprobe = 1; ///< number of probes at query time
73
73
  size_t max_codes = 0; ///< max nb of codes to visit to do a query
74
74
  SearchParameters* quantizer_params = nullptr;
75
+ /// context object to pass to InvertedLists
76
+ void* inverted_list_context = nullptr;
75
77
 
76
78
  virtual ~SearchParametersIVF() {}
77
79
  };
@@ -177,6 +179,7 @@ struct IndexIVF : Index, IndexIVFInterface {
177
179
  bool own_invlists = false;
178
180
 
179
181
  size_t code_size = 0; ///< code size per vector in bytes
182
+
180
183
  /** Parallel mode determines how queries are parallelized with OpenMP
181
184
  *
182
185
  * 0 (default): split over queries
@@ -194,6 +197,10 @@ struct IndexIVF : Index, IndexIVFInterface {
194
197
  * enables reconstruct() */
195
198
  DirectMap direct_map;
196
199
 
200
+ /// do the codes in the invlists encode the vectors relative to the
201
+ /// centroids?
202
+ bool by_residual = true;
203
+
197
204
  /** The Inverted file takes a quantizer (an Index) on input,
198
205
  * which implements the function mapping a vector to a list
199
206
  * identifier.
@@ -207,7 +214,7 @@ struct IndexIVF : Index, IndexIVFInterface {
207
214
 
208
215
  void reset() override;
209
216
 
210
- /// Trains the quantizer and calls train_residual to train sub-quantizers
217
+ /// Trains the quantizer and calls train_encoder to train sub-quantizers
211
218
  void train(idx_t n, const float* x) override;
212
219
 
213
220
  /// Calls add_with_ids with NULL ids
@@ -227,7 +234,8 @@ struct IndexIVF : Index, IndexIVFInterface {
227
234
  idx_t n,
228
235
  const float* x,
229
236
  const idx_t* xids,
230
- const idx_t* precomputed_idx);
237
+ const idx_t* precomputed_idx,
238
+ void* inverted_list_context = nullptr);
231
239
 
232
240
  /** Encodes a set of vectors as they would appear in the inverted lists
233
241
  *
@@ -252,9 +260,15 @@ struct IndexIVF : Index, IndexIVFInterface {
252
260
  */
253
261
  void add_sa_codes(idx_t n, const uint8_t* codes, const idx_t* xids);
254
262
 
255
- /// Sub-classes that encode the residuals can train their encoders here
256
- /// does nothing by default
257
- virtual void train_residual(idx_t n, const float* x);
263
+ /** Train the encoder for the vectors.
264
+ *
265
+ * If by_residual then it is called with residuals and corresponding assign
266
+ * array, otherwise x is the raw training vectors and assign=nullptr */
267
+ virtual void train_encoder(idx_t n, const float* x, const idx_t* assign);
268
+
269
+ /// can be redefined by subclasses to indicate how many training vectors
270
+ /// they need
271
+ virtual idx_t train_encoder_num_vectors() const;
258
272
 
259
273
  void search_preassigned(
260
274
  idx_t n,
@@ -346,6 +360,24 @@ struct IndexIVF : Index, IndexIVFInterface {
346
360
  float* recons,
347
361
  const SearchParameters* params = nullptr) const override;
348
362
 
363
+ /** Similar to search, but also returns the codes corresponding to the
364
+ * stored vectors for the search results.
365
+ *
366
+ * @param codes codes (n, k, code_size)
367
+ * @param include_listno
368
+ * include the list ids in the code (in this case add
369
+ * ceil(log8(nlist)) to the code size)
370
+ */
371
+ void search_and_return_codes(
372
+ idx_t n,
373
+ const float* x,
374
+ idx_t k,
375
+ float* distances,
376
+ idx_t* labels,
377
+ uint8_t* recons,
378
+ bool include_listno = false,
379
+ const SearchParameters* params = nullptr) const;
380
+
349
381
  /** Reconstruct a vector given the location in terms of (inv list index +
350
382
  * inv list offset) instead of the id.
351
383
  *
@@ -37,30 +37,20 @@ IndexIVFAdditiveQuantizer::IndexIVFAdditiveQuantizer(
37
37
  IndexIVFAdditiveQuantizer::IndexIVFAdditiveQuantizer(AdditiveQuantizer* aq)
38
38
  : IndexIVF(), aq(aq) {}
39
39
 
40
- void IndexIVFAdditiveQuantizer::train_residual(idx_t n, const float* x) {
41
- const float* x_in = x;
40
+ void IndexIVFAdditiveQuantizer::train_encoder(
41
+ idx_t n,
42
+ const float* x,
43
+ const idx_t* assign) {
44
+ aq->train(n, x);
45
+ }
42
46
 
47
+ idx_t IndexIVFAdditiveQuantizer::train_encoder_num_vectors() const {
43
48
  size_t max_train_points = 1024 * ((size_t)1 << aq->nbits[0]);
44
49
  // we need more data to train LSQ
45
50
  if (dynamic_cast<LocalSearchQuantizer*>(aq)) {
46
51
  max_train_points = 1024 * aq->M * ((size_t)1 << aq->nbits[0]);
47
52
  }
48
-
49
- x = fvecs_maybe_subsample(
50
- d, (size_t*)&n, max_train_points, x, verbose, 1234);
51
- ScopeDeleter<float> del_x(x_in == x ? nullptr : x);
52
-
53
- if (by_residual) {
54
- std::vector<idx_t> idx(n);
55
- quantizer->assign(n, x, idx.data());
56
-
57
- std::vector<float> residuals(n * d);
58
- quantizer->compute_residual_n(n, x, residuals.data(), idx.data());
59
-
60
- aq->train(n, residuals.data());
61
- } else {
62
- aq->train(n, x);
63
- }
53
+ return max_train_points;
64
54
  }
65
55
 
66
56
  void IndexIVFAdditiveQuantizer::encode_vectors(
@@ -126,7 +116,7 @@ void IndexIVFAdditiveQuantizer::sa_decode(
126
116
  }
127
117
  }
128
118
 
129
- IndexIVFAdditiveQuantizer::~IndexIVFAdditiveQuantizer() {}
119
+ IndexIVFAdditiveQuantizer::~IndexIVFAdditiveQuantizer() = default;
130
120
 
131
121
  /*********************************************
132
122
  * AQInvertedListScanner
@@ -159,6 +149,7 @@ struct AQInvertedListScanner : InvertedListScanner {
159
149
  const float* q;
160
150
  /// following codes come from this inverted list
161
151
  void set_list(idx_t list_no, float coarse_dis) override {
152
+ this->list_no = list_no;
162
153
  if (ia.metric_type == METRIC_L2 && ia.by_residual) {
163
154
  ia.quantizer->compute_residual(q0, tmp.data(), list_no);
164
155
  q = tmp.data();
@@ -167,7 +158,7 @@ struct AQInvertedListScanner : InvertedListScanner {
167
158
  }
168
159
  }
169
160
 
170
- ~AQInvertedListScanner() {}
161
+ ~AQInvertedListScanner() = default;
171
162
  };
172
163
 
173
164
  template <bool is_IP>
@@ -198,7 +189,7 @@ struct AQInvertedListScannerDecompress : AQInvertedListScanner {
198
189
  : fvec_L2sqr(q, b.data(), aq.d);
199
190
  }
200
191
 
201
- ~AQInvertedListScannerDecompress() override {}
192
+ ~AQInvertedListScannerDecompress() override = default;
202
193
  };
203
194
 
204
195
  template <bool is_IP, Search_type_t search_type>
@@ -241,7 +232,7 @@ struct AQInvertedListScannerLUT : AQInvertedListScanner {
241
232
  aq.compute_1_distance_LUT<is_IP, search_type>(code, LUT.data());
242
233
  }
243
234
 
244
- ~AQInvertedListScannerLUT() override {}
235
+ ~AQInvertedListScannerLUT() override = default;
245
236
  };
246
237
 
247
238
  } // anonymous namespace
@@ -320,7 +311,7 @@ IndexIVFResidualQuantizer::IndexIVFResidualQuantizer(
320
311
  metric,
321
312
  search_type) {}
322
313
 
323
- IndexIVFResidualQuantizer::~IndexIVFResidualQuantizer() {}
314
+ IndexIVFResidualQuantizer::~IndexIVFResidualQuantizer() = default;
324
315
 
325
316
  /**************************************************************************************
326
317
  * IndexIVFLocalSearchQuantizer
@@ -342,7 +333,7 @@ IndexIVFLocalSearchQuantizer::IndexIVFLocalSearchQuantizer(
342
333
  IndexIVFLocalSearchQuantizer::IndexIVFLocalSearchQuantizer()
343
334
  : IndexIVFAdditiveQuantizer(&lsq) {}
344
335
 
345
- IndexIVFLocalSearchQuantizer::~IndexIVFLocalSearchQuantizer() {}
336
+ IndexIVFLocalSearchQuantizer::~IndexIVFLocalSearchQuantizer() = default;
346
337
 
347
338
  /**************************************************************************************
348
339
  * IndexIVFProductResidualQuantizer
@@ -365,7 +356,7 @@ IndexIVFProductResidualQuantizer::IndexIVFProductResidualQuantizer(
365
356
  IndexIVFProductResidualQuantizer::IndexIVFProductResidualQuantizer()
366
357
  : IndexIVFAdditiveQuantizer(&prq) {}
367
358
 
368
- IndexIVFProductResidualQuantizer::~IndexIVFProductResidualQuantizer() {}
359
+ IndexIVFProductResidualQuantizer::~IndexIVFProductResidualQuantizer() = default;
369
360
 
370
361
  /**************************************************************************************
371
362
  * IndexIVFProductLocalSearchQuantizer
@@ -388,6 +379,7 @@ IndexIVFProductLocalSearchQuantizer::IndexIVFProductLocalSearchQuantizer(
388
379
  IndexIVFProductLocalSearchQuantizer::IndexIVFProductLocalSearchQuantizer()
389
380
  : IndexIVFAdditiveQuantizer(&plsq) {}
390
381
 
391
- IndexIVFProductLocalSearchQuantizer::~IndexIVFProductLocalSearchQuantizer() {}
382
+ IndexIVFProductLocalSearchQuantizer::~IndexIVFProductLocalSearchQuantizer() =
383
+ default;
392
384
 
393
385
  } // namespace faiss
@@ -26,7 +26,6 @@ namespace faiss {
26
26
  struct IndexIVFAdditiveQuantizer : IndexIVF {
27
27
  // the quantizer
28
28
  AdditiveQuantizer* aq;
29
- bool by_residual = true;
30
29
  int use_precomputed_table = 0; // for future use
31
30
 
32
31
  using Search_type_t = AdditiveQuantizer::Search_type_t;
@@ -40,7 +39,9 @@ struct IndexIVFAdditiveQuantizer : IndexIVF {
40
39
 
41
40
  explicit IndexIVFAdditiveQuantizer(AdditiveQuantizer* aq);
42
41
 
43
- void train_residual(idx_t n, const float* x) override;
42
+ void train_encoder(idx_t n, const float* x, const idx_t* assign) override;
43
+
44
+ idx_t train_encoder_num_vectors() const override;
44
45
 
45
46
  void encode_vectors(
46
47
  idx_t n,
@@ -125,51 +125,27 @@ IndexIVFAdditiveQuantizerFastScan::IndexIVFAdditiveQuantizerFastScan() {
125
125
  is_trained = false;
126
126
  }
127
127
 
128
- IndexIVFAdditiveQuantizerFastScan::~IndexIVFAdditiveQuantizerFastScan() {}
128
+ IndexIVFAdditiveQuantizerFastScan::~IndexIVFAdditiveQuantizerFastScan() =
129
+ default;
129
130
 
130
131
  /*********************************************************
131
132
  * Training
132
133
  *********************************************************/
133
134
 
134
- void IndexIVFAdditiveQuantizerFastScan::train_residual(
135
+ idx_t IndexIVFAdditiveQuantizerFastScan::train_encoder_num_vectors() const {
136
+ return max_train_points;
137
+ }
138
+
139
+ void IndexIVFAdditiveQuantizerFastScan::train_encoder(
135
140
  idx_t n,
136
- const float* x_in) {
141
+ const float* x,
142
+ const idx_t* assign) {
137
143
  if (aq->is_trained) {
138
144
  return;
139
145
  }
140
146
 
141
- const int seed = 0x12345;
142
- size_t nt = n;
143
- const float* x = fvecs_maybe_subsample(
144
- d, &nt, max_train_points, x_in, verbose, seed);
145
- n = nt;
146
147
  if (verbose) {
147
- printf("training additive quantizer on %zd vectors\n", nt);
148
- }
149
- aq->verbose = verbose;
150
-
151
- std::unique_ptr<float[]> del_x;
152
- if (x != x_in) {
153
- del_x.reset((float*)x);
154
- }
155
-
156
- const float* trainset;
157
- std::vector<float> residuals(n * d);
158
- std::vector<idx_t> assign(n);
159
-
160
- if (by_residual) {
161
- if (verbose) {
162
- printf("computing residuals\n");
163
- }
164
- quantizer->assign(n, x, assign.data());
165
- residuals.resize(n * d);
166
- for (idx_t i = 0; i < n; i++) {
167
- quantizer->compute_residual(
168
- x + i * d, residuals.data() + i * d, assign[i]);
169
- }
170
- trainset = residuals.data();
171
- } else {
172
- trainset = x;
148
+ printf("training additive quantizer on %d vectors\n", int(n));
173
149
  }
174
150
 
175
151
  if (verbose) {
@@ -181,17 +157,16 @@ void IndexIVFAdditiveQuantizerFastScan::train_residual(
181
157
  d);
182
158
  }
183
159
  aq->verbose = verbose;
184
- aq->train(n, trainset);
160
+ aq->train(n, x);
185
161
 
186
162
  // train norm quantizer
187
163
  if (by_residual && metric_type == METRIC_L2) {
188
164
  std::vector<float> decoded_x(n * d);
189
165
  std::vector<uint8_t> x_codes(n * aq->code_size);
190
- aq->compute_codes(residuals.data(), x_codes.data(), n);
166
+ aq->compute_codes(x, x_codes.data(), n);
191
167
  aq->decode(x_codes.data(), decoded_x.data(), n);
192
168
 
193
169
  // add coarse centroids
194
- FAISS_THROW_IF_NOT(assign.size() == n);
195
170
  std::vector<float> centroid(d);
196
171
  for (idx_t i = 0; i < n; i++) {
197
172
  auto xi = decoded_x.data() + i * d;
@@ -236,7 +211,8 @@ void IndexIVFAdditiveQuantizerFastScan::estimate_norm_scale(
236
211
 
237
212
  size_t index_nprobe = nprobe;
238
213
  nprobe = 1;
239
- compute_LUT(n, x, coarse_ids.data(), coarse_dis.data(), dis_tables, biases);
214
+ CoarseQuantized cq{index_nprobe, coarse_dis.data(), coarse_ids.data()};
215
+ compute_LUT(n, x, cq, dis_tables, biases);
240
216
  nprobe = index_nprobe;
241
217
 
242
218
  float scale = 0;
@@ -338,11 +314,8 @@ void IndexIVFAdditiveQuantizerFastScan::search(
338
314
  }
339
315
 
340
316
  NormTableScaler scaler(norm_scale);
341
- if (metric_type == METRIC_L2) {
342
- search_dispatch_implem<true>(n, x, k, distances, labels, scaler);
343
- } else {
344
- search_dispatch_implem<false>(n, x, k, distances, labels, scaler);
345
- }
317
+ IndexIVFFastScan::CoarseQuantized cq{nprobe};
318
+ search_dispatch_implem(n, x, k, distances, labels, cq, &scaler);
346
319
  }
347
320
 
348
321
  /*********************************************************
@@ -408,12 +381,12 @@ bool IndexIVFAdditiveQuantizerFastScan::lookup_table_is_3d() const {
408
381
  void IndexIVFAdditiveQuantizerFastScan::compute_LUT(
409
382
  size_t n,
410
383
  const float* x,
411
- const idx_t* coarse_ids,
412
- const float*,
384
+ const CoarseQuantized& cq,
413
385
  AlignedTable<float>& dis_tables,
414
386
  AlignedTable<float>& biases) const {
415
387
  const size_t dim12 = ksub * M;
416
388
  const size_t ip_dim12 = aq->M * ksub;
389
+ const size_t nprobe = cq.nprobe;
417
390
 
418
391
  dis_tables.resize(n * dim12);
419
392
 
@@ -434,7 +407,7 @@ void IndexIVFAdditiveQuantizerFastScan::compute_LUT(
434
407
  #pragma omp for
435
408
  for (idx_t ij = 0; ij < n * nprobe; ij++) {
436
409
  int i = ij / nprobe;
437
- quantizer->reconstruct(coarse_ids[ij], c);
410
+ quantizer->reconstruct(cq.ids[ij], c);
438
411
  biases[ij] = coef * fvec_inner_product(c, x + i * d, d);
439
412
  }
440
413
  }
@@ -63,7 +63,9 @@ struct IndexIVFAdditiveQuantizerFastScan : IndexIVFFastScan {
63
63
  const IndexIVFAdditiveQuantizer& orig,
64
64
  int bbs = 32);
65
65
 
66
- void train_residual(idx_t n, const float* x) override;
66
+ void train_encoder(idx_t n, const float* x, const idx_t* assign) override;
67
+
68
+ idx_t train_encoder_num_vectors() const override;
67
69
 
68
70
  void estimate_norm_scale(idx_t n, const float* x);
69
71
 
@@ -91,8 +93,7 @@ struct IndexIVFAdditiveQuantizerFastScan : IndexIVFFastScan {
91
93
  void compute_LUT(
92
94
  size_t n,
93
95
  const float* x,
94
- const idx_t* coarse_ids,
95
- const float* coarse_dis,
96
+ const CoarseQuantized& cq,
96
97
  AlignedTable<float>& dis_tables,
97
98
  AlignedTable<float>& biases) const override;
98
99