faiss 0.3.0 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (171) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +0 -1
  12. data/vendor/faiss/faiss/Clustering.cpp +4 -18
  13. data/vendor/faiss/faiss/Clustering.h +31 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +22 -11
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +20 -5
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
  26. data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
  27. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  28. data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
  29. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  30. data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
  31. data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
  33. data/vendor/faiss/faiss/IndexHNSW.h +12 -48
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  35. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  36. data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
  37. data/vendor/faiss/faiss/IndexIVF.h +37 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  42. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
  43. data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  46. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  48. data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
  49. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  50. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  52. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  53. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  56. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  57. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
  58. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  59. data/vendor/faiss/faiss/IndexNSG.h +10 -10
  60. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  61. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  62. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  64. data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
  65. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  66. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  67. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
  68. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  69. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  70. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  71. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  72. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  73. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  74. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  75. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  76. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  77. data/vendor/faiss/faiss/clone_index.h +3 -0
  78. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
  79. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  82. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  83. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  84. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
  85. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  88. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  89. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  90. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
  91. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  92. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
  93. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  94. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  95. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
  96. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  97. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  98. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  99. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  100. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  101. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  102. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
  103. data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
  104. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  105. data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
  106. data/vendor/faiss/faiss/impl/HNSW.h +9 -8
  107. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  108. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
  109. data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
  110. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  111. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  112. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  113. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
  114. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  115. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  116. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  117. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  118. data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
  119. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
  121. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
  122. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  123. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  124. data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
  125. data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
  126. data/vendor/faiss/faiss/impl/io.cpp +10 -10
  127. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  128. data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
  129. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
  130. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  131. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  132. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
  133. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  134. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  135. data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
  136. data/vendor/faiss/faiss/index_factory.cpp +10 -7
  137. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  138. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
  139. data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
  140. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  141. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  142. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  143. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  144. data/vendor/faiss/faiss/utils/distances.cpp +128 -74
  145. data/vendor/faiss/faiss/utils/distances.h +81 -4
  146. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  147. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  148. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  149. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  150. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  151. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  152. data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
  153. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  154. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  155. data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
  156. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  157. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  158. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  159. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
  160. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
  161. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  162. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  163. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  164. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  165. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  166. data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
  167. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  168. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  169. data/vendor/faiss/faiss/utils/utils.cpp +112 -6
  170. data/vendor/faiss/faiss/utils/utils.h +57 -20
  171. metadata +10 -3
@@ -72,6 +72,8 @@ struct SearchParametersIVF : SearchParameters {
72
72
  size_t nprobe = 1; ///< number of probes at query time
73
73
  size_t max_codes = 0; ///< max nb of codes to visit to do a query
74
74
  SearchParameters* quantizer_params = nullptr;
75
+ /// context object to pass to InvertedLists
76
+ void* inverted_list_context = nullptr;
75
77
 
76
78
  virtual ~SearchParametersIVF() {}
77
79
  };
