faiss 0.2.3 → 0.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +23 -21
  5. data/ext/faiss/extconf.rb +11 -0
  6. data/ext/faiss/index.cpp +4 -4
  7. data/ext/faiss/index_binary.cpp +6 -6
  8. data/ext/faiss/product_quantizer.cpp +4 -4
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  11. data/vendor/faiss/faiss/Clustering.cpp +32 -0
  12. data/vendor/faiss/faiss/Clustering.h +14 -0
  13. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  14. data/vendor/faiss/faiss/IVFlib.h +26 -2
  15. data/vendor/faiss/faiss/Index.cpp +36 -3
  16. data/vendor/faiss/faiss/Index.h +43 -6
  17. data/vendor/faiss/faiss/Index2Layer.cpp +24 -93
  18. data/vendor/faiss/faiss/Index2Layer.h +8 -17
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +610 -0
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +253 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  22. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  23. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  24. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  25. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  26. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  28. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  30. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  31. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  32. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  33. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  34. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  35. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  36. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  37. data/vendor/faiss/faiss/IndexFlat.cpp +52 -69
  38. data/vendor/faiss/faiss/IndexFlat.h +16 -19
  39. data/vendor/faiss/faiss/IndexFlatCodes.cpp +101 -0
  40. data/vendor/faiss/faiss/IndexFlatCodes.h +59 -0
  41. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  42. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  43. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  44. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  45. data/vendor/faiss/faiss/IndexIVF.cpp +200 -40
  46. data/vendor/faiss/faiss/IndexIVF.h +59 -22
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +393 -0
  48. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +183 -0
  49. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  50. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  51. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  52. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  53. data/vendor/faiss/faiss/IndexIVFFlat.cpp +43 -26
  54. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  55. data/vendor/faiss/faiss/IndexIVFPQ.cpp +238 -53
  56. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -2
  57. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  58. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  59. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  60. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  61. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +63 -40
  62. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +23 -7
  63. data/vendor/faiss/faiss/IndexLSH.cpp +8 -32
  64. data/vendor/faiss/faiss/IndexLSH.h +4 -16
  65. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  66. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  67. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -5
  68. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  69. data/vendor/faiss/faiss/IndexNSG.cpp +37 -5
  70. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  71. data/vendor/faiss/faiss/IndexPQ.cpp +108 -120
  72. data/vendor/faiss/faiss/IndexPQ.h +21 -22
  73. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  74. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  75. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  76. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  77. data/vendor/faiss/faiss/IndexRefine.cpp +36 -4
  78. data/vendor/faiss/faiss/IndexRefine.h +14 -2
  79. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  80. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +28 -43
  84. data/vendor/faiss/faiss/IndexScalarQuantizer.h +8 -23
  85. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  86. data/vendor/faiss/faiss/IndexShards.h +2 -1
  87. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  88. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  89. data/vendor/faiss/faiss/VectorTransform.cpp +45 -1
  90. data/vendor/faiss/faiss/VectorTransform.h +25 -4
  91. data/vendor/faiss/faiss/clone_index.cpp +26 -3
  92. data/vendor/faiss/faiss/clone_index.h +3 -0
  93. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  94. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  95. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  101. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  102. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  104. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  105. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -6
  106. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  107. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  108. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  109. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  110. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  111. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  112. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  113. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  114. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  115. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  116. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  117. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  120. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  121. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  122. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  123. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  124. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +331 -29
  125. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +110 -19
  126. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  127. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  128. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  129. data/vendor/faiss/faiss/impl/HNSW.cpp +133 -32
  130. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  131. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  132. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  133. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +378 -217
  134. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +106 -29
  135. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  136. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  137. data/vendor/faiss/faiss/impl/NSG.cpp +1 -4
  138. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  139. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  140. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  141. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  142. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  143. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  144. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +521 -55
  145. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +94 -16
  146. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  147. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +108 -191
  148. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  149. data/vendor/faiss/faiss/impl/index_read.cpp +338 -24
  150. data/vendor/faiss/faiss/impl/index_write.cpp +300 -18
  151. data/vendor/faiss/faiss/impl/io.cpp +1 -1
  152. data/vendor/faiss/faiss/impl/io_macros.h +20 -0
  153. data/vendor/faiss/faiss/impl/kmeans1d.cpp +303 -0
  154. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  159. data/vendor/faiss/faiss/index_factory.cpp +772 -412
  160. data/vendor/faiss/faiss/index_factory.h +3 -0
  161. data/vendor/faiss/faiss/index_io.h +5 -0
  162. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  163. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  164. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  165. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  166. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  167. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  168. data/vendor/faiss/faiss/utils/distances.cpp +384 -58
  169. data/vendor/faiss/faiss/utils/distances.h +149 -18
  170. data/vendor/faiss/faiss/utils/distances_simd.cpp +776 -6
  171. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  172. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  173. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  174. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  175. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  176. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  177. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  178. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  179. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  180. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  181. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  182. data/vendor/faiss/faiss/utils/random.h +5 -0
  183. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  184. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  185. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  186. data/vendor/faiss/faiss/utils/utils.h +1 -1
  187. metadata +46 -5
  188. data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
  189. data/vendor/faiss/faiss/IndexResidual.h +0 -152
