faiss 0.2.4 → 0.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +23 -21
  4. data/ext/faiss/extconf.rb +11 -0
  5. data/ext/faiss/index.cpp +4 -4
  6. data/ext/faiss/index_binary.cpp +6 -6
  7. data/ext/faiss/product_quantizer.cpp +4 -4
  8. data/lib/faiss/version.rb +1 -1
  9. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  10. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  11. data/vendor/faiss/faiss/IVFlib.h +26 -2
  12. data/vendor/faiss/faiss/Index.cpp +36 -3
  13. data/vendor/faiss/faiss/Index.h +43 -6
  14. data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
  15. data/vendor/faiss/faiss/Index2Layer.h +6 -1
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  20. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  23. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  31. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  32. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  33. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  34. data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
  35. data/vendor/faiss/faiss/IndexFlat.h +7 -4
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  39. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  40. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  41. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  42. data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
  43. data/vendor/faiss/faiss/IndexIVF.h +35 -16
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  48. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  49. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  50. data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
  51. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
  53. data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  56. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  57. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  58. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  60. data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
  61. data/vendor/faiss/faiss/IndexLSH.h +2 -1
  62. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  63. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  66. data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
  67. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  68. data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
  69. data/vendor/faiss/faiss/IndexPQ.h +19 -5
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  73. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  74. data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
  75. data/vendor/faiss/faiss/IndexRefine.h +4 -2
  76. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  77. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  78. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  79. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  80. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
  81. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
  82. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  83. data/vendor/faiss/faiss/IndexShards.h +2 -1
  84. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  85. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  86. data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
  87. data/vendor/faiss/faiss/VectorTransform.h +22 -4
  88. data/vendor/faiss/faiss/clone_index.cpp +23 -1
  89. data/vendor/faiss/faiss/clone_index.h +3 -0
  90. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  91. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  92. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  93. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  94. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  95. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  101. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  102. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
  103. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  104. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  105. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  106. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  107. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  108. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  109. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  110. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  111. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  113. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  118. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  119. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  124. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  125. data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
  126. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  127. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  128. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
  131. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  134. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  138. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  144. data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
  145. data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
  146. data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
  147. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  148. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  149. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  150. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  151. data/vendor/faiss/faiss/index_factory.cpp +196 -7
  152. data/vendor/faiss/faiss/index_io.h +5 -0
  153. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  154. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  155. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  156. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  157. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  158. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  159. data/vendor/faiss/faiss/utils/distances.cpp +380 -56
  160. data/vendor/faiss/faiss/utils/distances.h +113 -15
  161. data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
  162. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  163. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  164. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  165. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  166. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  167. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  168. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  169. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  170. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  171. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  172. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  173. data/vendor/faiss/faiss/utils/random.h +5 -0
  174. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  175. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  176. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  177. metadata +37 -3
