faiss 0.1.5 → 0.2.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +24 -0
  3. data/README.md +12 -0
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +6 -2
  6. data/ext/faiss/index.cpp +114 -43
  7. data/ext/faiss/index_binary.cpp +24 -30
  8. data/ext/faiss/kmeans.cpp +20 -16
  9. data/ext/faiss/numo.hpp +867 -0
  10. data/ext/faiss/pca_matrix.cpp +13 -14
  11. data/ext/faiss/product_quantizer.cpp +23 -24
  12. data/ext/faiss/utils.cpp +10 -37
  13. data/ext/faiss/utils.h +2 -13
  14. data/lib/faiss.rb +0 -5
  15. data/lib/faiss/version.rb +1 -1
  16. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  17. data/vendor/faiss/faiss/AutoTune.h +55 -56
  18. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  19. data/vendor/faiss/faiss/Clustering.h +88 -35
  20. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  21. data/vendor/faiss/faiss/IVFlib.h +48 -51
  22. data/vendor/faiss/faiss/Index.cpp +85 -103
  23. data/vendor/faiss/faiss/Index.h +54 -48
  24. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  25. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  26. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  27. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  28. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  29. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  30. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  31. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  32. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  33. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  34. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  35. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  36. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  37. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  38. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  39. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  40. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  41. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  42. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  43. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  54. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  55. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  69. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  73. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  76. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  77. data/vendor/faiss/faiss/IndexShards.h +85 -73
  78. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  79. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  81. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  82. data/vendor/faiss/faiss/MetricType.h +7 -7
  83. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  84. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  85. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  86. data/vendor/faiss/faiss/clone_index.h +4 -9
  87. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  88. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  89. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  90. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  91. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  96. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  102. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  103. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  104. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  106. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  108. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  110. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  112. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  113. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  114. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  115. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  116. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  121. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  122. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  124. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  125. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  126. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  128. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  129. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  130. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  131. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  133. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  135. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  136. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  137. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  138. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  139. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  140. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  142. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  144. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  145. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  146. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  148. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  149. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  151. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  153. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  154. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  156. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  157. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  158. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  159. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  160. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  161. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  162. data/vendor/faiss/faiss/impl/io.h +31 -41
  163. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  164. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  165. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  166. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  167. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  171. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  172. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  173. data/vendor/faiss/faiss/index_factory.h +6 -7
  174. data/vendor/faiss/faiss/index_io.h +23 -26
  175. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  177. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  178. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  179. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  180. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  181. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  183. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  185. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  186. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  187. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  188. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  189. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  190. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  191. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  192. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  193. data/vendor/faiss/faiss/utils/distances.h +133 -118
  194. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  195. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  196. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  197. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  198. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  199. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  200. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  201. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  202. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  203. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  204. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  205. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  206. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  207. data/vendor/faiss/faiss/utils/random.h +13 -16
  208. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  209. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  210. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  211. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  212. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  213. data/vendor/faiss/faiss/utils/utils.h +53 -48
  214. metadata +24 -10
  215. data/lib/faiss/index.rb +0 -20
  216. data/lib/faiss/index_binary.rb +0 -20
  217. data/lib/faiss/kmeans.rb +0 -15
  218. data/lib/faiss/pca_matrix.rb +0 -15
  219. data/lib/faiss/product_quantizer.rb +0 -22
