faiss 0.2.0 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (202) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  5. data/vendor/faiss/faiss/AutoTune.h +55 -56
  6. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  7. data/vendor/faiss/faiss/Clustering.h +88 -35
  8. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  9. data/vendor/faiss/faiss/IVFlib.h +48 -51
  10. data/vendor/faiss/faiss/Index.cpp +85 -103
  11. data/vendor/faiss/faiss/Index.h +54 -48
  12. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  13. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  14. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  15. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  16. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  17. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  18. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  19. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  20. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  22. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  23. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  24. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  25. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  26. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  27. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  29. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  30. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  31. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  32. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  33. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  34. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  35. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  38. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  39. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  40. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  41. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  42. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  43. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  44. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  45. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  46. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  47. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  48. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  49. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  50. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  51. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  52. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  53. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  54. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  55. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  56. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  57. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  58. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  59. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  60. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  61. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  62. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  63. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  64. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  65. data/vendor/faiss/faiss/IndexShards.h +85 -73
  66. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  67. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  68. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  69. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  70. data/vendor/faiss/faiss/MetricType.h +7 -7
  71. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  72. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  73. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  74. data/vendor/faiss/faiss/clone_index.h +4 -9
  75. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  76. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  77. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  78. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  79. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  82. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  84. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  85. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  89. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  90. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  91. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  92. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  93. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  94. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  95. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  96. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  97. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  98. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  99. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  100. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  101. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  102. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  103. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  104. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  105. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  106. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  107. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  108. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  110. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  111. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  112. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  113. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  114. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  115. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  116. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  117. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  118. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  119. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  125. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  126. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  127. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  128. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  134. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  135. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  136. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  137. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  138. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  139. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  141. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  142. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  144. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  145. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  146. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  147. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  148. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  149. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  150. data/vendor/faiss/faiss/impl/io.h +31 -41
  151. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  152. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  153. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  154. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  159. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  160. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  161. data/vendor/faiss/faiss/index_factory.h +6 -7
  162. data/vendor/faiss/faiss/index_io.h +23 -26
  163. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  164. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  165. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  166. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  167. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  168. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  169. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  170. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  172. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  173. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  174. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  175. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  176. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  177. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  178. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  179. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  180. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  181. data/vendor/faiss/faiss/utils/distances.h +133 -118
  182. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  183. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  184. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  185. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  186. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  187. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  188. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  189. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  190. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  191. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  192. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  193. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  194. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  195. data/vendor/faiss/faiss/utils/random.h +13 -16
  196. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  197. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  198. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  199. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  200. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  201. data/vendor/faiss/faiss/utils/utils.h +53 -48
  202. metadata +20 -2
