faiss 0.2.3 → 0.2.5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +23 -21
  5. data/ext/faiss/extconf.rb +11 -0
  6. data/ext/faiss/index.cpp +4 -4
  7. data/ext/faiss/index_binary.cpp +6 -6
  8. data/ext/faiss/product_quantizer.cpp +4 -4
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  11. data/vendor/faiss/faiss/Clustering.cpp +32 -0
  12. data/vendor/faiss/faiss/Clustering.h +14 -0
  13. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  14. data/vendor/faiss/faiss/IVFlib.h +26 -2
  15. data/vendor/faiss/faiss/Index.cpp +36 -3
  16. data/vendor/faiss/faiss/Index.h +43 -6
  17. data/vendor/faiss/faiss/Index2Layer.cpp +24 -93
  18. data/vendor/faiss/faiss/Index2Layer.h +8 -17
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +610 -0
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +253 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  22. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  23. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  24. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  25. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  26. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  28. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  30. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  31. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  32. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  33. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  34. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  35. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  36. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  37. data/vendor/faiss/faiss/IndexFlat.cpp +52 -69
  38. data/vendor/faiss/faiss/IndexFlat.h +16 -19
  39. data/vendor/faiss/faiss/IndexFlatCodes.cpp +101 -0
  40. data/vendor/faiss/faiss/IndexFlatCodes.h +59 -0
  41. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  42. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  43. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  44. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  45. data/vendor/faiss/faiss/IndexIVF.cpp +200 -40
  46. data/vendor/faiss/faiss/IndexIVF.h +59 -22
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +393 -0
  48. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +183 -0
  49. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  50. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  51. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  52. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  53. data/vendor/faiss/faiss/IndexIVFFlat.cpp +43 -26
  54. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  55. data/vendor/faiss/faiss/IndexIVFPQ.cpp +238 -53
  56. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -2
  57. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  58. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  59. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  60. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  61. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +63 -40
  62. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +23 -7
  63. data/vendor/faiss/faiss/IndexLSH.cpp +8 -32
  64. data/vendor/faiss/faiss/IndexLSH.h +4 -16
  65. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  66. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  67. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -5
  68. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  69. data/vendor/faiss/faiss/IndexNSG.cpp +37 -5
  70. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  71. data/vendor/faiss/faiss/IndexPQ.cpp +108 -120
  72. data/vendor/faiss/faiss/IndexPQ.h +21 -22
  73. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  74. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  75. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  76. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  77. data/vendor/faiss/faiss/IndexRefine.cpp +36 -4
  78. data/vendor/faiss/faiss/IndexRefine.h +14 -2
  79. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  80. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +28 -43
  84. data/vendor/faiss/faiss/IndexScalarQuantizer.h +8 -23
  85. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  86. data/vendor/faiss/faiss/IndexShards.h +2 -1
  87. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  88. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  89. data/vendor/faiss/faiss/VectorTransform.cpp +45 -1
  90. data/vendor/faiss/faiss/VectorTransform.h +25 -4
  91. data/vendor/faiss/faiss/clone_index.cpp +26 -3
  92. data/vendor/faiss/faiss/clone_index.h +3 -0
  93. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  94. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  95. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  101. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  102. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  104. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  105. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -6
  106. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  107. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  108. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  109. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  110. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  111. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  112. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  113. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  114. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  115. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  116. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  117. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  120. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  121. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  122. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  123. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  124. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +331 -29
  125. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +110 -19
  126. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  127. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  128. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  129. data/vendor/faiss/faiss/impl/HNSW.cpp +133 -32
  130. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  131. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  132. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  133. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +378 -217
  134. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +106 -29
  135. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  136. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  137. data/vendor/faiss/faiss/impl/NSG.cpp +1 -4
  138. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  139. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  140. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  141. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  142. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  143. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  144. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +521 -55
  145. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +94 -16
  146. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  147. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +108 -191
  148. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  149. data/vendor/faiss/faiss/impl/index_read.cpp +338 -24
  150. data/vendor/faiss/faiss/impl/index_write.cpp +300 -18
  151. data/vendor/faiss/faiss/impl/io.cpp +1 -1
  152. data/vendor/faiss/faiss/impl/io_macros.h +20 -0
  153. data/vendor/faiss/faiss/impl/kmeans1d.cpp +303 -0
  154. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  159. data/vendor/faiss/faiss/index_factory.cpp +772 -412
  160. data/vendor/faiss/faiss/index_factory.h +3 -0
  161. data/vendor/faiss/faiss/index_io.h +5 -0
  162. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  163. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  164. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  165. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  166. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  167. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  168. data/vendor/faiss/faiss/utils/distances.cpp +384 -58
  169. data/vendor/faiss/faiss/utils/distances.h +149 -18
  170. data/vendor/faiss/faiss/utils/distances_simd.cpp +776 -6
  171. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  172. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  173. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  174. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  175. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  176. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  177. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  178. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  179. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  180. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  181. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  182. data/vendor/faiss/faiss/utils/random.h +5 -0
  183. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  184. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  185. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  186. data/vendor/faiss/faiss/utils/utils.h +1 -1
  187. metadata +46 -5
  188. data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
  189. data/vendor/faiss/faiss/IndexResidual.h +0 -152