@@ -0,0 +1,92 @@
1
+ #pragma once
2
+
3
+ #include <cstdint>
4
+ #include <vector>
5
+
6
+ #include <faiss/Index.h>
7
+ #include <faiss/impl/platform_macros.h>
8
+
9
+ namespace faiss {
10
+
11
+ /// Index wrapper that performs rowwise normalization to [0,1], preserving
12
+ /// the coefficients. This is a vector codec index only.
13
+ ///
14
+ /// Basically, this index performs a rowwise scaling to [0,1] of every row
15
+ /// in an input dataset before calling subindex::train() and
16
+ /// subindex::sa_encode(). sa_encode() call stores the scaling coefficients
17
+ /// (scaler and minv) in the very beginning of every output code. The format:
18
+ /// [scaler][minv][subindex::sa_encode() output]
19
+ /// The de-scaling in sa_decode() is done using:
20
+ /// output_rescaled = scaler * output + minv
21
+ ///
22
+ /// An additional ::train_inplace() function is provided in order to do
23
+ /// an inplace scaling before calling subindex::train() and, thus, avoiding
24
+ /// the cloning of the input dataset, but modifying the input dataset because
25
+ /// of the scaling and the scaling back. It is up to user to call
26
+ /// this function instead of ::train()
27
+ ///
28
+ /// Derived classes provide different data types for scaling coefficients.
29
+ /// Currently, versions with fp16 and fp32 scaling coefficients are available.
30
+ /// * fp16 version adds 4 extra bytes per encoded vector
31
+ /// * fp32 version adds 8 extra bytes per encoded vector
32
+
33
+ /// Provides base functions for rowwise normalizing indices.
34
+ struct IndexRowwiseMinMaxBase : Index {
35
+ /// sub-index
36
+ Index* index;
37
+
38
+ /// whether the subindex needs to be freed in the destructor.
39
+ bool own_fields;
40
+
41
+ explicit IndexRowwiseMinMaxBase(Index* index);
42
+
43
+ IndexRowwiseMinMaxBase();
44
+ ~IndexRowwiseMinMaxBase() override;
45
+
46
+ void add(idx_t n, const float* x) override;
47
+ void search(
48
+ idx_t n,
49
+ const float* x,
50
+ idx_t k,
51
+ float* distances,
52
+ idx_t* labels,
53
+ const SearchParameters* params = nullptr) const override;
54
+
55
+ void reset() override;
56
+
57
+ virtual void train_inplace(idx_t n, float* x) = 0;
58
+ };
59
+
60
+ /// Stores scaling coefficients as fp16 values.
61
+ struct IndexRowwiseMinMaxFP16 : IndexRowwiseMinMaxBase {
62
+ explicit IndexRowwiseMinMaxFP16(Index* index);
63
+
64
+ IndexRowwiseMinMaxFP16();
65
+
66
+ void train(idx_t n, const float* x) override;
67
+ void train_inplace(idx_t n, float* x) override;
68
+
69
+ size_t sa_code_size() const override;
70
+ void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
71
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
72
+ };
73
+
74
+ /// Stores scaling coefficients as fp32 values.
75
+ struct IndexRowwiseMinMax : IndexRowwiseMinMaxBase {
76
+ explicit IndexRowwiseMinMax(Index* index);
77
+
78
+ IndexRowwiseMinMax();
79
+
80
+ void train(idx_t n, const float* x) override;
81
+ void train_inplace(idx_t n, float* x) override;
82
+
83
+ size_t sa_code_size() const override;
84
+ void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
85
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
86
+ };
87
+
88
+ /// block size for performing sa_encode and sa_decode
89
+ FAISS_API extern int rowwise_minmax_sa_encode_bs;
90
+ FAISS_API extern int rowwise_minmax_sa_decode_bs;
91
+
92
+ } // namespace faiss
@@ -16,6 +16,7 @@
16
16
 
17
17
  #include <faiss/impl/AuxIndexStructures.h>
18
18
  #include <faiss/impl/FaissAssert.h>
19
+ #include <faiss/impl/IDSelector.h>
19
20
  #include <faiss/impl/ScalarQuantizer.h>
20
21
  #include <faiss/utils/utils.h>
21
22
 
