faiss 0.2.0 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (202) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  5. data/vendor/faiss/faiss/AutoTune.h +55 -56
  6. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  7. data/vendor/faiss/faiss/Clustering.h +88 -35
  8. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  9. data/vendor/faiss/faiss/IVFlib.h +48 -51
  10. data/vendor/faiss/faiss/Index.cpp +85 -103
  11. data/vendor/faiss/faiss/Index.h +54 -48
  12. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  13. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  14. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  15. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  16. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  17. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  18. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  19. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  20. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  22. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  23. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  24. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  25. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  26. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  27. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  29. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  30. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  31. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  32. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  33. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  34. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  35. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  38. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  39. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  40. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  41. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  42. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  43. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  44. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  45. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  46. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  47. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  48. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  49. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  50. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  51. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  52. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  53. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  54. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  55. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  56. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  57. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  58. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  59. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  60. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  61. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  62. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  63. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  64. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  65. data/vendor/faiss/faiss/IndexShards.h +85 -73
  66. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  67. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  68. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  69. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  70. data/vendor/faiss/faiss/MetricType.h +7 -7
  71. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  72. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  73. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  74. data/vendor/faiss/faiss/clone_index.h +4 -9
  75. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  76. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  77. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  78. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  79. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  82. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  84. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  85. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  89. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  90. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  91. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  92. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  93. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  94. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  95. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  96. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  97. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  98. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  99. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  100. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  101. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  102. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  103. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  104. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  105. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  106. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  107. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  108. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  110. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  111. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  112. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  113. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  114. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  115. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  116. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  117. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  118. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  119. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  125. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  126. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  127. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  128. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  134. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  135. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  136. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  137. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  138. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  139. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  141. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  142. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  144. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  145. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  146. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  147. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  148. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  149. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  150. data/vendor/faiss/faiss/impl/io.h +31 -41
  151. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  152. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  153. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  154. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  159. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  160. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  161. data/vendor/faiss/faiss/index_factory.h +6 -7
  162. data/vendor/faiss/faiss/index_io.h +23 -26
  163. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  164. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  165. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  166. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  167. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  168. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  169. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  170. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  172. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  173. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  174. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  175. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  176. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  177. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  178. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  179. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  180. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  181. data/vendor/faiss/faiss/utils/distances.h +133 -118
  182. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  183. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  184. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  185. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  186. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  187. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  188. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  189. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  190. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  191. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  192. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  193. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  194. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  195. data/vendor/faiss/faiss/utils/random.h +13 -16
  196. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  197. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  198. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  199. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  200. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  201. data/vendor/faiss/faiss/utils/utils.h +53 -48
  202. metadata +20 -2
