faiss 0.2.4 → 0.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +23 -21
  4. data/ext/faiss/extconf.rb +11 -0
  5. data/ext/faiss/index.cpp +4 -4
  6. data/ext/faiss/index_binary.cpp +6 -6
  7. data/ext/faiss/product_quantizer.cpp +4 -4
  8. data/lib/faiss/version.rb +1 -1
  9. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  10. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  11. data/vendor/faiss/faiss/IVFlib.h +26 -2
  12. data/vendor/faiss/faiss/Index.cpp +36 -3
  13. data/vendor/faiss/faiss/Index.h +43 -6
  14. data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
  15. data/vendor/faiss/faiss/Index2Layer.h +6 -1
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  20. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  23. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  31. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  32. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  33. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  34. data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
  35. data/vendor/faiss/faiss/IndexFlat.h +7 -4
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  39. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  40. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  41. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  42. data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
  43. data/vendor/faiss/faiss/IndexIVF.h +35 -16
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  48. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  49. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  50. data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
  51. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
  53. data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  56. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  57. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  58. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  60. data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
  61. data/vendor/faiss/faiss/IndexLSH.h +2 -1
  62. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  63. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  66. data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
  67. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  68. data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
  69. data/vendor/faiss/faiss/IndexPQ.h +19 -5
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  73. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  74. data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
  75. data/vendor/faiss/faiss/IndexRefine.h +4 -2
  76. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  77. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  78. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  79. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  80. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
  81. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
  82. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  83. data/vendor/faiss/faiss/IndexShards.h +2 -1
  84. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  85. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  86. data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
  87. data/vendor/faiss/faiss/VectorTransform.h +22 -4
  88. data/vendor/faiss/faiss/clone_index.cpp +23 -1
  89. data/vendor/faiss/faiss/clone_index.h +3 -0
  90. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  91. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  92. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  93. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  94. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  95. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  101. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  102. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
  103. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  104. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  105. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  106. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  107. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  108. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  109. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  110. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  111. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  113. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  118. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  119. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  124. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  125. data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
  126. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  127. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  128. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
  131. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  134. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  138. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  144. data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
  145. data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
  146. data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
  147. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  148. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  149. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  150. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  151. data/vendor/faiss/faiss/index_factory.cpp +196 -7
  152. data/vendor/faiss/faiss/index_io.h +5 -0
  153. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  154. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  155. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  156. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  157. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  158. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  159. data/vendor/faiss/faiss/utils/distances.cpp +380 -56
  160. data/vendor/faiss/faiss/utils/distances.h +113 -15
  161. data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
  162. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  163. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  164. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  165. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  166. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  167. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  168. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  169. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  170. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  171. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  172. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  173. data/vendor/faiss/faiss/utils/random.h +5 -0
  174. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  175. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  176. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  177. metadata +37 -3
@@ -15,6 +15,7 @@
15
15
  #include <memory>
16
16
 
17
17
  #include <faiss/impl/AuxIndexStructures.h>
18
+ #include <faiss/impl/DistanceComputer.h>
18
19
  #include <faiss/impl/FaissAssert.h>
19
20
 