@@ -0,0 +1,672 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/impl/FaissAssert.h>
11
+ #include <faiss/impl/LocalSearchQuantizer.h>
12
+
13
+ #include <cstddef>
14
+ #include <cstdio>
15
+ #include <cstring>
16
+ #include <memory>
17
+ #include <random>
18
+
19
+ #include <algorithm>
20
+
21
+ #include <faiss/utils/distances.h>
22
+ #include <faiss/utils/hamming.h> // BitstringWriter
23
+ #include <faiss/utils/utils.h>
24
+
25
+ extern "C" {
26
+ // LU decomoposition of a general matrix
27
+ void sgetrf_(
28
+ FINTEGER* m,
29
+ FINTEGER* n,
30
+ float* a,
31
+ FINTEGER* lda,
32
+ FINTEGER* ipiv,
33
+ FINTEGER* info);
34
+
35
+ // generate inverse of a matrix given its LU decomposition
36
+ void sgetri_(
37
+ FINTEGER* n,
38
+ float* a,
39
+ FINTEGER* lda,
40
+ FINTEGER* ipiv,
41
+ float* work,
42
+ FINTEGER* lwork,
43
+ FINTEGER* info);
44
+
45
+ // solves a system of linear equations
46
+ void sgetrs_(
47
+ const char* trans,
48
+ FINTEGER* n,
49
+ FINTEGER* nrhs,
50
+ float* A,
51
+ FINTEGER* lda,
52
+ FINTEGER* ipiv,
53
+ float* b,
54
+ FINTEGER* ldb,
55
+ FINTEGER* info);
56
+
57
+ // general matrix multiplication
58
+ int sgemm_(
59
+ const char* transa,
60
+ const char* transb,
61
+ FINTEGER* m,
62
+ FINTEGER* n,
63
+ FINTEGER* k,
64
+ const float* alpha,
65
+ const float* a,
66
+ FINTEGER* lda,
67
+ const float* b,
68
+ FINTEGER* ldb,
69
+ float* beta,
70
+ float* c,
71
+ FINTEGER* ldc);
72
+ }
73
+
74
+ namespace {
75
+
76
+ // c and a and b can overlap
77
+ void fvec_add(size_t d, const float* a, const float* b, float* c) {
78
+ for (size_t i = 0; i < d; i++) {
79
+ c[i] = a[i] + b[i];
80
+ }
81
+ }
82
+
83
+ void fmat_inverse(float* a, int n) {
84
+ int info;
85
+ int lwork = n * n;
86
+ std::vector<int> ipiv(n);
87
+ std::vector<float> workspace(lwork);
88
+
89
+ sgetrf_(&n, &n, a, &n, ipiv.data(), &info);
90
+ FAISS_THROW_IF_NOT(info == 0);
91
+ sgetri_(&n, a, &n, ipiv.data(), workspace.data(), &lwork, &info);
92
+ FAISS_THROW_IF_NOT(info == 0);
93
+ }
94
+
95
+ void random_int32(
96
+ std::vector<int32_t>& x,
97
+ int32_t min,
98
+ int32_t max,
99
+ std::mt19937& gen) {
100
+ std::uniform_int_distribution<int32_t> distrib(min, max);
101
+ for (size_t i = 0; i < x.size(); i++) {
102
+ x[i] = distrib(gen);
103
+ }
104
+ }
105
+
106
+ } // anonymous namespace
107
+
108
+ namespace faiss {
109
+
110
+ LSQTimer lsq_timer;
111
+
112
+ LocalSearchQuantizer::LocalSearchQuantizer(size_t d, size_t M, size_t nbits) {
113
+ FAISS_THROW_IF_NOT((M * nbits) % 8 == 0);
114
+
115
+ this->d = d;
116
+ this->M = M;
117
+ this->nbits = std::vector<size_t>(M, nbits);
118
+
119
+ // set derived values
120
+ set_derived_values();
121
+
122
+ is_trained = false;
123
+ verbose = false;
124
+
125
+ K = (1 << nbits);
126
+
127
+ train_iters = 25;
128
+ train_ils_iters = 8;
129
+ icm_iters = 4;
130
+
131
+ encode_ils_iters = 16;
132
+
133
+ p = 0.5f;
134
+ lambd = 1e-2f;
135
+
136
+ chunk_size = 10000;
137
+ nperts = 4;
138
+
139
+ random_seed = 0x12345;
140
+ std::srand(random_seed);
141
+ }
142
+
143
+ void LocalSearchQuantizer::train(size_t n, const float* x) {
144
+ FAISS_THROW_IF_NOT(K == (1 << nbits[0]));
145
+ FAISS_THROW_IF_NOT(nperts <= M);
146
+
147
+ lsq_timer.reset();
148
+ if (verbose) {
149
+ lsq_timer.start("train");
150
+ printf("Training LSQ, with %zd subcodes on %zd %zdD vectors\n",
151
+ M,
152
+ n,
153
+ d);
154
+ }
155
+
156
+ // allocate memory for codebooks, size [M, K, d]
157
+ codebooks.resize(M * K * d);
158
+
159
+ // randomly intialize codes
160
+ std::mt19937 gen(random_seed);
161
+ std::vector<int32_t> codes(n * M); // [n, M]
162
+ random_int32(codes, 0, K - 1, gen);
163
+
164
+ // compute standard derivations of each dimension
165
+ std::vector<float> stddev(d, 0);
166
+
167
+ #pragma omp parallel for
168
+ for (int64_t i = 0; i < d; i++) {
169
+ float mean = 0;
170
+ for (size_t j = 0; j < n; j++) {
171
+ mean += x[j * d + i];
172
+ }
173
+ mean = mean / n;
174
+
175
+ float sum = 0;
176
+ for (size_t j = 0; j < n; j++) {
177
+ float xi = x[j * d + i] - mean;
178
+ sum += xi * xi;
179
+ }
180
+ stddev[i] = sqrtf(sum / n);
181
+ }
182
+
183
+ if (verbose) {
184
+ float obj = evaluate(codes.data(), x, n);
185
+ printf("Before training: obj = %lf\n", obj);
186
+ }
187
+
188
+ for (size_t i = 0; i < train_iters; i++) {
189
+ // 1. update codebooks given x and codes
190
+ // 2. add perturbation to codebooks (SR-D)
191
+ // 3. refine codes given x and codebooks using icm
192
+
193
+ // update codebooks
194
+ update_codebooks(x, codes.data(), n);
195
+
196
+ if (verbose) {
197
+ float obj = evaluate(codes.data(), x, n);
198
+ printf("iter %zd:\n", i);
199
+ printf("\tafter updating codebooks: obj = %lf\n", obj);
200
+ }
201
+
202
+ // SR-D: perturb codebooks
203
+ float T = pow((1.0f - (i + 1.0f) / train_iters), p);
204
+ perturb_codebooks(T, stddev, gen);
205
+
206
+ if (verbose) {
207
+ float obj = evaluate(codes.data(), x, n);
208
+ printf("\tafter perturbing codebooks: obj = %lf\n", obj);
209
+ }
210
+
211
+ // refine codes
212
+ icm_encode(x, codes.data(), n, train_ils_iters, gen);
213
+
214
+ if (verbose) {
215
+ float obj = evaluate(codes.data(), x, n);
216
+ printf("\tafter updating codes: obj = %lf\n", obj);
217
+ }
218
+ }
219
+
220
+ if (verbose) {
221
+ lsq_timer.end("train");
222
+ float obj = evaluate(codes.data(), x, n);
223
+ printf("After training: obj = %lf\n", obj);
224
+
225
+ printf("Time statistic:\n");
226
+ for (const auto& it : lsq_timer.duration) {
227
+ printf("\t%s time: %lf s\n", it.first.data(), it.second);
228
+ }
229
+ }
230
+
231
+ is_trained = true;
232
+ }
233
+
234
+ void LocalSearchQuantizer::perturb_codebooks(
235
+ float T,
236
+ const std::vector<float>& stddev,
237
+ std::mt19937& gen) {
238
+ lsq_timer.start("perturb_codebooks");
239
+
240
+ std::vector<std::normal_distribution<float>> distribs;
241
+ for (size_t i = 0; i < d; i++) {
242
+ distribs.emplace_back(0.0f, stddev[i]);
243
+ }
244
+
245
+ for (size_t m = 0; m < M; m++) {
246
+ for (size_t k = 0; k < K; k++) {
247
+ for (size_t i = 0; i < d; i++) {
248
+ codebooks[m * K * d + k * d + i] += T * distribs[i](gen) / M;
249
+ }
250
+ }
251
+ }
252
+
253
+ lsq_timer.end("perturb_codebooks");
254
+ }
255
+
256
+ void LocalSearchQuantizer::compute_codes(
257
+ const float* x,
258
+ uint8_t* codes_out,
259
+ size_t n) const {
260
+ FAISS_THROW_IF_NOT_MSG(is_trained, "LSQ is not trained yet.");
261
+ if (verbose) {
262
+ lsq_timer.reset();
263
+ printf("Encoding %zd vectors...\n", n);
264
+ lsq_timer.start("encode");
265
+ }
266
+
267
+ std::vector<int32_t> codes(n * M);
268
+ std::mt19937 gen(random_seed);
269
+ random_int32(codes, 0, K - 1, gen);
270
+
271
+ icm_encode(x, codes.data(), n, encode_ils_iters, gen);
272
+ pack_codes(n, codes.data(), codes_out);
273
+
274
+ if (verbose) {
275
+ lsq_timer.end("encode");
276
+ double t = lsq_timer.get("encode");
277
+ printf("Time to encode %zd vectors: %lf s\n", n, t);
278
+ }
279
+ }
280
+
281
+ /** update codebooks given x and codes
282
+ *
283
+ * Let B denote the sparse matrix of codes, size [n, M * K].
284
+ * Let C denote the codebooks, size [M * K, d].
285
+ * Let X denote the training vectors, size [n, d]
286
+ *
287
+ * objective function:
288
+ * L = (X - BC)^2
289
+ *
290
+ * To minimize L, we have:
291
+ * C = (B'B)^(-1)B'X
292
+ * where ' denote transposed
293
+ *
294
+ * Add a regularization term to make B'B inversible:
295
+ * C = (B'B + lambd * I)^(-1)B'X
296
+ */
297
+ void LocalSearchQuantizer::update_codebooks(
298
+ const float* x,
299
+ const int32_t* codes,
300
+ size_t n) {
301
+ lsq_timer.start("update_codebooks");
302
+
303
+ // allocate memory
304
+ // bb = B'B, bx = BX
305
+ std::vector<float> bb(M * K * M * K, 0.0f); // [M * K, M * K]
306
+ std::vector<float> bx(M * K * d, 0.0f); // [M * K, d]
307
+
308
+ // compute B'B
309
+ for (size_t i = 0; i < n; i++) {
310
+ for (size_t m = 0; m < M; m++) {
311
+ int32_t code1 = codes[i * M + m];
312
+ int32_t idx1 = m * K + code1;
313
+ bb[idx1 * M * K + idx1] += 1;
314
+
315
+ for (size_t m2 = m + 1; m2 < M; m2++) {
316
+ int32_t code2 = codes[i * M + m2];
317
+ int32_t idx2 = m2 * K + code2;
318
+ bb[idx1 * M * K + idx2] += 1;
319
+ bb[idx2 * M * K + idx1] += 1;
320
+ }
321
+ }
322
+ }
323
+
324
+ // add a regularization term to B'B
325
+ for (int64_t i = 0; i < M * K; i++) {
326
+ bb[i * (M * K) + i] += lambd;
327
+ }
328
+
329
+ // compute (B'B)^(-1)
330
+ fmat_inverse(bb.data(), M * K); // [M*K, M*K]
331
+
332
+ // compute BX
333
+ for (size_t i = 0; i < n; i++) {
334
+ for (size_t m = 0; m < M; m++) {
335
+ int32_t code = codes[i * M + m];
336
+ float* data = bx.data() + (m * K + code) * d;
337
+ fvec_add(d, data, x + i * d, data);
338
+ }
339
+ }
340
+
341
+ // compute C = (B'B)^(-1) @ BX
342
+ //
343
+ // NOTE: LAPACK use column major order
344
+ // out = alpha * op(A) * op(B) + beta * C
345
+ FINTEGER nrows_A = d;
346
+ FINTEGER ncols_A = M * K;
347
+
348
+ FINTEGER nrows_B = M * K;
349
+ FINTEGER ncols_B = M * K;
350
+
351
+ float alpha = 1.0f;
352
+ float beta = 0.0f;
353
+ sgemm_("Not Transposed",
354
+ "Not Transposed",
355
+ &nrows_A, // nrows of op(A)
356
+ &ncols_B, // ncols of op(B)
357
+ &ncols_A, // ncols of op(A)
358
+ &alpha,
359
+ bx.data(),
360
+ &nrows_A, // nrows of A
361
+ bb.data(),
362
+ &nrows_B, // nrows of B
363
+ &beta,
364
+ codebooks.data(),
365
+ &nrows_A); // nrows of output
366
+
367
+ lsq_timer.end("update_codebooks");
368
+ }
369
+
370
+ /** encode using iterative conditional mode
371
+ *
372
+ * iterative conditional mode:
373
+ * For every subcode ci (i = 1, ..., M) of a vector, we fix the other
374
+ * subcodes cj (j != i) and then find the optimal value of ci such
375
+ * that minimizing the objective function.
376
+
377
+ * objective function:
378
+ * L = (X - \sum cj)^2, j = 1, ..., M
379
+ * L = X^2 - 2X * \sum cj + (\sum cj)^2
380
+ *
381
+ * X^2 is negligable since it is the same for all possible value
382
+ * k of the m-th subcode.
383
+ *
384
+ * 2X * \sum cj is the unary term
385
+ * (\sum cj)^2 is the binary term
386
+ * These two terms can be precomputed and store in a look up table.
387
+ */
388
+ void LocalSearchQuantizer::icm_encode(
389
+ const float* x,
390
+ int32_t* codes,
391
+ size_t n,
392
+ size_t ils_iters,
393
+ std::mt19937& gen) const {
394
+ lsq_timer.start("icm_encode");
395
+
396
+ std::vector<float> binaries(M * M * K * K); // [M, M, K, K]
397
+ compute_binary_terms(binaries.data());
398
+
399
+ const size_t n_chunks = (n + chunk_size - 1) / chunk_size;
400
+ for (size_t i = 0; i < n_chunks; i++) {
401
+ size_t ni = std::min(chunk_size, n - i * chunk_size);
402
+
403
+ if (verbose) {
404
+ printf("\r\ticm encoding %zd/%zd ...", i * chunk_size + ni, n);
405
+ fflush(stdout);
406
+ if (i == n_chunks - 1 || i == 0) {
407
+ printf("\n");
408
+ }
409
+ }
410
+
411
+ const float* xi = x + i * chunk_size * d;
412
+ int32_t* codesi = codes + i * chunk_size * M;
413
+ icm_encode_partial(i, xi, codesi, ni, binaries.data(), ils_iters, gen);
414
+ }
415
+
416
+ lsq_timer.end("icm_encode");
417
+ }
418
+
419
+ void LocalSearchQuantizer::icm_encode_partial(
420
+ size_t index,
421
+ const float* x,
422
+ int32_t* codes,
423
+ size_t n,
424
+ const float* binaries,
425
+ size_t ils_iters,
426
+ std::mt19937& gen) const {
427
+ std::vector<float> unaries(n * M * K); // [n, M, K]
428
+ compute_unary_terms(x, unaries.data(), n);
429
+
430
+ std::vector<int32_t> best_codes;
431
+ best_codes.assign(codes, codes + n * M);
432
+
433
+ std::vector<float> best_objs(n, 0.0f);
434
+ evaluate(codes, x, n, best_objs.data());
435
+
436
+ FAISS_THROW_IF_NOT(nperts <= M);
437
+ for (size_t iter1 = 0; iter1 < ils_iters; iter1++) {
438
+ // add perturbation to codes
439
+ perturb_codes(codes, n, gen);
440
+
441
+ for (size_t iter2 = 0; iter2 < icm_iters; iter2++) {
442
+ icm_encode_step(unaries.data(), binaries, codes, n);
443
+ }
444
+
445
+ std::vector<float> icm_objs(n, 0.0f);
446
+ evaluate(codes, x, n, icm_objs.data());
447
+ size_t n_betters = 0;
448
+ float mean_obj = 0.0f;
449
+
450
+ // select the best code for every vector xi
451
+ #pragma omp parallel for reduction(+ : n_betters, mean_obj)
452
+ for (int64_t i = 0; i < n; i++) {
453
+ if (icm_objs[i] < best_objs[i]) {
454
+ best_objs[i] = icm_objs[i];
455
+ memcpy(best_codes.data() + i * M,
456
+ codes + i * M,
457
+ sizeof(int32_t) * M);
458
+ n_betters += 1;
459
+ }
460
+ mean_obj += best_objs[i];
461
+ }
462
+ mean_obj /= n;
463
+
464
+ memcpy(codes, best_codes.data(), sizeof(int32_t) * n * M);
465
+
466
+ if (verbose && index == 0) {
467
+ printf("\tils_iter %zd: obj = %lf, n_betters/n = %zd/%zd\n",
468
+ iter1,
469
+ mean_obj,
470
+ n_betters,
471
+ n);
472
+ }
473
+ } // loop ils_iters
474
+ }
475
+
476
+ void LocalSearchQuantizer::icm_encode_step(
477
+ const float* unaries,
478
+ const float* binaries,
479
+ int32_t* codes,
480
+ size_t n) const {
481
+ // condition on the m-th subcode
482
+ for (size_t m = 0; m < M; m++) {
483
+ std::vector<float> objs(n * K);
484
+ #pragma omp parallel for
485
+ for (int64_t i = 0; i < n; i++) {
486
+ auto u = unaries + i * (M * K) + m * K;
487
+ memcpy(objs.data() + i * K, u, sizeof(float) * K);
488
+ }
489
+
490
+ // compute objective function by adding unary
491
+ // and binary terms together
492
+ for (size_t other_m = 0; other_m < M; other_m++) {
493
+ if (other_m == m) {
494
+ continue;
495
+ }
496
+
497
+ #pragma omp parallel for
498
+ for (int64_t i = 0; i < n; i++) {
499
+ for (int32_t code = 0; code < K; code++) {
500
+ int32_t code2 = codes[i * M + other_m];
501
+ size_t binary_idx =
502
+ m * M * K * K + other_m * K * K + code * K + code2;
503
+ // binaries[m, other_m, code, code2]
504
+ objs[i * K + code] += binaries[binary_idx];
505
+ }
506
+ }
507
+ }
508
+
509
+ // find the optimal value of the m-th subcode
510
+ #pragma omp parallel for
511
+ for (int64_t i = 0; i < n; i++) {
512
+ float best_obj = HUGE_VALF;
513
+ int32_t best_code = 0;
514
+ for (size_t code = 0; code < K; code++) {
515
+ float obj = objs[i * K + code];
516
+ if (obj < best_obj) {
517
+ best_obj = obj;
518
+ best_code = code;
519
+ }
520
+ }
521
+ codes[i * M + m] = best_code;
522
+ }
523
+
524
+ } // loop M
525
+ }
526
+
527
+ void LocalSearchQuantizer::perturb_codes(
528
+ int32_t* codes,
529
+ size_t n,
530
+ std::mt19937& gen) const {
531
+ lsq_timer.start("perturb_codes");
532
+
533
+ std::uniform_int_distribution<size_t> m_distrib(0, M - 1);
534
+ std::uniform_int_distribution<int32_t> k_distrib(0, K - 1);
535
+
536
+ for (size_t i = 0; i < n; i++) {
537
+ for (size_t j = 0; j < nperts; j++) {
538
+ size_t m = m_distrib(gen);
539
+ codes[i * M + m] = k_distrib(gen);
540
+ }
541
+ }
542
+
543
+ lsq_timer.end("perturb_codes");
544
+ }
545
+
546
+ void LocalSearchQuantizer::compute_binary_terms(float* binaries) const {
547
+ lsq_timer.start("compute_binary_terms");
548
+
549
+ #pragma omp parallel for
550
+ for (int64_t m12 = 0; m12 < M * M; m12++) {
551
+ size_t m1 = m12 / M;
552
+ size_t m2 = m12 % M;
553
+
554
+ for (size_t code1 = 0; code1 < K; code1++) {
555
+ for (size_t code2 = 0; code2 < K; code2++) {
556
+ const float* c1 = codebooks.data() + m1 * K * d + code1 * d;
557
+ const float* c2 = codebooks.data() + m2 * K * d + code2 * d;
558
+ float ip = fvec_inner_product(c1, c2, d);
559
+ // binaries[m1, m2, code1, code2] = ip * 2
560
+ binaries[m1 * M * K * K + m2 * K * K + code1 * K + code2] =
561
+ ip * 2;
562
+ }
563
+ }
564
+ }
565
+
566
+ lsq_timer.end("compute_binary_terms");
567
+ }
568
+
569
+ void LocalSearchQuantizer::compute_unary_terms(
570
+ const float* x,
571
+ float* unaries,
572
+ size_t n) const {
573
+ lsq_timer.start("compute_unary_terms");
574
+
575
+ // compute x * codebooks^T
576
+ //
577
+ // NOTE: LAPACK use column major order
578
+ // out = alpha * op(A) * op(B) + beta * C
579
+ FINTEGER nrows_A = M * K;
580
+ FINTEGER ncols_A = d;
581
+
582
+ FINTEGER nrows_B = d;
583
+ FINTEGER ncols_B = n;
584
+
585
+ float alpha = -2.0f;
586
+ float beta = 0.0f;
587
+ sgemm_("Transposed",
588
+ "Not Transposed",
589
+ &nrows_A, // nrows of op(A)
590
+ &ncols_B, // ncols of op(B)
591
+ &ncols_A, // ncols of op(A)
592
+ &alpha,
593
+ codebooks.data(),
594
+ &ncols_A, // nrows of A
595
+ x,
596
+ &nrows_B, // nrows of B
597
+ &beta,
598
+ unaries,
599
+ &nrows_A); // nrows of output
600
+
601
+ std::vector<float> norms(M * K);
602
+ fvec_norms_L2sqr(norms.data(), codebooks.data(), d, M * K);
603
+
604
+ #pragma omp parallel for
605
+ for (int64_t i = 0; i < n; i++) {
606
+ float* u = unaries + i * (M * K);
607
+ fvec_add(M * K, u, norms.data(), u);
608
+ }
609
+
610
+ lsq_timer.end("compute_unary_terms");
611
+ }
612
+
613
+ float LocalSearchQuantizer::evaluate(
614
+ const int32_t* codes,
615
+ const float* x,
616
+ size_t n,
617
+ float* objs) const {
618
+ lsq_timer.start("evaluate");
619
+
620
+ // decode
621
+ std::vector<float> decoded_x(n * d, 0.0f);
622
+ float obj = 0.0f;
623
+
624
+ #pragma omp parallel for reduction(+ : obj)
625
+ for (int64_t i = 0; i < n; i++) {
626
+ const auto code = codes + i * M;
627
+ const auto decoded_i = decoded_x.data() + i * d;
628
+ for (size_t m = 0; m < M; m++) {
629
+ // c = codebooks[m, code[m]]
630
+ const auto c = codebooks.data() + m * K * d + code[m] * d;
631
+ fvec_add(d, decoded_i, c, decoded_i);
632
+ }
633
+
634
+ float err = fvec_L2sqr(x + i * d, decoded_i, d);
635
+ obj += err;
636
+
637
+ if (objs) {
638
+ objs[i] = err;
639
+ }
640
+ }
641
+
642
+ lsq_timer.end("evaluate");
643
+
644
+ obj = obj / n;
645
+ return obj;
646
+ }
647
+
648
+ double LSQTimer::get(const std::string& name) {
649
+ return duration[name];
650
+ }
651
+
652
+ void LSQTimer::start(const std::string& name) {
653
+ FAISS_THROW_IF_NOT_MSG(!started[name], " timer is already running");
654
+ started[name] = true;
655
+ t0[name] = getmillisecs();
656
+ }
657
+
658
+ void LSQTimer::end(const std::string& name) {
659
+ FAISS_THROW_IF_NOT_MSG(started[name], " timer is not running");
660
+ double t1 = getmillisecs();
661
+ double sec = (t1 - t0[name]) / 1000;
662
+ duration[name] += sec;
663
+ started[name] = false;
664
+ }
665
+
666
+ void LSQTimer::reset() {
667
+ duration.clear();
668
+ t0.clear();
669
+ started.clear();
670
+ }
671
+
672
+ } // namespace faiss