@@ -5,9 +5,6 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
- #include <faiss/impl/FaissAssert.h>
11
8
  #include <faiss/impl/LocalSearchQuantizer.h>
12
9
 
13
10
  #include <cstddef>
@@ -18,6 +15,8 @@
18
15
 
19
16
  #include <algorithm>
20
17
 
18
+ #include <faiss/impl/AuxIndexStructures.h>
19
+ #include <faiss/impl/FaissAssert.h>
21
20
  #include <faiss/utils/distances.h>
22
21
  #include <faiss/utils/hamming.h> // BitstringWriter
23
22
  #include <faiss/utils/utils.h>
@@ -42,18 +41,6 @@ void sgetri_(
42
41
  FINTEGER* lwork,
43
42
  FINTEGER* info);
44
43
 
45
- // solves a system of linear equations
46
- void sgetrs_(
47
- const char* trans,
48
- FINTEGER* n,
49
- FINTEGER* nrhs,
50
- float* A,
51
- FINTEGER* lda,
52
- FINTEGER* ipiv,
53
- float* b,
54
- FINTEGER* ldb,
55
- FINTEGER* info);
56
-
57
44
  // general matrix multiplication
58
45
  int sgemm_(
59
46
  const char* transa,
@@ -69,26 +56,73 @@ int sgemm_(
69
56
  float* beta,
70
57
  float* c,
71
58
  FINTEGER* ldc);
59
+
60
+ // LU decomoposition of a general matrix
61
+ void dgetrf_(
62
+ FINTEGER* m,
63
+ FINTEGER* n,
64
+ double* a,
65
+ FINTEGER* lda,
66
+ FINTEGER* ipiv,
67
+ FINTEGER* info);
68
+
69
+ // generate inverse of a matrix given its LU decomposition
70
+ void dgetri_(
71
+ FINTEGER* n,
72
+ double* a,
73
+ FINTEGER* lda,
74
+ FINTEGER* ipiv,
75
+ double* work,
76
+ FINTEGER* lwork,
77
+ FINTEGER* info);
78
+
79
+ // general matrix multiplication
80
+ int dgemm_(
81
+ const char* transa,
82
+ const char* transb,
83
+ FINTEGER* m,
84
+ FINTEGER* n,
85
+ FINTEGER* k,
86
+ const double* alpha,
87
+ const double* a,
88
+ FINTEGER* lda,
89
+ const double* b,
90
+ FINTEGER* ldb,
91
+ double* beta,
92
+ double* c,
93
+ FINTEGER* ldc);
72
94
  }
73
95
 