20
21
  namespace faiss {
@@ -157,29 +158,42 @@ void IndexPreTransform::add_with_ids(
157
158
  ntotal = index->ntotal;
158
159
  }
159
160
 
161
+ namespace {
162
+
163
+ const SearchParameters* extract_index_search_params(
164
+ const SearchParameters* params_in) {
165
+ auto params = dynamic_cast<const SearchParametersPreTransform*>(params_in);
166
+ return params ? params->index_params : params_in;
167
+ }
168
+
169
+ } // namespace
170
+
160
171
  void IndexPreTransform::search(
161
172
  idx_t n,
162
173
  const float* x,
163
174
  idx_t k,
164
175
  float* distances,
165
- idx_t* labels) const {
176
+ idx_t* labels,
177
+ const SearchParameters* params) const {
166
178
  FAISS_THROW_IF_NOT(k > 0);
167
-
168
179
  FAISS_THROW_IF_NOT(is_trained);
169
180
  const float* xt = apply_chain(n, x);
170
181
  ScopeDeleter<float> del(xt == x ? nullptr : xt);
171
- index->search(n, xt, k, distances, labels);
182
+ index->search(
183
+ n, xt, k, distances, labels, extract_index_search_params(params));
172
184
  }
173
185
 
174
186
  void IndexPreTransform::range_search(
175
187
  idx_t n,
176
188
  const float* x,
177
189
  float radius,
178
- RangeSearchResult* result) const {
190
+ RangeSearchResult* result,
191
+ const SearchParameters* params) const {
179
192
  FAISS_THROW_IF_NOT(is_trained);
180
193
  const float* xt = apply_chain(n, x);
181
194
  ScopeDeleter<float> del(xt == x ? nullptr : xt);
182
- index->range_search(n, xt, radius, result);
195
+ index->range_search(
196
+ n, xt, radius, result, extract_index_search_params(params));
183
197
  }
184
198
 
185
199
  void IndexPreTransform::reset() {
@@ -219,9 +233,9 @@ void IndexPreTransform::search_and_reconstruct(
219
233
  idx_t k,
220
234
  float* distances,
221
235
  idx_t* labels,
222
- float* recons) const {
236
+ float* recons,
237
+ const SearchParameters* params) const {
223
238
  FAISS_THROW_IF_NOT(k > 0);
224
-
225
239
  FAISS_THROW_IF_NOT(is_trained);
226
240
 
227
241
  const float* xt = apply_chain(n, x);
@@ -229,7 +243,14 @@ void IndexPreTransform::search_and_reconstruct(
229
243
 
230
244
  float* recons_temp = chain.empty() ? recons : new float[n * k * index->d];
231
245
  ScopeDeleter<float> del2((recons_temp == recons) ? nullptr : recons_temp);
232
- index->search_and_reconstruct(n, xt, k, distances, labels, recons_temp);
246
+ index->search_and_reconstruct(
247
+ n,
248
+ xt,
249
+ k,
250
+ distances,
251
+ labels,
252
+ recons_temp,
253
+ extract_index_search_params(params));
233
254
 
234
255
  // Revert transformations from last to first
235
256
  reverse_chain(n * k, recons_temp, recons);
@@ -262,6 +283,24 @@ void IndexPreTransform::sa_decode(idx_t n, const uint8_t* bytes, float* x)
262
283
  }
263
284
  }
264
285
 
286
+ void IndexPreTransform::merge_from(Index& otherIndex, idx_t add_id) {
287
+ check_compatible_for_merge(otherIndex);
288
+ auto other = static_cast<const IndexPreTransform*>(&otherIndex);
289
+ index->merge_from(*other->index, add_id);
290
+ ntotal = index->ntotal;
291
+ }
292
+
293
+ void IndexPreTransform::check_compatible_for_merge(
294
+ const Index& otherIndex) const {
295
+ auto other = dynamic_cast<const IndexPreTransform*>(&otherIndex);
296
+ FAISS_THROW_IF_NOT(other);
297
+ FAISS_THROW_IF_NOT(chain.size() == other->chain.size());
298
+ for (int i = 0; i < chain.size(); i++) {
299
+ chain[i]->check_identical(*other->chain[i]);
300
+ }
301
+ index->check_compatible_for_merge(*other->index);
302
+ }
303
+
265
304
  namespace {
266
305
 
267
306
  struct PreTransformDistanceComputer : DistanceComputer {
@@ -14,6 +14,12 @@
14
14
 
15
15
  namespace faiss {
16
16
 
17
+ struct SearchParametersPreTransform : SearchParameters {
18
+ // nothing to add here.
19
+ // as such, encapsulating the search params is considered optional
20
+ SearchParameters* index_params = nullptr;
21
+ };
22
+
17
23
  /** Index that applies a LinearTransform transform on vectors before
18
24
  * handing them over to a sub-index */
19
25
  struct IndexPreTransform : Index {
@@ -48,14 +54,16 @@ struct IndexPreTransform : Index {
48
54
  const float* x,
49
55
  idx_t k,
50
56
  float* distances,
51
- idx_t* labels) const override;
57
+ idx_t* labels,
58
+ const SearchParameters* params = nullptr) const override;
52
59
 
53
60
  /* range search, no attempt is done to change the radius */
54
61
  void range_search(
55
62
  idx_t n,
56
63
  const float* x,
57
64
  float radius,
58
- RangeSearchResult* result) const override;
65
+ RangeSearchResult* result,
66
+ const SearchParameters* params = nullptr) const override;
59
67
 
60
68
  void reconstruct(idx_t key, float* recons) const override;
61
69
 
@@ -67,7 +75,8 @@ struct IndexPreTransform : Index {
67
75
  idx_t k,
68
76
  float* distances,
69
77
  idx_t* labels,
70
- float* recons) const override;
78
+ float* recons,
79
+ const SearchParameters* params = nullptr) const override;
71
80
 
72
81
  /// apply the transforms in the chain. The returned float * may be
73
82
  /// equal to x, otherwise it should be deallocated.
@@ -84,6 +93,9 @@ struct IndexPreTransform : Index {
84
93
  void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
85
94
  void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
86
95
 
96
+ void merge_from(Index& otherIndex, idx_t add_id = 0) override;
97
+ void check_compatible_for_merge(const Index& otherIndex) const override;
98
+
87
99
  ~IndexPreTransform() override;
88
100
  };
89
101
 
@@ -95,9 +95,11 @@ void IndexRefine::search(
95
95
  const float* x,
96
96
  idx_t k,
97
97
  float* distances,
98
- idx_t* labels) const {
98
+ idx_t* labels,
99
+ const SearchParameters* params) const {
100
+ FAISS_THROW_IF_NOT_MSG(
101
+ !params, "search params not supported for this index");
99
102
  FAISS_THROW_IF_NOT(k > 0);
100
-
101
103
  FAISS_THROW_IF_NOT(is_trained);
102
104
  idx_t k_base = idx_t(k * k_factor);
103
105
  idx_t* base_labels = labels;
@@ -222,9 +224,11 @@ void IndexRefineFlat::search(
222
224
  const float* x,
223
225
  idx_t k,
224
226
  float* distances,
225
- idx_t* labels) const {
227
+ idx_t* labels,
228
+ const SearchParameters* params) const {
229
+ FAISS_THROW_IF_NOT_MSG(
230
+ !params, "search params not supported for this index");
226
231
  FAISS_THROW_IF_NOT(k > 0);
227
-
228
232
  FAISS_THROW_IF_NOT(is_trained);
229
233
  idx_t k_base = idx_t(k * k_factor);
230
234
  idx_t* base_labels = labels;
@@ -44,7 +44,8 @@ struct IndexRefine : Index {
44
44
  const float* x,
45
45
  idx_t k,
46
46
  float* distances,
47
- idx_t* labels) const override;
47
+ idx_t* labels,
48
+ const SearchParameters* params = nullptr) const override;
48
49
 
49
50
  // reconstruct is routed to the refine_index
50
51
  void reconstruct(idx_t key, float* recons) const override;
@@ -76,7 +77,8 @@ struct IndexRefineFlat : IndexRefine {
76
77
  const float* x,
77
78
  idx_t k,
78
79
  float* distances,
79
- idx_t* labels) const override;
80
+ idx_t* labels,
81
+ const SearchParameters* params = nullptr) const override;
80
82
  };
81
83
 
82
84
  } // namespace faiss
@@ -108,9 +108,11 @@ void IndexReplicasTemplate<IndexT>::search(
108
108
  const component_t* x,
109
109
  idx_t k,
110
110
  distance_t* distances,
111
- idx_t* labels) const {
111
+ idx_t* labels,
112
+ const SearchParameters* params) const {
113
+ FAISS_THROW_IF_NOT_MSG(
114
+ !params, "search params not supported for this index");
112
115
  FAISS_THROW_IF_NOT(k > 0);
113
-
114
116
  FAISS_THROW_IF_NOT_MSG(this->count() > 0, "no replicas in index");
115
117
 
116
118
  if (n == 0) {
@@ -65,7 +65,8 @@ class IndexReplicasTemplate : public ThreadedIndex<IndexT> {
65
65
  const component_t* x,
66
66
  idx_t k,
67
67
  distance_t* distances,
68
- idx_t* labels) const override;
68
+ idx_t* labels,
69
+ const SearchParameters* params = nullptr) const override;
69
70
 
70
71
  /// reconstructs from the first index
71
72
  void reconstruct(idx_t, component_t* v) const override;
@@ -0,0 +1,438 @@
1
+ #include <faiss/IndexRowwiseMinMax.h>
2
+
3
+ #include <cstdint>
4
+ #include <cstring>
5
+ #include <limits>
6
+
7
+ #include <faiss/impl/FaissAssert.h>
8
+ #include <faiss/utils/fp16.h>
9
+
10
+ namespace faiss {
11
+
12
+ namespace {
13
+
14
+ using idx_t = faiss::Index::idx_t;
15
+
16
+ struct StorageMinMaxFP16 {
17
+ uint16_t scaler;
18
+ uint16_t minv;
19
+
20
+ inline void from_floats(const float float_scaler, const float float_minv) {
21
+ scaler = encode_fp16(float_scaler);
22
+ minv = encode_fp16(float_minv);
23
+ }
24
+
25
+ inline void to_floats(float& float_scaler, float& float_minv) const {
26
+ float_scaler = decode_fp16(scaler);
27
+ float_minv = decode_fp16(minv);
28
+ }
29
+ };
30
+
31
+ struct StorageMinMaxFP32 {
32
+ float scaler;
33
+ float minv;
34
+
35
+ inline void from_floats(const float float_scaler, const float float_minv) {
36
+ scaler = float_scaler;
37
+ minv = float_minv;
38
+ }
39
+
40
+ inline void to_floats(float& float_scaler, float& float_minv) const {
41
+ float_scaler = scaler;
42
+ float_minv = minv;
43
+ }
44
+ };
45
+
46
+ template <typename StorageMinMaxT>
47
+ void sa_encode_impl(
48
+ const IndexRowwiseMinMaxBase* const index,
49
+ const idx_t n_input,
50
+ const float* x_input,
51
+ uint8_t* bytes_output) {
52
+ // process chunks
53
+ const size_t chunk_size = rowwise_minmax_sa_encode_bs;
54
+
55
+ // useful variables
56
+ const Index* const sub_index = index->index;
57
+ const int d = index->d;
58
+
59
+ // the code size of the subindex
60
+ const size_t old_code_size = sub_index->sa_code_size();
61
+ // the code size of the index
62
+ const size_t new_code_size = index->sa_code_size();
63
+
64
+ // allocate tmp buffers
65
+ std::vector<float> tmp(chunk_size * d);
66
+ std::vector<StorageMinMaxT> minmax(chunk_size);
67
+
68
+ // all the elements to process
69
+ size_t n_left = n_input;
70
+
71
+ const float* __restrict x = x_input;
72
+ uint8_t* __restrict bytes = bytes_output;
73
+
74
+ while (n_left > 0) {
75
+ // current portion to be processed
76
+ const idx_t n = std::min(n_left, chunk_size);
77
+
78
+ // allocate a temporary buffer and do the rescale
79
+ for (idx_t i = 0; i < n; i++) {
80
+ // compute min & max values
81
+ float minv = std::numeric_limits<float>::max();
82
+ float maxv = std::numeric_limits<float>::lowest();
83
+
84
+ const float* const vec_in = x + i * d;
85
+ for (idx_t j = 0; j < d; j++) {
86
+ minv = std::min(minv, vec_in[j]);
87
+ maxv = std::max(maxv, vec_in[j]);
88
+ }
89
+
90
+ // save the coefficients
91
+ const float scaler = maxv - minv;
92
+ minmax[i].from_floats(scaler, minv);
93
+
94
+ // and load them back, because the coefficients might
95
+ // be modified.
96
+ float actual_scaler = 0;
97
+ float actual_minv = 0;
98
+ minmax[i].to_floats(actual_scaler, actual_minv);
99
+
100
+ float* const vec_out = tmp.data() + i * d;
101
+ if (actual_scaler == 0) {
102
+ for (idx_t j = 0; j < d; j++) {
103
+ vec_out[j] = 0;
104
+ }
105
+ } else {
106
+ float inv_actual_scaler = 1.0f / actual_scaler;
107
+ for (idx_t j = 0; j < d; j++) {
108
+ vec_out[j] = (vec_in[j] - actual_minv) * inv_actual_scaler;
109
+ }
110
+ }
111
+ }
112
+
113
+ // do the coding
114
+ sub_index->sa_encode(n, tmp.data(), bytes);
115
+
116
+ // rearrange
117
+ for (idx_t i = n; (i--) > 0;) {
118
+ // move a single index
119
+ std::memmove(
120
+ bytes + i * new_code_size + (new_code_size - old_code_size),
121
+ bytes + i * old_code_size,
122
+ old_code_size);
123
+
124
+ // save min & max values
125
+ StorageMinMaxT* fpv = reinterpret_cast<StorageMinMaxT*>(
126
+ bytes + i * new_code_size);
127
+ *fpv = minmax[i];
128
+ }
129
+
130
+ // next chunk
131
+ x += n * d;
132
+ bytes += n * new_code_size;
133
+
134
+ n_left -= n;
135
+ }
136
+ }
137
+
138
+ template <typename StorageMinMaxT>
139
+ void sa_decode_impl(
140
+ const IndexRowwiseMinMaxBase* const index,
141
+ const idx_t n_input,
142
+ const uint8_t* bytes_input,
143
+ float* x_output) {
144
+ // process chunks
145
+ const size_t chunk_size = rowwise_minmax_sa_decode_bs;
146
+
147
+ // useful variables
148
+ const Index* const sub_index = index->index;
149
+ const int d = index->d;
150
+
151
+ // the code size of the subindex
152
+ const size_t old_code_size = sub_index->sa_code_size();
153
+ // the code size of the index
154
+ const size_t new_code_size = index->sa_code_size();
155
+
156
+ // allocate tmp buffers
157
+ std::vector<uint8_t> tmp(
158
+ (chunk_size < n_input ? chunk_size : n_input) * old_code_size);
159
+ std::vector<StorageMinMaxFP16> minmax(
160
+ (chunk_size < n_input ? chunk_size : n_input));
161
+
162
+ // all the elements to process
163
+ size_t n_left = n_input;
164
+
165
+ const uint8_t* __restrict bytes = bytes_input;
166
+ float* __restrict x = x_output;
167
+
168
+ while (n_left > 0) {
169
+ // current portion to be processed
170
+ const idx_t n = std::min(n_left, chunk_size);
171
+
172
+ // rearrange
173
+ for (idx_t i = 0; i < n; i++) {
174
+ std::memcpy(
175
+ tmp.data() + i * old_code_size,
176
+ bytes + i * new_code_size + (new_code_size - old_code_size),
177
+ old_code_size);
178
+ }
179
+
180
+ // decode
181
+ sub_index->sa_decode(n, tmp.data(), x);
182
+
183
+ // scale back
184
+ for (idx_t i = 0; i < n; i++) {
185
+ const uint8_t* const vec_in = bytes + i * new_code_size;
186
+ StorageMinMaxT fpv =
187
+ *(reinterpret_cast<const StorageMinMaxT*>(vec_in));
188
+
189
+ float scaler = 0;
190
+ float minv = 0;
191
+ fpv.to_floats(scaler, minv);
192
+
193
+ float* const __restrict vec = x + d * i;
194
+
195
+ for (idx_t j = 0; j < d; j++) {
196
+ vec[j] = vec[j] * scaler + minv;
197
+ }
198
+ }
199
+
200
+ // next chunk
201
+ bytes += n * new_code_size;
202
+ x += n * d;
203
+
204
+ n_left -= n;
205
+ }
206
+ }
207
+
208
+ //
209
+ template <typename StorageMinMaxT>
210
+ void train_inplace_impl(
211
+ IndexRowwiseMinMaxBase* const index,
212
+ idx_t n,
213
+ float* x) {
214
+ // useful variables
215
+ Index* const sub_index = index->index;
216
+ const int d = index->d;
217
+
218
+ // save normalizing coefficients
219
+ std::vector<StorageMinMaxT> minmax(n);
220
+
221
+ // normalize
222
+ #pragma omp for
223
+ for (idx_t i = 0; i < n; i++) {
224
+ // compute min & max values
225
+ float minv = std::numeric_limits<float>::max();
226
+ float maxv = std::numeric_limits<float>::lowest();
227
+
228
+ float* const vec = x + i * d;
229
+ for (idx_t j = 0; j < d; j++) {
230
+ minv = std::min(minv, vec[j]);
231
+ maxv = std::max(maxv, vec[j]);
232
+ }
233
+
234
+ // save the coefficients
235
+ const float scaler = maxv - minv;
236
+ minmax[i].from_floats(scaler, minv);
237
+
238
+ // and load them back, because the coefficients might
239
+ // be modified.
240
+ float actual_scaler = 0;
241
+ float actual_minv = 0;
242
+ minmax[i].to_floats(actual_scaler, actual_minv);
243
+
244
+ if (actual_scaler == 0) {
245
+ for (idx_t j = 0; j < d; j++) {
246
+ vec[j] = 0;
247
+ }
248
+ } else {
249
+ float inv_actual_scaler = 1.0f / actual_scaler;
250
+ for (idx_t j = 0; j < d; j++) {
251
+ vec[j] = (vec[j] - actual_minv) * inv_actual_scaler;
252
+ }
253
+ }
254
+ }
255
+
256
+ // train the subindex
257
+ sub_index->train(n, x);
258
+
259
+ // rescale data back
260
+ for (idx_t i = 0; i < n; i++) {
261
+ float scaler = 0;
262
+ float minv = 0;
263
+ minmax[i].to_floats(scaler, minv);
264
+
265
+ float* const vec = x + i * d;
266
+
267
+ for (idx_t j = 0; j < d; j++) {
268
+ vec[j] = vec[j] * scaler + minv;
269
+ }
270
+ }
271
+ }
272
+
273
+ //
274
+ template <typename StorageMinMaxT>
275
+ void train_impl(IndexRowwiseMinMaxBase* const index, idx_t n, const float* x) {
276
+ // the default training that creates a copy of the input data
277
+
278
+ // useful variables
279
+ Index* const sub_index = index->index;
280
+ const int d = index->d;
281
+
282
+ // temp buffer
283
+ std::vector<float> tmp(n * d);
284
+
285
+ #pragma omp for
286
+ for (idx_t i = 0; i < n; i++) {
287
+ // compute min & max values
288
+ float minv = std::numeric_limits<float>::max();
289
+ float maxv = std::numeric_limits<float>::lowest();
290
+
291
+ const float* const __restrict vec_in = x + i * d;
292
+ for (idx_t j = 0; j < d; j++) {
293
+ minv = std::min(minv, vec_in[j]);
294
+ maxv = std::max(maxv, vec_in[j]);
295
+ }
296
+
297
+ const float scaler = maxv - minv;
298
+
299
+ // save the coefficients
300
+ StorageMinMaxT storage;
301
+ storage.from_floats(scaler, minv);
302
+
303
+ // and load them back, because the coefficients might
304
+ // be modified.
305
+ float actual_scaler = 0;
306
+ float actual_minv = 0;
307
+ storage.to_floats(actual_scaler, actual_minv);
308
+
309
+ float* const __restrict vec_out = tmp.data() + i * d;
310
+ if (actual_scaler == 0) {
311
+ for (idx_t j = 0; j < d; j++) {
312
+ vec_out[j] = 0;
313
+ }
314
+ } else {
315
+ float inv_actual_scaler = 1.0f / actual_scaler;
316
+ for (idx_t j = 0; j < d; j++) {
317
+ vec_out[j] = (vec_in[j] - actual_minv) * inv_actual_scaler;
318
+ }
319
+ }
320
+ }
321
+
322
+ sub_index->train(n, tmp.data());
323
+ }
324
+
325
+ } // namespace
326
+
327
+ // block size for performing sa_encode and sa_decode
328
+ int rowwise_minmax_sa_encode_bs = 16384;
329
+ int rowwise_minmax_sa_decode_bs = 16384;
330
+
331
+ /*********************************************************
332
+ * IndexRowwiseMinMaxBase implementation
333
+ ********************************************************/
334
+
335
+ IndexRowwiseMinMaxBase::IndexRowwiseMinMaxBase(Index* index)
336
+ : Index(index->d, index->metric_type),
337
+ index{index},
338
+ own_fields{false} {}
339
+
340
+ IndexRowwiseMinMaxBase::IndexRowwiseMinMaxBase()
341
+ : index{nullptr}, own_fields{false} {}
342
+
343
+ IndexRowwiseMinMaxBase::~IndexRowwiseMinMaxBase() {
344
+ if (own_fields) {
345
+ delete index;
346
+ index = nullptr;
347
+ }
348
+ }
349
+
350
+ void IndexRowwiseMinMaxBase::add(idx_t, const float*) {
351
+ FAISS_THROW_MSG("add not implemented for this type of index");
352
+ }
353
+
354
+ void IndexRowwiseMinMaxBase::search(
355
+ idx_t,
356
+ const float*,
357
+ idx_t,
358
+ float*,
359
+ idx_t*,
360
+ const SearchParameters*) const {
361
+ FAISS_THROW_MSG("search not implemented for this type of index");
362
+ }
363
+
364
+ void IndexRowwiseMinMaxBase::reset() {
365
+ FAISS_THROW_MSG("reset not implemented for this type of index");
366
+ }
367
+
368
+ /*********************************************************
369
+ * IndexRowwiseMinMaxFP16 implementation
370
+ ********************************************************/
371
+
372
+ IndexRowwiseMinMaxFP16::IndexRowwiseMinMaxFP16(Index* index)
373
+ : IndexRowwiseMinMaxBase(index) {}
374
+
375
+ IndexRowwiseMinMaxFP16::IndexRowwiseMinMaxFP16() : IndexRowwiseMinMaxBase() {}
376
+
377
+ size_t IndexRowwiseMinMaxFP16::sa_code_size() const {
378
+ return index->sa_code_size() + 2 * sizeof(uint16_t);
379
+ }
380
+
381
+ void IndexRowwiseMinMaxFP16::sa_encode(
382
+ idx_t n_input,
383
+ const float* x_input,
384
+ uint8_t* bytes_output) const {
385
+ sa_encode_impl<StorageMinMaxFP16>(this, n_input, x_input, bytes_output);
386
+ }
387
+
388
+ void IndexRowwiseMinMaxFP16::sa_decode(
389
+ idx_t n_input,
390
+ const uint8_t* bytes_input,
391
+ float* x_output) const {
392
+ sa_decode_impl<StorageMinMaxFP16>(this, n_input, bytes_input, x_output);
393
+ }
394
+
395
+ void IndexRowwiseMinMaxFP16::train(idx_t n, const float* x) {
396
+ train_impl<StorageMinMaxFP16>(this, n, x);
397
+ }
398
+
399
+ void IndexRowwiseMinMaxFP16::train_inplace(idx_t n, float* x) {
400
+ train_inplace_impl<StorageMinMaxFP16>(this, n, x);
401
+ }
402
+
403
+ /*********************************************************
404
+ * IndexRowwiseMinMax implementation
405
+ ********************************************************/
406
+
407
+ IndexRowwiseMinMax::IndexRowwiseMinMax(Index* index)
408
+ : IndexRowwiseMinMaxBase(index) {}
409
+
410
+ IndexRowwiseMinMax::IndexRowwiseMinMax() : IndexRowwiseMinMaxBase() {}
411
+
412
+ size_t IndexRowwiseMinMax::sa_code_size() const {
413
+ return index->sa_code_size() + 2 * sizeof(float);
414
+ }
415
+
416
+ void IndexRowwiseMinMax::sa_encode(
417
+ idx_t n_input,
418
+ const float* x_input,
419
+ uint8_t* bytes_output) const {
420
+ sa_encode_impl<StorageMinMaxFP32>(this, n_input, x_input, bytes_output);
421
+ }
422
+
423
+ void IndexRowwiseMinMax::sa_decode(
424
+ idx_t n_input,
425
+ const uint8_t* bytes_input,
426
+ float* x_output) const {
427
+ sa_decode_impl<StorageMinMaxFP32>(this, n_input, bytes_input, x_output);
428
+ }
429
+
430
+ void IndexRowwiseMinMax::train(idx_t n, const float* x) {
431
+ train_impl<StorageMinMaxFP32>(this, n, x);
432
+ }
433
+
434
+ void IndexRowwiseMinMax::train_inplace(idx_t n, float* x) {
435
+ train_inplace_impl<StorageMinMaxFP32>(this, n, x);
436
+ }
437
+
438
+ } // namespace faiss