faiss 0.2.3 → 0.2.5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +23 -21
  5. data/ext/faiss/extconf.rb +11 -0
  6. data/ext/faiss/index.cpp +4 -4
  7. data/ext/faiss/index_binary.cpp +6 -6
  8. data/ext/faiss/product_quantizer.cpp +4 -4
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  11. data/vendor/faiss/faiss/Clustering.cpp +32 -0
  12. data/vendor/faiss/faiss/Clustering.h +14 -0
  13. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  14. data/vendor/faiss/faiss/IVFlib.h +26 -2
  15. data/vendor/faiss/faiss/Index.cpp +36 -3
  16. data/vendor/faiss/faiss/Index.h +43 -6
  17. data/vendor/faiss/faiss/Index2Layer.cpp +24 -93
  18. data/vendor/faiss/faiss/Index2Layer.h +8 -17
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +610 -0
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +253 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  22. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  23. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  24. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  25. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  26. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  28. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  30. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  31. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  32. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  33. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  34. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  35. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  36. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  37. data/vendor/faiss/faiss/IndexFlat.cpp +52 -69
  38. data/vendor/faiss/faiss/IndexFlat.h +16 -19
  39. data/vendor/faiss/faiss/IndexFlatCodes.cpp +101 -0
  40. data/vendor/faiss/faiss/IndexFlatCodes.h +59 -0
  41. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  42. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  43. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  44. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  45. data/vendor/faiss/faiss/IndexIVF.cpp +200 -40
  46. data/vendor/faiss/faiss/IndexIVF.h +59 -22
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +393 -0
  48. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +183 -0
  49. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  50. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  51. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  52. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  53. data/vendor/faiss/faiss/IndexIVFFlat.cpp +43 -26
  54. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  55. data/vendor/faiss/faiss/IndexIVFPQ.cpp +238 -53
  56. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -2
  57. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  58. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  59. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  60. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  61. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +63 -40
  62. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +23 -7
  63. data/vendor/faiss/faiss/IndexLSH.cpp +8 -32
  64. data/vendor/faiss/faiss/IndexLSH.h +4 -16
  65. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  66. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  67. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -5
  68. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  69. data/vendor/faiss/faiss/IndexNSG.cpp +37 -5
  70. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  71. data/vendor/faiss/faiss/IndexPQ.cpp +108 -120
  72. data/vendor/faiss/faiss/IndexPQ.h +21 -22
  73. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  74. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  75. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  76. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  77. data/vendor/faiss/faiss/IndexRefine.cpp +36 -4
  78. data/vendor/faiss/faiss/IndexRefine.h +14 -2
  79. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  80. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +28 -43
  84. data/vendor/faiss/faiss/IndexScalarQuantizer.h +8 -23
  85. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  86. data/vendor/faiss/faiss/IndexShards.h +2 -1
  87. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  88. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  89. data/vendor/faiss/faiss/VectorTransform.cpp +45 -1
  90. data/vendor/faiss/faiss/VectorTransform.h +25 -4
  91. data/vendor/faiss/faiss/clone_index.cpp +26 -3
  92. data/vendor/faiss/faiss/clone_index.h +3 -0
  93. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  94. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  95. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  101. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  102. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  104. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  105. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -6
  106. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  107. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  108. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  109. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  110. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  111. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  112. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  113. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  114. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  115. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  116. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  117. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  120. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  121. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  122. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  123. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  124. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +331 -29
  125. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +110 -19
  126. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  127. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  128. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  129. data/vendor/faiss/faiss/impl/HNSW.cpp +133 -32
  130. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  131. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  132. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  133. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +378 -217
  134. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +106 -29
  135. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  136. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  137. data/vendor/faiss/faiss/impl/NSG.cpp +1 -4
  138. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  139. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  140. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  141. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  142. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  143. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  144. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +521 -55
  145. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +94 -16
  146. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  147. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +108 -191
  148. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  149. data/vendor/faiss/faiss/impl/index_read.cpp +338 -24
  150. data/vendor/faiss/faiss/impl/index_write.cpp +300 -18
  151. data/vendor/faiss/faiss/impl/io.cpp +1 -1
  152. data/vendor/faiss/faiss/impl/io_macros.h +20 -0
  153. data/vendor/faiss/faiss/impl/kmeans1d.cpp +303 -0
  154. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  159. data/vendor/faiss/faiss/index_factory.cpp +772 -412
  160. data/vendor/faiss/faiss/index_factory.h +3 -0
  161. data/vendor/faiss/faiss/index_io.h +5 -0
  162. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  163. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  164. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  165. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  166. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  167. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  168. data/vendor/faiss/faiss/utils/distances.cpp +384 -58
  169. data/vendor/faiss/faiss/utils/distances.h +149 -18
  170. data/vendor/faiss/faiss/utils/distances_simd.cpp +776 -6
  171. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  172. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  173. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  174. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  175. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  176. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  177. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  178. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  179. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  180. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  181. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  182. data/vendor/faiss/faiss/utils/random.h +5 -0
  183. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  184. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  185. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  186. data/vendor/faiss/faiss/utils/utils.h +1 -1
  187. metadata +46 -5
  188. data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
  189. data/vendor/faiss/faiss/IndexResidual.h +0 -152