@@ -0,0 +1,303 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <algorithm>
9
+ #include <cstdint>
10
+ #include <cstring>
11
+ #include <functional>
12
+ #include <numeric>
13
+ #include <string>
14
+ #include <unordered_map>
15
+ #include <vector>
16
+
17
+ #include <faiss/Index.h>
18
+ #include <faiss/impl/FaissAssert.h>
19
+ #include <faiss/impl/kmeans1d.h>
20
+
21
+ namespace faiss {
22
+
23
+ using idx_t = Index::idx_t;
24
+ using LookUpFunc = std::function<float(idx_t, idx_t)>;
25
+
26
+ void reduce(
27
+ const std::vector<idx_t>& rows,
28
+ const std::vector<idx_t>& input_cols,
29
+ const LookUpFunc& lookup,
30
+ std::vector<idx_t>& output_cols) {
31
+ for (idx_t col : input_cols) {
32
+ while (!output_cols.empty()) {
33
+ idx_t row = rows[output_cols.size() - 1];
34
+ float a = lookup(row, col);
35
+ float b = lookup(row, output_cols.back());
36
+ if (a >= b) { // defeated
37
+ break;
38
+ }
39
+ output_cols.pop_back();
40
+ }
41
+ if (output_cols.size() < rows.size()) {
42
+ output_cols.push_back(col);
43
+ }
44
+ }
45
+ }
46
+
47
+ void interpolate(
48
+ const std::vector<idx_t>& rows,
49
+ const std::vector<idx_t>& cols,
50
+ const LookUpFunc& lookup,
51
+ idx_t* argmins) {
52
+ std::unordered_map<idx_t, idx_t> idx_to_col;
53
+ for (idx_t idx = 0; idx < cols.size(); ++idx) {
54
+ idx_to_col[cols[idx]] = idx;
55
+ }
56
+
57
+ idx_t start = 0;
58
+ for (idx_t r = 0; r < rows.size(); r += 2) {
59
+ idx_t row = rows[r];
60
+ idx_t end = cols.size() - 1;
61
+ if (r < rows.size() - 1) {
62
+ idx_t idx = argmins[rows[r + 1]];
63
+ end = idx_to_col[idx];
64
+ }
65
+ idx_t argmin = cols[start];
66
+ float min = lookup(row, argmin);
67
+ for (idx_t c = start + 1; c <= end; c++) {
68
+ float value = lookup(row, cols[c]);
69
+ if (value < min) {
70
+ argmin = cols[c];
71
+ min = value;
72
+ }
73
+ }
74
+ argmins[row] = argmin;
75
+ start = end;
76
+ }
77
+ }
78
+
79
+ /** SMAWK algo. Find the row minima of a monotone matrix.
80
+ *
81
+ * References:
82
+ * 1. http://web.cs.unlv.edu/larmore/Courses/CSC477/monge.pdf
83
+ * 2. https://gist.github.com/dstein64/8e94a6a25efc1335657e910ff525f405
84
+ * 3. https://github.com/dstein64/kmeans1d
85
+ */
86
+ void smawk_impl(
87
+ const std::vector<idx_t>& rows,
88
+ const std::vector<idx_t>& input_cols,
89
+ const LookUpFunc& lookup,
90
+ idx_t* argmins) {
91
+ if (rows.size() == 0) {
92
+ return;
93
+ }
94
+
95
+ /**********************************
96
+ * REDUCE
97
+ **********************************/
98
+ auto ptr = &input_cols;
99
+ std::vector<idx_t> survived_cols; // survived columns
100
+ if (rows.size() < input_cols.size()) {
101
+ reduce(rows, input_cols, lookup, survived_cols);
102
+ ptr = &survived_cols;
103
+ }
104
+ auto& cols = *ptr; // avoid memory copy
105
+
106
+ /**********************************
107
+ * INTERPOLATE
108
+ **********************************/
109
+
110
+ // call recursively on odd-indexed rows
111
+ std::vector<idx_t> odd_rows;
112
+ for (idx_t i = 1; i < rows.size(); i += 2) {
113
+ odd_rows.push_back(rows[i]);
114
+ }
115
+ smawk_impl(odd_rows, cols, lookup, argmins);
116
+
117
+ // interpolate the even-indexed rows
118
+ interpolate(rows, cols, lookup, argmins);
119
+ }
120
+
121
+ void smawk(
122
+ const idx_t nrows,
123
+ const idx_t ncols,
124
+ const LookUpFunc& lookup,
125
+ idx_t* argmins) {
126
+ std::vector<idx_t> rows(nrows);
127
+ std::vector<idx_t> cols(ncols);
128
+ std::iota(std::begin(rows), std::end(rows), 0);
129
+ std::iota(std::begin(cols), std::end(cols), 0);
130
+
131
+ smawk_impl(rows, cols, lookup, argmins);
132
+ }
133
+
134
+ void smawk(
135
+ const idx_t nrows,
136
+ const idx_t ncols,
137
+ const float* x,
138
+ idx_t* argmins) {
139
+ auto lookup = [&x, &ncols](idx_t i, idx_t j) { return x[i * ncols + j]; };
140
+ smawk(nrows, ncols, lookup, argmins);
141
+ }
142
+
143
+ namespace {
144
+
145
+ class CostCalculator {
146
+ // The reuslt would be inaccurate if we use float
147
+ std::vector<double> cumsum;
148
+ std::vector<double> cumsum2;
149
+
150
+ public:
151
+ CostCalculator(const std::vector<float>& vec, idx_t n) {
152
+ cumsum.push_back(0.0);
153
+ cumsum2.push_back(0.0);
154
+ for (idx_t i = 0; i < n; ++i) {
155
+ float x = vec[i];
156
+ cumsum.push_back(x + cumsum[i]);
157
+ cumsum2.push_back(x * x + cumsum2[i]);
158
+ }
159
+ }
160
+
161
+ float operator()(idx_t i, idx_t j) {
162
+ if (j < i) {
163
+ return 0.0f;
164
+ }
165
+ auto mu = (cumsum[j + 1] - cumsum[i]) / (j - i + 1);
166
+ auto result = cumsum2[j + 1] - cumsum2[i];
167
+ result += (j - i + 1) * (mu * mu);
168
+ result -= (2 * mu) * (cumsum[j + 1] - cumsum[i]);
169
+ return float(result);
170
+ }
171
+ };
172
+
173
+ template <class T>
174
+ class Matrix {
175
+ std::vector<T> data;
176
+ idx_t nrows;
177
+ idx_t ncols;
178
+
179
+ public:
180
+ Matrix(idx_t nrows, idx_t ncols) {
181
+ this->nrows = nrows;
182
+ this->ncols = ncols;
183
+ data.resize(nrows * ncols);
184
+ }
185
+
186
+ inline T& at(idx_t i, idx_t j) {
187
+ return data[i * ncols + j];
188
+ }
189
+ };
190
+
191
+ } // anonymous namespace
192
+
193
+ double kmeans1d(const float* x, size_t n, size_t nclusters, float* centroids) {
194
+ FAISS_THROW_IF_NOT(n >= nclusters);
195
+
196
+ // corner case
197
+ if (n == nclusters) {
198
+ memcpy(centroids, x, n * sizeof(*x));
199
+ return 0.0f;
200
+ }
201
+
202
+ /***************************************************
203
+ * sort in ascending order, O(NlogN) in time
204
+ ***************************************************/
205
+ std::vector<float> arr(x, x + n);
206
+ std::sort(arr.begin(), arr.end());
207
+
208
+ /***************************************************
209
+ dynamic programming algorithm
210
+
211
+ Reference: https://arxiv.org/abs/1701.07204
212
+ -------------------------------
213
+
214
+ Assume x is already sorted in ascending order.
215
+
216
+ N: number of points
217
+ K: number of clusters
218
+
219
+ CC(i, j): the cost of grouping xi,...,xj into one cluster
220
+ D[k][m]: the cost of optimally clustering x1,...,xm into k clusters
221
+ T[k][m]: the start index of the k-th cluster
222
+
223
+ The DP process is as follow:
224
+ D[k][m] = min_i D[k − 1][i − 1] + CC(i, m)
225
+ T[k][m] = argmin_i D[k − 1][i − 1] + CC(i, m)
226
+
227
+ This could be solved in O(KN^2) time and O(KN) space.
228
+
229
+ To further reduce the time complexity, we use SMAWK algo to
230
+ solve the argmin problem as follow:
231
+
232
+ For each k:
233
+ C[m][i] = D[k − 1][i − 1] + CC(i, m)
234
+
235
+ Here C is a n x n totally monotone matrix.
236
+ We could find the row minima by SMAWK in O(N) time.
237
+
238
+ Now the time complexity is reduced from O(kN^2) to O(KN).
239
+ ****************************************************/
240
+
241
+ CostCalculator CC(arr, n);
242
+ Matrix<float> D(nclusters, n);
243
+ Matrix<idx_t> T(nclusters, n);
244
+
245
+ for (idx_t m = 0; m < n; m++) {
246
+ D.at(0, m) = CC(0, m);
247
+ T.at(0, m) = 0;
248
+ }
249
+
250
+ std::vector<idx_t> indices(nclusters, 0);
251
+
252
+ for (idx_t k = 1; k < nclusters; ++k) {
253
+ // we define C here
254
+ auto C = [&D, &CC, &k](idx_t m, idx_t i) {
255
+ if (i == 0) {
256
+ return CC(i, m);
257
+ }
258
+ idx_t col = std::min(m, i - 1);
259
+ return D.at(k - 1, col) + CC(i, m);
260
+ };
261
+
262
+ std::vector<idx_t> argmins(n); // argmin of each row
263
+ smawk(n, n, C, argmins.data());
264
+ for (idx_t m = 0; m < argmins.size(); m++) {
265
+ idx_t idx = argmins[m];
266
+ D.at(k, m) = C(m, idx);
267
+ T.at(k, m) = idx;
268
+ }
269
+ }
270
+
271
+ /***************************************************
272
+ compute centroids by backtracking
273
+
274
+ T[K - 1][T[K][N] - 1] T[K][N] N
275
+ --------------|------------------------|-----------|
276
+ | cluster K - 1 | cluster K |
277
+
278
+ ****************************************************/
279
+
280
+ // for imbalance factor
281
+ double tot = 0.0;
282
+ double uf = 0.0;
283
+
284
+ idx_t end = n;
285
+ for (idx_t k = nclusters - 1; k >= 0; k--) {
286
+ const idx_t start = T.at(k, end - 1);
287
+ const float sum =
288
+ std::accumulate(arr.data() + start, arr.data() + end, 0.0f);
289
+ const idx_t size = end - start;
290
+ FAISS_THROW_IF_NOT_FMT(
291
+ size > 0, "Cluster %d: size %d", int(k), int(size));
292
+ centroids[k] = sum / size;
293
+ end = start;
294
+
295
+ tot += size;
296
+ uf += size * double(size);
297
+ }
298
+
299
+ uf = uf * nclusters / (tot * tot);
300
+ return uf;
301
+ }
302
+
303
+ } // namespace faiss
@@ -0,0 +1,48 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <faiss/Index.h>
11
+ #include <functional>
12
+
13
+ namespace faiss {
14
+
15
+ /** SMAWK algorithm. Find the row minima of a monotone matrix.
16
+ *
17
+ * Expose this for testing.
18
+ *
19
+ * @param nrows number of rows
20
+ * @param ncols number of columns
21
+ * @param x input matrix, size (nrows, ncols)
22
+ * @param argmins argmin of each row
23
+ */
24
+ void smawk(
25
+ const Index::idx_t nrows,
26
+ const Index::idx_t ncols,
27
+ const float* x,
28
+ Index::idx_t* argmins);
29
+
30
+ /** Exact 1D K-Means by dynamic programming
31
+ *
32
+ * From "Fast Exact k-Means, k-Medians and Bregman Divergence Clustering in 1D"
33
+ * Allan Grønlund, Kasper Green Larsen, Alexander Mathiasen, Jesper Sindahl
34
+ * Nielsen, Stefan Schneider, Mingzhou Song, ArXiV'17
35
+ *
36
+ * Section 2.2
37
+ *
38
+ * https://arxiv.org/abs/1701.07204
39
+ *
40
+ * @param x input 1D array
41
+ * @param n input array length
42
+ * @param nclusters number of clusters
43
+ * @param centroids output centroids, size nclusters
44
+ * @return imbalancce factor
45
+ */
46
+ double kmeans1d(const float* x, size_t n, size_t nclusters, float* centroids);
47
+
48
+ } // namespace faiss
@@ -122,30 +122,70 @@ void pq4_pack_codes_range(
122
122
  }
