faiss 0.2.3 → 0.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +23 -21
  5. data/ext/faiss/extconf.rb +11 -0
  6. data/ext/faiss/index.cpp +4 -4
  7. data/ext/faiss/index_binary.cpp +6 -6
  8. data/ext/faiss/product_quantizer.cpp +4 -4
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  11. data/vendor/faiss/faiss/Clustering.cpp +32 -0
  12. data/vendor/faiss/faiss/Clustering.h +14 -0
  13. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  14. data/vendor/faiss/faiss/IVFlib.h +26 -2
  15. data/vendor/faiss/faiss/Index.cpp +36 -3
  16. data/vendor/faiss/faiss/Index.h +43 -6
  17. data/vendor/faiss/faiss/Index2Layer.cpp +24 -93
  18. data/vendor/faiss/faiss/Index2Layer.h +8 -17
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +610 -0
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +253 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  22. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  23. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  24. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  25. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  26. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  28. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  30. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  31. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  32. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  33. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  34. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  35. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  36. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  37. data/vendor/faiss/faiss/IndexFlat.cpp +52 -69
  38. data/vendor/faiss/faiss/IndexFlat.h +16 -19
  39. data/vendor/faiss/faiss/IndexFlatCodes.cpp +101 -0
  40. data/vendor/faiss/faiss/IndexFlatCodes.h +59 -0
  41. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  42. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  43. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  44. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  45. data/vendor/faiss/faiss/IndexIVF.cpp +200 -40
  46. data/vendor/faiss/faiss/IndexIVF.h +59 -22
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +393 -0
  48. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +183 -0
  49. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  50. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  51. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  52. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  53. data/vendor/faiss/faiss/IndexIVFFlat.cpp +43 -26
  54. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  55. data/vendor/faiss/faiss/IndexIVFPQ.cpp +238 -53
  56. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -2
  57. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  58. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  59. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  60. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  61. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +63 -40
  62. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +23 -7
  63. data/vendor/faiss/faiss/IndexLSH.cpp +8 -32
  64. data/vendor/faiss/faiss/IndexLSH.h +4 -16
  65. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  66. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  67. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -5
  68. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  69. data/vendor/faiss/faiss/IndexNSG.cpp +37 -5
  70. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  71. data/vendor/faiss/faiss/IndexPQ.cpp +108 -120
  72. data/vendor/faiss/faiss/IndexPQ.h +21 -22
  73. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  74. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  75. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  76. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  77. data/vendor/faiss/faiss/IndexRefine.cpp +36 -4
  78. data/vendor/faiss/faiss/IndexRefine.h +14 -2
  79. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  80. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +28 -43
  84. data/vendor/faiss/faiss/IndexScalarQuantizer.h +8 -23
  85. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  86. data/vendor/faiss/faiss/IndexShards.h +2 -1
  87. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  88. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  89. data/vendor/faiss/faiss/VectorTransform.cpp +45 -1
  90. data/vendor/faiss/faiss/VectorTransform.h +25 -4
  91. data/vendor/faiss/faiss/clone_index.cpp +26 -3
  92. data/vendor/faiss/faiss/clone_index.h +3 -0
  93. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  94. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  95. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  101. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  102. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  104. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  105. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -6
  106. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  107. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  108. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  109. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  110. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  111. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  112. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  113. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  114. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  115. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  116. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  117. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  120. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  121. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  122. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  123. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  124. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +331 -29
  125. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +110 -19
  126. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  127. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  128. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  129. data/vendor/faiss/faiss/impl/HNSW.cpp +133 -32
  130. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  131. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  132. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  133. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +378 -217
  134. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +106 -29
  135. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  136. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  137. data/vendor/faiss/faiss/impl/NSG.cpp +1 -4
  138. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  139. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  140. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  141. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  142. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  143. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  144. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +521 -55
  145. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +94 -16
  146. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  147. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +108 -191
  148. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  149. data/vendor/faiss/faiss/impl/index_read.cpp +338 -24
  150. data/vendor/faiss/faiss/impl/index_write.cpp +300 -18
  151. data/vendor/faiss/faiss/impl/io.cpp +1 -1
  152. data/vendor/faiss/faiss/impl/io_macros.h +20 -0
  153. data/vendor/faiss/faiss/impl/kmeans1d.cpp +303 -0
  154. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  159. data/vendor/faiss/faiss/index_factory.cpp +772 -412
  160. data/vendor/faiss/faiss/index_factory.h +3 -0
  161. data/vendor/faiss/faiss/index_io.h +5 -0
  162. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  163. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  164. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  165. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  166. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  167. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  168. data/vendor/faiss/faiss/utils/distances.cpp +384 -58
  169. data/vendor/faiss/faiss/utils/distances.h +149 -18
  170. data/vendor/faiss/faiss/utils/distances_simd.cpp +776 -6
  171. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  172. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  173. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  174. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  175. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  176. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  177. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  178. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  179. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  180. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  181. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  182. data/vendor/faiss/faiss/utils/random.h +5 -0
  183. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  184. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  185. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  186. data/vendor/faiss/faiss/utils/utils.h +1 -1
  187. metadata +46 -5
  188. data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
  189. data/vendor/faiss/faiss/IndexResidual.h +0 -152