@@ -5,9 +5,6 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
- #include <faiss/impl/FaissAssert.h>
11
8
  #include <faiss/impl/LocalSearchQuantizer.h>
12
9
 
13
10
  #include <cstddef>
@@ -18,6 +15,8 @@
18
15
 
19
16
  #include <algorithm>
20
17
 
18
+ #include <faiss/impl/AuxIndexStructures.h>
19
+ #include <faiss/impl/FaissAssert.h>
21
20
  #include <faiss/utils/distances.h>
22
21
  #include <faiss/utils/hamming.h> // BitstringWriter
23
22
  #include <faiss/utils/utils.h>
@@ -42,18 +41,6 @@ void sgetri_(
42
41
  FINTEGER* lwork,
43
42
  FINTEGER* info);
44
43
 
45
- // solves a system of linear equations
46
- void sgetrs_(
47
- const char* trans,
48
- FINTEGER* n,
49
- FINTEGER* nrhs,
50
- float* A,
51
- FINTEGER* lda,
52
- FINTEGER* ipiv,
53
- float* b,
54
- FINTEGER* ldb,
55
- FINTEGER* info);
56
-
57
44
  // general matrix multiplication
58
45
  int sgemm_(
59
46
  const char* transa,
@@ -69,26 +56,73 @@ int sgemm_(
69
56
  float* beta,
70
57
  float* c,
71
58
  FINTEGER* ldc);
59
+
60
+ // LU decomoposition of a general matrix
61
+ void dgetrf_(
62
+ FINTEGER* m,
63
+ FINTEGER* n,
64
+ double* a,
65
+ FINTEGER* lda,
66
+ FINTEGER* ipiv,
67
+ FINTEGER* info);
68
+
69
+ // generate inverse of a matrix given its LU decomposition
70
+ void dgetri_(
71
+ FINTEGER* n,
72
+ double* a,
73
+ FINTEGER* lda,
74
+ FINTEGER* ipiv,
75
+ double* work,
76
+ FINTEGER* lwork,
77
+ FINTEGER* info);
78
+
79
+ // general matrix multiplication
80
+ int dgemm_(
81
+ const char* transa,
82
+ const char* transb,
83
+ FINTEGER* m,
84
+ FINTEGER* n,
85
+ FINTEGER* k,
86
+ const double* alpha,
87
+ const double* a,
88
+ FINTEGER* lda,
89
+ const double* b,
90
+ FINTEGER* ldb,
91
+ double* beta,
92
+ double* c,
93
+ FINTEGER* ldc);
72
94
  }
73
95
 