@@ -0,0 +1,672 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/impl/FaissAssert.h>
11
+ #include <faiss/impl/LocalSearchQuantizer.h>
12
+
13
+ #include <cstddef>
14
+ #include <cstdio>
15
+ #include <cstring>
16
+ #include <memory>
17
+ #include <random>
18
+
19
+ #include <algorithm>
20
+
21
+ #include <faiss/utils/distances.h>
22
+ #include <faiss/utils/hamming.h> // BitstringWriter
23
+ #include <faiss/utils/utils.h>
24
+
25
+ extern "C" {
26
+ // LU decomoposition of a general matrix
27
+ void sgetrf_(
28
+ FINTEGER* m,
29
+ FINTEGER* n,
30
+ float* a,
31
+ FINTEGER* lda,
32
+ FINTEGER* ipiv,
33
+ FINTEGER* info);
34
+
35
+ // generate inverse of a matrix given its LU decomposition
36
+ void sgetri_(
37
+ FINTEGER* n,
38
+ float* a,
39
+ FINTEGER* lda,
40
+ FINTEGER* ipiv,
41
+ float* work,
42
+ FINTEGER* lwork,
43
+ FINTEGER* info);
44
+
45
+ // solves a system of linear equations
46
+ void sgetrs_(
47
+ const char* trans,
48
+ FINTEGER* n,
49
+ FINTEGER* nrhs,
50
+ float* A,
51
+ FINTEGER* lda,
52
+ FINTEGER* ipiv,
53
+ float* b,
54
+ FINTEGER* ldb,
55
+ FINTEGER* info);
56
+
57
+ // general matrix multiplication
58
+ int sgemm_(
59
+ const char* transa,
60
+ const char* transb,
61
+ FINTEGER* m,
62
+ FINTEGER* n,
63
+ FINTEGER* k,
64
+ const float* alpha,
65
+ const float* a,
66
+ FINTEGER* lda,
67
+ const float* b,
68
+ FINTEGER* ldb,
69
+ float* beta,
70
+ float* c,
71
+ FINTEGER* ldc);
72
+ }
73
+
74
+ namespace {
75
+
76
+ // c and a and b can overlap
77
+ void fvec_add(size_t d, const float* a, const float* b, float* c) {
78
+ for (size_t i = 0; i < d; i++) {
79
+ c[i] = a[i] + b[i];
80
+ }
81
+ }
82
+
83
+ void fmat_inverse(float* a, int n) {
84
+ int info;
85
+ int lwork = n * n;
86
+ std::vector<int> ipiv(n);
87
+ std::vector<float> workspace(lwork);
88
+
89
+ sgetrf_(&n, &n, a, &n, ipiv.data(), &info);
90
+ FAISS_THROW_IF_NOT(info == 0);
91
+ sgetri_(&n, a, &n, ipiv.data(), workspace.data(), &lwork, &info);
92
+ FAISS_THROW_IF_NOT(info == 0);
93
+ }
94
+
95
+ void random_int32(
96
+ std::vector<int32_t>& x,
97
+ int32_t min,
98
+ int32_t max,
99
+ std::mt19937& gen) {
100
+ std::uniform_int_distribution<int32_t> distrib(min, max);
101
+ for (size_t i = 0; i < x.size(); i++) {
102
+ x[i] = distrib(gen);
103
+ }
104
+ }
105
+
106
+ } // anonymous namespace
107
+
108
+ namespace faiss {
109
+
110
+ LSQTimer lsq_timer;
111
+
112
+ LocalSearchQuantizer::LocalSearchQuantizer(size_t d, size_t M, size_t nbits) {
113
+ FAISS_THROW_IF_NOT((M * nbits) % 8 == 0);
114
+
115
+ this->d = d;
116
+ this->M = M;
117
+ this->nbits = std::vector<size_t>(M, nbits);
118
+
119
+ // set derived values
120
+ set_derived_values();
121
+
122
+ is_trained = false;
123
+ verbose = false;
124
+
125
+ K = (1 << nbits);
126
+
127
+ train_iters = 25;
128
+ train_ils_iters = 8;
129
+ icm_iters = 4;
130
+
131
+ encode_ils_iters = 16;
132
+
133
+ p = 0.5f;
134
+ lambd = 1e-2f;
135
+
136
+ chunk_size = 10000;
137
+ nperts = 4;
138
+
139
+ random_seed = 0x12345;
140
+ std::srand(random_seed);
141
+ }
142
+
143
+ void LocalSearchQuantizer::train(size_t n, const float* x) {
144
+ FAISS_THROW_IF_NOT(K == (1 << nbits[0]));
145
+ FAISS_THROW_IF_NOT(nperts <= M);
146
+
147
+ lsq_timer.reset();
148
+ if (verbose) {
149
+ lsq_timer.start("train");
150
+ printf("Training LSQ, with %zd subcodes on %zd %zdD vectors\n",
151
+ M,
152
+ n,
153
+ d);
154
+ }
155
+
156
+ // allocate memory for codebooks, size [M, K, d]
157
+ codebooks.resize(M * K * d);
158
+
159
+ // randomly intialize codes
160
+ std::mt19937 gen(random_seed);
161
+ std::vector<int32_t> codes(n * M); // [n, M]
162
+ random_int32(codes, 0, K - 1, gen);
163
+
164
+ // compute standard derivations of each dimension
165
+ std::vector<float> stddev(d, 0);
166
+
167
+ #pragma omp parallel for
168
+ for (int64_t i = 0; i < d; i++) {
169
+ float mean = 0;
170
+ for (size_t j = 0; j < n; j++) {
171
+ mean += x[j * d + i];
172
+ }
173
+ mean = mean / n;
174
+
175
+ float sum = 0;
176
+ for (size_t j = 0; j < n; j++) {
177
+ float xi = x[j * d + i] - mean;
178
+ sum += xi * xi;
179
+ }
180
+ stddev[i] = sqrtf(sum / n);
181
+ }
182
+
183
+ if (verbose) {
184
+ float obj = evaluate(codes.data(), x, n);
185
+ printf("Before training: obj = %lf\n", obj);
186
+ }
187
+
188
+ for (size_t i = 0; i < train_iters; i++) {
189
+ // 1. update codebooks given x and codes
190
+ // 2. add perturbation to codebooks (SR-D)
191
+ // 3. refine codes given x and codebooks using icm
192
+
193
+ // update codebooks
194
+ update_codebooks(x, codes.data(), n);
195
+
196
+ if (verbose) {
197
+ float obj = evaluate(codes.data(), x, n);
198
+ printf("iter %zd:\n", i);
199
+ printf("\tafter updating codebooks: obj = %lf\n", obj);
200
+ }
201
+
202
+ // SR-D: perturb codebooks
203
+ float T = pow((1.0f - (i + 1.0f) / train_iters), p);
204
+ perturb_codebooks(T, stddev, gen);
205
+
206
+ if (verbose) {
207
+ float obj = evaluate(codes.data(), x, n);
208
+ printf("\tafter perturbing codebooks: obj = %lf\n", obj);
209
+ }
210
+
211
+ // refine codes
212
+ icm_encode(x, codes.data(), n, train_ils_iters, gen);
213
+
214
+ if (verbose) {
215
+ float obj = evaluate(codes.data(), x, n);
216
+ printf("\tafter updating codes: obj = %lf\n", obj);
217
+ }
218
+ }
219
+
220
+ if (verbose) {
221
+ lsq_timer.end("train");
222
+ float obj = evaluate(codes.data(), x, n);
223
+ printf("After training: obj = %lf\n", obj);
224
+
225
+ printf("Time statistic:\n");
226
+ for (const auto& it : lsq_timer.duration) {
227
+ printf("\t%s time: %lf s\n", it.first.data(), it.second);
228
+ }
229
+ }
230
+
231
+ is_trained = true;
232
+ }
233
+
234
+ void LocalSearchQuantizer::perturb_codebooks(
235
+ float T,
236
+ const std::vector<float>& stddev,
237
+ std::mt19937& gen) {
238
+ lsq_timer.start("perturb_codebooks");
239
+
240
+ std::vector<std::normal_distribution<float>> distribs;
241
+ for (size_t i = 0; i < d; i++) {
242
+ distribs.emplace_back(0.0f, stddev[i]);
243
+ }
244
+
245
+ for (size_t m = 0; m < M; m++) {
246
+ for (size_t k = 0; k < K; k++) {
247
+ for (size_t i = 0; i < d; i++) {
248
+ codebooks[m * K * d + k * d + i] += T * distribs[i](gen) / M;
249
+ }
250
+ }
251
+ }
252
+
253
+ lsq_timer.end("perturb_codebooks");
254
+ }
255
+
256
+ void LocalSearchQuantizer::compute_codes(
257
+ const float* x,
258
+ uint8_t* codes_out,
259
+ size_t n) const {
260
+ FAISS_THROW_IF_NOT_MSG(is_trained, "LSQ is not trained yet.");
261
+ if (verbose) {
262
+ lsq_timer.reset();
263
+ printf("Encoding %zd vectors...\n", n);
264
+ lsq_timer.start("encode");
265
+ }
266
+
267
+ std::vector<int32_t> codes(n * M);
268
+ std::mt19937 gen(random_seed);
269
+ random_int32(codes, 0, K - 1, gen);
270
+
271
+ icm_encode(x, codes.data(), n, encode_ils_iters, gen);
272
+ pack_codes(n, codes.data(), codes_out);
273
+
274
+ if (verbose) {
275
+ lsq_timer.end("encode");
276
+ double t = lsq_timer.get("encode");
277
+ printf("Time to encode %zd vectors: %lf s\n", n, t);
278
+ }
279
+ }
280
+
281
+ /** update codebooks given x and codes
282
+ *
283
+ * Let B denote the sparse matrix of codes, size [n, M * K].
284
+ * Let C denote the codebooks, size [M * K, d].
285
+ * Let X denote the training vectors, size [n, d]
286
+ *
287
+ * objective function:
288
+ * L = (X - BC)^2
289
+ *
290
+ * To minimize L, we have:
291
+ * C = (B'B)^(-1)B'X
292
+ * where ' denote transposed
293
+ *
294
+ * Add a regularization term to make B'B inversible:
295
+ * C = (B'B + lambd * I)^(-1)B'X
296
+ */
297
+ void LocalSearchQuantizer::update_codebooks(
298
+ const float* x,
299
+ const int32_t* codes,
300
+ size_t n) {
301
+ lsq_timer.start("update_codebooks");
302
+
303
+ // allocate memory
304
+ // bb = B'B, bx = BX
305
+ std::vector<float> bb(M * K * M * K, 0.0f); // [M * K, M * K]
306
+ std::vector<float> bx(M * K * d, 0.0f); // [M * K, d]
307
+
308
+ // compute B'B
309
+ for (size_t i = 0; i < n; i++) {
310
+ for (size_t m = 0; m < M; m++) {
311
+ int32_t code1 = codes[i * M + m];
312
+ int32_t idx1 = m * K + code1;
313
+ bb[idx1 * M * K + idx1] += 1;
314
+
315
+ for (size_t m2 = m + 1; m2 < M; m2++) {
316
+ int32_t code2 = codes[i * M + m2];
317
+ int32_t idx2 = m2 * K + code2;
318
+ bb[idx1 * M * K + idx2] += 1;
319
+ bb[idx2 * M * K + idx1] += 1;
320
+ }
321
+ }
322
+ }
323
+
324
+ // add a regularization term to B'B
325
+ for (int64_t i = 0; i < M * K; i++) {
326
+ bb[i * (M * K) + i] += lambd;
327
+ }
328
+
329
+ // compute (B'B)^(-1)
330
+ fmat_inverse(bb.data(), M * K); // [M*K, M*K]
331
+
332
+ // compute BX
333
+ for (size_t i = 0; i < n; i++) {
334
+ for (size_t m = 0; m < M; m++) {
335
+ int32_t code = codes[i * M + m];
336
+ float* data = bx.data() + (m * K + code) * d;
337
+ fvec_add(d, data, x + i * d, data);
338
+ }
339
+ }
340
+
341
+ // compute C = (B'B)^(-1) @ BX
342
+ //
343
+ // NOTE: LAPACK use column major order
344
+ // out = alpha * op(A) * op(B) + beta * C
345
+ FINTEGER nrows_A = d;
346
+ FINTEGER ncols_A = M * K;
347
+
348
+ FINTEGER nrows_B = M * K;
349
+ FINTEGER ncols_B = M * K;
350
+
351
+ float alpha = 1.0f;
352
+ float beta = 0.0f;
353
+ sgemm_("Not Transposed",
354
+ "Not Transposed",
355
+ &nrows_A, // nrows of op(A)
356
+ &ncols_B, // ncols of op(B)
357
+ &ncols_A, // ncols of op(A)
358
+ &alpha,
359
+ bx.data(),
360
+ &nrows_A, // nrows of A
361
+ bb.data(),
362
+ &nrows_B, // nrows of B
363
+ &beta,
364
+ codebooks.data(),
365
+ &nrows_A); // nrows of output
366
+
367
+ lsq_timer.end("update_codebooks");
368
+ }
369
+
370
+ /** encode using iterative conditional mode
371
+ *
372
+ * iterative conditional mode:
373
+ * For every subcode ci (i = 1, ..., M) of a vector, we fix the other
374
+ * subcodes cj (j != i) and then find the optimal value of ci such
375
+ * that minimizing the objective function.
376
+
377
+ * objective function:
378
+ * L = (X - \sum cj)^2, j = 1, ..., M
379
+ * L = X^2 - 2X * \sum cj + (\sum cj)^2
380
+ *
381
+ * X^2 is negligable since it is the same for all possible value
382
+ * k of the m-th subcode.
383
+ *
384
+ * 2X * \sum cj is the unary term
385
+ * (\sum cj)^2 is the binary term
386
+ * These two terms can be precomputed and store in a look up table.
387
+ */
388
+ void LocalSearchQuantizer::icm_encode(
389
+ const float* x,
390
+ int32_t* codes,
391
+ size_t n,
392
+ size_t ils_iters,
393
+ std::mt19937& gen) const {
394
+ lsq_timer.start("icm_encode");
395
+
396
+ std::vector<float> binaries(M * M * K * K); // [M, M, K, K]
397
+ compute_binary_terms(binaries.data());
398
+
399
+ const size_t n_chunks = (n + chunk_size - 1) / chunk_size;
400
+ for (size_t i = 0; i < n_chunks; i++) {
401
+ size_t ni = std::min(chunk_size, n - i * chunk_size);
402
+
403
+ if (verbose) {
404
+ printf("\r\ticm encoding %zd/%zd ...", i * chunk_size + ni, n);
405
+ fflush(stdout);
406
+ if (i == n_chunks - 1 || i == 0) {
407
+ printf("\n");
408
+ }
409
+ }
410
+
411
+ const float* xi = x + i * chunk_size * d;
412
+ int32_t* codesi = codes + i * chunk_size * M;
413
+ icm_encode_partial(i, xi, codesi, ni, binaries.data(), ils_iters, gen);
414
+ }
415
+
416
+ lsq_timer.end("icm_encode");
417
+ }
418
+
419
+ void LocalSearchQuantizer::icm_encode_partial(
420
+ size_t index,
421
+ const float* x,
422
+ int32_t* codes,
423
+ size_t n,
424
+ const float* binaries,
425
+ size_t ils_iters,
426
+ std::mt19937& gen) const {
427
+ std::vector<float> unaries(n * M * K); // [n, M, K]
428
+ compute_unary_terms(x, unaries.data(), n);
429
+
430
+ std::vector<int32_t> best_codes;
431
+ best_codes.assign(codes, codes + n * M);
432
+
433
+ std::vector<float> best_objs(n, 0.0f);
434
+ evaluate(codes, x, n, best_objs.data());
435
+
436
+ FAISS_THROW_IF_NOT(nperts <= M);
437
+ for (size_t iter1 = 0; iter1 < ils_iters; iter1++) {
438
+ // add perturbation to codes
439
+ perturb_codes(codes, n, gen);
440
+
441
+ for (size_t iter2 = 0; iter2 < icm_iters; iter2++) {
442
+ icm_encode_step(unaries.data(), binaries, codes, n);
443
+ }
444
+
445
+ std::vector<float> icm_objs(n, 0.0f);
446
+ evaluate(codes, x, n, icm_objs.data());
447
+ size_t n_betters = 0;
448
+ float mean_obj = 0.0f;
449
+
450
+ // select the best code for every vector xi
451
+ #pragma omp parallel for reduction(+ : n_betters, mean_obj)
452
+ for (int64_t i = 0; i < n; i++) {
453
+ if (icm_objs[i] < best_objs[i]) {
454
+ best_objs[i] = icm_objs[i];
455
+ memcpy(best_codes.data() + i * M,
456
+ codes + i * M,
457
+ sizeof(int32_t) * M);
458
+ n_betters += 1;
459
+ }
460
+ mean_obj += best_objs[i];
461
+ }
462
+ mean_obj /= n;
463
+
464
+ memcpy(codes, best_codes.data(), sizeof(int32_t) * n * M);
465
+
466
+ if (verbose && index == 0) {
467
+ printf("\tils_iter %zd: obj = %lf, n_betters/n = %zd/%zd\n",
468
+ iter1,
469
+ mean_obj,
470
+ n_betters,
471
+ n);
472
+ }
473
+ } // loop ils_iters
474
+ }
475
+
476
+ void LocalSearchQuantizer::icm_encode_step(
477
+ const float* unaries,
478
+ const float* binaries,
479
+ int32_t* codes,
480
+ size_t n) const {
481
+ // condition on the m-th subcode
482
+ for (size_t m = 0; m < M; m++) {
483
+ std::vector<float> objs(n * K);
484
+ #pragma omp parallel for
485
+ for (int64_t i = 0; i < n; i++) {
486
+ auto u = unaries + i * (M * K) + m * K;
487
+ memcpy(objs.data() + i * K, u, sizeof(float) * K);
488
+ }
489
+
490
+ // compute objective function by adding unary
491
+ // and binary terms together
492
+ for (size_t other_m = 0; other_m < M; other_m++) {
493
+ if (other_m == m) {
494
+ continue;
495
+ }
496
+
497
+ #pragma omp parallel for
498
+ for (int64_t i = 0; i < n; i++) {
499
+ for (int32_t code = 0; code < K; code++) {
500
+ int32_t code2 = codes[i * M + other_m];
501
+ size_t binary_idx =
502
+ m * M * K * K + other_m * K * K + code * K + code2;
503
+ // binaries[m, other_m, code, code2]
504
+ objs[i * K + code] += binaries[binary_idx];
505
+ }
506
+ }
507
+ }
508
+
509
+ // find the optimal value of the m-th subcode
510
+ #pragma omp parallel for
511
+ for (int64_t i = 0; i < n; i++) {
512
+ float best_obj = HUGE_VALF;
513
+ int32_t best_code = 0;
514
+ for (size_t code = 0; code < K; code++) {
515
+ float obj = objs[i * K + code];
516
+ if (obj < best_obj) {
517
+ best_obj = obj;
518
+ best_code = code;
519
+ }
520
+ }
521
+ codes[i * M + m] = best_code;
522
+ }
523
+
524
+ } // loop M
525
+ }
526
+
527
+ void LocalSearchQuantizer::perturb_codes(
528
+ int32_t* codes,
529
+ size_t n,
530
+ std::mt19937& gen) const {
531
+ lsq_timer.start("perturb_codes");
532
+
533
+ std::uniform_int_distribution<size_t> m_distrib(0, M - 1);
534
+ std::uniform_int_distribution<int32_t> k_distrib(0, K - 1);
535
+
536
+ for (size_t i = 0; i < n; i++) {
537
+ for (size_t j = 0; j < nperts; j++) {
538
+ size_t m = m_distrib(gen);
539
+ codes[i * M + m] = k_distrib(gen);
540
+ }
541
+ }
542
+
543
+ lsq_timer.end("perturb_codes");
544
+ }
545
+
546
+ void LocalSearchQuantizer::compute_binary_terms(float* binaries) const {
547
+ lsq_timer.start("compute_binary_terms");
548
+
549
+ #pragma omp parallel for
550
+ for (int64_t m12 = 0; m12 < M * M; m12++) {
551
+ size_t m1 = m12 / M;
552
+ size_t m2 = m12 % M;
553
+
554
+ for (size_t code1 = 0; code1 < K; code1++) {
555
+ for (size_t code2 = 0; code2 < K; code2++) {
556
+ const float* c1 = codebooks.data() + m1 * K * d + code1 * d;
557
+ const float* c2 = codebooks.data() + m2 * K * d + code2 * d;
558
+ float ip = fvec_inner_product(c1, c2, d);
559
+ // binaries[m1, m2, code1, code2] = ip * 2
560
+ binaries[m1 * M * K * K + m2 * K * K + code1 * K + code2] =
561
+ ip * 2;
562
+ }
563
+ }
564
+ }
565
+
566
+ lsq_timer.end("compute_binary_terms");
567
+ }
568
+
569
+ void LocalSearchQuantizer::compute_unary_terms(
570
+ const float* x,
571
+ float* unaries,
572
+ size_t n) const {
573
+ lsq_timer.start("compute_unary_terms");
574
+
575
+ // compute x * codebooks^T
576
+ //
577
+ // NOTE: LAPACK use column major order
578
+ // out = alpha * op(A) * op(B) + beta * C
579
+ FINTEGER nrows_A = M * K;
580
+ FINTEGER ncols_A = d;
581
+
582
+ FINTEGER nrows_B = d;
583
+ FINTEGER ncols_B = n;
584
+
585
+ float alpha = -2.0f;
586
+ float beta = 0.0f;
587
+ sgemm_("Transposed",
588
+ "Not Transposed",
589
+ &nrows_A, // nrows of op(A)
590
+ &ncols_B, // ncols of op(B)
591
+ &ncols_A, // ncols of op(A)
592
+ &alpha,
593
+ codebooks.data(),
594
+ &ncols_A, // nrows of A
595
+ x,
596
+ &nrows_B, // nrows of B
597
+ &beta,
598
+ unaries,
599
+ &nrows_A); // nrows of output
600
+
601
+ std::vector<float> norms(M * K);
602
+ fvec_norms_L2sqr(norms.data(), codebooks.data(), d, M * K);
603
+
604
+ #pragma omp parallel for
605
+ for (int64_t i = 0; i < n; i++) {
606
+ float* u = unaries + i * (M * K);
607
+ fvec_add(M * K, u, norms.data(), u);
608
+ }
609
+
610
+ lsq_timer.end("compute_unary_terms");
611
+ }
612
+
613
+ float LocalSearchQuantizer::evaluate(
614
+ const int32_t* codes,
615
+ const float* x,
616
+ size_t n,
617
+ float* objs) const {
618
+ lsq_timer.start("evaluate");
619
+
620
+ // decode
621
+ std::vector<float> decoded_x(n * d, 0.0f);
622
+ float obj = 0.0f;
623
+
624
+ #pragma omp parallel for reduction(+ : obj)
625
+ for (int64_t i = 0; i < n; i++) {
626
+ const auto code = codes + i * M;
627
+ const auto decoded_i = decoded_x.data() + i * d;
628
+ for (size_t m = 0; m < M; m++) {
629
+ // c = codebooks[m, code[m]]
630
+ const auto c = codebooks.data() + m * K * d + code[m] * d;
631
+ fvec_add(d, decoded_i, c, decoded_i);
632
+ }
633
+
634
+ float err = fvec_L2sqr(x + i * d, decoded_i, d);
635
+ obj += err;
636
+
637
+ if (objs) {
638
+ objs[i] = err;
639
+ }
640
+ }
641
+
642
+ lsq_timer.end("evaluate");
643
+
644
+ obj = obj / n;
645
+ return obj;
646
+ }
647
+
648
+ double LSQTimer::get(const std::string& name) {
649
+ return duration[name];
650
+ }
651
+
652
+ void LSQTimer::start(const std::string& name) {
653
+ FAISS_THROW_IF_NOT_MSG(!started[name], " timer is already running");
654
+ started[name] = true;
655
+ t0[name] = getmillisecs();
656
+ }
657
+
658
+ void LSQTimer::end(const std::string& name) {
659
+ FAISS_THROW_IF_NOT_MSG(started[name], " timer is not running");
660
+ double t1 = getmillisecs();
661
+ double sec = (t1 - t0[name]) / 1000;
662
+ duration[name] += sec;
663
+ started[name] = false;
664
+ }
665
+
666
+ void LSQTimer::reset() {
667
+ duration.clear();
668
+ t0.clear();
669
+ started.clear();
670
+ }
671
+
672
+ } // namespace faiss