@@ -0,0 +1,672 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/impl/FaissAssert.h>
11
+ #include <faiss/impl/LocalSearchQuantizer.h>
12
+
13
+ #include <cstddef>
14
+ #include <cstdio>
15
+ #include <cstring>
16
+ #include <memory>
17
+ #include <random>
18
+
19
+ #include <algorithm>
20
+
21
+ #include <faiss/utils/distances.h>
22
+ #include <faiss/utils/hamming.h> // BitstringWriter
23
+ #include <faiss/utils/utils.h>
24
+
25
+ extern "C" {
26
+ // LU decomoposition of a general matrix
27
+ void sgetrf_(
28
+ FINTEGER* m,
29
+ FINTEGER* n,
30
+ float* a,
31
+ FINTEGER* lda,
32
+ FINTEGER* ipiv,
33
+ FINTEGER* info);
34
+
35
+ // generate inverse of a matrix given its LU decomposition
36
+ void sgetri_(
37
+ FINTEGER* n,
38
+ float* a,
39
+ FINTEGER* lda,
40
+ FINTEGER* ipiv,
41
+ float* work,
42
+ FINTEGER* lwork,
43
+ FINTEGER* info);
44
+
45
+ // solves a system of linear equations
46
+ void sgetrs_(
47
+ const char* trans,
48
+ FINTEGER* n,
49
+ FINTEGER* nrhs,
50
+ float* A,
51
+ FINTEGER* lda,
52
+ FINTEGER* ipiv,
53
+ float* b,
54
+ FINTEGER* ldb,
55
+ FINTEGER* info);
56
+
57
+ // general matrix multiplication
58
+ int sgemm_(
59
+ const char* transa,
60
+ const char* transb,
61
+ FINTEGER* m,
62
+ FINTEGER* n,
63
+ FINTEGER* k,
64
+ const float* alpha,
65
+ const float* a,
66
+ FINTEGER* lda,
67
+ const float* b,
68
+ FINTEGER* ldb,
69
+ float* beta,
70
+ float* c,
71
+ FINTEGER* ldc);
72
+ }
73
+
74
+ namespace {
75
+
76
+ // c and a and b can overlap
77
+ void fvec_add(size_t d, const float* a, const float* b, float* c) {
78
+ for (size_t i = 0; i < d; i++) {
79
+ c[i] = a[i] + b[i];
80
+ }
81
+ }
82
+
83
+ void fmat_inverse(float* a, int n) {
84
+ int info;
85
+ int lwork = n * n;
86
+ std::vector<int> ipiv(n);
87
+ std::vector<float> workspace(lwork);
88
+
89
+ sgetrf_(&n, &n, a, &n, ipiv.data(), &info);
90
+ FAISS_THROW_IF_NOT(info == 0);
91
+ sgetri_(&n, a, &n, ipiv.data(), workspace.data(), &lwork, &info);
92
+ FAISS_THROW_IF_NOT(info == 0);
93
+ }
94
+
95
+ void random_int32(
96
+ std::vector<int32_t>& x,
97
+ int32_t min,
98
+ int32_t max,
99
+ std::mt19937& gen) {
100
+ std::uniform_int_distribution<int32_t> distrib(min, max);
101
+ for (size_t i = 0; i < x.size(); i++) {
102
+ x[i] = distrib(gen);
103
+ }
104
+ }
105
+
106
+ } // anonymous namespace
107
+
108
+ namespace faiss {
109
+
110
+ LSQTimer lsq_timer;
111
+
112
+ LocalSearchQuantizer::LocalSearchQuantizer(size_t d, size_t M, size_t nbits) {
113
+ FAISS_THROW_IF_NOT((M * nbits) % 8 == 0);
114
+
115
+ this->d = d;
116
+ this->M = M;
117
+ this->nbits = std::vector<size_t>(M, nbits);
118
+
119
+ // set derived values
120
+ set_derived_values();
121
+
122
+ is_trained = false;
123
+ verbose = false;
124
+
125
+ K = (1 << nbits);
126
+
127
+ train_iters = 25;
128
+ train_ils_iters = 8;
129
+ icm_iters = 4;
130
+
131
+ encode_ils_iters = 16;
132
+
133
+ p = 0.5f;
134
+ lambd = 1e-2f;
135
+
136
+ chunk_size = 10000;
137
+ nperts = 4;
138
+
139
+ random_seed = 0x12345;
140
+ std::srand(random_seed);
141
+ }
142
+
143
+ void LocalSearchQuantizer::train(size_t n, const float* x) {
144
+ FAISS_THROW_IF_NOT(K == (1 << nbits[0]));
145
+ FAISS_THROW_IF_NOT(nperts <= M);
146
+
147
+ lsq_timer.reset();
148
+ if (verbose) {
149
+ lsq_timer.start("train");
150
+ printf("Training LSQ, with %zd subcodes on %zd %zdD vectors\n",
151
+ M,
152
+ n,
153
+ d);
154
+ }
155
+
156
+ // allocate memory for codebooks, size [M, K, d]
157
+ codebooks.resize(M * K * d);
158
+
159
+ // randomly intialize codes
160
+ std::mt19937 gen(random_seed);
161
+ std::vector<int32_t> codes(n * M); // [n, M]
162
+ random_int32(codes, 0, K - 1, gen);
163
+
164
+ // compute standard derivations of each dimension
165
+ std::vector<float> stddev(d, 0);
166
+
167
+ #pragma omp parallel for
168
+ for (int64_t i = 0; i < d; i++) {
169
+ float mean = 0;
170
+ for (size_t j = 0; j < n; j++) {
171
+ mean += x[j * d + i];
172
+ }
173
+ mean = mean / n;
174
+
175
+ float sum = 0;
176
+ for (size_t j = 0; j < n; j++) {
177
+ float xi = x[j * d + i] - mean;
178
+ sum += xi * xi;
179
+ }
180
+ stddev[i] = sqrtf(sum / n);
181
+ }
182
+
183
+ if (verbose) {
184
+ float obj = evaluate(codes.data(), x, n);
185
+ printf("Before training: obj = %lf\n", obj);
186
+ }
187
+
188
+ for (size_t i = 0; i < train_iters; i++) {
189
+ // 1. update codebooks given x and codes
190
+ // 2. add perturbation to codebooks (SR-D)
191
+ // 3. refine codes given x and codebooks using icm
192
+
193
+ // update codebooks
194
+ update_codebooks(x, codes.data(), n);
195
+
196
+ if (verbose) {
197
+ float obj = evaluate(codes.data(), x, n);
198
+ printf("iter %zd:\n", i);
199
+ printf("\tafter updating codebooks: obj = %lf\n", obj);
200
+ }
201
+
202
+ // SR-D: perturb codebooks
203
+ float T = pow((1.0f - (i + 1.0f) / train_iters), p);
204
+ perturb_codebooks(T, stddev, gen);
205
+
206
+ if (verbose) {
207
+ float obj = evaluate(codes.data(), x, n);
208
+ printf("\tafter perturbing codebooks: obj = %lf\n", obj);
209
+ }
210
+
211
+ // refine codes
212
+ icm_encode(x, codes.data(), n, train_ils_iters, gen);
213
+
214
+ if (verbose) {
215
+ float obj = evaluate(codes.data(), x, n);
216
+ printf("\tafter updating codes: obj = %lf\n", obj);
217
+ }
218
+ }
219
+
220
+ if (verbose) {
221
+ lsq_timer.end("train");
222
+ float obj = evaluate(codes.data(), x, n);
223
+ printf("After training: obj = %lf\n", obj);
224
+
225
+ printf("Time statistic:\n");
226
+ for (const auto& it : lsq_timer.duration) {
227
+ printf("\t%s time: %lf s\n", it.first.data(), it.second);
228
+ }
229
+ }
230
+
231
+ is_trained = true;
232
+ }
233
+
234
+ void LocalSearchQuantizer::perturb_codebooks(
235
+ float T,
236
+ const std::vector<float>& stddev,
237
+ std::mt19937& gen) {
238
+ lsq_timer.start("perturb_codebooks");
239
+
240
+ std::vector<std::normal_distribution<float>> distribs;
241
+ for (size_t i = 0; i < d; i++) {
242
+ distribs.emplace_back(0.0f, stddev[i]);
243
+ }
244
+
245
+ for (size_t m = 0; m < M; m++) {
246
+ for (size_t k = 0; k < K; k++) {
247
+ for (size_t i = 0; i < d; i++) {
248
+ codebooks[m * K * d + k * d + i] += T * distribs[i](gen) / M;
249
+ }
250
+ }
251
+ }
252
+
253
+ lsq_timer.end("perturb_codebooks");
254
+ }
255
+
256
+ void LocalSearchQuantizer::compute_codes(
257
+ const float* x,
258
+ uint8_t* codes_out,
259
+ size_t n) const {
260
+ FAISS_THROW_IF_NOT_MSG(is_trained, "LSQ is not trained yet.");
261
+ if (verbose) {
262
+ lsq_timer.reset();
263
+ printf("Encoding %zd vectors...\n", n);
264
+ lsq_timer.start("encode");
265
+ }
266
+
267
+ std::vector<int32_t> codes(n * M);
268
+ std::mt19937 gen(random_seed);
269
+ random_int32(codes, 0, K - 1, gen);
270
+
271
+ icm_encode(x, codes.data(), n, encode_ils_iters, gen);
272
+ pack_codes(n, codes.data(), codes_out);
273
+
274
+ if (verbose) {
275
+ lsq_timer.end("encode");
276
+ double t = lsq_timer.get("encode");
277
+ printf("Time to encode %zd vectors: %lf s\n", n, t);
278
+ }
279
+ }
280
+
281
+ /** update codebooks given x and codes
282
+ *
283
+ * Let B denote the sparse matrix of codes, size [n, M * K].
284
+ * Let C denote the codebooks, size [M * K, d].
285
+ * Let X denote the training vectors, size [n, d]
286
+ *
287
+ * objective function:
288
+ * L = (X - BC)^2
289
+ *
290
+ * To minimize L, we have:
291
+ * C = (B'B)^(-1)B'X
292
+ * where ' denote transposed
293
+ *
294
+ * Add a regularization term to make B'B inversible:
295
+ * C = (B'B + lambd * I)^(-1)B'X
296
+ */
297
+ void LocalSearchQuantizer::update_codebooks(
298
+ const float* x,
299
+ const int32_t* codes,
300
+ size_t n) {
301
+ lsq_timer.start("update_codebooks");
302
+
303
+ // allocate memory
304
+ // bb = B'B, bx = BX
305
+ std::vector<float> bb(M * K * M * K, 0.0f); // [M * K, M * K]
306
+ std::vector<float> bx(M * K * d, 0.0f); // [M * K, d]
307
+
308
+ // compute B'B
309
+ for (size_t i = 0; i < n; i++) {
310
+ for (size_t m = 0; m < M; m++) {
311
+ int32_t code1 = codes[i * M + m];
312
+ int32_t idx1 = m * K + code1;
313
+ bb[idx1 * M * K + idx1] += 1;
314
+
315
+ for (size_t m2 = m + 1; m2 < M; m2++) {
316
+ int32_t code2 = codes[i * M + m2];
317
+ int32_t idx2 = m2 * K + code2;
318
+ bb[idx1 * M * K + idx2] += 1;
319
+ bb[idx2 * M * K + idx1] += 1;
320
+ }
321
+ }
322
+ }
323
+
324
+ // add a regularization term to B'B
325
+ for (int64_t i = 0; i < M * K; i++) {
326
+ bb[i * (M * K) + i] += lambd;
327
+ }
328
+
329
+ // compute (B'B)^(-1)
330
+ fmat_inverse(bb.data(), M * K); // [M*K, M*K]
331
+
332
+ // compute BX
333
+ for (size_t i = 0; i < n; i++) {
334
+ for (size_t m = 0; m < M; m++) {
335
+ int32_t code = codes[i * M + m];
336
+ float* data = bx.data() + (m * K + code) * d;
337
+ fvec_add(d, data, x + i * d, data);
338
+ }
339
+ }
340
+
341
+ // compute C = (B'B)^(-1) @ BX
342
+ //
343
+ // NOTE: LAPACK use column major order
344
+ // out = alpha * op(A) * op(B) + beta * C
345
+ FINTEGER nrows_A = d;
346
+ FINTEGER ncols_A = M * K;
347
+
348
+ FINTEGER nrows_B = M * K;
349
+ FINTEGER ncols_B = M * K;
350
+
351
+ float alpha = 1.0f;
352
+ float beta = 0.0f;
353
+ sgemm_("Not Transposed",
354
+ "Not Transposed",
355
+ &nrows_A, // nrows of op(A)
356
+ &ncols_B, // ncols of op(B)
357
+ &ncols_A, // ncols of op(A)
358
+ &alpha,
359
+ bx.data(),
360
+ &nrows_A, // nrows of A
361
+ bb.data(),
362
+ &nrows_B, // nrows of B
363
+ &beta,
364
+ codebooks.data(),
365
+ &nrows_A); // nrows of output
366
+
367
+ lsq_timer.end("update_codebooks");
368
+ }
369
+
370
+ /** encode using iterative conditional mode
371
+ *
372
+ * iterative conditional mode:
373
+ * For every subcode ci (i = 1, ..., M) of a vector, we fix the other
374
+ * subcodes cj (j != i) and then find the optimal value of ci such
375
+ * that minimizing the objective function.
376
+
377
+ * objective function:
378
+ * L = (X - \sum cj)^2, j = 1, ..., M
379
+ * L = X^2 - 2X * \sum cj + (\sum cj)^2
380
+ *
381
+ * X^2 is negligable since it is the same for all possible value
382
+ * k of the m-th subcode.
383
+ *
384
+ * 2X * \sum cj is the unary term
385
+ * (\sum cj)^2 is the binary term
386
+ * These two terms can be precomputed and store in a look up table.
387
+ */
388
+ void LocalSearchQuantizer::icm_encode(
389
+ const float* x,
390
+ int32_t* codes,
391
+ size_t n,
392
+ size_t ils_iters,
393
+ std::mt19937& gen) const {
394
+ lsq_timer.start("icm_encode");
395
+
396
+ std::vector<float> binaries(M * M * K * K); // [M, M, K, K]
397
+ compute_binary_terms(binaries.data());
398
+
399
+ const size_t n_chunks = (n + chunk_size - 1) / chunk_size;
400
+ for (size_t i = 0; i < n_chunks; i++) {
401
+ size_t ni = std::min(chunk_size, n - i * chunk_size);
402
+
403
+ if (verbose) {
404
+ printf("\r\ticm encoding %zd/%zd ...", i * chunk_size + ni, n);
405
+ fflush(stdout);
406
+ if (i == n_chunks - 1 || i == 0) {
407
+ printf("\n");
408
+ }
409
+ }
410
+
411
+ const float* xi = x + i * chunk_size * d;
412
+ int32_t* codesi = codes + i * chunk_size * M;
413
+ icm_encode_partial(i, xi, codesi, ni, binaries.data(), ils_iters, gen);
414
+ }
415
+
416
+ lsq_timer.end("icm_encode");
417
+ }
418
+
419
+ void LocalSearchQuantizer::icm_encode_partial(
420
+ size_t index,
421
+ const float* x,
422
+ int32_t* codes,
423
+ size_t n,
424
+ const float* binaries,
425
+ size_t ils_iters,
426
+ std::mt19937& gen) const {
427
+ std::vector<float> unaries(n * M * K); // [n, M, K]
428
+ compute_unary_terms(x, unaries.data(), n);
429
+
430
+ std::vector<int32_t> best_codes;
431
+ best_codes.assign(codes, codes + n * M);
432
+
433
+ std::vector<float> best_objs(n, 0.0f);
434
+ evaluate(codes, x, n, best_objs.data());
435
+
436
+ FAISS_THROW_IF_NOT(nperts <= M);
437
+ for (size_t iter1 = 0; iter1 < ils_iters; iter1++) {
438
+ // add perturbation to codes
439
+ perturb_codes(codes, n, gen);
440
+
441
+ for (size_t iter2 = 0; iter2 < icm_iters; iter2++) {
442
+ icm_encode_step(unaries.data(), binaries, codes, n);
443
+ }
444
+
445
+ std::vector<float> icm_objs(n, 0.0f);
446
+ evaluate(codes, x, n, icm_objs.data());
447
+ size_t n_betters = 0;
448
+ float mean_obj = 0.0f;
449
+
450
+ // select the best code for every vector xi
451
+ #pragma omp parallel for reduction(+ : n_betters, mean_obj)
452
+ for (int64_t i = 0; i < n; i++) {
453
+ if (icm_objs[i] < best_objs[i]) {
454
+ best_objs[i] = icm_objs[i];
455
+ memcpy(best_codes.data() + i * M,
456
+ codes + i * M,
457
+ sizeof(int32_t) * M);
458
+ n_betters += 1;
459
+ }
460
+ mean_obj += best_objs[i];
461
+ }
462
+ mean_obj /= n;
463
+
464
+ memcpy(codes, best_codes.data(), sizeof(int32_t) * n * M);
465
+
466
+ if (verbose && index == 0) {
467
+ printf("\tils_iter %zd: obj = %lf, n_betters/n = %zd/%zd\n",
468
+ iter1,
469
+ mean_obj,
470
+ n_betters,
471
+ n);
472
+ }
473
+ } // loop ils_iters
474
+ }
475
+
476
+ void LocalSearchQuantizer::icm_encode_step(
477
+ const float* unaries,
478
+ const float* binaries,
479
+ int32_t* codes,
480
+ size_t n) const {
481
+ // condition on the m-th subcode
482
+ for (size_t m = 0; m < M; m++) {
483
+ std::vector<float> objs(n * K);
484
+ #pragma omp parallel for
485
+ for (int64_t i = 0; i < n; i++) {
486
+ auto u = unaries + i * (M * K) + m * K;
487
+ memcpy(objs.data() + i * K, u, sizeof(float) * K);
488
+ }
489
+
490
+ // compute objective function by adding unary
491
+ // and binary terms together
492
+ for (size_t other_m = 0; other_m < M; other_m++) {
493
+ if (other_m == m) {
494
+ continue;
495
+ }
496
+
497
+ #pragma omp parallel for
498
+ for (int64_t i = 0; i < n; i++) {
499
+ for (int32_t code = 0; code < K; code++) {
500
+ int32_t code2 = codes[i * M + other_m];
501
+ size_t binary_idx =
502
+ m * M * K * K + other_m * K * K + code * K + code2;
503
+ // binaries[m, other_m, code, code2]
504
+ objs[i * K + code] += binaries[binary_idx];
505
+ }
506
+ }
507
+ }
508
+
509
+ // find the optimal value of the m-th subcode
510
+ #pragma omp parallel for
511
+ for (int64_t i = 0; i < n; i++) {
512
+ float best_obj = HUGE_VALF;
513
+ int32_t best_code = 0;
514
+ for (size_t code = 0; code < K; code++) {
515
+ float obj = objs[i * K + code];
516
+ if (obj < best_obj) {
517
+ best_obj = obj;
518
+ best_code = code;
519
+ }
520
+ }
521
+ codes[i * M + m] = best_code;
522
+ }
523
+
524
+ } // loop M
525
+ }
526
+
527
+ void LocalSearchQuantizer::perturb_codes(
528
+ int32_t* codes,
529
+ size_t n,
530
+ std::mt19937& gen) const {
531
+ lsq_timer.start("perturb_codes");
532
+
533
+ std::uniform_int_distribution<size_t> m_distrib(0, M - 1);
534
+ std::uniform_int_distribution<int32_t> k_distrib(0, K - 1);
535
+
536
+ for (size_t i = 0; i < n; i++) {
537
+ for (size_t j = 0; j < nperts; j++) {
538
+ size_t m = m_distrib(gen);
539
+ codes[i * M + m] = k_distrib(gen);
540
+ }
541
+ }
542
+
543
+ lsq_timer.end("perturb_codes");
544
+ }
545
+
546
+ void LocalSearchQuantizer::compute_binary_terms(float* binaries) const {
547
+ lsq_timer.start("compute_binary_terms");
548
+
549
+ #pragma omp parallel for
550
+ for (int64_t m12 = 0; m12 < M * M; m12++) {
551
+ size_t m1 = m12 / M;
552
+ size_t m2 = m12 % M;
553
+
554
+ for (size_t code1 = 0; code1 < K; code1++) {
555
+ for (size_t code2 = 0; code2 < K; code2++) {
556
+ const float* c1 = codebooks.data() + m1 * K * d + code1 * d;
557
+ const float* c2 = codebooks.data() + m2 * K * d + code2 * d;
558
+ float ip = fvec_inner_product(c1, c2, d);
559
+ // binaries[m1, m2, code1, code2] = ip * 2
560
+ binaries[m1 * M * K * K + m2 * K * K + code1 * K + code2] =
561
+ ip * 2;
562
+ }
563
+ }
564
+ }
565
+
566
+ lsq_timer.end("compute_binary_terms");
567
+ }
568
+
569
+ void LocalSearchQuantizer::compute_unary_terms(
570
+ const float* x,
571
+ float* unaries,
572
+ size_t n) const {
573
+ lsq_timer.start("compute_unary_terms");
574
+
575
+ // compute x * codebooks^T
576
+ //
577
+ // NOTE: LAPACK use column major order
578
+ // out = alpha * op(A) * op(B) + beta * C
579
+ FINTEGER nrows_A = M * K;
580
+ FINTEGER ncols_A = d;
581
+
582
+ FINTEGER nrows_B = d;
583
+ FINTEGER ncols_B = n;
584
+
585
+ float alpha = -2.0f;
586
+ float beta = 0.0f;
587
+ sgemm_("Transposed",
588
+ "Not Transposed",
589
+ &nrows_A, // nrows of op(A)
590
+ &ncols_B, // ncols of op(B)
591
+ &ncols_A, // ncols of op(A)
592
+ &alpha,
593
+ codebooks.data(),
594
+ &ncols_A, // nrows of A
595
+ x,
596
+ &nrows_B, // nrows of B
597
+ &beta,
598
+ unaries,
599
+ &nrows_A); // nrows of output
600
+
601
+ std::vector<float> norms(M * K);
602
+ fvec_norms_L2sqr(norms.data(), codebooks.data(), d, M * K);
603
+
604
+ #pragma omp parallel for
605
+ for (int64_t i = 0; i < n; i++) {
606
+ float* u = unaries + i * (M * K);
607
+ fvec_add(M * K, u, norms.data(), u);
608
+ }
609
+
610
+ lsq_timer.end("compute_unary_terms");
611
+ }
612
+
613
+ float LocalSearchQuantizer::evaluate(
614
+ const int32_t* codes,
615
+ const float* x,
616
+ size_t n,
617
+ float* objs) const {
618
+ lsq_timer.start("evaluate");
619
+
620
+ // decode
621
+ std::vector<float> decoded_x(n * d, 0.0f);
622
+ float obj = 0.0f;
623
+
624
+ #pragma omp parallel for reduction(+ : obj)
625
+ for (int64_t i = 0; i < n; i++) {
626
+ const auto code = codes + i * M;
627
+ const auto decoded_i = decoded_x.data() + i * d;
628
+ for (size_t m = 0; m < M; m++) {
629
+ // c = codebooks[m, code[m]]
630
+ const auto c = codebooks.data() + m * K * d + code[m] * d;
631
+ fvec_add(d, decoded_i, c, decoded_i);
632
+ }
633
+
634
+ float err = fvec_L2sqr(x + i * d, decoded_i, d);
635
+ obj += err;
636
+
637
+ if (objs) {
638
+ objs[i] = err;
639
+ }
640
+ }
641
+
642
+ lsq_timer.end("evaluate");
643
+
644
+ obj = obj / n;
645
+ return obj;
646
+ }
647
+
648
+ double LSQTimer::get(const std::string& name) {
649
+ return duration[name];
650
+ }
651
+
652
+ void LSQTimer::start(const std::string& name) {
653
+ FAISS_THROW_IF_NOT_MSG(!started[name], " timer is already running");
654
+ started[name] = true;
655
+ t0[name] = getmillisecs();
656
+ }
657
+
658
+ void LSQTimer::end(const std::string& name) {
659
+ FAISS_THROW_IF_NOT_MSG(started[name], " timer is not running");
660
+ double t1 = getmillisecs();
661
+ double sec = (t1 - t0[name]) / 1000;
662
+ duration[name] += sec;
663
+ started[name] = false;
664
+ }
665
+
666
+ void LSQTimer::reset() {
667
+ duration.clear();
668
+ t0.clear();
669
+ started.clear();
670
+ }
671
+
672
+ } // namespace faiss