123
123
  }
124
124
 
125
+ namespace {
126
+
127
+ // get the specific address of the vector inside a block
128
+ // shift is used for determine the if the saved in bits 0..3 (false) or
129
+ // bits 4..7 (true)
130
+ uint8_t get_vector_specific_address(
131
+ size_t bbs,
132
+ size_t vector_id,
133
+ size_t sq,
134
+ bool& shift) {
135
+ // get the vector_id inside the block
136
+ vector_id = vector_id % bbs;
137
+ shift = vector_id > 15;
138
+ vector_id = vector_id & 15;
139
+
140
+ // get the address of the vector in sq
141
+ size_t address;
142
+ if (vector_id < 8) {
143
+ address = vector_id << 1;
144
+ } else {
145
+ address = ((vector_id - 8) << 1) + 1;
146
+ }
147
+ if (sq & 1) {
148
+ address += 16;
149
+ }
150
+ return (sq >> 1) * bbs + address;
151
+ }
152
+
153
+ } // anonymous namespace
154
+
125
155
  uint8_t pq4_get_packed_element(
126
156
  const uint8_t* data,
127
157
  size_t bbs,
128
158
  size_t nsq,
129
- size_t i,
159
+ size_t vector_id,
130
160
  size_t sq) {
131
161
  // move to correct bbs-sized block
132
- data += (i / bbs * (nsq / 2) + sq / 2) * bbs;
133
- sq = sq & 1;
134
- i = i % bbs;
135
-
136
- // another step
137
- data += (i / 32) * 32;
138
- i = i % 32;
139
-
140
- if (sq == 1) {
141
- data += 16;
162
+ // number of blocks * block size
163
+ data += (vector_id / bbs) * (((nsq + 1) / 2) * bbs);
164
+ bool shift;
165
+ size_t address = get_vector_specific_address(bbs, vector_id, sq, shift);
166
+ if (shift) {
167
+ return data[address] >> 4;
168
+ } else {
169
+ return data[address] & 15;
142
170
  }
143
- const uint8_t iperm0[16] = {
144
- 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
145
- if (i < 16) {
146
- return data[iperm0[i]] & 15;
171
+ }
172
+
173
+ void pq4_set_packed_element(
174
+ uint8_t* data,
175
+ uint8_t code,
176
+ size_t bbs,
177
+ size_t nsq,
178
+ size_t vector_id,
179
+ size_t sq) {
180
+ // move to correct bbs-sized block
181
+ // number of blocks * block size
182
+ data += (vector_id / bbs) * (((nsq + 1) / 2) * bbs);
183
+ bool shift;
184
+ size_t address = get_vector_specific_address(bbs, vector_id, sq, shift);
185
+ if (shift) {
186
+ data[address] = (code << 4) | (data[address] & 15);
147
187
  } else {
148
- return data[iperm0[i - 16]] >> 4;
188
+ data[address] = code | (data[address] & ~15);
149
189
  }
150
190
  }
151
191
 
@@ -26,7 +26,7 @@ namespace faiss {
26
26
  * The unused bytes are set to 0.
27
27
  *
28
28
  * @param codes input codes, size (ntotal, ceil(M / 2))
29
- * @param nototal number of input codes
29
+ * @param ntotal number of input codes
30
30
  * @param nb output number of codes (ntotal rounded up to a multiple of
31
31
  * bbs)
32
32
  * @param M2 number of sub-quantizers (=M rounded up to a muliple of 2)
@@ -61,14 +61,27 @@ void pq4_pack_codes_range(
61
61
 
62
62
  /** get a single element from a packed codes table
63
63
  *
64
- * @param i vector id
64
+ * @param vector_id vector id
65
65
  * @param sq subquantizer (< nsq)
66
66
  */
67
67
  uint8_t pq4_get_packed_element(
68
68
  const uint8_t* data,
69
69
  size_t bbs,
70
70
  size_t nsq,
71
- size_t i,
71
+ size_t vector_id,
72
+ size_t sq);
73
+
74
+ /** set a single element "code" into a packed codes table
75
+ *
76
+ * @param vector_id vector id
77
+ * @param sq subquantizer (< nsq)
78
+ */
79
+ void pq4_set_packed_element(
80
+ uint8_t* data,
81
+ uint8_t code,
82
+ size_t bbs,
83
+ size_t nsq,
84
+ size_t vector_id,
72
85
  size_t sq);
73
86
 
74
87
  /** Pack Look-up table for consumption by the kernel.
@@ -88,8 +101,9 @@ void pq4_pack_LUT(int nq, int nsq, const uint8_t* src, uint8_t* dest);
88
101
  * @param nsq number of sub-quantizers (muliple of 2)
89
102
  * @param codes packed codes array
90
103
  * @param LUT packed look-up table
104
+ * @param scaler scaler to scale the encoded norm
91
105
  */
92
- template <class ResultHandler>
106
+ template <class ResultHandler, class Scaler>
93
107
  void pq4_accumulate_loop(
94
108
  int nq,
95
109
  size_t nb,
@@ -97,7 +111,8 @@ void pq4_accumulate_loop(
97
111
  int nsq,
98
112
  const uint8_t* codes,
99
113
  const uint8_t* LUT,
100
- ResultHandler& res);
114
+ ResultHandler& res,
115
+ const Scaler& scaler);
101
116
 
102
117
  /* qbs versions, supported only for bbs=32.
103
118
  *
@@ -141,20 +156,22 @@ int pq4_pack_LUT_qbs_q_map(
141
156
 
142
157
  /** Run accumulation loop.
143
158
  *
144
- * @param qbs 4-bit encded number of queries
159
+ * @param qbs 4-bit encoded number of queries
145
160
  * @param nb number of database codes (mutliple of bbs)
146
161
  * @param nsq number of sub-quantizers
147
162
  * @param codes encoded database vectors (packed)
148
163
  * @param LUT look-up table (packed)
149
164
  * @param res call-back for the resutls
165
+ * @param scaler scaler to scale the encoded norm
150
166
  */
151
- template <class ResultHandler>
167
+ template <class ResultHandler, class Scaler>
152
168
  void pq4_accumulate_loop_qbs(
153
169
  int qbs,
154
170
  size_t nb,
155
171
  int nsq,
156
172
  const uint8_t* codes,
157
173
  const uint8_t* LUT,
158
- ResultHandler& res);
174
+ ResultHandler& res,
175
+ const Scaler& scaler);
159
176
 
160
177
  } // namespace faiss
@@ -8,6 +8,7 @@
8
8
  #include <faiss/impl/pq4_fast_scan.h>
9
9
 
10
10
  #include <faiss/impl/FaissAssert.h>
11
+ #include <faiss/impl/LookupTableScaler.h>
11
12
  #include <faiss/impl/simd_result_handlers.h>
12
13
 
13
14
  namespace faiss {
@@ -26,12 +27,13 @@ namespace {
26
27
  * writes results in a ResultHandler
27
28
  */
28
29
 
29
- template <int NQ, int BB, class ResultHandler>
30
+ template <int NQ, int BB, class ResultHandler, class Scaler>
30
31
  void kernel_accumulate_block(
31
32
  int nsq,
32
33
  const uint8_t* codes,
33
34
  const uint8_t* LUT,
34
- ResultHandler& res) {
35
+ ResultHandler& res,
36
+ const Scaler& scaler) {
35
37
  // distance accumulators
36
38
  simd16uint16 accu[NQ][BB][4];
37
39
 
@@ -44,7 +46,7 @@ void kernel_accumulate_block(
44
46
  }
45
47
  }
46
48
 
47
- for (int sq = 0; sq < nsq; sq += 2) {
49
+ for (int sq = 0; sq < nsq - scaler.nscale; sq += 2) {
48
50
  simd32uint8 lut_cache[NQ];
49
51
  for (int q = 0; q < NQ; q++) {
50
52
  lut_cache[q] = simd32uint8(LUT);
@@ -72,6 +74,35 @@ void kernel_accumulate_block(
72
74
  }
73
75
  }
74
76
 
77
+ for (int sq = 0; sq < scaler.nscale; sq += 2) {
78
+ simd32uint8 lut_cache[NQ];
79
+ for (int q = 0; q < NQ; q++) {
80
+ lut_cache[q] = simd32uint8(LUT);
81
+ LUT += 32;
82
+ }
83
+
84
+ for (int b = 0; b < BB; b++) {
85
+ simd32uint8 c = simd32uint8(codes);
86
+ codes += 32;
87
+ simd32uint8 mask(15);
88
+ simd32uint8 chi = simd32uint8(simd16uint16(c) >> 4) & mask;
89
+ simd32uint8 clo = c & mask;
90
+
91
+ for (int q = 0; q < NQ; q++) {
92
+ simd32uint8 lut = lut_cache[q];
93
+
94
+ simd32uint8 res0 = scaler.lookup(lut, clo);
95
+ accu[q][b][0] += scaler.scale_lo(res0); // handle vectors 0..7
96
+ accu[q][b][1] += scaler.scale_hi(res0); // handle vectors 8..15
97
+
98
+ simd32uint8 res1 = scaler.lookup(lut, chi);
99
+ accu[q][b][2] += scaler.scale_lo(res1); // handle vectors 16..23
100
+ accu[q][b][3] +=
101
+ scaler.scale_hi(res1); // handle vectors 24..31
102
+ }
103
+ }
104
+ }
105
+
75
106
  for (int q = 0; q < NQ; q++) {
76
107
  for (int b = 0; b < BB; b++) {
77
108
  accu[q][b][0] -= accu[q][b][1] << 8;
@@ -85,17 +116,18 @@ void kernel_accumulate_block(
85
116
  }
86
117
  }
87
118
 
88
- template <int NQ, int BB, class ResultHandler>
119
+ template <int NQ, int BB, class ResultHandler, class Scaler>
89
120
  void accumulate_fixed_blocks(
90
121
  size_t nb,
91
122
  int nsq,
92
123
  const uint8_t* codes,
93
124
  const uint8_t* LUT,
94
- ResultHandler& res) {
125
+ ResultHandler& res,
126
+ const Scaler& scaler) {
95
127
  constexpr int bbs = 32 * BB;
96
128
  for (int64_t j0 = 0; j0 < nb; j0 += bbs) {
97
129
  FixedStorageHandler<NQ, 2 * BB> res2;
98
- kernel_accumulate_block<NQ, BB>(nsq, codes, LUT, res2);
130
+ kernel_accumulate_block<NQ, BB>(nsq, codes, LUT, res2, scaler);
99
131
  res.set_block_origin(0, j0);
100
132
  res2.to_other_handler(res);
101
133
  codes += bbs * nsq / 2;
@@ -104,7 +136,7 @@ void accumulate_fixed_blocks(
104
136
 
105
137
  } // anonymous namespace
106
138
 
107
- template <class ResultHandler>
139
+ template <class ResultHandler, class Scaler>
108
140
  void pq4_accumulate_loop(
109
141
  int nq,
110
142
  size_t nb,
@@ -112,15 +144,16 @@ void pq4_accumulate_loop(
112
144
  int nsq,
113
145
  const uint8_t* codes,
114
146
  const uint8_t* LUT,
115
- ResultHandler& res) {
147
+ ResultHandler& res,
148
+ const Scaler& scaler) {
116
149
  FAISS_THROW_IF_NOT(is_aligned_pointer(codes));
117
150
  FAISS_THROW_IF_NOT(is_aligned_pointer(LUT));
118
151
  FAISS_THROW_IF_NOT(bbs % 32 == 0);
119
152
  FAISS_THROW_IF_NOT(nb % bbs == 0);
120
153
 
121
- #define DISPATCH(NQ, BB) \
122
- case NQ * 1000 + BB: \
123
- accumulate_fixed_blocks<NQ, BB>(nb, nsq, codes, LUT, res); \
154
+ #define DISPATCH(NQ, BB) \
155
+ case NQ * 1000 + BB: \
156
+ accumulate_fixed_blocks<NQ, BB>(nb, nsq, codes, LUT, res, scaler); \
124
157
  break
125
158
 
126
159
  switch (nq * 1000 + bbs / 32) {
@@ -141,20 +174,28 @@ void pq4_accumulate_loop(
141
174
 
142
175
  // explicit template instantiations
143
176
 
144
- #define INSTANTIATE_ACCUMULATE(TH, C, with_id_map) \
145
- template void pq4_accumulate_loop<TH<C, with_id_map>>( \
146
- int, \
147
- size_t, \
148
- int, \
149
- int, \
150
- const uint8_t*, \
151
- const uint8_t*, \
152
- TH<C, with_id_map>&);
153
-
154
- #define INSTANTIATE_3(C, with_id_map) \
155
- INSTANTIATE_ACCUMULATE(SingleResultHandler, C, with_id_map) \
156
- INSTANTIATE_ACCUMULATE(HeapHandler, C, with_id_map) \
157
- INSTANTIATE_ACCUMULATE(ReservoirHandler, C, with_id_map)
177
+ #define INSTANTIATE_ACCUMULATE(TH, C, with_id_map, S) \
178
+ template void pq4_accumulate_loop<TH<C, with_id_map>, S>( \
179
+ int, \
180
+ size_t, \
181
+ int, \
182
+ int, \
183
+ const uint8_t*, \
184
+ const uint8_t*, \
185
+ TH<C, with_id_map>&, \
186
+ const S&);
187
+
188
+ using DS = DummyScaler;
189
+ using NS = NormTableScaler;
190
+
191
+ #define INSTANTIATE_3(C, with_id_map) \
192
+ INSTANTIATE_ACCUMULATE(SingleResultHandler, C, with_id_map, DS) \
193
+ INSTANTIATE_ACCUMULATE(HeapHandler, C, with_id_map, DS) \
194
+ INSTANTIATE_ACCUMULATE(ReservoirHandler, C, with_id_map, DS) \
195
+ \
196
+ INSTANTIATE_ACCUMULATE(SingleResultHandler, C, with_id_map, NS) \
197
+ INSTANTIATE_ACCUMULATE(HeapHandler, C, with_id_map, NS) \
198
+ INSTANTIATE_ACCUMULATE(ReservoirHandler, C, with_id_map, NS)
158
199
 
159
200
  using Csi = CMax<uint16_t, int>;
160
201
  INSTANTIATE_3(Csi, false);