@@ -48,9 +49,11 @@ void IndexScalarQuantizer::search(
48
49
  const float* x,
49
50
  idx_t k,
50
51
  float* distances,
51
- idx_t* labels) const {
52
- FAISS_THROW_IF_NOT(k > 0);
52
+ idx_t* labels,
53
+ const SearchParameters* params) const {
54
+ const IDSelector* sel = params ? params->sel : nullptr;
53
55
 
56
+ FAISS_THROW_IF_NOT(k > 0);
54
57
  FAISS_THROW_IF_NOT(is_trained);
55
58
  FAISS_THROW_IF_NOT(
56
59
  metric_type == METRIC_L2 || metric_type == METRIC_INNER_PRODUCT);
@@ -58,7 +61,8 @@ void IndexScalarQuantizer::search(
58
61
  #pragma omp parallel
59
62
  {
60
63
  InvertedListScanner* scanner =
61
- sq.select_InvertedListScanner(metric_type, nullptr, true);
64
+ sq.select_InvertedListScanner(metric_type, nullptr, true, sel);
65
+
62
66
  ScopeDeleter1<InvertedListScanner> del(scanner);
63
67
  scanner->list_no = 0; // directly the list number
64
68
 
@@ -85,7 +89,8 @@ void IndexScalarQuantizer::search(
85
89
  }
86
90
  }
87
91
 
88
- DistanceComputer* IndexScalarQuantizer::get_distance_computer() const {
92
+ FlatCodesDistanceComputer* IndexScalarQuantizer::get_FlatCodesDistanceComputer()
93
+ const {
89
94
  ScalarQuantizer::SQDistanceComputer* dc =
90
95
  sq.get_distance_computer(metric_type);
91
96
  dc->code_size = sq.code_size;
@@ -140,7 +145,7 @@ void IndexIVFScalarQuantizer::encode_vectors(
140
145
  const idx_t* list_nos,
141
146
  uint8_t* codes,
142
147
  bool include_listnos) const {
143
- std::unique_ptr<ScalarQuantizer::Quantizer> squant(sq.select_quantizer());
148
+ std::unique_ptr<ScalarQuantizer::SQuantizer> squant(sq.select_quantizer());
144
149
  size_t coarse_size = include_listnos ? coarse_code_size() : 0;
145
150
  memset(codes, 0, (code_size + coarse_size) * n);
146
151
 
@@ -169,7 +174,7 @@ void IndexIVFScalarQuantizer::encode_vectors(
169
174
 
170
175
  void IndexIVFScalarQuantizer::sa_decode(idx_t n, const uint8_t* codes, float* x)
171
176
  const {
172
- std::unique_ptr<ScalarQuantizer::Quantizer> squant(sq.select_quantizer());
177
+ std::unique_ptr<ScalarQuantizer::SQuantizer> squant(sq.select_quantizer());
173
178
  size_t coarse_size = coarse_code_size();
174
179
 
175
180
  #pragma omp parallel if (n > 1000)
@@ -200,7 +205,7 @@ void IndexIVFScalarQuantizer::add_core(
200
205
  FAISS_THROW_IF_NOT(is_trained);
201
206
 
202
207
  size_t nadd = 0;
203
- std::unique_ptr<ScalarQuantizer::Quantizer> squant(sq.select_quantizer());
208
+ std::unique_ptr<ScalarQuantizer::SQuantizer> squant(sq.select_quantizer());
204
209
 
205
210
  DirectMapAdd dm_add(direct_map, n, xids);
206
211
 
@@ -241,22 +246,28 @@ void IndexIVFScalarQuantizer::add_core(
241
246
  }
242
247
 
243
248
  InvertedListScanner* IndexIVFScalarQuantizer::get_InvertedListScanner(
244
- bool store_pairs) const {
249
+ bool store_pairs,
250
+ const IDSelector* sel) const {
245
251
  return sq.select_InvertedListScanner(
246
- metric_type, quantizer, store_pairs, by_residual);
252
+ metric_type, quantizer, store_pairs, sel, by_residual);
247
253
  }
248
254
 
249
255
  void IndexIVFScalarQuantizer::reconstruct_from_offset(
250
256
  int64_t list_no,
251
257
  int64_t offset,
252
258
  float* recons) const {
253
- std::vector<float> centroid(d);
254
- quantizer->reconstruct(list_no, centroid.data());
255
-
256
259
  const uint8_t* code = invlists->get_single_code(list_no, offset);
257
- sq.decode(code, recons, 1);
258
- for (int i = 0; i < d; ++i) {
259
- recons[i] += centroid[i];
260
+
261
+ if (by_residual) {
262
+ std::vector<float> centroid(d);
263
+ quantizer->reconstruct(list_no, centroid.data());
264
+
265
+ sq.decode(code, recons, 1);
266
+ for (int i = 0; i < d; ++i) {
267
+ recons[i] += centroid[i];
268
+ }
269
+ } else {
270
+ sq.decode(code, recons, 1);
260
271
  }
261
272
  }
262
273
 
@@ -20,11 +20,8 @@
20
20
  namespace faiss {
21
21
 
22
22
  /**
23
- * The uniform quantizer has a range [vmin, vmax]. The range can be
24
- * the same for all dimensions (uniform) or specific per dimension
25
- * (default).
23
+ * Flat index built on a scalar quantizer.
26
24
  */
27
-
28
25
  struct IndexScalarQuantizer : IndexFlatCodes {
29
26
  /// Used to encode the vectors
30
27
  ScalarQuantizer sq;
@@ -49,9 +46,10 @@ struct IndexScalarQuantizer : IndexFlatCodes {
49
46
  const float* x,
50
47
  idx_t k,
51
48
  float* distances,
52
- idx_t* labels) const override;
49
+ idx_t* labels,
50
+ const SearchParameters* params = nullptr) const override;
53
51
 
54
- DistanceComputer* get_distance_computer() const override;
52
+ FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const override;
55
53
 
56
54
  /* standalone codec interface */
57
55
  void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
@@ -95,7 +93,8 @@ struct IndexIVFScalarQuantizer : IndexIVF {
95
93
  const idx_t* precomputed_idx) override;
96
94
 
97
95
  InvertedListScanner* get_InvertedListScanner(
98
- bool store_pairs) const override;
96
+ bool store_pairs,
97
+ const IDSelector* sel) const override;
99
98
 
100
99
  void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
101
100
  const override;
@@ -288,7 +288,10 @@ void IndexShardsTemplate<IndexT>::search(
288
288
  const component_t* x,
289
289
  idx_t k,
290
290
  distance_t* distances,
291
- idx_t* labels) const {
291
+ idx_t* labels,
292
+ const SearchParameters* params) const {
293
+ FAISS_THROW_IF_NOT_MSG(
294
+ !params, "search params not supported for this index");
292
295
  FAISS_THROW_IF_NOT(k > 0);
293
296
 
294
297
  long nshard = this->count();
@@ -87,7 +87,8 @@ struct IndexShardsTemplate : public ThreadedIndex<IndexT> {
87
87
  const component_t* x,
88
88
  idx_t k,
89
89
  distance_t* distances,
90
- idx_t* labels) const override;
90
+ idx_t* labels,
91
+ const SearchParameters* params = nullptr) const override;
91
92
 
92
93
  void train(idx_t n, const component_t* x) override;
93
94
 
@@ -16,188 +16,12 @@
16
16
 
17
17
  #include <faiss/impl/AuxIndexStructures.h>
18
18
  #include <faiss/impl/FaissAssert.h>
19
+ #include <faiss/impl/IDSelector.h>
19
20
  #include <faiss/utils/Heap.h>
20
21
  #include <faiss/utils/WorkerThread.h>
21
22
 
22
23
  namespace faiss {
23
24
 
24
- namespace {} // namespace
25
-
26
- /*****************************************************
27
- * IndexIDMap implementation
28
- *******************************************************/
29
-
30
- template <typename IndexT>
31
- IndexIDMapTemplate<IndexT>::IndexIDMapTemplate(IndexT* index)
32
- : index(index), own_fields(false) {
33
- FAISS_THROW_IF_NOT_MSG(index->ntotal == 0, "index must be empty on input");
34
- this->is_trained = index->is_trained;
35
- this->metric_type = index->metric_type;
36
- this->verbose = index->verbose;
37
- this->d = index->d;
38
- }
39
-
40
- template <typename IndexT>
41
- void IndexIDMapTemplate<IndexT>::add(
42
- idx_t,
43
- const typename IndexT::component_t*) {
44
- FAISS_THROW_MSG(
45
- "add does not make sense with IndexIDMap, "
46
- "use add_with_ids");
47
- }
48
-
49
- template <typename IndexT>
50
- void IndexIDMapTemplate<IndexT>::train(
51
- idx_t n,
52
- const typename IndexT::component_t* x) {
53
- index->train(n, x);
54
- this->is_trained = index->is_trained;
55
- }
56
-
57
- template <typename IndexT>
58
- void IndexIDMapTemplate<IndexT>::reset() {
59
- index->reset();
60
- id_map.clear();
61
- this->ntotal = 0;
62
- }
63
-
64
- template <typename IndexT>
65
- void IndexIDMapTemplate<IndexT>::add_with_ids(
66
- idx_t n,
67
- const typename IndexT::component_t* x,
68
- const typename IndexT::idx_t* xids) {
69
- index->add(n, x);
70
- for (idx_t i = 0; i < n; i++)
71
- id_map.push_back(xids[i]);
72
- this->ntotal = index->ntotal;
73
- }
74
-
75
- template <typename IndexT>
76
- void IndexIDMapTemplate<IndexT>::search(
77
- idx_t n,
78
- const typename IndexT::component_t* x,
79
- idx_t k,
80
- typename IndexT::distance_t* distances,
81
- typename IndexT::idx_t* labels) const {
82
- index->search(n, x, k, distances, labels);
83
- idx_t* li = labels;
84
- #pragma omp parallel for
85
- for (idx_t i = 0; i < n * k; i++) {
86
- li[i] = li[i] < 0 ? li[i] : id_map[li[i]];
87
- }
88
- }
89
-
90
- template <typename IndexT>
91
- void IndexIDMapTemplate<IndexT>::range_search(
92
- typename IndexT::idx_t n,
93
- const typename IndexT::component_t* x,
94
- typename IndexT::distance_t radius,
95
- RangeSearchResult* result) const {
96
- index->range_search(n, x, radius, result);
97
- #pragma omp parallel for
98
- for (idx_t i = 0; i < result->lims[result->nq]; i++) {
99
- result->labels[i] = result->labels[i] < 0 ? result->labels[i]
100
- : id_map[result->labels[i]];
101
- }
102
- }
103
-
104
- namespace {
105
-
106
- struct IDTranslatedSelector : IDSelector {
107
- const std::vector<int64_t>& id_map;
108
- const IDSelector& sel;
109
- IDTranslatedSelector(
110
- const std::vector<int64_t>& id_map,
111
- const IDSelector& sel)
112
- : id_map(id_map), sel(sel) {}
113
- bool is_member(idx_t id) const override {
114
- return sel.is_member(id_map[id]);
115
- }
116
- };
117
-
118
- } // namespace
119
-
120
- template <typename IndexT>
121
- size_t IndexIDMapTemplate<IndexT>::remove_ids(const IDSelector& sel) {
122
- // remove in sub-index first
123
- IDTranslatedSelector sel2(id_map, sel);
124
- size_t nremove = index->remove_ids(sel2);
125
-
126
- int64_t j = 0;
127
- for (idx_t i = 0; i < this->ntotal; i++) {
128
- if (sel.is_member(id_map[i])) {
129
- // remove
130
- } else {
131
- id_map[j] = id_map[i];
132
- j++;
133
- }
134
- }
135
- FAISS_ASSERT(j == index->ntotal);
136
- this->ntotal = j;
137
- id_map.resize(this->ntotal);
138
- return nremove;
139
- }
140
-
141
- template <typename IndexT>
142
- IndexIDMapTemplate<IndexT>::~IndexIDMapTemplate() {
143
- if (own_fields)
144
- delete index;
145
- }
146
-
147
- /*****************************************************
148
- * IndexIDMap2 implementation
149
- *******************************************************/
150
-
151
- template <typename IndexT>
152
- IndexIDMap2Template<IndexT>::IndexIDMap2Template(IndexT* index)
153
- : IndexIDMapTemplate<IndexT>(index) {}
154
-
155
- template <typename IndexT>
156
- void IndexIDMap2Template<IndexT>::add_with_ids(
157
- idx_t n,
158
- const typename IndexT::component_t* x,
159
- const typename IndexT::idx_t* xids) {
160
- size_t prev_ntotal = this->ntotal;
161
- IndexIDMapTemplate<IndexT>::add_with_ids(n, x, xids);
162
- for (size_t i = prev_ntotal; i < this->ntotal; i++) {
163
- rev_map[this->id_map[i]] = i;
164
- }
165
- }
166
-
167
- template <typename IndexT>
168
- void IndexIDMap2Template<IndexT>::construct_rev_map() {
169
- rev_map.clear();
170
- for (size_t i = 0; i < this->ntotal; i++) {
171
- rev_map[this->id_map[i]] = i;
172
- }
173
- }
174
-
175
- template <typename IndexT>
176
- size_t IndexIDMap2Template<IndexT>::remove_ids(const IDSelector& sel) {
177
- // This is quite inefficient
178
- size_t nremove = IndexIDMapTemplate<IndexT>::remove_ids(sel);
179
- construct_rev_map();
180
- return nremove;
181
- }
182
-
183
- template <typename IndexT>
184
- void IndexIDMap2Template<IndexT>::reconstruct(
185
- idx_t key,
186
- typename IndexT::component_t* recons) const {
187
- try {
188
- this->index->reconstruct(rev_map.at(key), recons);
189
- } catch (const std::out_of_range& e) {
190
- FAISS_THROW_FMT("key %" PRId64 " not found", key);
191
- }
192
- }
193
-
194
- // explicit template instantiations
195
-
196
- template struct IndexIDMapTemplate<Index>;
197
- template struct IndexIDMapTemplate<IndexBinary>;
198
- template struct IndexIDMap2Template<Index>;
199
- template struct IndexIDMap2Template<IndexBinary>;
200
-
201
25
  /*****************************************************
202
26
  * IndexSplitVectors implementation
203
27
  *******************************************************/
@@ -235,7 +59,10 @@ void IndexSplitVectors::search(
235
59
  const float* x,
236
60
  idx_t k,
237
61
  float* distances,
238
- idx_t* labels) const {
62
+ idx_t* labels,
63
+ const SearchParameters* params) const {
64
+ FAISS_THROW_IF_NOT_MSG(
65
+ !params, "search params not supported for this index");
239
66
  FAISS_THROW_IF_NOT_MSG(k == 1, "search implemented only for k=1");
240
67
  FAISS_THROW_IF_NOT_MSG(
241
68
  sum_d == d, "not enough indexes compared to # dimensions");
@@ -11,92 +11,13 @@
11
11
  #define META_INDEXES_H
12
12
 
13
13
  #include <faiss/Index.h>
14
+ #include <faiss/IndexIDMap.h>
14
15
  #include <faiss/IndexReplicas.h>
15
16
  #include <faiss/IndexShards.h>
16
- #include <unordered_map>
17
17
  #include <vector>
18
18
 
19
19
  namespace faiss {
20
20
 
21
- /** Index that translates search results to ids */
22
- template <typename IndexT>
23
- struct IndexIDMapTemplate : IndexT {
24
- using idx_t = typename IndexT::idx_t;
25
- using component_t = typename IndexT::component_t;
26
- using distance_t = typename IndexT::distance_t;
27
-
28
- IndexT* index; ///! the sub-index
29
- bool own_fields; ///! whether pointers are deleted in destructo
30
- std::vector<idx_t> id_map;
31
-
32
- explicit IndexIDMapTemplate(IndexT* index);
33
-
34
- /// @param xids if non-null, ids to store for the vectors (size n)
35
- void add_with_ids(idx_t n, const component_t* x, const idx_t* xids)
36
- override;
37
-
38
- /// this will fail. Use add_with_ids
39
- void add(idx_t n, const component_t* x) override;
40
-
41
- void search(
42
- idx_t n,
43
- const component_t* x,
44
- idx_t k,
45
- distance_t* distances,
46
- idx_t* labels) const override;
47
-
48
- void train(idx_t n, const component_t* x) override;
49
-
50
- void reset() override;
51
-
52
- /// remove ids adapted to IndexFlat
53
- size_t remove_ids(const IDSelector& sel) override;
54
-
55
- void range_search(
56
- idx_t n,
57
- const component_t* x,
58
- distance_t radius,
59
- RangeSearchResult* result) const override;
60
-
61
- ~IndexIDMapTemplate() override;
62
- IndexIDMapTemplate() {
63
- own_fields = false;
64
- index = nullptr;
65
- }
66
- };
67
-
68
- using IndexIDMap = IndexIDMapTemplate<Index>;
69
- using IndexBinaryIDMap = IndexIDMapTemplate<IndexBinary>;
70
-
71
- /** same as IndexIDMap but also provides an efficient reconstruction
72
- * implementation via a 2-way index */
73
- template <typename IndexT>
74
- struct IndexIDMap2Template : IndexIDMapTemplate<IndexT> {
75
- using idx_t = typename IndexT::idx_t;
76
- using component_t = typename IndexT::component_t;
77
- using distance_t = typename IndexT::distance_t;
78
-
79
- std::unordered_map<idx_t, idx_t> rev_map;
80
-
81
- explicit IndexIDMap2Template(IndexT* index);
82
-
83
- /// make the rev_map from scratch
84
- void construct_rev_map();
85
-
86
- void add_with_ids(idx_t n, const component_t* x, const idx_t* xids)
87
- override;
88
-
89
- size_t remove_ids(const IDSelector& sel) override;
90
-
91
- void reconstruct(idx_t key, component_t* recons) const override;
92
-
93
- ~IndexIDMap2Template() override {}
94
- IndexIDMap2Template() {}
95
- };
96
-
97
- using IndexIDMap2 = IndexIDMap2Template<Index>;
98
- using IndexBinaryIDMap2 = IndexIDMap2Template<IndexBinary>;
99
-
100
21
  /** splits input vectors in segments and assigns each segment to a sub-index
101
22
  * used to distribute a MultiIndexQuantizer
102
23
  */
@@ -118,7 +39,8 @@ struct IndexSplitVectors : Index {
118
39
  const float* x,
119
40
  idx_t k,
120
41
  float* distances,
121
- idx_t* labels) const override;
42
+ idx_t* labels,
43
+ const SearchParameters* params = nullptr) const override;
122
44
 
123
45
  void train(idx_t n, const float* x) override;
124
46
 
@@ -149,6 +149,10 @@ void VectorTransform::reverse_transform(idx_t, const float*, float*) const {
149
149
  FAISS_THROW_MSG("reverse transform not implemented");
150
150
  }
151
151
 
152
+ void VectorTransform::check_identical(const VectorTransform& other) const {
153
+ FAISS_THROW_IF_NOT(other.d_in == d_in && other.d_in == d_in);
154
+ }
155
+
152
156
  /*********************************************
153
157
  * LinearTransform
154
158
  *********************************************/
@@ -308,6 +312,13 @@ void LinearTransform::print_if_verbose(
308
312
  printf("]\n");
309
313
  }
310
314
 
315
+ void LinearTransform::check_identical(const VectorTransform& other_in) const {
316
+ VectorTransform::check_identical(other_in);
317
+ auto other = dynamic_cast<const LinearTransform*>(&other_in);
318
+ FAISS_THROW_IF_NOT(other);
319
+ FAISS_THROW_IF_NOT(other->A == A && other->b == b);
320
+ }
321
+
311
322
  /*********************************************
312
323
  * RandomRotationMatrix
313
324
  *********************************************/
@@ -966,6 +977,14 @@ void ITQTransform::apply_noalloc(Index::idx_t n, const float* x, float* xt)
966
977
  pca_then_itq.apply_noalloc(n, x_norm.get(), xt);
967
978
  }
968
979
 
980
+ void ITQTransform::check_identical(const VectorTransform& other_in) const {
981
+ VectorTransform::check_identical(other_in);
982
+ auto other = dynamic_cast<const ITQTransform*>(&other_in);
983
+ FAISS_THROW_IF_NOT(other);
984
+ pca_then_itq.check_identical(other->pca_then_itq);
985
+ FAISS_THROW_IF_NOT(other->mean == mean);
986
+ }
987
+
969
988
  /*********************************************
970
989
  * OPQMatrix
971
990
  *********************************************/
@@ -1226,6 +1245,14 @@ void NormalizationTransform::reverse_transform(
1226
1245
  memcpy(x, xt, sizeof(xt[0]) * n * d_in);
1227
1246
  }
1228
1247
 
1248
+ void NormalizationTransform::check_identical(
1249
+ const VectorTransform& other_in) const {
1250
+ VectorTransform::check_identical(other_in);
1251
+ auto other = dynamic_cast<const NormalizationTransform*>(&other_in);
1252
+ FAISS_THROW_IF_NOT(other);
1253
+ FAISS_THROW_IF_NOT(other->norm == norm);
1254
+ }
1255
+
1229
1256
  /*********************************************
1230
1257
  * CenteringTransform
1231
1258
  *********************************************/
@@ -1271,6 +1298,14 @@ void CenteringTransform::reverse_transform(idx_t n, const float* xt, float* x)
1271
1298
  }
1272
1299
  }
1273
1300
 
1301
+ void CenteringTransform::check_identical(
1302
+ const VectorTransform& other_in) const {
1303
+ VectorTransform::check_identical(other_in);
1304
+ auto other = dynamic_cast<const CenteringTransform*>(&other_in);
1305
+ FAISS_THROW_IF_NOT(other);
1306
+ FAISS_THROW_IF_NOT(other->mean == mean);
1307
+ }
1308
+
1274
1309
  /*********************************************
1275
1310
  * RemapDimensionsTransform
1276
1311
  *********************************************/
@@ -1335,3 +1370,11 @@ void RemapDimensionsTransform::reverse_transform(
1335
1370
  xt += d_out;
1336
1371
  }
1337
1372
  }
1373
+
1374
+ void RemapDimensionsTransform::check_identical(
1375
+ const VectorTransform& other_in) const {
1376
+ VectorTransform::check_identical(other_in);
1377
+ auto other = dynamic_cast<const RemapDimensionsTransform*>(&other_in);
1378
+ FAISS_THROW_IF_NOT(other);
1379
+ FAISS_THROW_IF_NOT(other->map == map);
1380
+ }