@@ -177,6 +179,7 @@ struct IndexIVF : Index, IndexIVFInterface {
177
179
  bool own_invlists = false;
178
180
 
179
181
  size_t code_size = 0; ///< code size per vector in bytes
182
+
180
183
  /** Parallel mode determines how queries are parallelized with OpenMP
181
184
  *
182
185
  * 0 (default): split over queries
@@ -194,6 +197,10 @@ struct IndexIVF : Index, IndexIVFInterface {
194
197
  * enables reconstruct() */
195
198
  DirectMap direct_map;
196
199
 
200
+ /// do the codes in the invlists encode the vectors relative to the
201
+ /// centroids?
202
+ bool by_residual = true;
203
+
197
204
  /** The Inverted file takes a quantizer (an Index) on input,
198
205
  * which implements the function mapping a vector to a list
199
206
  * identifier.
@@ -207,7 +214,7 @@ struct IndexIVF : Index, IndexIVFInterface {
207
214
 
208
215
  void reset() override;
209
216
 
210
- /// Trains the quantizer and calls train_residual to train sub-quantizers
217
+ /// Trains the quantizer and calls train_encoder to train sub-quantizers
211
218
  void train(idx_t n, const float* x) override;
212
219
 
213
220
  /// Calls add_with_ids with NULL ids
@@ -227,7 +234,8 @@ struct IndexIVF : Index, IndexIVFInterface {
227
234
  idx_t n,
228
235
  const float* x,
229
236
  const idx_t* xids,
230
- const idx_t* precomputed_idx);
237
+ const idx_t* precomputed_idx,
238
+ void* inverted_list_context = nullptr);
231
239
 
232
240
  /** Encodes a set of vectors as they would appear in the inverted lists
233
241
  *
@@ -252,9 +260,15 @@ struct IndexIVF : Index, IndexIVFInterface {
252
260
  */
253
261
  void add_sa_codes(idx_t n, const uint8_t* codes, const idx_t* xids);
254
262
 
255
- /// Sub-classes that encode the residuals can train their encoders here
256
- /// does nothing by default
257
- virtual void train_residual(idx_t n, const float* x);
263
+ /** Train the encoder for the vectors.
264
+ *
265
+ * If by_residual then it is called with residuals and corresponding assign
266
+ * array, otherwise x is the raw training vectors and assign=nullptr */
267
+ virtual void train_encoder(idx_t n, const float* x, const idx_t* assign);
268
+
269
+ /// can be redefined by subclasses to indicate how many training vectors
270
+ /// they need
271
+ virtual idx_t train_encoder_num_vectors() const;
258
272
 
259
273
  void search_preassigned(
260
274
  idx_t n,
@@ -346,6 +360,24 @@ struct IndexIVF : Index, IndexIVFInterface {
346
360
  float* recons,
347
361
  const SearchParameters* params = nullptr) const override;
348
362
 
363
+ /** Similar to search, but also returns the codes corresponding to the
364
+ * stored vectors for the search results.
365
+ *
366
+ * @param codes codes (n, k, code_size)
367
+ * @param include_listno
368
+ * include the list ids in the code (in this case add
369
+ * ceil(log8(nlist)) to the code size)
370
+ */
371
+ void search_and_return_codes(
372
+ idx_t n,
373
+ const float* x,
374
+ idx_t k,
375
+ float* distances,
376
+ idx_t* labels,
377
+ uint8_t* recons,
378
+ bool include_listno = false,
379
+ const SearchParameters* params = nullptr) const;
380
+
349
381
  /** Reconstruct a vector given the location in terms of (inv list index +
350
382
  * inv list offset) instead of the id.
351
383
  *
@@ -37,30 +37,20 @@ IndexIVFAdditiveQuantizer::IndexIVFAdditiveQuantizer(
37
37
  IndexIVFAdditiveQuantizer::IndexIVFAdditiveQuantizer(AdditiveQuantizer* aq)
38
38
  : IndexIVF(), aq(aq) {}
39
39
 
40
- void IndexIVFAdditiveQuantizer::train_residual(idx_t n, const float* x) {
41
- const float* x_in = x;
40
+ void IndexIVFAdditiveQuantizer::train_encoder(
41
+ idx_t n,
42
+ const float* x,
43
+ const idx_t* assign) {
44
+ aq->train(n, x);
45
+ }
42
46
 
47
+ idx_t IndexIVFAdditiveQuantizer::train_encoder_num_vectors() const {
43
48
  size_t max_train_points = 1024 * ((size_t)1 << aq->nbits[0]);
44
49
  // we need more data to train LSQ
45
50
  if (dynamic_cast<LocalSearchQuantizer*>(aq)) {
46
51
  max_train_points = 1024 * aq->M * ((size_t)1 << aq->nbits[0]);
47
52
  }
48
-
49
- x = fvecs_maybe_subsample(
50
- d, (size_t*)&n, max_train_points, x, verbose, 1234);
51
- ScopeDeleter<float> del_x(x_in == x ? nullptr : x);
52
-
53
- if (by_residual) {
54
- std::vector<idx_t> idx(n);
55
- quantizer->assign(n, x, idx.data());
56
-
57
- std::vector<float> residuals(n * d);
58
- quantizer->compute_residual_n(n, x, residuals.data(), idx.data());
59
-
60
- aq->train(n, residuals.data());
61
- } else {
62
- aq->train(n, x);
63
- }
53
+ return max_train_points;
64
54
  }
65
55
 
66
56
  void IndexIVFAdditiveQuantizer::encode_vectors(
@@ -126,7 +116,7 @@ void IndexIVFAdditiveQuantizer::sa_decode(
126
116
  }
127
117
  }
128
118
 
129
- IndexIVFAdditiveQuantizer::~IndexIVFAdditiveQuantizer() {}
119
+ IndexIVFAdditiveQuantizer::~IndexIVFAdditiveQuantizer() = default;
130
120
 
131
121
  /*********************************************
132
122
  * AQInvertedListScanner
@@ -159,6 +149,7 @@ struct AQInvertedListScanner : InvertedListScanner {
159
149
  const float* q;
160
150
  /// following codes come from this inverted list
161
151
  void set_list(idx_t list_no, float coarse_dis) override {
152
+ this->list_no = list_no;
162
153
  if (ia.metric_type == METRIC_L2 && ia.by_residual) {
163
154
  ia.quantizer->compute_residual(q0, tmp.data(), list_no);
164
155
  q = tmp.data();
@@ -167,7 +158,7 @@ struct AQInvertedListScanner : InvertedListScanner {
167
158
  }
168
159
  }
169
160
 
170
- ~AQInvertedListScanner() {}
161
+ ~AQInvertedListScanner() = default;
171
162
  };
172
163
 
173
164
  template <bool is_IP>
@@ -198,7 +189,7 @@ struct AQInvertedListScannerDecompress : AQInvertedListScanner {
198
189
  : fvec_L2sqr(q, b.data(), aq.d);
199
190
  }
200
191
 
201
- ~AQInvertedListScannerDecompress() override {}
192
+ ~AQInvertedListScannerDecompress() override = default;
202
193
  };
203
194
 
204
195
  template <bool is_IP, Search_type_t search_type>
@@ -241,7 +232,7 @@ struct AQInvertedListScannerLUT : AQInvertedListScanner {
241
232
  aq.compute_1_distance_LUT<is_IP, search_type>(code, LUT.data());
242
233
  }
243
234
 
244
- ~AQInvertedListScannerLUT() override {}
235
+ ~AQInvertedListScannerLUT() override = default;
245
236
  };
246
237
 
247
238
  } // anonymous namespace
@@ -320,7 +311,7 @@ IndexIVFResidualQuantizer::IndexIVFResidualQuantizer(
320
311
  metric,
321
312
  search_type) {}
322
313
 
323
- IndexIVFResidualQuantizer::~IndexIVFResidualQuantizer() {}
314
+ IndexIVFResidualQuantizer::~IndexIVFResidualQuantizer() = default;
324
315
 
325
316
  /**************************************************************************************
326
317
  * IndexIVFLocalSearchQuantizer
@@ -342,7 +333,7 @@ IndexIVFLocalSearchQuantizer::IndexIVFLocalSearchQuantizer(
342
333
  IndexIVFLocalSearchQuantizer::IndexIVFLocalSearchQuantizer()
343
334
  : IndexIVFAdditiveQuantizer(&lsq) {}
344
335
 
345
- IndexIVFLocalSearchQuantizer::~IndexIVFLocalSearchQuantizer() {}
336
+ IndexIVFLocalSearchQuantizer::~IndexIVFLocalSearchQuantizer() = default;
346
337
 
347
338
  /**************************************************************************************
348
339
  * IndexIVFProductResidualQuantizer
@@ -365,7 +356,7 @@ IndexIVFProductResidualQuantizer::IndexIVFProductResidualQuantizer(
365
356
  IndexIVFProductResidualQuantizer::IndexIVFProductResidualQuantizer()
366
357
  : IndexIVFAdditiveQuantizer(&prq) {}
367
358
 
368
- IndexIVFProductResidualQuantizer::~IndexIVFProductResidualQuantizer() {}
359
+ IndexIVFProductResidualQuantizer::~IndexIVFProductResidualQuantizer() = default;
369
360
 
370
361
  /**************************************************************************************
371
362
  * IndexIVFProductLocalSearchQuantizer
@@ -388,6 +379,7 @@ IndexIVFProductLocalSearchQuantizer::IndexIVFProductLocalSearchQuantizer(
388
379
  IndexIVFProductLocalSearchQuantizer::IndexIVFProductLocalSearchQuantizer()
389
380
  : IndexIVFAdditiveQuantizer(&plsq) {}
390
381
 
391
- IndexIVFProductLocalSearchQuantizer::~IndexIVFProductLocalSearchQuantizer() {}
382
+ IndexIVFProductLocalSearchQuantizer::~IndexIVFProductLocalSearchQuantizer() =
383
+ default;
392
384
 
393
385
  } // namespace faiss
@@ -26,7 +26,6 @@ namespace faiss {
26
26
  struct IndexIVFAdditiveQuantizer : IndexIVF {
27
27
  // the quantizer
28
28
  AdditiveQuantizer* aq;
29
- bool by_residual = true;
30
29
  int use_precomputed_table = 0; // for future use
31
30
 
32
31
  using Search_type_t = AdditiveQuantizer::Search_type_t;
@@ -40,7 +39,9 @@ struct IndexIVFAdditiveQuantizer : IndexIVF {
40
39
 
41
40
  explicit IndexIVFAdditiveQuantizer(AdditiveQuantizer* aq);
42
41
 
43
- void train_residual(idx_t n, const float* x) override;
42
+ void train_encoder(idx_t n, const float* x, const idx_t* assign) override;
43
+
44
+ idx_t train_encoder_num_vectors() const override;
44
45
 
45
46
  void encode_vectors(
46
47
  idx_t n,
@@ -125,51 +125,27 @@ IndexIVFAdditiveQuantizerFastScan::IndexIVFAdditiveQuantizerFastScan() {
125
125
  is_trained = false;
126
126
  }
127
127
 
128
- IndexIVFAdditiveQuantizerFastScan::~IndexIVFAdditiveQuantizerFastScan() {}
128
+ IndexIVFAdditiveQuantizerFastScan::~IndexIVFAdditiveQuantizerFastScan() =
129
+ default;
129
130
 
130
131
  /*********************************************************
131
132
  * Training
132
133
  *********************************************************/
133
134
 
134
- void IndexIVFAdditiveQuantizerFastScan::train_residual(
135
+ idx_t IndexIVFAdditiveQuantizerFastScan::train_encoder_num_vectors() const {
136
+ return max_train_points;
137
+ }
138
+
139
+ void IndexIVFAdditiveQuantizerFastScan::train_encoder(
135
140
  idx_t n,
136
- const float* x_in) {
141
+ const float* x,
142
+ const idx_t* assign) {
137
143
  if (aq->is_trained) {
138
144
  return;
139
145
  }
140
146
 
141
- const int seed = 0x12345;
142
- size_t nt = n;
143
- const float* x = fvecs_maybe_subsample(
144
- d, &nt, max_train_points, x_in, verbose, seed);
145
- n = nt;
146
147
  if (verbose) {
147
- printf("training additive quantizer on %zd vectors\n", nt);
148
- }
149
- aq->verbose = verbose;
150
-
151
- std::unique_ptr<float[]> del_x;
152
- if (x != x_in) {
153
- del_x.reset((float*)x);
154
- }
155
-
156
- const float* trainset;
157
- std::vector<float> residuals(n * d);
158
- std::vector<idx_t> assign(n);
159
-
160
- if (by_residual) {
161
- if (verbose) {
162
- printf("computing residuals\n");
163
- }
164
- quantizer->assign(n, x, assign.data());
165
- residuals.resize(n * d);
166
- for (idx_t i = 0; i < n; i++) {
167
- quantizer->compute_residual(
168
- x + i * d, residuals.data() + i * d, assign[i]);
169
- }
170
- trainset = residuals.data();
171
- } else {
172
- trainset = x;
148
+ printf("training additive quantizer on %d vectors\n", int(n));
173
149
  }
174
150
 
175
151
  if (verbose) {
@@ -181,17 +157,16 @@ void IndexIVFAdditiveQuantizerFastScan::train_residual(
181
157
  d);
182
158
  }
183
159
  aq->verbose = verbose;
184
- aq->train(n, trainset);
160
+ aq->train(n, x);
185
161
 
186
162
  // train norm quantizer
187
163
  if (by_residual && metric_type == METRIC_L2) {
188
164
  std::vector<float> decoded_x(n * d);
189
165
  std::vector<uint8_t> x_codes(n * aq->code_size);
190
- aq->compute_codes(residuals.data(), x_codes.data(), n);
166
+ aq->compute_codes(x, x_codes.data(), n);
191
167
  aq->decode(x_codes.data(), decoded_x.data(), n);
192
168
 
193
169
  // add coarse centroids
194
- FAISS_THROW_IF_NOT(assign.size() == n);
195
170
  std::vector<float> centroid(d);
196
171
  for (idx_t i = 0; i < n; i++) {
197
172
  auto xi = decoded_x.data() + i * d;
@@ -236,7 +211,8 @@ void IndexIVFAdditiveQuantizerFastScan::estimate_norm_scale(
236
211
 
237
212
  size_t index_nprobe = nprobe;
238
213
  nprobe = 1;
239
- compute_LUT(n, x, coarse_ids.data(), coarse_dis.data(), dis_tables, biases);
214
+ CoarseQuantized cq{index_nprobe, coarse_dis.data(), coarse_ids.data()};
215
+ compute_LUT(n, x, cq, dis_tables, biases);
240
216
  nprobe = index_nprobe;
241
217
 
242
218
  float scale = 0;
@@ -338,11 +314,8 @@ void IndexIVFAdditiveQuantizerFastScan::search(
338
314
  }
339
315
 
340
316
  NormTableScaler scaler(norm_scale);
341
- if (metric_type == METRIC_L2) {
342
- search_dispatch_implem<true>(n, x, k, distances, labels, scaler);
343
- } else {
344
- search_dispatch_implem<false>(n, x, k, distances, labels, scaler);
345
- }
317
+ IndexIVFFastScan::CoarseQuantized cq{nprobe};
318
+ search_dispatch_implem(n, x, k, distances, labels, cq, &scaler);
346
319
  }
347
320
 
348
321
  /*********************************************************
@@ -408,12 +381,12 @@ bool IndexIVFAdditiveQuantizerFastScan::lookup_table_is_3d() const {
408
381
  void IndexIVFAdditiveQuantizerFastScan::compute_LUT(
409
382
  size_t n,
410
383
  const float* x,
411
- const idx_t* coarse_ids,
412
- const float*,
384
+ const CoarseQuantized& cq,
413
385
  AlignedTable<float>& dis_tables,
414
386
  AlignedTable<float>& biases) const {
415
387
  const size_t dim12 = ksub * M;
416
388
  const size_t ip_dim12 = aq->M * ksub;
389
+ const size_t nprobe = cq.nprobe;
417
390
 
418
391
  dis_tables.resize(n * dim12);
419
392
 
@@ -434,7 +407,7 @@ void IndexIVFAdditiveQuantizerFastScan::compute_LUT(
434
407
  #pragma omp for
435
408
  for (idx_t ij = 0; ij < n * nprobe; ij++) {
436
409
  int i = ij / nprobe;
437
- quantizer->reconstruct(coarse_ids[ij], c);
410
+ quantizer->reconstruct(cq.ids[ij], c);
438
411
  biases[ij] = coef * fvec_inner_product(c, x + i * d, d);
439
412
  }
440
413
  }
@@ -63,7 +63,9 @@ struct IndexIVFAdditiveQuantizerFastScan : IndexIVFFastScan {
63
63
  const IndexIVFAdditiveQuantizer& orig,
64
64
  int bbs = 32);
65
65
 
66
- void train_residual(idx_t n, const float* x) override;
66
+ void train_encoder(idx_t n, const float* x, const idx_t* assign) override;
67
+
68
+ idx_t train_encoder_num_vectors() const override;
67
69
 
68
70
  void estimate_norm_scale(idx_t n, const float* x);
69
71
 
@@ -91,8 +93,7 @@ struct IndexIVFAdditiveQuantizerFastScan : IndexIVFFastScan {
91
93
  void compute_LUT(
92
94
  size_t n,
93
95
  const float* x,
94
- const idx_t* coarse_ids,
95
- const float* coarse_dis,
96
+ const CoarseQuantized& cq,
96
97
  AlignedTable<float>& dis_tables,
97
98
  AlignedTable<float>& biases) const override;
98
99