74
96
  namespace {
75
97
 
98
+ void fmat_inverse(float* a, int n) {
99
+ int info;
100
+ int lwork = n * n;
101
+ std::vector<int> ipiv(n);
102
+ std::vector<float> workspace(lwork);
103
+
104
+ sgetrf_(&n, &n, a, &n, ipiv.data(), &info);
105
+ FAISS_THROW_IF_NOT(info == 0);
106
+ sgetri_(&n, a, &n, ipiv.data(), workspace.data(), &lwork, &info);
107
+ FAISS_THROW_IF_NOT(info == 0);
108
+ }
109
+
76
110
  // c and a and b can overlap
77
- void fvec_add(size_t d, const float* a, const float* b, float* c) {
111
+ void dfvec_add(size_t d, const double* a, const float* b, double* c) {
78
112
  for (size_t i = 0; i < d; i++) {
79
113
  c[i] = a[i] + b[i];
80
114
  }
81
115
  }
82
116
 
83
- void fmat_inverse(float* a, int n) {
117
+ void dmat_inverse(double* a, int n) {
84
118
  int info;
85
119
  int lwork = n * n;
86
120
  std::vector<int> ipiv(n);
87
- std::vector<float> workspace(lwork);
121
+ std::vector<double> workspace(lwork);
88
122
 
89
- sgetrf_(&n, &n, a, &n, ipiv.data(), &info);
123
+ dgetrf_(&n, &n, a, &n, ipiv.data(), &info);
90
124
  FAISS_THROW_IF_NOT(info == 0);
91
- sgetri_(&n, a, &n, ipiv.data(), workspace.data(), &lwork, &info);
125
+ dgetri_(&n, a, &n, ipiv.data(), workspace.data(), &lwork, &info);
92
126
  FAISS_THROW_IF_NOT(info == 0);
93
127
  }
94
128
 
@@ -107,21 +141,15 @@ void random_int32(
107
141
 
108
142
  namespace faiss {
109
143
 
110
- LSQTimer lsq_timer;
111
-
112
- LocalSearchQuantizer::LocalSearchQuantizer(size_t d, size_t M, size_t nbits) {
113
- FAISS_THROW_IF_NOT((M * nbits) % 8 == 0);
114
-
115
- this->d = d;
116
- this->M = M;
117
- this->nbits = std::vector<size_t>(M, nbits);
118
-
119
- // set derived values
120
- set_derived_values();
121
-
122
- is_trained = false;
123
- verbose = false;
144
+ lsq::LSQTimer lsq_timer;
145
+ using lsq::LSQTimerScope;
124
146
 
147
+ LocalSearchQuantizer::LocalSearchQuantizer(
148
+ size_t d,
149
+ size_t M,
150
+ size_t nbits,
151
+ Search_type_t search_type)
152
+ : AdditiveQuantizer(d, std::vector<size_t>(M, nbits), search_type) {
125
153
  K = (1 << nbits);
126
154
 
127
155
  train_iters = 25;
@@ -138,15 +166,23 @@ LocalSearchQuantizer::LocalSearchQuantizer(size_t d, size_t M, size_t nbits) {
138
166
 
139
167
  random_seed = 0x12345;
140
168
  std::srand(random_seed);
169
+
170
+ icm_encoder_factory = nullptr;
171
+ }
172
+
173
+ LocalSearchQuantizer::~LocalSearchQuantizer() {
174
+ delete icm_encoder_factory;
141
175
  }
142
176
 
177
+ LocalSearchQuantizer::LocalSearchQuantizer() : LocalSearchQuantizer(0, 0, 0) {}
178
+
143
179
  void LocalSearchQuantizer::train(size_t n, const float* x) {
144
180
  FAISS_THROW_IF_NOT(K == (1 << nbits[0]));
145
- FAISS_THROW_IF_NOT(nperts <= M);
181
+ nperts = std::min(nperts, M);
146
182
 
147
183
  lsq_timer.reset();
184
+ LSQTimerScope scope(&lsq_timer, "train");
148
185
  if (verbose) {
149
- lsq_timer.start("train");
150
186
  printf("Training LSQ, with %zd subcodes on %zd %zdD vectors\n",
151
187
  M,
152
188
  n,
@@ -209,7 +245,7 @@ void LocalSearchQuantizer::train(size_t n, const float* x) {
209
245
  }
210
246
 
211
247
  // refine codes
212
- icm_encode(x, codes.data(), n, train_ils_iters, gen);
248
+ icm_encode(codes.data(), x, n, train_ils_iters, gen);
213
249
 
214
250
  if (verbose) {
215
251
  float obj = evaluate(codes.data(), x, n);
@@ -217,25 +253,33 @@ void LocalSearchQuantizer::train(size_t n, const float* x) {
217
253
  }
218
254
  }
219
255
 
256
+ is_trained = true;
257
+ {
258
+ std::vector<float> x_recons(n * d);
259
+ std::vector<float> norms(n);
260
+ decode_unpacked(codes.data(), x_recons.data(), n);
261
+ fvec_norms_L2sqr(norms.data(), x_recons.data(), d, n);
262
+
263
+ train_norm(n, norms.data());
264
+ }
265
+
220
266
  if (verbose) {
221
- lsq_timer.end("train");
222
267
  float obj = evaluate(codes.data(), x, n);
268
+ scope.finish();
223
269
  printf("After training: obj = %lf\n", obj);
224
270
 
225
271
  printf("Time statistic:\n");
226
- for (const auto& it : lsq_timer.duration) {
227
- printf("\t%s time: %lf s\n", it.first.data(), it.second);
272
+ for (const auto& it : lsq_timer.t) {
273
+ printf("\t%s time: %lf s\n", it.first.data(), it.second / 1000);
228
274
  }
229
275
  }
230
-
231
- is_trained = true;
232
276
  }
233
277
 
234
278
  void LocalSearchQuantizer::perturb_codebooks(
235
279
  float T,
236
280
  const std::vector<float>& stddev,
237
281
  std::mt19937& gen) {
238
- lsq_timer.start("perturb_codebooks");
282
+ LSQTimerScope scope(&lsq_timer, "perturb_codebooks");
239
283
 
240
284
  std::vector<std::normal_distribution<float>> distribs;
241
285
  for (size_t i = 0; i < d; i++) {
@@ -249,32 +293,34 @@ void LocalSearchQuantizer::perturb_codebooks(
249
293
  }
250
294
  }
251
295
  }
252
-
253
- lsq_timer.end("perturb_codebooks");
254
296
  }
255
297
 
256
- void LocalSearchQuantizer::compute_codes(
298
+ void LocalSearchQuantizer::compute_codes_add_centroids(
257
299
  const float* x,
258
300
  uint8_t* codes_out,
259
- size_t n) const {
301
+ size_t n,
302
+ const float* centroids) const {
260
303
  FAISS_THROW_IF_NOT_MSG(is_trained, "LSQ is not trained yet.");
304
+
305
+ lsq_timer.reset();
306
+ LSQTimerScope scope(&lsq_timer, "encode");
261
307
  if (verbose) {
262
- lsq_timer.reset();
263
308
  printf("Encoding %zd vectors...\n", n);
264
- lsq_timer.start("encode");
265
309
  }
266
310
 
267
311
  std::vector<int32_t> codes(n * M);
268
312
  std::mt19937 gen(random_seed);
269
313
  random_int32(codes, 0, K - 1, gen);
270
314
 
271
- icm_encode(x, codes.data(), n, encode_ils_iters, gen);
272
- pack_codes(n, codes.data(), codes_out);
315
+ icm_encode(codes.data(), x, n, encode_ils_iters, gen);
316
+ pack_codes(n, codes.data(), codes_out, -1, nullptr, centroids);
273
317
 
274
318
  if (verbose) {
275
- lsq_timer.end("encode");
276
- double t = lsq_timer.get("encode");
277
- printf("Time to encode %zd vectors: %lf s\n", n, t);
319
+ scope.finish();
320
+ printf("Time statistic:\n");
321
+ for (const auto& it : lsq_timer.t) {
322
+ printf("\t%s time: %lf s\n", it.first.data(), it.second / 1000);
323
+ }
278
324
  }
279
325
  }
280
326
 
@@ -298,73 +344,144 @@ void LocalSearchQuantizer::update_codebooks(
298
344
  const float* x,
299
345
  const int32_t* codes,
300
346
  size_t n) {
301
- lsq_timer.start("update_codebooks");
347
+ LSQTimerScope scope(&lsq_timer, "update_codebooks");
348
+
349
+ if (!update_codebooks_with_double) {
350
+ // allocate memory
351
+ // bb = B'B, bx = BX
352
+ std::vector<float> bb(M * K * M * K, 0.0f); // [M * K, M * K]
353
+ std::vector<float> bx(M * K * d, 0.0f); // [M * K, d]
354
+
355
+ // compute B'B
356
+ for (size_t i = 0; i < n; i++) {
357
+ for (size_t m = 0; m < M; m++) {
358
+ int32_t code1 = codes[i * M + m];
359
+ int32_t idx1 = m * K + code1;
360
+ bb[idx1 * M * K + idx1] += 1;
361
+
362
+ for (size_t m2 = m + 1; m2 < M; m2++) {
363
+ int32_t code2 = codes[i * M + m2];
364
+ int32_t idx2 = m2 * K + code2;
365
+ bb[idx1 * M * K + idx2] += 1;
366
+ bb[idx2 * M * K + idx1] += 1;
367
+ }
368
+ }
369
+ }
302
370
 
303
- // allocate memory
304
- // bb = B'B, bx = BX
305
- std::vector<float> bb(M * K * M * K, 0.0f); // [M * K, M * K]
306
- std::vector<float> bx(M * K * d, 0.0f); // [M * K, d]
371
+ // add a regularization term to B'B
372
+ for (int64_t i = 0; i < M * K; i++) {
373
+ bb[i * (M * K) + i] += lambd;
374
+ }
307
375
 
308
- // compute B'B
309
- for (size_t i = 0; i < n; i++) {
310
- for (size_t m = 0; m < M; m++) {
311
- int32_t code1 = codes[i * M + m];
312
- int32_t idx1 = m * K + code1;
313
- bb[idx1 * M * K + idx1] += 1;
314
-
315
- for (size_t m2 = m + 1; m2 < M; m2++) {
316
- int32_t code2 = codes[i * M + m2];
317
- int32_t idx2 = m2 * K + code2;
318
- bb[idx1 * M * K + idx2] += 1;
319
- bb[idx2 * M * K + idx1] += 1;
376
+ // compute (B'B)^(-1)
377
+ fmat_inverse(bb.data(), M * K); // [M*K, M*K]
378
+
379
+ // compute BX
380
+ for (size_t i = 0; i < n; i++) {
381
+ for (size_t m = 0; m < M; m++) {
382
+ int32_t code = codes[i * M + m];
383
+ float* data = bx.data() + (m * K + code) * d;
384
+ fvec_add(d, data, x + i * d, data);
320
385
  }
321
386
  }
322
- }
323
387
 
324
- // add a regularization term to B'B
325
- for (int64_t i = 0; i < M * K; i++) {
326
- bb[i * (M * K) + i] += lambd;
327
- }
388
+ // compute C = (B'B)^(-1) @ BX
389
+ //
390
+ // NOTE: LAPACK use column major order
391
+ // out = alpha * op(A) * op(B) + beta * C
392
+ FINTEGER nrows_A = d;
393
+ FINTEGER ncols_A = M * K;
394
+
395
+ FINTEGER nrows_B = M * K;
396
+ FINTEGER ncols_B = M * K;
397
+
398
+ float alpha = 1.0f;
399
+ float beta = 0.0f;
400
+ sgemm_("Not Transposed",
401
+ "Not Transposed",
402
+ &nrows_A, // nrows of op(A)
403
+ &ncols_B, // ncols of op(B)
404
+ &ncols_A, // ncols of op(A)
405
+ &alpha,
406
+ bx.data(),
407
+ &nrows_A, // nrows of A
408
+ bb.data(),
409
+ &nrows_B, // nrows of B
410
+ &beta,
411
+ codebooks.data(),
412
+ &nrows_A); // nrows of output
413
+
414
+ } else {
415
+ // allocate memory
416
+ // bb = B'B, bx = BX
417
+ std::vector<double> bb(M * K * M * K, 0.0f); // [M * K, M * K]
418
+ std::vector<double> bx(M * K * d, 0.0f); // [M * K, d]
419
+
420
+ // compute B'B
421
+ for (size_t i = 0; i < n; i++) {
422
+ for (size_t m = 0; m < M; m++) {
423
+ int32_t code1 = codes[i * M + m];
424
+ int32_t idx1 = m * K + code1;
425
+ bb[idx1 * M * K + idx1] += 1;
426
+
427
+ for (size_t m2 = m + 1; m2 < M; m2++) {
428
+ int32_t code2 = codes[i * M + m2];
429
+ int32_t idx2 = m2 * K + code2;
430
+ bb[idx1 * M * K + idx2] += 1;
431
+ bb[idx2 * M * K + idx1] += 1;
432
+ }
433
+ }
434
+ }
328
435
 
329
- // compute (B'B)^(-1)
330
- fmat_inverse(bb.data(), M * K); // [M*K, M*K]
436
+ // add a regularization term to B'B
437
+ for (int64_t i = 0; i < M * K; i++) {
438
+ bb[i * (M * K) + i] += lambd;
439
+ }
331
440
 
332
- // compute BX
333
- for (size_t i = 0; i < n; i++) {
334
- for (size_t m = 0; m < M; m++) {
335
- int32_t code = codes[i * M + m];
336
- float* data = bx.data() + (m * K + code) * d;
337
- fvec_add(d, data, x + i * d, data);
441
+ // compute (B'B)^(-1)
442
+ dmat_inverse(bb.data(), M * K); // [M*K, M*K]
443
+
444
+ // compute BX
445
+ for (size_t i = 0; i < n; i++) {
446
+ for (size_t m = 0; m < M; m++) {
447
+ int32_t code = codes[i * M + m];
448
+ double* data = bx.data() + (m * K + code) * d;
449
+ dfvec_add(d, data, x + i * d, data);
450
+ }
338
451
  }
339
- }
340
452
 
341
- // compute C = (B'B)^(-1) @ BX
342
- //
343
- // NOTE: LAPACK use column major order
344
- // out = alpha * op(A) * op(B) + beta * C
345
- FINTEGER nrows_A = d;
346
- FINTEGER ncols_A = M * K;
347
-
348
- FINTEGER nrows_B = M * K;
349
- FINTEGER ncols_B = M * K;
350
-
351
- float alpha = 1.0f;
352
- float beta = 0.0f;
353
- sgemm_("Not Transposed",
354
- "Not Transposed",
355
- &nrows_A, // nrows of op(A)
356
- &ncols_B, // ncols of op(B)
357
- &ncols_A, // ncols of op(A)
358
- &alpha,
359
- bx.data(),
360
- &nrows_A, // nrows of A
361
- bb.data(),
362
- &nrows_B, // nrows of B
363
- &beta,
364
- codebooks.data(),
365
- &nrows_A); // nrows of output
366
-
367
- lsq_timer.end("update_codebooks");
453
+ // compute C = (B'B)^(-1) @ BX
454
+ //
455
+ // NOTE: LAPACK use column major order
456
+ // out = alpha * op(A) * op(B) + beta * C
457
+ FINTEGER nrows_A = d;
458
+ FINTEGER ncols_A = M * K;
459
+
460
+ FINTEGER nrows_B = M * K;
461
+ FINTEGER ncols_B = M * K;
462
+
463
+ std::vector<double> d_codebooks(M * K * d);
464
+
465
+ double alpha = 1.0f;
466
+ double beta = 0.0f;
467
+ dgemm_("Not Transposed",
468
+ "Not Transposed",
469
+ &nrows_A, // nrows of op(A)
470
+ &ncols_B, // ncols of op(B)
471
+ &ncols_A, // ncols of op(A)
472
+ &alpha,
473
+ bx.data(),
474
+ &nrows_A, // nrows of A
475
+ bb.data(),
476
+ &nrows_B, // nrows of B
477
+ &beta,
478
+ d_codebooks.data(),
479
+ &nrows_A); // nrows of output
480
+
481
+ for (size_t i = 0; i < M * K * d; i++) {
482
+ codebooks[i] = (float)d_codebooks[i];
483
+ }
484
+ }
368
485
  }
369
486
 
370
487
  /** encode using iterative conditional mode
@@ -386,15 +503,23 @@ void LocalSearchQuantizer::update_codebooks(
386
503
  * These two terms can be precomputed and store in a look up table.
387
504
  */
388
505
  void LocalSearchQuantizer::icm_encode(
389
- const float* x,
390
506
  int32_t* codes,
507
+ const float* x,
391
508
  size_t n,
392
509
  size_t ils_iters,
393
510
  std::mt19937& gen) const {
394
- lsq_timer.start("icm_encode");
511
+ LSQTimerScope scope(&lsq_timer, "icm_encode");
512
+
513
+ auto factory = icm_encoder_factory;
514
+ std::unique_ptr<lsq::IcmEncoder> icm_encoder;
515
+ if (factory == nullptr) {
516
+ icm_encoder.reset(lsq::IcmEncoderFactory().get(this));
517
+ } else {
518
+ icm_encoder.reset(factory->get(this));
519
+ }
395
520
 
396
- std::vector<float> binaries(M * M * K * K); // [M, M, K, K]
397
- compute_binary_terms(binaries.data());
521
+ // precompute binary terms for all chunks
522
+ icm_encoder->set_binary_term();
398
523
 
399
524
  const size_t n_chunks = (n + chunk_size - 1) / chunk_size;
400
525
  for (size_t i = 0; i < n_chunks; i++) {
@@ -410,21 +535,20 @@ void LocalSearchQuantizer::icm_encode(
410
535
 
411
536
  const float* xi = x + i * chunk_size * d;
412
537
  int32_t* codesi = codes + i * chunk_size * M;
413
- icm_encode_partial(i, xi, codesi, ni, binaries.data(), ils_iters, gen);
538
+ icm_encoder->verbose = (verbose && i == 0);
539
+ icm_encoder->encode(codesi, xi, gen, ni, ils_iters);
414
540
  }
415
-
416
- lsq_timer.end("icm_encode");
417
541
  }
418
542
 
419
- void LocalSearchQuantizer::icm_encode_partial(
420
- size_t index,
421
- const float* x,
543
+ void LocalSearchQuantizer::icm_encode_impl(
422
544
  int32_t* codes,
423
- size_t n,
545
+ const float* x,
424
546
  const float* binaries,
547
+ std::mt19937& gen,
548
+ size_t n,
425
549
  size_t ils_iters,
426
- std::mt19937& gen) const {
427
- std::vector<float> unaries(n * M * K); // [n, M, K]
550
+ bool verbose) const {
551
+ std::vector<float> unaries(n * M * K); // [M, n, K]
428
552
  compute_unary_terms(x, unaries.data(), n);
429
553
 
430
554
  std::vector<int32_t> best_codes;
@@ -438,9 +562,7 @@ void LocalSearchQuantizer::icm_encode_partial(
438
562
  // add perturbation to codes
439
563
  perturb_codes(codes, n, gen);
440
564
 
441
- for (size_t iter2 = 0; iter2 < icm_iters; iter2++) {
442
- icm_encode_step(unaries.data(), binaries, codes, n);
443
- }
565
+ icm_encode_step(codes, unaries.data(), binaries, n, icm_iters);
444
566
 
445
567
  std::vector<float> icm_objs(n, 0.0f);
446
568
  evaluate(codes, x, n, icm_objs.data());
@@ -463,7 +585,7 @@ void LocalSearchQuantizer::icm_encode_partial(
463
585
 
464
586
  memcpy(codes, best_codes.data(), sizeof(int32_t) * n * M);
465
587
 
466
- if (verbose && index == 0) {
588
+ if (verbose) {
467
589
  printf("\tils_iter %zd: obj = %lf, n_betters/n = %zd/%zd\n",
468
590
  iter1,
469
591
  mean_obj,
@@ -474,61 +596,67 @@ void LocalSearchQuantizer::icm_encode_partial(
474
596
  }
475
597
 
476
598
  void LocalSearchQuantizer::icm_encode_step(
599
+ int32_t* codes,
477
600
  const float* unaries,
478
601
  const float* binaries,
479
- int32_t* codes,
480
- size_t n) const {
481
- // condition on the m-th subcode
482
- for (size_t m = 0; m < M; m++) {
483
- std::vector<float> objs(n * K);
484
- #pragma omp parallel for
485
- for (int64_t i = 0; i < n; i++) {
486
- auto u = unaries + i * (M * K) + m * K;
487
- memcpy(objs.data() + i * K, u, sizeof(float) * K);
488
- }
602
+ size_t n,
603
+ size_t n_iters) const {
604
+ FAISS_THROW_IF_NOT(M != 0 && K != 0);
605
+ FAISS_THROW_IF_NOT(binaries != nullptr);
489
606
 
490
- // compute objective function by adding unary
491
- // and binary terms together
492
- for (size_t other_m = 0; other_m < M; other_m++) {
493
- if (other_m == m) {
494
- continue;
607
+ for (size_t iter = 0; iter < n_iters; iter++) {
608
+ // condition on the m-th subcode
609
+ for (size_t m = 0; m < M; m++) {
610
+ std::vector<float> objs(n * K);
611
+ #pragma omp parallel for
612
+ for (int64_t i = 0; i < n; i++) {
613
+ auto u = unaries + m * n * K + i * K;
614
+ memcpy(objs.data() + i * K, u, sizeof(float) * K);
495
615
  }
496
616
 
617
+ // compute objective function by adding unary
618
+ // and binary terms together
619
+ for (size_t other_m = 0; other_m < M; other_m++) {
620
+ if (other_m == m) {
621
+ continue;
622
+ }
623
+
497
624
  #pragma omp parallel for
498
- for (int64_t i = 0; i < n; i++) {
499
- for (int32_t code = 0; code < K; code++) {
500
- int32_t code2 = codes[i * M + other_m];
501
- size_t binary_idx =
502
- m * M * K * K + other_m * K * K + code * K + code2;
503
- // binaries[m, other_m, code, code2]
504
- objs[i * K + code] += binaries[binary_idx];
625
+ for (int64_t i = 0; i < n; i++) {
626
+ for (int32_t code = 0; code < K; code++) {
627
+ int32_t code2 = codes[i * M + other_m];
628
+ size_t binary_idx = m * M * K * K + other_m * K * K +
629
+ code * K + code2;
630
+ // binaries[m, other_m, code, code2]
631
+ objs[i * K + code] += binaries[binary_idx];
632
+ }
505
633
  }
506
634
  }
507
- }
508
635
 
509
- // find the optimal value of the m-th subcode
636
+ // find the optimal value of the m-th subcode
510
637
  #pragma omp parallel for
511
- for (int64_t i = 0; i < n; i++) {
512
- float best_obj = HUGE_VALF;
513
- int32_t best_code = 0;
514
- for (size_t code = 0; code < K; code++) {
515
- float obj = objs[i * K + code];
516
- if (obj < best_obj) {
517
- best_obj = obj;
518
- best_code = code;
638
+ for (int64_t i = 0; i < n; i++) {
639
+ float best_obj = HUGE_VALF;
640
+ int32_t best_code = 0;
641
+ for (size_t code = 0; code < K; code++) {
642
+ float obj = objs[i * K + code];
643
+ if (obj < best_obj) {
644
+ best_obj = obj;
645
+ best_code = code;
646
+ }
519
647
  }
648
+ codes[i * M + m] = best_code;
520
649
  }
521
- codes[i * M + m] = best_code;
522
- }
523
650
 
524
- } // loop M
651
+ } // loop M
652
+ }
525
653
  }
526
654
 
527
655
  void LocalSearchQuantizer::perturb_codes(
528
656
  int32_t* codes,
529
657
  size_t n,
530
658
  std::mt19937& gen) const {
531
- lsq_timer.start("perturb_codes");
659
+ LSQTimerScope scope(&lsq_timer, "perturb_codes");
532
660
 
533
661
  std::uniform_int_distribution<size_t> m_distrib(0, M - 1);
534
662
  std::uniform_int_distribution<int32_t> k_distrib(0, K - 1);
@@ -539,12 +667,10 @@ void LocalSearchQuantizer::perturb_codes(
539
667
  codes[i * M + m] = k_distrib(gen);
540
668
  }
541
669
  }
542
-
543
- lsq_timer.end("perturb_codes");
544
670
  }
545
671
 
546
672
  void LocalSearchQuantizer::compute_binary_terms(float* binaries) const {
547
- lsq_timer.start("compute_binary_terms");
673
+ LSQTimerScope scope(&lsq_timer, "compute_binary_terms");
548
674
 
549
675
  #pragma omp parallel for
550
676
  for (int64_t m12 = 0; m12 < M * M; m12++) {
@@ -562,52 +688,53 @@ void LocalSearchQuantizer::compute_binary_terms(float* binaries) const {
562
688
  }
563
689
  }
564
690
  }
565
-
566
- lsq_timer.end("compute_binary_terms");
567
691
  }
568
692
 
569
693
  void LocalSearchQuantizer::compute_unary_terms(
570
694
  const float* x,
571
- float* unaries,
695
+ float* unaries, // [M, n, K]
572
696
  size_t n) const {
573
- lsq_timer.start("compute_unary_terms");
697
+ LSQTimerScope scope(&lsq_timer, "compute_unary_terms");
574
698
 
575
- // compute x * codebooks^T
699
+ // compute x * codebook^T for each codebook
576
700
  //
577
701
  // NOTE: LAPACK use column major order
578
702
  // out = alpha * op(A) * op(B) + beta * C
579
- FINTEGER nrows_A = M * K;
580
- FINTEGER ncols_A = d;
581
-
582
- FINTEGER nrows_B = d;
583
- FINTEGER ncols_B = n;
584
-
585
- float alpha = -2.0f;
586
- float beta = 0.0f;
587
- sgemm_("Transposed",
588
- "Not Transposed",
589
- &nrows_A, // nrows of op(A)
590
- &ncols_B, // ncols of op(B)
591
- &ncols_A, // ncols of op(A)
592
- &alpha,
593
- codebooks.data(),
594
- &ncols_A, // nrows of A
595
- x,
596
- &nrows_B, // nrows of B
597
- &beta,
598
- unaries,
599
- &nrows_A); // nrows of output
703
+
704
+ for (size_t m = 0; m < M; m++) {
705
+ FINTEGER nrows_A = K;
706
+ FINTEGER ncols_A = d;
707
+
708
+ FINTEGER nrows_B = d;
709
+ FINTEGER ncols_B = n;
710
+
711
+ float alpha = -2.0f;
712
+ float beta = 0.0f;
713
+ sgemm_("Transposed",
714
+ "Not Transposed",
715
+ &nrows_A, // nrows of op(A)
716
+ &ncols_B, // ncols of op(B)
717
+ &ncols_A, // ncols of op(A)
718
+ &alpha,
719
+ codebooks.data() + m * K * d,
720
+ &ncols_A, // nrows of A
721
+ x,
722
+ &nrows_B, // nrows of B
723
+ &beta,
724
+ unaries + m * n * K,
725
+ &nrows_A); // nrows of output
726
+ }
600
727
 
601
728
  std::vector<float> norms(M * K);
602
729
  fvec_norms_L2sqr(norms.data(), codebooks.data(), d, M * K);
603
730
 
604
731
  #pragma omp parallel for
605
732
  for (int64_t i = 0; i < n; i++) {
606
- float* u = unaries + i * (M * K);
607
- fvec_add(M * K, u, norms.data(), u);
733
+ for (size_t m = 0; m < M; m++) {
734
+ float* u = unaries + m * n * K + i * K;
735
+ fvec_add(K, u, norms.data() + m * K, u);
736
+ }
608
737
  }
609
-
610
- lsq_timer.end("compute_unary_terms");
611
738
  }
612
739
 
613
740
  float LocalSearchQuantizer::evaluate(
@@ -615,7 +742,7 @@ float LocalSearchQuantizer::evaluate(
615
742
  const float* x,
616
743
  size_t n,
617
744
  float* objs) const {
618
- lsq_timer.start("evaluate");
745
+ LSQTimerScope scope(&lsq_timer, "evaluate");
619
746
 
620
747
  // decode
621
748
  std::vector<float> decoded_x(n * d, 0.0f);
@@ -631,7 +758,7 @@ float LocalSearchQuantizer::evaluate(
631
758
  fvec_add(d, decoded_i, c, decoded_i);
632
759
  }
633
760
 
634
- float err = fvec_L2sqr(x + i * d, decoded_i, d);
761
+ float err = faiss::fvec_L2sqr(x + i * d, decoded_i, d);
635
762
  obj += err;
636
763
 
637
764
  if (objs) {
@@ -639,34 +766,68 @@ float LocalSearchQuantizer::evaluate(
639
766
  }
640
767
  }
641
768
 
642
- lsq_timer.end("evaluate");
643
-
644
769
  obj = obj / n;
645
770
  return obj;
646
771
  }
647
772
 
648
- double LSQTimer::get(const std::string& name) {
649
- return duration[name];
773
+ namespace lsq {
774
+
775
+ IcmEncoder::IcmEncoder(const LocalSearchQuantizer* lsq)
776
+ : verbose(false), lsq(lsq) {}
777
+
778
+ void IcmEncoder::set_binary_term() {
779
+ auto M = lsq->M;
780
+ auto K = lsq->K;
781
+ binaries.resize(M * M * K * K);
782
+ lsq->compute_binary_terms(binaries.data());
650
783
  }
651
784
 
652
- void LSQTimer::start(const std::string& name) {
653
- FAISS_THROW_IF_NOT_MSG(!started[name], " timer is already running");
654
- started[name] = true;
655
- t0[name] = getmillisecs();
785
+ void IcmEncoder::encode(
786
+ int32_t* codes,
787
+ const float* x,
788
+ std::mt19937& gen,
789
+ size_t n,
790
+ size_t ils_iters) const {
791
+ lsq->icm_encode_impl(codes, x, binaries.data(), gen, n, ils_iters, verbose);
656
792
  }
657
793
 
658
- void LSQTimer::end(const std::string& name) {
659
- FAISS_THROW_IF_NOT_MSG(started[name], " timer is not running");
660
- double t1 = getmillisecs();
661
- double sec = (t1 - t0[name]) / 1000;
662
- duration[name] += sec;
663
- started[name] = false;
794
+ double LSQTimer::get(const std::string& name) {
795
+ if (t.count(name) == 0) {
796
+ return 0.0;
797
+ } else {
798
+ return t[name];
799
+ }
800
+ }
801
+
802
+ void LSQTimer::add(const std::string& name, double delta) {
803
+ if (t.count(name) == 0) {
804
+ t[name] = delta;
805
+ } else {
806
+ t[name] += delta;
807
+ }
664
808
  }
665
809
 
666
810
  void LSQTimer::reset() {
667
- duration.clear();
668
- t0.clear();
669
- started.clear();
811
+ t.clear();
812
+ }
813
+
814
+ LSQTimerScope::LSQTimerScope(LSQTimer* timer, std::string name)
815
+ : timer(timer), name(name), finished(false) {
816
+ t0 = getmillisecs();
670
817
  }
671
818
 
819
+ void LSQTimerScope::finish() {
820
+ if (!finished) {
821
+ auto delta = getmillisecs() - t0;
822
+ timer->add(name, delta);
823
+ finished = true;
824
+ }
825
+ }
826
+
827
+ LSQTimerScope::~LSQTimerScope() {
828
+ finish();
829
+ }
830
+
831
+ } // namespace lsq
832
+
672
833
  } // namespace faiss