faiss 0.2.0 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -0,0 +1,855 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/impl/LocalSearchQuantizer.h>
9
+
10
+ #include <cstddef>
11
+ #include <cstdio>
12
+ #include <cstring>
13
+ #include <memory>
14
+ #include <random>
15
+
16
+ #include <algorithm>
17
+
18
+ #include <faiss/Clustering.h>
19
+ #include <faiss/impl/AuxIndexStructures.h>
20
+ #include <faiss/impl/FaissAssert.h>
21
+ #include <faiss/utils/distances.h>
22
+ #include <faiss/utils/hamming.h> // BitstringWriter
23
+ #include <faiss/utils/utils.h>
24
+
25
+ extern "C" {
26
+ // LU decomoposition of a general matrix
27
+ void sgetrf_(
28
+ FINTEGER* m,
29
+ FINTEGER* n,
30
+ float* a,
31
+ FINTEGER* lda,
32
+ FINTEGER* ipiv,
33
+ FINTEGER* info);
34
+
35
+ // generate inverse of a matrix given its LU decomposition
36
+ void sgetri_(
37
+ FINTEGER* n,
38
+ float* a,
39
+ FINTEGER* lda,
40
+ FINTEGER* ipiv,
41
+ float* work,
42
+ FINTEGER* lwork,
43
+ FINTEGER* info);
44
+
45
+ // general matrix multiplication
46
+ int sgemm_(
47
+ const char* transa,
48
+ const char* transb,
49
+ FINTEGER* m,
50
+ FINTEGER* n,
51
+ FINTEGER* k,
52
+ const float* alpha,
53
+ const float* a,
54
+ FINTEGER* lda,
55
+ const float* b,
56
+ FINTEGER* ldb,
57
+ float* beta,
58
+ float* c,
59
+ FINTEGER* ldc);
60
+
61
+ // LU decomoposition of a general matrix
62
+ void dgetrf_(
63
+ FINTEGER* m,
64
+ FINTEGER* n,
65
+ double* a,
66
+ FINTEGER* lda,
67
+ FINTEGER* ipiv,
68
+ FINTEGER* info);
69
+
70
+ // generate inverse of a matrix given its LU decomposition
71
+ void dgetri_(
72
+ FINTEGER* n,
73
+ double* a,
74
+ FINTEGER* lda,
75
+ FINTEGER* ipiv,
76
+ double* work,
77
+ FINTEGER* lwork,
78
+ FINTEGER* info);
79
+
80
+ // general matrix multiplication
81
+ int dgemm_(
82
+ const char* transa,
83
+ const char* transb,
84
+ FINTEGER* m,
85
+ FINTEGER* n,
86
+ FINTEGER* k,
87
+ const double* alpha,
88
+ const double* a,
89
+ FINTEGER* lda,
90
+ const double* b,
91
+ FINTEGER* ldb,
92
+ double* beta,
93
+ double* c,
94
+ FINTEGER* ldc);
95
+ }
96
+
97
+ namespace {
98
+
99
+ void fmat_inverse(float* a, int n) {
100
+ int info;
101
+ int lwork = n * n;
102
+ std::vector<int> ipiv(n);
103
+ std::vector<float> workspace(lwork);
104
+
105
+ sgetrf_(&n, &n, a, &n, ipiv.data(), &info);
106
+ FAISS_THROW_IF_NOT(info == 0);
107
+ sgetri_(&n, a, &n, ipiv.data(), workspace.data(), &lwork, &info);
108
+ FAISS_THROW_IF_NOT(info == 0);
109
+ }
110
+
111
+ // c and a and b can overlap
112
+ void dfvec_add(size_t d, const double* a, const float* b, double* c) {
113
+ for (size_t i = 0; i < d; i++) {
114
+ c[i] = a[i] + b[i];
115
+ }
116
+ }
117
+
118
+ void dmat_inverse(double* a, int n) {
119
+ int info;
120
+ int lwork = n * n;
121
+ std::vector<int> ipiv(n);
122
+ std::vector<double> workspace(lwork);
123
+
124
+ dgetrf_(&n, &n, a, &n, ipiv.data(), &info);
125
+ FAISS_THROW_IF_NOT(info == 0);
126
+ dgetri_(&n, a, &n, ipiv.data(), workspace.data(), &lwork, &info);
127
+ FAISS_THROW_IF_NOT(info == 0);
128
+ }
129
+
130
+ void random_int32(
131
+ std::vector<int32_t>& x,
132
+ int32_t min,
133
+ int32_t max,
134
+ std::mt19937& gen) {
135
+ std::uniform_int_distribution<int32_t> distrib(min, max);
136
+ for (size_t i = 0; i < x.size(); i++) {
137
+ x[i] = distrib(gen);
138
+ }
139
+ }
140
+
141
+ } // anonymous namespace
142
+
143
+ namespace faiss {
144
+
145
+ lsq::LSQTimer lsq_timer;
146
+ using lsq::LSQTimerScope;
147
+
148
+ LocalSearchQuantizer::LocalSearchQuantizer(
149
+ size_t d,
150
+ size_t M,
151
+ size_t nbits,
152
+ Search_type_t search_type)
153
+ : AdditiveQuantizer(d, std::vector<size_t>(M, nbits), search_type) {
154
+ is_trained = false;
155
+ verbose = false;
156
+
157
+ K = (1 << nbits);
158
+
159
+ train_iters = 25;
160
+ train_ils_iters = 8;
161
+ icm_iters = 4;
162
+
163
+ encode_ils_iters = 16;
164
+
165
+ p = 0.5f;
166
+ lambd = 1e-2f;
167
+
168
+ chunk_size = 10000;
169
+ nperts = 4;
170
+
171
+ random_seed = 0x12345;
172
+ std::srand(random_seed);
173
+
174
+ icm_encoder_factory = nullptr;
175
+ }
176
+
177
+ LocalSearchQuantizer::~LocalSearchQuantizer() {
178
+ delete icm_encoder_factory;
179
+ }
180
+
181
+ LocalSearchQuantizer::LocalSearchQuantizer() : LocalSearchQuantizer(0, 0, 0) {}
182
+
183
+ void LocalSearchQuantizer::train(size_t n, const float* x) {
184
+ FAISS_THROW_IF_NOT(K == (1 << nbits[0]));
185
+ FAISS_THROW_IF_NOT(nperts <= M);
186
+
187
+ lsq_timer.reset();
188
+ LSQTimerScope scope(&lsq_timer, "train");
189
+ if (verbose) {
190
+ printf("Training LSQ, with %zd subcodes on %zd %zdD vectors\n",
191
+ M,
192
+ n,
193
+ d);
194
+ }
195
+
196
+ // allocate memory for codebooks, size [M, K, d]
197
+ codebooks.resize(M * K * d);
198
+
199
+ // randomly intialize codes
200
+ std::mt19937 gen(random_seed);
201
+ std::vector<int32_t> codes(n * M); // [n, M]
202
+ random_int32(codes, 0, K - 1, gen);
203
+
204
+ // compute standard derivations of each dimension
205
+ std::vector<float> stddev(d, 0);
206
+
207
+ #pragma omp parallel for
208
+ for (int64_t i = 0; i < d; i++) {
209
+ float mean = 0;
210
+ for (size_t j = 0; j < n; j++) {
211
+ mean += x[j * d + i];
212
+ }
213
+ mean = mean / n;
214
+
215
+ float sum = 0;
216
+ for (size_t j = 0; j < n; j++) {
217
+ float xi = x[j * d + i] - mean;
218
+ sum += xi * xi;
219
+ }
220
+ stddev[i] = sqrtf(sum / n);
221
+ }
222
+
223
+ if (verbose) {
224
+ float obj = evaluate(codes.data(), x, n);
225
+ printf("Before training: obj = %lf\n", obj);
226
+ }
227
+
228
+ for (size_t i = 0; i < train_iters; i++) {
229
+ // 1. update codebooks given x and codes
230
+ // 2. add perturbation to codebooks (SR-D)
231
+ // 3. refine codes given x and codebooks using icm
232
+
233
+ // update codebooks
234
+ update_codebooks(x, codes.data(), n);
235
+
236
+ if (verbose) {
237
+ float obj = evaluate(codes.data(), x, n);
238
+ printf("iter %zd:\n", i);
239
+ printf("\tafter updating codebooks: obj = %lf\n", obj);
240
+ }
241
+
242
+ // SR-D: perturb codebooks
243
+ float T = pow((1.0f - (i + 1.0f) / train_iters), p);
244
+ perturb_codebooks(T, stddev, gen);
245
+
246
+ if (verbose) {
247
+ float obj = evaluate(codes.data(), x, n);
248
+ printf("\tafter perturbing codebooks: obj = %lf\n", obj);
249
+ }
250
+
251
+ // refine codes
252
+ icm_encode(codes.data(), x, n, train_ils_iters, gen);
253
+
254
+ if (verbose) {
255
+ float obj = evaluate(codes.data(), x, n);
256
+ printf("\tafter updating codes: obj = %lf\n", obj);
257
+ }
258
+ }
259
+
260
+ is_trained = true;
261
+ {
262
+ std::vector<float> x_recons(n * d);
263
+ std::vector<float> norms(n);
264
+ decode_unpacked(codes.data(), x_recons.data(), n);
265
+ fvec_norms_L2sqr(norms.data(), x_recons.data(), d, n);
266
+
267
+ norm_min = HUGE_VALF;
268
+ norm_max = -HUGE_VALF;
269
+ for (idx_t i = 0; i < n; i++) {
270
+ if (norms[i] < norm_min) {
271
+ norm_min = norms[i];
272
+ }
273
+ if (norms[i] > norm_max) {
274
+ norm_max = norms[i];
275
+ }
276
+ }
277
+
278
+ if (search_type == ST_norm_cqint8 || search_type == ST_norm_cqint4) {
279
+ size_t k = (1 << 8);
280
+ if (search_type == ST_norm_cqint4) {
281
+ k = (1 << 4);
282
+ }
283
+ Clustering1D clus(k);
284
+ clus.train_exact(n, norms.data());
285
+ qnorm.add(clus.k, clus.centroids.data());
286
+ }
287
+ }
288
+
289
+ if (verbose) {
290
+ float obj = evaluate(codes.data(), x, n);
291
+ scope.finish();
292
+ printf("After training: obj = %lf\n", obj);
293
+
294
+ printf("Time statistic:\n");
295
+ for (const auto& it : lsq_timer.t) {
296
+ printf("\t%s time: %lf s\n", it.first.data(), it.second / 1000);
297
+ }
298
+ }
299
+ }
300
+
301
+ void LocalSearchQuantizer::perturb_codebooks(
302
+ float T,
303
+ const std::vector<float>& stddev,
304
+ std::mt19937& gen) {
305
+ LSQTimerScope scope(&lsq_timer, "perturb_codebooks");
306
+
307
+ std::vector<std::normal_distribution<float>> distribs;
308
+ for (size_t i = 0; i < d; i++) {
309
+ distribs.emplace_back(0.0f, stddev[i]);
310
+ }
311
+
312
+ for (size_t m = 0; m < M; m++) {
313
+ for (size_t k = 0; k < K; k++) {
314
+ for (size_t i = 0; i < d; i++) {
315
+ codebooks[m * K * d + k * d + i] += T * distribs[i](gen) / M;
316
+ }
317
+ }
318
+ }
319
+ }
320
+
321
+ void LocalSearchQuantizer::compute_codes(
322
+ const float* x,
323
+ uint8_t* codes_out,
324
+ size_t n) const {
325
+ FAISS_THROW_IF_NOT_MSG(is_trained, "LSQ is not trained yet.");
326
+
327
+ lsq_timer.reset();
328
+ LSQTimerScope scope(&lsq_timer, "encode");
329
+ if (verbose) {
330
+ printf("Encoding %zd vectors...\n", n);
331
+ }
332
+
333
+ std::vector<int32_t> codes(n * M);
334
+ std::mt19937 gen(random_seed);
335
+ random_int32(codes, 0, K - 1, gen);
336
+
337
+ icm_encode(codes.data(), x, n, encode_ils_iters, gen);
338
+ pack_codes(n, codes.data(), codes_out);
339
+
340
+ if (verbose) {
341
+ scope.finish();
342
+ printf("Time statistic:\n");
343
+ for (const auto& it : lsq_timer.t) {
344
+ printf("\t%s time: %lf s\n", it.first.data(), it.second / 1000);
345
+ }
346
+ }
347
+ }
348
+
349
+ /** update codebooks given x and codes
350
+ *
351
+ * Let B denote the sparse matrix of codes, size [n, M * K].
352
+ * Let C denote the codebooks, size [M * K, d].
353
+ * Let X denote the training vectors, size [n, d]
354
+ *
355
+ * objective function:
356
+ * L = (X - BC)^2
357
+ *
358
+ * To minimize L, we have:
359
+ * C = (B'B)^(-1)B'X
360
+ * where ' denote transposed
361
+ *
362
+ * Add a regularization term to make B'B inversible:
363
+ * C = (B'B + lambd * I)^(-1)B'X
364
+ */
365
+ void LocalSearchQuantizer::update_codebooks(
366
+ const float* x,
367
+ const int32_t* codes,
368
+ size_t n) {
369
+ LSQTimerScope scope(&lsq_timer, "update_codebooks");
370
+
371
+ if (!update_codebooks_with_double) {
372
+ // allocate memory
373
+ // bb = B'B, bx = BX
374
+ std::vector<float> bb(M * K * M * K, 0.0f); // [M * K, M * K]
375
+ std::vector<float> bx(M * K * d, 0.0f); // [M * K, d]
376
+
377
+ // compute B'B
378
+ for (size_t i = 0; i < n; i++) {
379
+ for (size_t m = 0; m < M; m++) {
380
+ int32_t code1 = codes[i * M + m];
381
+ int32_t idx1 = m * K + code1;
382
+ bb[idx1 * M * K + idx1] += 1;
383
+
384
+ for (size_t m2 = m + 1; m2 < M; m2++) {
385
+ int32_t code2 = codes[i * M + m2];
386
+ int32_t idx2 = m2 * K + code2;
387
+ bb[idx1 * M * K + idx2] += 1;
388
+ bb[idx2 * M * K + idx1] += 1;
389
+ }
390
+ }
391
+ }
392
+
393
+ // add a regularization term to B'B
394
+ for (int64_t i = 0; i < M * K; i++) {
395
+ bb[i * (M * K) + i] += lambd;
396
+ }
397
+
398
+ // compute (B'B)^(-1)
399
+ fmat_inverse(bb.data(), M * K); // [M*K, M*K]
400
+
401
+ // compute BX
402
+ for (size_t i = 0; i < n; i++) {
403
+ for (size_t m = 0; m < M; m++) {
404
+ int32_t code = codes[i * M + m];
405
+ float* data = bx.data() + (m * K + code) * d;
406
+ fvec_add(d, data, x + i * d, data);
407
+ }
408
+ }
409
+
410
+ // compute C = (B'B)^(-1) @ BX
411
+ //
412
+ // NOTE: LAPACK use column major order
413
+ // out = alpha * op(A) * op(B) + beta * C
414
+ FINTEGER nrows_A = d;
415
+ FINTEGER ncols_A = M * K;
416
+
417
+ FINTEGER nrows_B = M * K;
418
+ FINTEGER ncols_B = M * K;
419
+
420
+ float alpha = 1.0f;
421
+ float beta = 0.0f;
422
+ sgemm_("Not Transposed",
423
+ "Not Transposed",
424
+ &nrows_A, // nrows of op(A)
425
+ &ncols_B, // ncols of op(B)
426
+ &ncols_A, // ncols of op(A)
427
+ &alpha,
428
+ bx.data(),
429
+ &nrows_A, // nrows of A
430
+ bb.data(),
431
+ &nrows_B, // nrows of B
432
+ &beta,
433
+ codebooks.data(),
434
+ &nrows_A); // nrows of output
435
+
436
+ } else {
437
+ // allocate memory
438
+ // bb = B'B, bx = BX
439
+ std::vector<double> bb(M * K * M * K, 0.0f); // [M * K, M * K]
440
+ std::vector<double> bx(M * K * d, 0.0f); // [M * K, d]
441
+
442
+ // compute B'B
443
+ for (size_t i = 0; i < n; i++) {
444
+ for (size_t m = 0; m < M; m++) {
445
+ int32_t code1 = codes[i * M + m];
446
+ int32_t idx1 = m * K + code1;
447
+ bb[idx1 * M * K + idx1] += 1;
448
+
449
+ for (size_t m2 = m + 1; m2 < M; m2++) {
450
+ int32_t code2 = codes[i * M + m2];
451
+ int32_t idx2 = m2 * K + code2;
452
+ bb[idx1 * M * K + idx2] += 1;
453
+ bb[idx2 * M * K + idx1] += 1;
454
+ }
455
+ }
456
+ }
457
+
458
+ // add a regularization term to B'B
459
+ for (int64_t i = 0; i < M * K; i++) {
460
+ bb[i * (M * K) + i] += lambd;
461
+ }
462
+
463
+ // compute (B'B)^(-1)
464
+ dmat_inverse(bb.data(), M * K); // [M*K, M*K]
465
+
466
+ // compute BX
467
+ for (size_t i = 0; i < n; i++) {
468
+ for (size_t m = 0; m < M; m++) {
469
+ int32_t code = codes[i * M + m];
470
+ double* data = bx.data() + (m * K + code) * d;
471
+ dfvec_add(d, data, x + i * d, data);
472
+ }
473
+ }
474
+
475
+ // compute C = (B'B)^(-1) @ BX
476
+ //
477
+ // NOTE: LAPACK use column major order
478
+ // out = alpha * op(A) * op(B) + beta * C
479
+ FINTEGER nrows_A = d;
480
+ FINTEGER ncols_A = M * K;
481
+
482
+ FINTEGER nrows_B = M * K;
483
+ FINTEGER ncols_B = M * K;
484
+
485
+ std::vector<double> d_codebooks(M * K * d);
486
+
487
+ double alpha = 1.0f;
488
+ double beta = 0.0f;
489
+ dgemm_("Not Transposed",
490
+ "Not Transposed",
491
+ &nrows_A, // nrows of op(A)
492
+ &ncols_B, // ncols of op(B)
493
+ &ncols_A, // ncols of op(A)
494
+ &alpha,
495
+ bx.data(),
496
+ &nrows_A, // nrows of A
497
+ bb.data(),
498
+ &nrows_B, // nrows of B
499
+ &beta,
500
+ d_codebooks.data(),
501
+ &nrows_A); // nrows of output
502
+
503
+ for (size_t i = 0; i < M * K * d; i++) {
504
+ codebooks[i] = (float)d_codebooks[i];
505
+ }
506
+ }
507
+ }
508
+
509
+ /** encode using iterative conditional mode
510
+ *
511
+ * iterative conditional mode:
512
+ * For every subcode ci (i = 1, ..., M) of a vector, we fix the other
513
+ * subcodes cj (j != i) and then find the optimal value of ci such
514
+ * that minimizing the objective function.
515
+
516
+ * objective function:
517
+ * L = (X - \sum cj)^2, j = 1, ..., M
518
+ * L = X^2 - 2X * \sum cj + (\sum cj)^2
519
+ *
520
+ * X^2 is negligable since it is the same for all possible value
521
+ * k of the m-th subcode.
522
+ *
523
+ * 2X * \sum cj is the unary term
524
+ * (\sum cj)^2 is the binary term
525
+ * These two terms can be precomputed and store in a look up table.
526
+ */
527
+ void LocalSearchQuantizer::icm_encode(
528
+ int32_t* codes,
529
+ const float* x,
530
+ size_t n,
531
+ size_t ils_iters,
532
+ std::mt19937& gen) const {
533
+ LSQTimerScope scope(&lsq_timer, "icm_encode");
534
+
535
+ auto factory = icm_encoder_factory;
536
+ std::unique_ptr<lsq::IcmEncoder> icm_encoder;
537
+ if (factory == nullptr) {
538
+ icm_encoder.reset(lsq::IcmEncoderFactory().get(this));
539
+ } else {
540
+ icm_encoder.reset(factory->get(this));
541
+ }
542
+
543
+ // precompute binary terms for all chunks
544
+ icm_encoder->set_binary_term();
545
+
546
+ const size_t n_chunks = (n + chunk_size - 1) / chunk_size;
547
+ for (size_t i = 0; i < n_chunks; i++) {
548
+ size_t ni = std::min(chunk_size, n - i * chunk_size);
549
+
550
+ if (verbose) {
551
+ printf("\r\ticm encoding %zd/%zd ...", i * chunk_size + ni, n);
552
+ fflush(stdout);
553
+ if (i == n_chunks - 1 || i == 0) {
554
+ printf("\n");
555
+ }
556
+ }
557
+
558
+ const float* xi = x + i * chunk_size * d;
559
+ int32_t* codesi = codes + i * chunk_size * M;
560
+ icm_encoder->verbose = (verbose && i == 0);
561
+ icm_encoder->encode(codesi, xi, gen, ni, ils_iters);
562
+ }
563
+ }
564
+
565
+ void LocalSearchQuantizer::icm_encode_impl(
566
+ int32_t* codes,
567
+ const float* x,
568
+ const float* binaries,
569
+ std::mt19937& gen,
570
+ size_t n,
571
+ size_t ils_iters,
572
+ bool verbose) const {
573
+ std::vector<float> unaries(n * M * K); // [M, n, K]
574
+ compute_unary_terms(x, unaries.data(), n);
575
+
576
+ std::vector<int32_t> best_codes;
577
+ best_codes.assign(codes, codes + n * M);
578
+
579
+ std::vector<float> best_objs(n, 0.0f);
580
+ evaluate(codes, x, n, best_objs.data());
581
+
582
+ FAISS_THROW_IF_NOT(nperts <= M);
583
+ for (size_t iter1 = 0; iter1 < ils_iters; iter1++) {
584
+ // add perturbation to codes
585
+ perturb_codes(codes, n, gen);
586
+
587
+ icm_encode_step(codes, unaries.data(), binaries, n, icm_iters);
588
+
589
+ std::vector<float> icm_objs(n, 0.0f);
590
+ evaluate(codes, x, n, icm_objs.data());
591
+ size_t n_betters = 0;
592
+ float mean_obj = 0.0f;
593
+
594
+ // select the best code for every vector xi
595
+ #pragma omp parallel for reduction(+ : n_betters, mean_obj)
596
+ for (int64_t i = 0; i < n; i++) {
597
+ if (icm_objs[i] < best_objs[i]) {
598
+ best_objs[i] = icm_objs[i];
599
+ memcpy(best_codes.data() + i * M,
600
+ codes + i * M,
601
+ sizeof(int32_t) * M);
602
+ n_betters += 1;
603
+ }
604
+ mean_obj += best_objs[i];
605
+ }
606
+ mean_obj /= n;
607
+
608
+ memcpy(codes, best_codes.data(), sizeof(int32_t) * n * M);
609
+
610
+ if (verbose) {
611
+ printf("\tils_iter %zd: obj = %lf, n_betters/n = %zd/%zd\n",
612
+ iter1,
613
+ mean_obj,
614
+ n_betters,
615
+ n);
616
+ }
617
+ } // loop ils_iters
618
+ }
619
+
620
+ void LocalSearchQuantizer::icm_encode_step(
621
+ int32_t* codes,
622
+ const float* unaries,
623
+ const float* binaries,
624
+ size_t n,
625
+ size_t n_iters) const {
626
+ FAISS_THROW_IF_NOT(M != 0 && K != 0);
627
+ FAISS_THROW_IF_NOT(binaries != nullptr);
628
+
629
+ for (size_t iter = 0; iter < n_iters; iter++) {
630
+ // condition on the m-th subcode
631
+ for (size_t m = 0; m < M; m++) {
632
+ std::vector<float> objs(n * K);
633
+ #pragma omp parallel for
634
+ for (int64_t i = 0; i < n; i++) {
635
+ auto u = unaries + m * n * K + i * K;
636
+ memcpy(objs.data() + i * K, u, sizeof(float) * K);
637
+ }
638
+
639
+ // compute objective function by adding unary
640
+ // and binary terms together
641
+ for (size_t other_m = 0; other_m < M; other_m++) {
642
+ if (other_m == m) {
643
+ continue;
644
+ }
645
+
646
+ #pragma omp parallel for
647
+ for (int64_t i = 0; i < n; i++) {
648
+ for (int32_t code = 0; code < K; code++) {
649
+ int32_t code2 = codes[i * M + other_m];
650
+ size_t binary_idx = m * M * K * K + other_m * K * K +
651
+ code * K + code2;
652
+ // binaries[m, other_m, code, code2]
653
+ objs[i * K + code] += binaries[binary_idx];
654
+ }
655
+ }
656
+ }
657
+
658
+ // find the optimal value of the m-th subcode
659
+ #pragma omp parallel for
660
+ for (int64_t i = 0; i < n; i++) {
661
+ float best_obj = HUGE_VALF;
662
+ int32_t best_code = 0;
663
+ for (size_t code = 0; code < K; code++) {
664
+ float obj = objs[i * K + code];
665
+ if (obj < best_obj) {
666
+ best_obj = obj;
667
+ best_code = code;
668
+ }
669
+ }
670
+ codes[i * M + m] = best_code;
671
+ }
672
+
673
+ } // loop M
674
+ }
675
+ }
676
+
677
+ void LocalSearchQuantizer::perturb_codes(
678
+ int32_t* codes,
679
+ size_t n,
680
+ std::mt19937& gen) const {
681
+ LSQTimerScope scope(&lsq_timer, "perturb_codes");
682
+
683
+ std::uniform_int_distribution<size_t> m_distrib(0, M - 1);
684
+ std::uniform_int_distribution<int32_t> k_distrib(0, K - 1);
685
+
686
+ for (size_t i = 0; i < n; i++) {
687
+ for (size_t j = 0; j < nperts; j++) {
688
+ size_t m = m_distrib(gen);
689
+ codes[i * M + m] = k_distrib(gen);
690
+ }
691
+ }
692
+ }
693
+
694
+ void LocalSearchQuantizer::compute_binary_terms(float* binaries) const {
695
+ LSQTimerScope scope(&lsq_timer, "compute_binary_terms");
696
+
697
+ #pragma omp parallel for
698
+ for (int64_t m12 = 0; m12 < M * M; m12++) {
699
+ size_t m1 = m12 / M;
700
+ size_t m2 = m12 % M;
701
+
702
+ for (size_t code1 = 0; code1 < K; code1++) {
703
+ for (size_t code2 = 0; code2 < K; code2++) {
704
+ const float* c1 = codebooks.data() + m1 * K * d + code1 * d;
705
+ const float* c2 = codebooks.data() + m2 * K * d + code2 * d;
706
+ float ip = fvec_inner_product(c1, c2, d);
707
+ // binaries[m1, m2, code1, code2] = ip * 2
708
+ binaries[m1 * M * K * K + m2 * K * K + code1 * K + code2] =
709
+ ip * 2;
710
+ }
711
+ }
712
+ }
713
+ }
714
+
715
+ void LocalSearchQuantizer::compute_unary_terms(
716
+ const float* x,
717
+ float* unaries, // [M, n, K]
718
+ size_t n) const {
719
+ LSQTimerScope scope(&lsq_timer, "compute_unary_terms");
720
+
721
+ // compute x * codebook^T for each codebook
722
+ //
723
+ // NOTE: LAPACK use column major order
724
+ // out = alpha * op(A) * op(B) + beta * C
725
+
726
+ for (size_t m = 0; m < M; m++) {
727
+ FINTEGER nrows_A = K;
728
+ FINTEGER ncols_A = d;
729
+
730
+ FINTEGER nrows_B = d;
731
+ FINTEGER ncols_B = n;
732
+
733
+ float alpha = -2.0f;
734
+ float beta = 0.0f;
735
+ sgemm_("Transposed",
736
+ "Not Transposed",
737
+ &nrows_A, // nrows of op(A)
738
+ &ncols_B, // ncols of op(B)
739
+ &ncols_A, // ncols of op(A)
740
+ &alpha,
741
+ codebooks.data() + m * K * d,
742
+ &ncols_A, // nrows of A
743
+ x,
744
+ &nrows_B, // nrows of B
745
+ &beta,
746
+ unaries + m * n * K,
747
+ &nrows_A); // nrows of output
748
+ }
749
+
750
+ std::vector<float> norms(M * K);
751
+ fvec_norms_L2sqr(norms.data(), codebooks.data(), d, M * K);
752
+
753
+ #pragma omp parallel for
754
+ for (int64_t i = 0; i < n; i++) {
755
+ for (size_t m = 0; m < M; m++) {
756
+ float* u = unaries + m * n * K + i * K;
757
+ fvec_add(K, u, norms.data() + m * K, u);
758
+ }
759
+ }
760
+ }
761
+
762
+ float LocalSearchQuantizer::evaluate(
763
+ const int32_t* codes,
764
+ const float* x,
765
+ size_t n,
766
+ float* objs) const {
767
+ LSQTimerScope scope(&lsq_timer, "evaluate");
768
+
769
+ // decode
770
+ std::vector<float> decoded_x(n * d, 0.0f);
771
+ float obj = 0.0f;
772
+
773
+ #pragma omp parallel for reduction(+ : obj)
774
+ for (int64_t i = 0; i < n; i++) {
775
+ const auto code = codes + i * M;
776
+ const auto decoded_i = decoded_x.data() + i * d;
777
+ for (size_t m = 0; m < M; m++) {
778
+ // c = codebooks[m, code[m]]
779
+ const auto c = codebooks.data() + m * K * d + code[m] * d;
780
+ fvec_add(d, decoded_i, c, decoded_i);
781
+ }
782
+
783
+ float err = faiss::fvec_L2sqr(x + i * d, decoded_i, d);
784
+ obj += err;
785
+
786
+ if (objs) {
787
+ objs[i] = err;
788
+ }
789
+ }
790
+
791
+ obj = obj / n;
792
+ return obj;
793
+ }
794
+
795
+ namespace lsq {
796
+
797
+ IcmEncoder::IcmEncoder(const LocalSearchQuantizer* lsq)
798
+ : verbose(false), lsq(lsq) {}
799
+
800
+ void IcmEncoder::set_binary_term() {
801
+ auto M = lsq->M;
802
+ auto K = lsq->K;
803
+ binaries.resize(M * M * K * K);
804
+ lsq->compute_binary_terms(binaries.data());
805
+ }
806
+
807
+ void IcmEncoder::encode(
808
+ int32_t* codes,
809
+ const float* x,
810
+ std::mt19937& gen,
811
+ size_t n,
812
+ size_t ils_iters) const {
813
+ lsq->icm_encode_impl(codes, x, binaries.data(), gen, n, ils_iters, verbose);
814
+ }
815
+
816
+ double LSQTimer::get(const std::string& name) {
817
+ if (t.count(name) == 0) {
818
+ return 0.0;
819
+ } else {
820
+ return t[name];
821
+ }
822
+ }
823
+
824
+ void LSQTimer::add(const std::string& name, double delta) {
825
+ if (t.count(name) == 0) {
826
+ t[name] = delta;
827
+ } else {
828
+ t[name] += delta;
829
+ }
830
+ }
831
+
832
+ void LSQTimer::reset() {
833
+ t.clear();
834
+ }
835
+
836
+ LSQTimerScope::LSQTimerScope(LSQTimer* timer, std::string name)
837
+ : timer(timer), name(name), finished(false) {
838
+ t0 = getmillisecs();
839
+ }
840
+
841
+ void LSQTimerScope::finish() {
842
+ if (!finished) {
843
+ auto delta = getmillisecs() - t0;
844
+ timer->add(name, delta);
845
+ finished = true;
846
+ }
847
+ }
848
+
849
+ LSQTimerScope::~LSQTimerScope() {
850
+ finish();
851
+ }
852
+
853
+ } // namespace lsq
854
+
855
+ } // namespace faiss