74
96
  namespace {
75
97
 
98
+ void fmat_inverse(float* a, int n) {
99
+ int info;
100
+ int lwork = n * n;
101
+ std::vector<int> ipiv(n);
102
+ std::vector<float> workspace(lwork);
103
+
104
+ sgetrf_(&n, &n, a, &n, ipiv.data(), &info);
105
+ FAISS_THROW_IF_NOT(info == 0);
106
+ sgetri_(&n, a, &n, ipiv.data(), workspace.data(), &lwork, &info);
107
+ FAISS_THROW_IF_NOT(info == 0);
108
+ }
109
+
76
110
  // c and a and b can overlap
77
- void fvec_add(size_t d, const float* a, const float* b, float* c) {
111
+ void dfvec_add(size_t d, const double* a, const float* b, double* c) {
78
112
  for (size_t i = 0; i < d; i++) {
79
113
  c[i] = a[i] + b[i];
80
114
  }
81
115
  }
82
116
 
83
- void fmat_inverse(float* a, int n) {
117
+ void dmat_inverse(double* a, int n) {
84
118
  int info;
85
119
  int lwork = n * n;
86
120
  std::vector<int> ipiv(n);
87
- std::vector<float> workspace(lwork);
121
+ std::vector<double> workspace(lwork);
88
122
 
89
- sgetrf_(&n, &n, a, &n, ipiv.data(), &info);
123
+ dgetrf_(&n, &n, a, &n, ipiv.data(), &info);
90
124
  FAISS_THROW_IF_NOT(info == 0);
91
- sgetri_(&n, a, &n, ipiv.data(), workspace.data(), &lwork, &info);
125
+ dgetri_(&n, a, &n, ipiv.data(), workspace.data(), &lwork, &info);
92
126
  FAISS_THROW_IF_NOT(info == 0);
93
127
  }
94
128
 
@@ -107,21 +141,15 @@ void random_int32(
107
141
 
108
142
  namespace faiss {
109
143
 
110
- LSQTimer lsq_timer;
111
-
112
- LocalSearchQuantizer::LocalSearchQuantizer(size_t d, size_t M, size_t nbits) {
113
- FAISS_THROW_IF_NOT((M * nbits) % 8 == 0);
114
-
115
- this->d = d;
116
- this->M = M;
117
- this->nbits = std::vector<size_t>(M, nbits);
118
-
119
- // set derived values
120
- set_derived_values();
121
-
122
- is_trained = false;
123
- verbose = false;
144
+ lsq::LSQTimer lsq_timer;
145
+ using lsq::LSQTimerScope;
124
146
 
147
+ LocalSearchQuantizer::LocalSearchQuantizer(
148
+ size_t d,
149
+ size_t M,
150
+ size_t nbits,
151
+ Search_type_t search_type)
152
+ : AdditiveQuantizer(d, std::vector<size_t>(M, nbits), search_type) {
125
153
  K = (1 << nbits);
126
154
 
127
155
  train_iters = 25;
@@ -138,15 +166,23 @@ LocalSearchQuantizer::LocalSearchQuantizer(size_t d, size_t M, size_t nbits) {
138
166
 
139
167
  random_seed = 0x12345;
140
168
  std::srand(random_seed);
169
+
170
+ icm_encoder_factory = nullptr;
171
+ }
172
+
173
+ LocalSearchQuantizer::~LocalSearchQuantizer() {
174
+ delete icm_encoder_factory;
141
175
  }
142
176
 
177
+ LocalSearchQuantizer::LocalSearchQuantizer() : LocalSearchQuantizer(0, 0, 0) {}
178
+
143
179
  void LocalSearchQuantizer::train(size_t n, const float* x) {
144
180
  FAISS_THROW_IF_NOT(K == (1 << nbits[0]));
145
- FAISS_THROW_IF_NOT(nperts <= M);
181
+ nperts = std::min(nperts, M);
146
182
 
147
183
  lsq_timer.reset();
184
+ LSQTimerScope scope(&lsq_timer, "train");
148
185
  if (verbose) {
149
- lsq_timer.start("train");
150
186
  printf("Training LSQ, with %zd subcodes on %zd %zdD vectors\n",
151
187
  M,
152
188
  n,
@@ -209,7 +245,7 @@ void LocalSearchQuantizer::train(size_t n, const float* x) {
209
245
  }
210
246
 
211
247
  // refine codes
212
- icm_encode(x, codes.data(), n, train_ils_iters, gen);
248
+ icm_encode(codes.data(), x, n, train_ils_iters, gen);
213
249
 
214
250
  if (verbose) {
215
251
  float obj = evaluate(codes.data(), x, n);
@@ -217,25 +253,33 @@ void LocalSearchQuantizer::train(size_t n, const float* x) {
217
253
  }
218
254
  }
219
255
 
256
+ is_trained = true;
257
+ {
258
+ std::vector<float> x_recons(n * d);
259
+ std::vector<float> norms(n);
260
+ decode_unpacked(codes.data(), x_recons.data(), n);
261
+ fvec_norms_L2sqr(norms.data(), x_recons.data(), d, n);
262
+
263
+ train_norm(n, norms.data());
264
+ }
265
+
220
266
  if (verbose) {
221
- lsq_timer.end("train");
222
267
  float obj = evaluate(codes.data(), x, n);
268
+ scope.finish();
223
269
  printf("After training: obj = %lf\n", obj);
224
270
 
225
271
  printf("Time statistic:\n");
226
- for (const auto& it : lsq_timer.duration) {
227
- printf("\t%s time: %lf s\n", it.first.data(), it.second);
272
+ for (const auto& it : lsq_timer.t) {
273
+ printf("\t%s time: %lf s\n", it.first.data(), it.second / 1000);
228
274
  }
229
275
  }
230
-
231
- is_trained = true;
232
276
  }
233
277
 
234
278
  void LocalSearchQuantizer::perturb_codebooks(
235
279
  float T,
236
280
  const std::vector<float>& stddev,
237
281
  std::mt19937& gen) {
238
- lsq_timer.start("perturb_codebooks");
282
+ LSQTimerScope scope(&lsq_timer, "perturb_codebooks");
239
283
 
240
284
  std::vector<std::normal_distribution<float>> distribs;
241
285
  for (size_t i = 0; i < d; i++) {
@@ -249,32 +293,34 @@ void LocalSearchQuantizer::perturb_codebooks(
249
293
  }
250
294
  }
251
295
  }
252
-
253
- lsq_timer.end("perturb_codebooks");
254
296
  }
255
297
 
256
- void LocalSearchQuantizer::compute_codes(
298
+ void LocalSearchQuantizer::compute_codes_add_centroids(
257
299
  const float* x,
258
300
  uint8_t* codes_out,
259
- size_t n) const {
301
+ size_t n,
302
+ const float* centroids) const {
260
303
  FAISS_THROW_IF_NOT_MSG(is_trained, "LSQ is not trained yet.");
304
+
305
+ lsq_timer.reset();
306
+ LSQTimerScope scope(&lsq_timer, "encode");
261
307
  if (verbose) {
262
- lsq_timer.reset();
263
308
  printf("Encoding %zd vectors...\n", n);
264
- lsq_timer.start("encode");
265
309
  }
266
310
 
267
311
  std::vector<int32_t> codes(n * M);
268
312
  std::mt19937 gen(random_seed);
269
313
  random_int32(codes, 0, K - 1, gen);
270
314
 
271
- icm_encode(x, codes.data(), n, encode_ils_iters, gen);
272
- pack_codes(n, codes.data(), codes_out);
315
+ icm_encode(codes.data(), x, n, encode_ils_iters, gen);
316
+ pack_codes(n, codes.data(), codes_out, -1, nullptr, centroids);
273
317
 
274
318
  if (verbose) {
275
- lsq_timer.end("encode");
276
- double t = lsq_timer.get("encode");
277
- printf("Time to encode %zd vectors: %lf s\n", n, t);
319
+ scope.finish();
320
+ printf("Time statistic:\n");
321
+ for (const auto& it : lsq_timer.t) {
322
+ printf("\t%s time: %lf s\n", it.first.data(), it.second / 1000);
323
+ }
278
324
  }
279
325
  }
280
326
 
@@ -298,73 +344,144 @@ void LocalSearchQuantizer::update_codebooks(
298
344
  const float* x,
299
345
  const int32_t* codes,
300
346
  size_t n) {
301
- lsq_timer.start("update_codebooks");
347
+ LSQTimerScope scope(&lsq_timer, "update_codebooks");
348
+
349
+ if (!update_codebooks_with_double) {
350
+ // allocate memory
351
+ // bb = B'B, bx = BX
352
+ std::vector<float> bb(M * K * M * K, 0.0f); // [M * K, M * K]
353
+ std::vector<float> bx(M * K * d, 0.0f); // [M * K, d]
354
+
355
+ // compute B'B
356
+ for (size_t i = 0; i < n; i++) {
357
+ for (size_t m = 0; m < M; m++) {
358
+ int32_t code1 = codes[i * M + m];
359
+ int32_t idx1 = m * K + code1;
360
+ bb[idx1 * M * K + idx1] += 1;
361
+
362
+ for (size_t m2 = m + 1; m2 < M; m2++) {
363
+ int32_t code2 = codes[i * M + m2];
364
+ int32_t idx2 = m2 * K + code2;
365
+ bb[idx1 * M * K + idx2] += 1;
366
+ bb[idx2 * M * K + idx1] += 1;
367
+ }
368
+ }
369
+ }
302
370
 
303
- // allocate memory
304
- // bb = B'B, bx = BX
305
- std::vector<float> bb(M * K * M * K, 0.0f); // [M * K, M * K]
306
- std::vector<float> bx(M * K * d, 0.0f); // [M * K, d]
371
+ // add a regularization term to B'B
372
+ for (int64_t i = 0; i < M * K; i++) {
373
+ bb[i * (M * K) + i] += lambd;
374
+ }
307
375
 
308
- // compute B'B
309
- for (size_t i = 0; i < n; i++) {
310
- for (size_t m = 0; m < M; m++) {
311
- int32_t code1 = codes[i * M + m];
312
- int32_t idx1 = m * K + code1;
313
- bb[idx1 * M * K + idx1] += 1;
314
-
315
- for (size_t m2 = m + 1; m2 < M; m2++) {
316
- int32_t code2 = codes[i * M + m2];
317
- int32_t idx2 = m2 * K + code2;
318
- bb[idx1 * M * K + idx2] += 1;
319
- bb[idx2 * M * K + idx1] += 1;
376
+ // compute (B'B)^(-1)
377
+ fmat_inverse(bb.data(), M * K); // [M*K, M*K]
378
+
379
+ // compute BX
380
+ for (size_t i = 0; i < n; i++) {
381
+ for (size_t m = 0; m < M; m++) {
382
+ int32_t code = codes[i * M + m];
383
+ float* data = bx.data() + (m * K + code) * d;
384
+ fvec_add(d, data, x + i * d, data);
320
385
  }
321
386
  }
322
- }
323
387
 
324
- // add a regularization term to B'B
325
- for (int64_t i = 0; i < M * K; i++) {
326
- bb[i * (M * K) + i] += lambd;
327
- }
388
+ // compute C = (B'B)^(-1) @ BX
389
+ //
390
+ // NOTE: LAPACK use column major order
391
+ // out = alpha * op(A) * op(B) + beta * C
392
+ FINTEGER nrows_A = d;
393
+ FINTEGER ncols_A = M * K;
394
+
395
+ FINTEGER nrows_B = M * K;
396
+ FINTEGER ncols_B = M * K;
397
+
398
+ float alpha = 1.0f;
399
+ float beta = 0.0f;
400
+ sgemm_("Not Transposed",
401
+ "Not Transposed",
402
+ &nrows_A, // nrows of op(A)
403
+ &ncols_B, // ncols of op(B)
404
+ &ncols_A, // ncols of op(A)
405
+ &alpha,
406
+ bx.data(),
407
+ &nrows_A, // nrows of A
408
+ bb.data(),
409
+ &nrows_B, // nrows of B
410
+ &beta,
411
+ codebooks.data(),
412
+ &nrows_A); // nrows of output
413
+
414
+ } else {
415
+ // allocate memory
416
+ // bb = B'B, bx = BX
417
+ std::vector<double> bb(M * K * M * K, 0.0f); // [M * K, M * K]
418
+ std::vector<double> bx(M * K * d, 0.0f); // [M * K, d]
419
+
420
+ // compute B'B
421
+ for (size_t i = 0; i < n; i++) {
422
+ for (size_t m = 0; m < M; m++) {
423
+ int32_t code1 = codes[i * M + m];
424
+ int32_t idx1 = m * K + code1;
425
+ bb[idx1 * M * K + idx1] += 1;
426
+
427
+ for (size_t m2 = m + 1; m2 < M; m2++) {
428
+ int32_t code2 = codes[i * M + m2];
429
+ int32_t idx2 = m2 * K + code2;
430
+ bb[idx1 * M * K + idx2] += 1;
431
+ bb[idx2 * M * K + idx1] += 1;
432
+ }
433
+ }
434
+ }
328
435
 
329
- // compute (B'B)^(-1)
330
- fmat_inverse(bb.data(), M * K); // [M*K, M*K]
436
+ // add a regularization term to B'B
437
+ for (int64_t i = 0; i < M * K; i++) {
438
+ bb[i * (M * K) + i] += lambd;
439
+ }
331
440
 
332
- // compute BX
333
- for (size_t i = 0; i < n; i++) {
334
- for (size_t m = 0; m < M; m++) {
335
- int32_t code = codes[i * M + m];
336
- float* data = bx.data() + (m * K + code) * d;
337
- fvec_add(d, data, x + i * d, data);
441
+ // compute (B'B)^(-1)
442
+ dmat_inverse(bb.data(), M * K); // [M*K, M*K]
443
+
444
+ // compute BX
445
+ for (size_t i = 0; i < n; i++) {
446
+ for (size_t m = 0; m < M; m++) {
447
+ int32_t code = codes[i * M + m];
448
+ double* data = bx.data() + (m * K + code) * d;
449
+ dfvec_add(d, data, x + i * d, data);
450
+ }
338
451
  }
339
- }
340
452
 
341
- // compute C = (B'B)^(-1) @ BX
342
- //
343
- // NOTE: LAPACK use column major order
344
- // out = alpha * op(A) * op(B) + beta * C
345
- FINTEGER nrows_A = d;
346
- FINTEGER ncols_A = M * K;
347
-
348
- FINTEGER nrows_B = M * K;
349
- FINTEGER ncols_B = M * K;
350
-
351
- float alpha = 1.0f;
352
- float beta = 0.0f;
353
- sgemm_("Not Transposed",
354
- "Not Transposed",
355
- &nrows_A, // nrows of op(A)
356
- &ncols_B, // ncols of op(B)
357
- &ncols_A, // ncols of op(A)
358
- &alpha,
359
- bx.data(),
360
- &nrows_A, // nrows of A
361
- bb.data(),
362
- &nrows_B, // nrows of B
363
- &beta,
364
- codebooks.data(),
365
- &nrows_A); // nrows of output
366
-
367
- lsq_timer.end("update_codebooks");
453
+ // compute C = (B'B)^(-1) @ BX
454
+ //
455
+ // NOTE: LAPACK use column major order
456
+ // out = alpha * op(A) * op(B) + beta * C
457
+ FINTEGER nrows_A = d;
458
+ FINTEGER ncols_A = M * K;
459
+
460
+ FINTEGER nrows_B = M * K;
461
+ FINTEGER ncols_B = M * K;
462
+
463
+ std::vector<double> d_codebooks(M * K * d);
464
+
465
+ double alpha = 1.0f;
466
+ double beta = 0.0f;
467
+ dgemm_("Not Transposed",
468
+ "Not Transposed",
469
+ &nrows_A, // nrows of op(A)
470
+ &ncols_B, // ncols of op(B)
471
+ &ncols_A, // ncols of op(A)
472
+ &alpha,
473
+ bx.data(),
474
+ &nrows_A, // nrows of A
475
+ bb.data(),
476
+ &nrows_B, // nrows of B
477
+ &beta,
478
+ d_codebooks.data(),
479
+ &nrows_A); // nrows of output
480
+
481
+ for (size_t i = 0; i < M * K * d; i++) {
482
+ codebooks[i] = (float)d_codebooks[i];
483
+ }
484
+ }
368
485
  }
369
486
 
370
487
  /** encode using iterative conditional mode
@@ -386,15 +503,23 @@ void LocalSearchQuantizer::update_codebooks(
386
503
  * These two terms can be precomputed and store in a look up table.
387
504
  */
388
505
  void LocalSearchQuantizer::icm_encode(
389
- const float* x,
390
506
  int32_t* codes,
507
+ const float* x,
391
508
  size_t n,
392
509
  size_t ils_iters,
393
510
  std::mt19937& gen) const {
394
- lsq_timer.start("icm_encode");
511
+ LSQTimerScope scope(&lsq_timer, "icm_encode");
512
+
513
+ auto factory = icm_encoder_factory;
514
+ std::unique_ptr<lsq::IcmEncoder> icm_encoder;
515
+ if (factory == nullptr) {
516
+ icm_encoder.reset(lsq::IcmEncoderFactory().get(this));
517
+ } else {
518
+ icm_encoder.reset(factory->get(this));
519
+ }
395
520
 
396
- std::vector<float> binaries(M * M * K * K); // [M, M, K, K]
397
- compute_binary_terms(binaries.data());
521
+ // precompute binary terms for all chunks
522
+ icm_encoder->set_binary_term();
398
523
 
399
524
  const size_t n_chunks = (n + chunk_size - 1) / chunk_size;
400
525
  for (size_t i = 0; i < n_chunks; i++) {
@@ -410,21 +535,20 @@ void LocalSearchQuantizer::icm_encode(
410
535
 
411
536
  const float* xi = x + i * chunk_size * d;
412
537
  int32_t* codesi = codes + i * chunk_size * M;
413
- icm_encode_partial(i, xi, codesi, ni, binaries.data(), ils_iters, gen);
538
+ icm_encoder->verbose = (verbose && i == 0);
539
+ icm_encoder->encode(codesi, xi, gen, ni, ils_iters);
414
540
  }
415
-
416
- lsq_timer.end("icm_encode");
417
541
  }
418
542
 
419
- void LocalSearchQuantizer::icm_encode_partial(
420
- size_t index,
421
- const float* x,
543
+ void LocalSearchQuantizer::icm_encode_impl(
422
544
  int32_t* codes,
423
- size_t n,
545
+ const float* x,
424
546
  const float* binaries,
547
+ std::mt19937& gen,
548
+ size_t n,
425
549
  size_t ils_iters,
426
- std::mt19937& gen) const {
427
- std::vector<float> unaries(n * M * K); // [n, M, K]
550
+ bool verbose) const {
551
+ std::vector<float> unaries(n * M * K); // [M, n, K]
428
552
  compute_unary_terms(x, unaries.data(), n);
429
553
 
430
554
  std::vector<int32_t> best_codes;
@@ -438,9 +562,7 @@ void LocalSearchQuantizer::icm_encode_partial(
438
562
  // add perturbation to codes
439
563
  perturb_codes(codes, n, gen);
440
564
 
441
- for (size_t iter2 = 0; iter2 < icm_iters; iter2++) {
442
- icm_encode_step(unaries.data(), binaries, codes, n);
443
- }
565
+ icm_encode_step(codes, unaries.data(), binaries, n, icm_iters);
444
566
 
445
567
  std::vector<float> icm_objs(n, 0.0f);
446
568
  evaluate(codes, x, n, icm_objs.data());
@@ -463,7 +585,7 @@ void LocalSearchQuantizer::icm_encode_partial(
463
585
 
464
586
  memcpy(codes, best_codes.data(), sizeof(int32_t) * n * M);
465
587
 
466
- if (verbose && index == 0) {
588
+ if (verbose) {
467
589
  printf("\tils_iter %zd: obj = %lf, n_betters/n = %zd/%zd\n",
468
590
  iter1,
469
591
  mean_obj,
@@ -474,61 +596,67 @@ void LocalSearchQuantizer::icm_encode_partial(
474
596
  }
475
597
 
476
598
  void LocalSearchQuantizer::icm_encode_step(
599
+ int32_t* codes,
477
600
  const float* unaries,
478
601
  const float* binaries,
479
- int32_t* codes,
480
- size_t n) const {
481
- // condition on the m-th subcode
482
- for (size_t m = 0; m < M; m++) {
483
- std::vector<float> objs(n * K);
484
- #pragma omp parallel for
485
- for (int64_t i = 0; i < n; i++) {
486
- auto u = unaries + i * (M * K) + m * K;
487
- memcpy(objs.data() + i * K, u, sizeof(float) * K);
488
- }
602
+ size_t n,
603
+ size_t n_iters) const {
604
+ FAISS_THROW_IF_NOT(M != 0 && K != 0);
605
+ FAISS_THROW_IF_NOT(binaries != nullptr);
489
606
 
490
- // compute objective function by adding unary
491
- // and binary terms together
492
- for (size_t other_m = 0; other_m < M; other_m++) {
493
- if (other_m == m) {
494
- continue;
607
+ for (size_t iter = 0; iter < n_iters; iter++) {
608
+ // condition on the m-th subcode
609
+ for (size_t m = 0; m < M; m++) {
610
+ std::vector<float> objs(n * K);
611
+ #pragma omp parallel for
612
+ for (int64_t i = 0; i < n; i++) {
613
+ auto u = unaries + m * n * K + i * K;
614
+ memcpy(objs.data() + i * K, u, sizeof(float) * K);
495
615
  }
496
616
 
617
+ // compute objective function by adding unary
618
+ // and binary terms together
619
+ for (size_t other_m = 0; other_m < M; other_m++) {
620
+ if (other_m == m) {
621
+ continue;
622
+ }
623
+
497
624
  #pragma omp parallel for
498
- for (int64_t i = 0; i < n; i++) {
499
- for (int32_t code = 0; code < K; code++) {
500
- int32_t code2 = codes[i * M + other_m];
501
- size_t binary_idx =
502
- m * M * K * K + other_m * K * K + code * K + code2;
503
- // binaries[m, other_m, code, code2]
504
- objs[i * K + code] += binaries[binary_idx];
625
+ for (int64_t i = 0; i < n; i++) {
626
+ for (int32_t code = 0; code < K; code++) {
627
+ int32_t code2 = codes[i * M + other_m];
628
+ size_t binary_idx = m * M * K * K + other_m * K * K +
629
+ code * K + code2;
630
+ // binaries[m, other_m, code, code2]
631
+ objs[i * K + code] += binaries[binary_idx];
632
+ }
505
633
  }
506
634
  }
507
- }
508
635
 
509
- // find the optimal value of the m-th subcode
636
+ // find the optimal value of the m-th subcode
510
637
  #pragma omp parallel for
511
- for (int64_t i = 0; i < n; i++) {
512
- float best_obj = HUGE_VALF;
513
- int32_t best_code = 0;
514
- for (size_t code = 0; code < K; code++) {
515
- float obj = objs[i * K + code];
516
- if (obj < best_obj) {
517
- best_obj = obj;
518
- best_code = code;
638
+ for (int64_t i = 0; i < n; i++) {
639
+ float best_obj = HUGE_VALF;
640
+ int32_t best_code = 0;
641
+ for (size_t code = 0; code < K; code++) {
642
+ float obj = objs[i * K + code];
643
+ if (obj < best_obj) {
644
+ best_obj = obj;
645
+ best_code = code;
646
+ }
519
647
  }
648
+ codes[i * M + m] = best_code;
520
649
  }
521
- codes[i * M + m] = best_code;
522
- }
523
650
 
524
- } // loop M
651
+ } // loop M
652
+ }
525
653
  }
526
654
 
527
655
  void LocalSearchQuantizer::perturb_codes(
528
656
  int32_t* codes,
529
657
  size_t n,
530
658
  std::mt19937& gen) const {
531
- lsq_timer.start("perturb_codes");
659
+ LSQTimerScope scope(&lsq_timer, "perturb_codes");
532
660
 
533
661
  std::uniform_int_distribution<size_t> m_distrib(0, M - 1);
534
662
  std::uniform_int_distribution<int32_t> k_distrib(0, K - 1);
@@ -539,12 +667,10 @@ void LocalSearchQuantizer::perturb_codes(
539
667
  codes[i * M + m] = k_distrib(gen);
540
668
  }
541
669
  }
542
-
543
- lsq_timer.end("perturb_codes");
544
670
  }
545
671
 
546
672
  void LocalSearchQuantizer::compute_binary_terms(float* binaries) const {
547
- lsq_timer.start("compute_binary_terms");
673
+ LSQTimerScope scope(&lsq_timer, "compute_binary_terms");
548
674
 
549
675
  #pragma omp parallel for
550
676
  for (int64_t m12 = 0; m12 < M * M; m12++) {
@@ -562,52 +688,53 @@ void LocalSearchQuantizer::compute_binary_terms(float* binaries) const {
562
688
  }
563
689
  }
564
690
  }
565
-
566
- lsq_timer.end("compute_binary_terms");
567
691
  }
568
692
 
569
693
  void LocalSearchQuantizer::compute_unary_terms(
570
694
  const float* x,
571
- float* unaries,
695
+ float* unaries, // [M, n, K]
572
696
  size_t n) const {
573
- lsq_timer.start("compute_unary_terms");
697
+ LSQTimerScope scope(&lsq_timer, "compute_unary_terms");
574
698
 
575
- // compute x * codebooks^T
699
+ // compute x * codebook^T for each codebook
576
700
  //
577
701
  // NOTE: LAPACK use column major order
578
702
  // out = alpha * op(A) * op(B) + beta * C
579
- FINTEGER nrows_A = M * K;
580
- FINTEGER ncols_A = d;
581
-
582
- FINTEGER nrows_B = d;
583
- FINTEGER ncols_B = n;
584
-
585
- float alpha = -2.0f;
586
- float beta = 0.0f;
587
- sgemm_("Transposed",
588
- "Not Transposed",
589
- &nrows_A, // nrows of op(A)
590
- &ncols_B, // ncols of op(B)
591
- &ncols_A, // ncols of op(A)
592
- &alpha,
593
- codebooks.data(),
594
- &ncols_A, // nrows of A
595
- x,
596
- &nrows_B, // nrows of B
597
- &beta,
598
- unaries,
599
- &nrows_A); // nrows of output
703
+
704
+ for (size_t m = 0; m < M; m++) {
705
+ FINTEGER nrows_A = K;
706
+ FINTEGER ncols_A = d;
707
+
708
+ FINTEGER nrows_B = d;
709
+ FINTEGER ncols_B = n;
710
+
711
+ float alpha = -2.0f;
712
+ float beta = 0.0f;
713
+ sgemm_("Transposed",
714
+ "Not Transposed",
715
+ &nrows_A, // nrows of op(A)
716
+ &ncols_B, // ncols of op(B)
717
+ &ncols_A, // ncols of op(A)
718
+ &alpha,
719
+ codebooks.data() + m * K * d,
720
+ &ncols_A, // nrows of A
721
+ x,
722
+ &nrows_B, // nrows of B
723
+ &beta,
724
+ unaries + m * n * K,
725
+ &nrows_A); // nrows of output
726
+ }
600
727
 
601
728
  std::vector<float> norms(M * K);
602
729
  fvec_norms_L2sqr(norms.data(), codebooks.data(), d, M * K);
603
730
 
604
731
  #pragma omp parallel for
605
732
  for (int64_t i = 0; i < n; i++) {
606
- float* u = unaries + i * (M * K);
607
- fvec_add(M * K, u, norms.data(), u);
733
+ for (size_t m = 0; m < M; m++) {
734
+ float* u = unaries + m * n * K + i * K;
735
+ fvec_add(K, u, norms.data() + m * K, u);
736
+ }
608
737
  }
609
-
610
- lsq_timer.end("compute_unary_terms");
611
738
  }
612
739
 
613
740
  float LocalSearchQuantizer::evaluate(
@@ -615,7 +742,7 @@ float LocalSearchQuantizer::evaluate(
615
742
  const float* x,
616
743
  size_t n,
617
744
  float* objs) const {
618
- lsq_timer.start("evaluate");
745
+ LSQTimerScope scope(&lsq_timer, "evaluate");
619
746
 
620
747
  // decode
621
748
  std::vector<float> decoded_x(n * d, 0.0f);
@@ -631,7 +758,7 @@ float LocalSearchQuantizer::evaluate(
631
758
  fvec_add(d, decoded_i, c, decoded_i);
632
759
  }
633
760
 
634
- float err = fvec_L2sqr(x + i * d, decoded_i, d);
761
+ float err = faiss::fvec_L2sqr(x + i * d, decoded_i, d);
635
762
  obj += err;
636
763
 
637
764
  if (objs) {
@@ -639,34 +766,68 @@ float LocalSearchQuantizer::evaluate(
639
766
  }
640
767
  }
641
768
 
642
- lsq_timer.end("evaluate");
643
-
644
769
  obj = obj / n;
645
770
  return obj;
646
771
  }
647
772
 
648
- double LSQTimer::get(const std::string& name) {
649
- return duration[name];
773
+ namespace lsq {
774
+
775
+ IcmEncoder::IcmEncoder(const LocalSearchQuantizer* lsq)
776
+ : verbose(false), lsq(lsq) {}
777
+
778
+ void IcmEncoder::set_binary_term() {
779
+ auto M = lsq->M;
780
+ auto K = lsq->K;
781
+ binaries.resize(M * M * K * K);
782
+ lsq->compute_binary_terms(binaries.data());
650
783
  }
651
784
 
652
- void LSQTimer::start(const std::string& name) {
653
- FAISS_THROW_IF_NOT_MSG(!started[name], " timer is already running");
654
- started[name] = true;
655
- t0[name] = getmillisecs();
785
+ void IcmEncoder::encode(
786
+ int32_t* codes,
787
+ const float* x,
788
+ std::mt19937& gen,
789
+ size_t n,
790
+ size_t ils_iters) const {
791
+ lsq->icm_encode_impl(codes, x, binaries.data(), gen, n, ils_iters, verbose);
656
792
  }
657
793
 
658
- void LSQTimer::end(const std::string& name) {
659
- FAISS_THROW_IF_NOT_MSG(started[name], " timer is not running");
660
- double t1 = getmillisecs();
661
- double sec = (t1 - t0[name]) / 1000;
662
- duration[name] += sec;
663
- started[name] = false;
794
+ double LSQTimer::get(const std::string& name) {
795
+ if (t.count(name) == 0) {
796
+ return 0.0;
797
+ } else {
798
+ return t[name];
799
+ }
800
+ }
801
+
802
+ void LSQTimer::add(const std::string& name, double delta) {
803
+ if (t.count(name) == 0) {
804
+ t[name] = delta;
805
+ } else {
806
+ t[name] += delta;
807
+ }
664
808
  }
665
809
 
666
810
  void LSQTimer::reset() {
667
- duration.clear();
668
- t0.clear();
669
- started.clear();
811
+ t.clear();
812
+ }
813
+
814
+ LSQTimerScope::LSQTimerScope(LSQTimer* timer, std::string name)
815
+ : timer(timer), name(name), finished(false) {
816
+ t0 = getmillisecs();
670
817
  }
671
818
 
819
+ void LSQTimerScope::finish() {
820
+ if (!finished) {
821
+ auto delta = getmillisecs() - t0;
822
+ timer->add(name, delta);
823
+ finished = true;
824
+ }
825
+ }
826
+
827
+ LSQTimerScope::~LSQTimerScope() {
828
+ finish();
829
+ }
830
+
831
+ } // namespace lsq
832
+
672
833
  } // namespace faiss