@@ -5,20 +5,15 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
- #include "faiss/impl/ResidualQuantizer.h"
11
- #include <faiss/impl/FaissAssert.h>
12
8
  #include <faiss/impl/ResidualQuantizer.h>
13
- #include "faiss/utils/utils.h"
14
9
 
10
+ #include <algorithm>
11
+ #include <cmath>
15
12
  #include <cstddef>
16
13
  #include <cstdio>
17
14
  #include <cstring>
18
15
  #include <memory>
19
16
 
20
- #include <algorithm>
21
-
22
17
  #include <faiss/IndexFlat.h>
23
18
  #include <faiss/VectorTransform.h>
24
19
  #include <faiss/impl/AuxIndexStructures.h>
@@ -28,39 +23,109 @@
28
23
  #include <faiss/utils/hamming.h>
29
24
  #include <faiss/utils/utils.h>
30
25
 
26
+ extern "C" {
27
+
28
+ // general matrix multiplication
29
+ int sgemm_(
30
+ const char* transa,
31
+ const char* transb,
32
+ FINTEGER* m,
33
+ FINTEGER* n,
34
+ FINTEGER* k,
35
+ const float* alpha,
36
+ const float* a,
37
+ FINTEGER* lda,
38
+ const float* b,
39
+ FINTEGER* ldb,
40
+ float* beta,
41
+ float* c,
42
+ FINTEGER* ldc);
43
+
44
+ // http://www.netlib.org/clapack/old/single/sgels.c
45
+ // solve least squares
46
+
47
+ int sgelsd_(
48
+ FINTEGER* m,
49
+ FINTEGER* n,
50
+ FINTEGER* nrhs,
51
+ float* a,
52
+ FINTEGER* lda,
53
+ float* b,
54
+ FINTEGER* ldb,
55
+ float* s,
56
+ float* rcond,
57
+ FINTEGER* rank,
58
+ float* work,
59
+ FINTEGER* lwork,
60
+ FINTEGER* iwork,
61
+ FINTEGER* info);
62
+ }
63
+
31
64
  namespace faiss {
32
65
 
33
66
  ResidualQuantizer::ResidualQuantizer()
34
67
  : train_type(Train_progressive_dim),
35
- max_beam_size(30),
36
- max_mem_distances(5 * (size_t(1) << 30)), // 5 GiB
68
+ niter_codebook_refine(5),
69
+ max_beam_size(5),
70
+ use_beam_LUT(0),
37
71
  assign_index_factory(nullptr) {
38
72
  d = 0;
39
73
  M = 0;
40
74
  verbose = false;
41
75
  }
42
76
 
43
- ResidualQuantizer::ResidualQuantizer(size_t d, const std::vector<size_t>& nbits)
77
+ ResidualQuantizer::ResidualQuantizer(
78
+ size_t d,
79
+ const std::vector<size_t>& nbits,
80
+ Search_type_t search_type)
44
81
  : ResidualQuantizer() {
82
+ this->search_type = search_type;
45
83
  this->d = d;
46
84
  M = nbits.size();
47
85
  this->nbits = nbits;
48
86
  set_derived_values();
49
87
  }
50
88
 
51
- ResidualQuantizer::ResidualQuantizer(size_t d, size_t M, size_t nbits)
52
- : ResidualQuantizer(d, std::vector<size_t>(M, nbits)) {}
89
+ ResidualQuantizer::ResidualQuantizer(
90
+ size_t d,
91
+ size_t M,
92
+ size_t nbits,
93
+ Search_type_t search_type)
94
+ : ResidualQuantizer(d, std::vector<size_t>(M, nbits), search_type) {}
95
+
96
+ void ResidualQuantizer::initialize_from(
97
+ const ResidualQuantizer& other,
98
+ int skip_M) {
99
+ FAISS_THROW_IF_NOT(M + skip_M <= other.M);
100
+ FAISS_THROW_IF_NOT(skip_M >= 0);
101
+
102
+ Search_type_t this_search_type = search_type;
103
+ int this_M = M;
104
+
105
+ // a first good approximation: override everything
106
+ *this = other;
107
+
108
+ // adjust derived values
109
+ M = this_M;
110
+ search_type = this_search_type;
111
+ nbits.resize(M);
112
+ memcpy(nbits.data(),
113
+ other.nbits.data() + skip_M,
114
+ nbits.size() * sizeof(nbits[0]));
53
115
 
54
- namespace {
116
+ set_derived_values();
55
117
 
56
- void fvec_sub(size_t d, const float* a, const float* b, float* c) {
57
- for (size_t i = 0; i < d; i++) {
58
- c[i] = a[i] - b[i];
118
+ // resize codebooks if trained
119
+ if (codebooks.size() > 0) {
120
+ FAISS_THROW_IF_NOT(codebooks.size() == other.total_codebook_size * d);
121
+ codebooks.resize(total_codebook_size * d);
122
+ memcpy(codebooks.data(),
123
+ other.codebooks.data() + other.codebook_offsets[skip_M] * d,
124
+ codebooks.size() * sizeof(codebooks[0]));
125
+ // TODO: norm_tabs?
59
126
  }
60
127
  }
61
128
 
62
- } // anonymous namespace
63
-
64
129
  void beam_search_encode_step(
65
130
  size_t d,
66
131
  size_t K,
@@ -90,7 +155,7 @@ void beam_search_encode_step(
90
155
  cent_ids.resize(n * beam_size * new_beam_size);
91
156
  if (assign_index->ntotal != 0) {
92
157
  // then we assume the codebooks are already added to the index
93
- FAISS_THROW_IF_NOT(assign_index->ntotal != K);
158
+ FAISS_THROW_IF_NOT(assign_index->ntotal == K);
94
159
  } else {
95
160
  assign_index->add(K, cent);
96
161
  }
@@ -208,6 +273,7 @@ void ResidualQuantizer::train(size_t n, const float* x) {
208
273
  std::vector<int32_t> codes;
209
274
  std::vector<float> distances;
210
275
  double t0 = getmillisecs();
276
+ double clustering_time = 0;
211
277
 
212
278
  for (int m = 0; m < M; m++) {
213
279
  int K = 1 << nbits[m];
@@ -224,8 +290,6 @@ void ResidualQuantizer::train(size_t n, const float* x) {
224
290
  }
225
291
  train_residuals = residuals1;
226
292
  }
227
- train_type_t tt = train_type_t(train_type & ~Train_top_beam);
228
-
229
293
  std::vector<float> codebooks;
230
294
  float obj = 0;
231
295
 
@@ -235,7 +299,10 @@ void ResidualQuantizer::train(size_t n, const float* x) {
235
299
  } else {
236
300
  assign_index.reset(new IndexFlatL2(d));
237
301
  }
238
- if (tt == Train_default) {
302
+
303
+ double t1 = getmillisecs();
304
+
305
+ if (!(train_type & Train_progressive_dim)) { // regular kmeans
239
306
  Clustering clus(d, K, cp);
240
307
  clus.train(
241
308
  train_residuals.size() / d,
@@ -244,7 +311,7 @@ void ResidualQuantizer::train(size_t n, const float* x) {
244
311
  codebooks.swap(clus.centroids);
245
312
  assign_index->reset();
246
313
  obj = clus.iteration_stats.back().obj;
247
- } else if (tt == Train_progressive_dim) {
314
+ } else { // progressive dim clustering
248
315
  ProgressiveDimClustering clus(d, K, cp);
249
316
  ProgressiveDimIndexFactory default_fac;
250
317
  clus.train(
@@ -253,9 +320,8 @@ void ResidualQuantizer::train(size_t n, const float* x) {
253
320
  assign_index_factory ? *assign_index_factory : default_fac);
254
321
  codebooks.swap(clus.centroids);
255
322
  obj = clus.iteration_stats.back().obj;
256
- } else {
257
- FAISS_THROW_MSG("train type not supported");
258
323
  }
324
+ clustering_time += (getmillisecs() - t1) / 1000;
259
325
 
260
326
  memcpy(this->codebooks.data() + codebook_offsets[m] * d,
261
327
  codebooks.data(),
@@ -268,21 +334,38 @@ void ResidualQuantizer::train(size_t n, const float* x) {
268
334
  std::vector<float> new_residuals(n * new_beam_size * d);
269
335
  std::vector<float> new_distances(n * new_beam_size);
270
336
 
271
- beam_search_encode_step(
272
- d,
273
- K,
274
- codebooks.data(),
275
- n,
276
- cur_beam_size,
277
- residuals.data(),
278
- m,
279
- codes.data(),
280
- new_beam_size,
281
- new_codes.data(),
282
- new_residuals.data(),
283
- new_distances.data(),
284
- assign_index.get());
337
+ size_t bs;
338
+ { // determine batch size
339
+ size_t mem = memory_per_point();
340
+ if (n > 1 && mem * n > max_mem_distances) {
341
+ // then split queries to reduce temp memory
342
+ bs = std::max(max_mem_distances / mem, size_t(1));
343
+ } else {
344
+ bs = n;
345
+ }
346
+ }
285
347
 
348
+ for (size_t i0 = 0; i0 < n; i0 += bs) {
349
+ size_t i1 = std::min(i0 + bs, n);
350
+
351
+ /* printf("i0: %ld i1: %ld K %d ntotal assign index %ld\n",
352
+ i0, i1, K, assign_index->ntotal); */
353
+
354
+ beam_search_encode_step(
355
+ d,
356
+ K,
357
+ codebooks.data(),
358
+ i1 - i0,
359
+ cur_beam_size,
360
+ residuals.data() + i0 * cur_beam_size * d,
361
+ m,
362
+ codes.data() + i0 * cur_beam_size * m,
363
+ new_beam_size,
364
+ new_codes.data() + i0 * new_beam_size * (m + 1),
365
+ new_residuals.data() + i0 * new_beam_size * d,
366
+ new_distances.data() + i0 * new_beam_size,
367
+ assign_index.get());
368
+ }
286
369
  codes.swap(new_codes);
287
370
  residuals.swap(new_residuals);
288
371
  distances.swap(new_distances);
@@ -293,20 +376,165 @@ void ResidualQuantizer::train(size_t n, const float* x) {
293
376
  }
294
377
 
295
378
  if (verbose) {
296
- printf("[%.3f s] train stage %d, %d bits, kmeans objective %g, "
297
- "total distance %g, beam_size %d->%d\n",
379
+ printf("[%.3f s, %.3f s clustering] train stage %d, %d bits, kmeans objective %g, "
380
+ "total distance %g, beam_size %d->%d (batch size %zd)\n",
298
381
  (getmillisecs() - t0) / 1000,
382
+ clustering_time,
299
383
  m,
300
384
  int(nbits[m]),
301
385
  obj,
302
386
  sum_distances,
303
387
  cur_beam_size,
304
- new_beam_size);
388
+ new_beam_size,
389
+ bs);
305
390
  }
306
391
  cur_beam_size = new_beam_size;
307
392
  }
308
393
 
309
394
  is_trained = true;
395
+
396
+ if (train_type & Train_refine_codebook) {
397
+ for (int iter = 0; iter < niter_codebook_refine; iter++) {
398
+ if (verbose) {
399
+ printf("re-estimating the codebooks to minimize "
400
+ "quantization errors (iter %d).\n",
401
+ iter);
402
+ }
403
+ retrain_AQ_codebook(n, x);
404
+ }
405
+ }
406
+
407
+ // find min and max norms
408
+ std::vector<float> norms(n);
409
+
410
+ for (size_t i = 0; i < n; i++) {
411
+ norms[i] = fvec_L2sqr(
412
+ x + i * d, residuals.data() + i * cur_beam_size * d, d);
413
+ }
414
+
415
+ // fvec_norms_L2sqr(norms.data(), x, d, n);
416
+ train_norm(n, norms.data());
417
+
418
+ if (!(train_type & Skip_codebook_tables)) {
419
+ compute_codebook_tables();
420
+ }
421
+ }
422
+
423
+ float ResidualQuantizer::retrain_AQ_codebook(size_t n, const float* x) {
424
+ FAISS_THROW_IF_NOT_MSG(n >= total_codebook_size, "too few training points");
425
+
426
+ if (verbose) {
427
+ printf(" encoding %zd training vectors\n", n);
428
+ }
429
+ std::vector<uint8_t> codes(n * code_size);
430
+ compute_codes(x, codes.data(), n);
431
+
432
+ // compute reconstruction error
433
+ float input_recons_error;
434
+ {
435
+ std::vector<float> x_recons(n * d);
436
+ decode(codes.data(), x_recons.data(), n);
437
+ input_recons_error = fvec_L2sqr(x, x_recons.data(), n * d);
438
+ if (verbose) {
439
+ printf(" input quantization error %g\n", input_recons_error);
440
+ }
441
+ }
442
+
443
+ // build matrix of the linear system
444
+ std::vector<float> C(n * total_codebook_size);
445
+ for (size_t i = 0; i < n; i++) {
446
+ BitstringReader bsr(codes.data() + i * code_size, code_size);
447
+ for (int m = 0; m < M; m++) {
448
+ int idx = bsr.read(nbits[m]);
449
+ C[i + (codebook_offsets[m] + idx) * n] = 1;
450
+ }
451
+ }
452
+
453
+ // transpose training vectors
454
+ std::vector<float> xt(n * d);
455
+
456
+ for (size_t i = 0; i < n; i++) {
457
+ for (size_t j = 0; j < d; j++) {
458
+ xt[j * n + i] = x[i * d + j];
459
+ }
460
+ }
461
+
462
+ { // solve least squares
463
+ FINTEGER lwork = -1;
464
+ FINTEGER di = d, ni = n, tcsi = total_codebook_size;
465
+ FINTEGER info = -1, rank = -1;
466
+
467
+ float rcond = 1e-4; // this is an important parameter because the code
468
+ // matrix can be rank deficient for small problems,
469
+ // the default rcond=-1 does not work
470
+ float worksize;
471
+ std::vector<float> sing_vals(total_codebook_size);
472
+ FINTEGER nlvl = 1000; // formula is a bit convoluted so let's take an
473
+ // upper bound
474
+ std::vector<FINTEGER> iwork(
475
+ 3 * total_codebook_size * nlvl + 11 * total_codebook_size);
476
+
477
+ // worksize query
478
+ sgelsd_(&ni,
479
+ &tcsi,
480
+ &di,
481
+ C.data(),
482
+ &ni,
483
+ xt.data(),
484
+ &ni,
485
+ sing_vals.data(),
486
+ &rcond,
487
+ &rank,
488
+ &worksize,
489
+ &lwork,
490
+ iwork.data(),
491
+ &info);
492
+ FAISS_THROW_IF_NOT(info == 0);
493
+
494
+ lwork = worksize;
495
+ std::vector<float> work(lwork);
496
+ // actual call
497
+ sgelsd_(&ni,
498
+ &tcsi,
499
+ &di,
500
+ C.data(),
501
+ &ni,
502
+ xt.data(),
503
+ &ni,
504
+ sing_vals.data(),
505
+ &rcond,
506
+ &rank,
507
+ work.data(),
508
+ &lwork,
509
+ iwork.data(),
510
+ &info);
511
+ FAISS_THROW_IF_NOT_FMT(info == 0, "SGELS returned info=%d", int(info));
512
+ if (verbose) {
513
+ printf(" sgelsd rank=%d/%d\n",
514
+ int(rank),
515
+ int(total_codebook_size));
516
+ }
517
+ }
518
+
519
+ // result is in xt, re-transpose to codebook
520
+
521
+ for (size_t i = 0; i < total_codebook_size; i++) {
522
+ for (size_t j = 0; j < d; j++) {
523
+ codebooks[i * d + j] = xt[j * n + i];
524
+ FAISS_THROW_IF_NOT(std::isfinite(codebooks[i * d + j]));
525
+ }
526
+ }
527
+
528
+ float output_recons_error = 0;
529
+ for (size_t j = 0; j < d; j++) {
530
+ output_recons_error += fvec_norm_L2sqr(
531
+ xt.data() + total_codebook_size + n * j,
532
+ n - total_codebook_size);
533
+ }
534
+ if (verbose) {
535
+ printf(" output quantization error %g\n", output_recons_error);
536
+ }
537
+ return output_recons_error;
310
538
  }
311
539
 
312
540
  size_t ResidualQuantizer::memory_per_point(int beam_size) const {
@@ -321,10 +549,11 @@ size_t ResidualQuantizer::memory_per_point(int beam_size) const {
321
549
  return mem;
322
550
  }
323
551
 
324
- void ResidualQuantizer::compute_codes(
552
+ void ResidualQuantizer::compute_codes_add_centroids(
325
553
  const float* x,
326
554
  uint8_t* codes_out,
327
- size_t n) const {
555
+ size_t n,
556
+ const float* centroids) const {
328
557
  FAISS_THROW_IF_NOT_MSG(is_trained, "RQ is not trained yet.");
329
558
 
330
559
  size_t mem = memory_per_point();
@@ -336,27 +565,87 @@ void ResidualQuantizer::compute_codes(
336
565
  }
337
566
  for (size_t i0 = 0; i0 < n; i0 += bs) {
338
567
  size_t i1 = std::min(n, i0 + bs);
339
- compute_codes(x + i0 * d, codes_out + i0 * code_size, i1 - i0);
568
+ const float* cent = nullptr;
569
+ if (centroids != nullptr) {
570
+ cent = centroids + i0 * d;
571
+ }
572
+ compute_codes_add_centroids(
573
+ x + i0 * d, codes_out + i0 * code_size, i1 - i0, cent);
340
574
  }
341
575
  return;
342
576
  }
343
577
 
344
- std::vector<float> residuals(max_beam_size * n * d);
345
578
  std::vector<int32_t> codes(max_beam_size * M * n);
579
+ std::vector<float> norms;
346
580
  std::vector<float> distances(max_beam_size * n);
347
581
 
348
- refine_beam(
349
- n,
350
- 1,
351
- x,
352
- max_beam_size,
353
- codes.data(),
354
- residuals.data(),
355
- distances.data());
582
+ if (use_beam_LUT == 0) {
583
+ std::vector<float> residuals(max_beam_size * n * d);
356
584
 
585
+ refine_beam(
586
+ n,
587
+ 1,
588
+ x,
589
+ max_beam_size,
590
+ codes.data(),
591
+ residuals.data(),
592
+ distances.data());
593
+
594
+ if (search_type == ST_norm_float || search_type == ST_norm_qint8 ||
595
+ search_type == ST_norm_qint4) {
596
+ norms.resize(n);
597
+ // recover the norms of reconstruction as
598
+ // || original_vector - residual ||^2
599
+ for (size_t i = 0; i < n; i++) {
600
+ norms[i] = fvec_L2sqr(
601
+ x + i * d, residuals.data() + i * max_beam_size * d, d);
602
+ }
603
+ }
604
+ } else if (use_beam_LUT == 1) {
605
+ FAISS_THROW_IF_NOT_MSG(
606
+ codebook_cross_products.size() ==
607
+ total_codebook_size * total_codebook_size,
608
+ "call compute_codebook_tables first");
609
+
610
+ std::vector<float> query_norms(n);
611
+ fvec_norms_L2sqr(query_norms.data(), x, d, n);
612
+
613
+ std::vector<float> query_cp(n * total_codebook_size);
614
+ {
615
+ FINTEGER ti = total_codebook_size, di = d, ni = n;
616
+ float zero = 0, one = 1;
617
+ sgemm_("Transposed",
618
+ "Not transposed",
619
+ &ti,
620
+ &ni,
621
+ &di,
622
+ &one,
623
+ codebooks.data(),
624
+ &di,
625
+ x,
626
+ &di,
627
+ &zero,
628
+ query_cp.data(),
629
+ &ti);
630
+ }
631
+
632
+ refine_beam_LUT(
633
+ n,
634
+ query_norms.data(),
635
+ query_cp.data(),
636
+ max_beam_size,
637
+ codes.data(),
638
+ distances.data());
639
+ }
357
640
  // pack only the first code of the beam (hence the ld_codes=M *
358
641
  // max_beam_size)
359
- pack_codes(n, codes.data(), codes_out, M * max_beam_size);
642
+ pack_codes(
643
+ n,
644
+ codes.data(),
645
+ codes_out,
646
+ M * max_beam_size,
647
+ norms.size() > 0 ? norms.data() : nullptr,
648
+ centroids);
360
649
  }
361
650
 
362
651
  void ResidualQuantizer::refine_beam(
@@ -445,4 +734,181 @@ void ResidualQuantizer::refine_beam(
445
734
  }
446
735
  }
447
736
 
737
+ /*******************************************************************
738
+ * Functions using the dot products between codebook entries
739
+ *******************************************************************/
740
+
741
+ void ResidualQuantizer::compute_codebook_tables() {
742
+ codebook_cross_products.resize(total_codebook_size * total_codebook_size);
743
+ cent_norms.resize(total_codebook_size);
744
+ // stricly speaking we could use ssyrk
745
+ {
746
+ FINTEGER ni = total_codebook_size;
747
+ FINTEGER di = d;
748
+ float zero = 0, one = 1;
749
+ sgemm_("Transposed",
750
+ "Not transposed",
751
+ &ni,
752
+ &ni,
753
+ &di,
754
+ &one,
755
+ codebooks.data(),
756
+ &di,
757
+ codebooks.data(),
758
+ &di,
759
+ &zero,
760
+ codebook_cross_products.data(),
761
+ &ni);
762
+ }
763
+ for (size_t i = 0; i < total_codebook_size; i++) {
764
+ cent_norms[i] = codebook_cross_products[i + i * total_codebook_size];
765
+ }
766
+ }
767
+
768
+ void beam_search_encode_step_tab(
769
+ size_t K,
770
+ size_t n,
771
+ size_t beam_size, // input sizes
772
+ const float* codebook_cross_norms, // size K * ldc
773
+ size_t ldc, // >= K
774
+ const uint64_t* codebook_offsets, // m
775
+ const float* query_cp, // size n * ldqc
776
+ size_t ldqc, // >= K
777
+ const float* cent_norms_i, // size K
778
+ size_t m,
779
+ const int32_t* codes, // n * beam_size * m
780
+ const float* distances, // n * beam_size
781
+ size_t new_beam_size,
782
+ int32_t* new_codes, // n * new_beam_size * (m + 1)
783
+ float* new_distances) // n * new_beam_size
784
+ {
785
+ FAISS_THROW_IF_NOT(ldc >= K);
786
+
787
+ #pragma omp parallel for if (n > 100)
788
+ for (int64_t i = 0; i < n; i++) {
789
+ std::vector<float> cent_distances(beam_size * K);
790
+ std::vector<float> cd_common(K);
791
+
792
+ const int32_t* codes_i = codes + i * m * beam_size;
793
+ const float* query_cp_i = query_cp + i * ldqc;
794
+ const float* distances_i = distances + i * beam_size;
795
+
796
+ for (size_t k = 0; k < K; k++) {
797
+ cd_common[k] = cent_norms_i[k] - 2 * query_cp_i[k];
798
+ }
799
+
800
+ for (size_t b = 0; b < beam_size; b++) {
801
+ std::vector<float> dp(K);
802
+
803
+ for (size_t m1 = 0; m1 < m; m1++) {
804
+ size_t c = codes_i[b * m + m1];
805
+ const float* cb =
806
+ &codebook_cross_norms[(codebook_offsets[m1] + c) * ldc];
807
+ fvec_add(K, cb, dp.data(), dp.data());
808
+ }
809
+
810
+ for (size_t k = 0; k < K; k++) {
811
+ cent_distances[b * K + k] =
812
+ distances_i[b] + cd_common[k] + 2 * dp[k];
813
+ }
814
+ }
815
+
816
+ using C = CMax<float, int>;
817
+ int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size;
818
+ float* new_distances_i = new_distances + i * new_beam_size;
819
+
820
+ const float* cent_distances_i = cent_distances.data();
821
+
822
+ // then we have to select the best results
823
+ for (int i = 0; i < new_beam_size; i++) {
824
+ new_distances_i[i] = C::neutral();
825
+ }
826
+ std::vector<int> perm(new_beam_size, -1);
827
+ heap_addn<C>(
828
+ new_beam_size,
829
+ new_distances_i,
830
+ perm.data(),
831
+ cent_distances_i,
832
+ nullptr,
833
+ beam_size * K);
834
+ heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
835
+
836
+ for (int j = 0; j < new_beam_size; j++) {
837
+ int js = perm[j] / K;
838
+ int ls = perm[j] % K;
839
+ if (m > 0) {
840
+ memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m);
841
+ }
842
+ new_codes_i[m] = ls;
843
+ new_codes_i += m + 1;
844
+ }
845
+ }
846
+ }
847
+
848
+ void ResidualQuantizer::refine_beam_LUT(
849
+ size_t n,
850
+ const float* query_norms, // size n
851
+ const float* query_cp, //
852
+ int out_beam_size,
853
+ int32_t* out_codes,
854
+ float* out_distances) const {
855
+ int beam_size = 1;
856
+
857
+ std::vector<int32_t> codes;
858
+ std::vector<float> distances(query_norms, query_norms + n);
859
+ double t0 = getmillisecs();
860
+
861
+ for (int m = 0; m < M; m++) {
862
+ int K = 1 << nbits[m];
863
+
864
+ int new_beam_size = std::min(beam_size * K, out_beam_size);
865
+ std::vector<int32_t> new_codes(n * new_beam_size * (m + 1));
866
+ std::vector<float> new_distances(n * new_beam_size);
867
+
868
+ beam_search_encode_step_tab(
869
+ K,
870
+ n,
871
+ beam_size,
872
+ codebook_cross_products.data() + codebook_offsets[m],
873
+ total_codebook_size,
874
+ codebook_offsets.data(),
875
+ query_cp + codebook_offsets[m],
876
+ total_codebook_size,
877
+ cent_norms.data() + codebook_offsets[m],
878
+ m,
879
+ codes.data(),
880
+ distances.data(),
881
+ new_beam_size,
882
+ new_codes.data(),
883
+ new_distances.data());
884
+
885
+ codes.swap(new_codes);
886
+ distances.swap(new_distances);
887
+ beam_size = new_beam_size;
888
+
889
+ if (verbose) {
890
+ float sum_distances = 0;
891
+ for (int j = 0; j < distances.size(); j++) {
892
+ sum_distances += distances[j];
893
+ }
894
+ printf("[%.3f s] encode stage %d, %d bits, "
895
+ "total error %g, beam_size %d\n",
896
+ (getmillisecs() - t0) / 1000,
897
+ m,
898
+ int(nbits[m]),
899
+ sum_distances,
900
+ beam_size);
901
+ }
902
+ }
903
+
904
+ if (out_codes) {
905
+ memcpy(out_codes, codes.data(), codes.size() * sizeof(codes[0]));
906
+ }
907
+ if (out_distances) {
908
+ memcpy(out_distances,
909
+ distances.data(),
910
+ distances.size() * sizeof(distances[0]));
911
+ }
912
+ }
913
+
448
914
  } // namespace faiss