faiss 0.2.3 → 0.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +23 -21
  5. data/ext/faiss/extconf.rb +11 -0
  6. data/ext/faiss/index.cpp +4 -4
  7. data/ext/faiss/index_binary.cpp +6 -6
  8. data/ext/faiss/product_quantizer.cpp +4 -4
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  11. data/vendor/faiss/faiss/Clustering.cpp +32 -0
  12. data/vendor/faiss/faiss/Clustering.h +14 -0
  13. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  14. data/vendor/faiss/faiss/IVFlib.h +26 -2
  15. data/vendor/faiss/faiss/Index.cpp +36 -3
  16. data/vendor/faiss/faiss/Index.h +43 -6
  17. data/vendor/faiss/faiss/Index2Layer.cpp +24 -93
  18. data/vendor/faiss/faiss/Index2Layer.h +8 -17
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +610 -0
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +253 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  22. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  23. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  24. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  25. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  26. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  28. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  30. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  31. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  32. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  33. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  34. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  35. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  36. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  37. data/vendor/faiss/faiss/IndexFlat.cpp +52 -69
  38. data/vendor/faiss/faiss/IndexFlat.h +16 -19
  39. data/vendor/faiss/faiss/IndexFlatCodes.cpp +101 -0
  40. data/vendor/faiss/faiss/IndexFlatCodes.h +59 -0
  41. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  42. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  43. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  44. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  45. data/vendor/faiss/faiss/IndexIVF.cpp +200 -40
  46. data/vendor/faiss/faiss/IndexIVF.h +59 -22
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +393 -0
  48. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +183 -0
  49. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  50. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  51. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  52. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  53. data/vendor/faiss/faiss/IndexIVFFlat.cpp +43 -26
  54. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  55. data/vendor/faiss/faiss/IndexIVFPQ.cpp +238 -53
  56. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -2
  57. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  58. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  59. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  60. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  61. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +63 -40
  62. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +23 -7
  63. data/vendor/faiss/faiss/IndexLSH.cpp +8 -32
  64. data/vendor/faiss/faiss/IndexLSH.h +4 -16
  65. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  66. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  67. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -5
  68. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  69. data/vendor/faiss/faiss/IndexNSG.cpp +37 -5
  70. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  71. data/vendor/faiss/faiss/IndexPQ.cpp +108 -120
  72. data/vendor/faiss/faiss/IndexPQ.h +21 -22
  73. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  74. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  75. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  76. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  77. data/vendor/faiss/faiss/IndexRefine.cpp +36 -4
  78. data/vendor/faiss/faiss/IndexRefine.h +14 -2
  79. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  80. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +28 -43
  84. data/vendor/faiss/faiss/IndexScalarQuantizer.h +8 -23
  85. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  86. data/vendor/faiss/faiss/IndexShards.h +2 -1
  87. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  88. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  89. data/vendor/faiss/faiss/VectorTransform.cpp +45 -1
  90. data/vendor/faiss/faiss/VectorTransform.h +25 -4
  91. data/vendor/faiss/faiss/clone_index.cpp +26 -3
  92. data/vendor/faiss/faiss/clone_index.h +3 -0
  93. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  94. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  95. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  101. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  102. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  104. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  105. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -6
  106. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  107. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  108. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  109. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  110. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  111. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  112. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  113. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  114. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  115. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  116. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  117. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  120. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  121. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  122. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  123. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  124. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +331 -29
  125. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +110 -19
  126. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  127. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  128. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  129. data/vendor/faiss/faiss/impl/HNSW.cpp +133 -32
  130. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  131. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  132. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  133. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +378 -217
  134. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +106 -29
  135. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  136. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  137. data/vendor/faiss/faiss/impl/NSG.cpp +1 -4
  138. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  139. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  140. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  141. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  142. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  143. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  144. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +521 -55
  145. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +94 -16
  146. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  147. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +108 -191
  148. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  149. data/vendor/faiss/faiss/impl/index_read.cpp +338 -24
  150. data/vendor/faiss/faiss/impl/index_write.cpp +300 -18
  151. data/vendor/faiss/faiss/impl/io.cpp +1 -1
  152. data/vendor/faiss/faiss/impl/io_macros.h +20 -0
  153. data/vendor/faiss/faiss/impl/kmeans1d.cpp +303 -0
  154. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  159. data/vendor/faiss/faiss/index_factory.cpp +772 -412
  160. data/vendor/faiss/faiss/index_factory.h +3 -0
  161. data/vendor/faiss/faiss/index_io.h +5 -0
  162. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  163. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  164. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  165. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  166. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  167. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  168. data/vendor/faiss/faiss/utils/distances.cpp +384 -58
  169. data/vendor/faiss/faiss/utils/distances.h +149 -18
  170. data/vendor/faiss/faiss/utils/distances_simd.cpp +776 -6
  171. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  172. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  173. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  174. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  175. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  176. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  177. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  178. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  179. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  180. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  181. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  182. data/vendor/faiss/faiss/utils/random.h +5 -0
  183. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  184. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  185. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  186. data/vendor/faiss/faiss/utils/utils.h +1 -1
  187. metadata +46 -5
  188. data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
  189. data/vendor/faiss/faiss/IndexResidual.h +0 -152
@@ -0,0 +1,303 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <algorithm>
9
+ #include <cstdint>
10
+ #include <cstring>
11
+ #include <functional>
12
+ #include <numeric>
13
+ #include <string>
14
+ #include <unordered_map>
15
+ #include <vector>
16
+
17
+ #include <faiss/Index.h>
18
+ #include <faiss/impl/FaissAssert.h>
19
+ #include <faiss/impl/kmeans1d.h>
20
+
21
+ namespace faiss {
22
+
23
+ using idx_t = Index::idx_t;
24
+ using LookUpFunc = std::function<float(idx_t, idx_t)>;
25
+
26
+ void reduce(
27
+ const std::vector<idx_t>& rows,
28
+ const std::vector<idx_t>& input_cols,
29
+ const LookUpFunc& lookup,
30
+ std::vector<idx_t>& output_cols) {
31
+ for (idx_t col : input_cols) {
32
+ while (!output_cols.empty()) {
33
+ idx_t row = rows[output_cols.size() - 1];
34
+ float a = lookup(row, col);
35
+ float b = lookup(row, output_cols.back());
36
+ if (a >= b) { // defeated
37
+ break;
38
+ }
39
+ output_cols.pop_back();
40
+ }
41
+ if (output_cols.size() < rows.size()) {
42
+ output_cols.push_back(col);
43
+ }
44
+ }
45
+ }
46
+
47
+ void interpolate(
48
+ const std::vector<idx_t>& rows,
49
+ const std::vector<idx_t>& cols,
50
+ const LookUpFunc& lookup,
51
+ idx_t* argmins) {
52
+ std::unordered_map<idx_t, idx_t> idx_to_col;
53
+ for (idx_t idx = 0; idx < cols.size(); ++idx) {
54
+ idx_to_col[cols[idx]] = idx;
55
+ }
56
+
57
+ idx_t start = 0;
58
+ for (idx_t r = 0; r < rows.size(); r += 2) {
59
+ idx_t row = rows[r];
60
+ idx_t end = cols.size() - 1;
61
+ if (r < rows.size() - 1) {
62
+ idx_t idx = argmins[rows[r + 1]];
63
+ end = idx_to_col[idx];
64
+ }
65
+ idx_t argmin = cols[start];
66
+ float min = lookup(row, argmin);
67
+ for (idx_t c = start + 1; c <= end; c++) {
68
+ float value = lookup(row, cols[c]);
69
+ if (value < min) {
70
+ argmin = cols[c];
71
+ min = value;
72
+ }
73
+ }
74
+ argmins[row] = argmin;
75
+ start = end;
76
+ }
77
+ }
78
+
79
+ /** SMAWK algo. Find the row minima of a monotone matrix.
80
+ *
81
+ * References:
82
+ * 1. http://web.cs.unlv.edu/larmore/Courses/CSC477/monge.pdf
83
+ * 2. https://gist.github.com/dstein64/8e94a6a25efc1335657e910ff525f405
84
+ * 3. https://github.com/dstein64/kmeans1d
85
+ */
86
+ void smawk_impl(
87
+ const std::vector<idx_t>& rows,
88
+ const std::vector<idx_t>& input_cols,
89
+ const LookUpFunc& lookup,
90
+ idx_t* argmins) {
91
+ if (rows.size() == 0) {
92
+ return;
93
+ }
94
+
95
+ /**********************************
96
+ * REDUCE
97
+ **********************************/
98
+ auto ptr = &input_cols;
99
+ std::vector<idx_t> survived_cols; // survived columns
100
+ if (rows.size() < input_cols.size()) {
101
+ reduce(rows, input_cols, lookup, survived_cols);
102
+ ptr = &survived_cols;
103
+ }
104
+ auto& cols = *ptr; // avoid memory copy
105
+
106
+ /**********************************
107
+ * INTERPOLATE
108
+ **********************************/
109
+
110
+ // call recursively on odd-indexed rows
111
+ std::vector<idx_t> odd_rows;
112
+ for (idx_t i = 1; i < rows.size(); i += 2) {
113
+ odd_rows.push_back(rows[i]);
114
+ }
115
+ smawk_impl(odd_rows, cols, lookup, argmins);
116
+
117
+ // interpolate the even-indexed rows
118
+ interpolate(rows, cols, lookup, argmins);
119
+ }
120
+
121
+ void smawk(
122
+ const idx_t nrows,
123
+ const idx_t ncols,
124
+ const LookUpFunc& lookup,
125
+ idx_t* argmins) {
126
+ std::vector<idx_t> rows(nrows);
127
+ std::vector<idx_t> cols(ncols);
128
+ std::iota(std::begin(rows), std::end(rows), 0);
129
+ std::iota(std::begin(cols), std::end(cols), 0);
130
+
131
+ smawk_impl(rows, cols, lookup, argmins);
132
+ }
133
+
134
+ void smawk(
135
+ const idx_t nrows,
136
+ const idx_t ncols,
137
+ const float* x,
138
+ idx_t* argmins) {
139
+ auto lookup = [&x, &ncols](idx_t i, idx_t j) { return x[i * ncols + j]; };
140
+ smawk(nrows, ncols, lookup, argmins);
141
+ }
142
+
143
+ namespace {
144
+
145
+ class CostCalculator {
146
+ // The reuslt would be inaccurate if we use float
147
+ std::vector<double> cumsum;
148
+ std::vector<double> cumsum2;
149
+
150
+ public:
151
+ CostCalculator(const std::vector<float>& vec, idx_t n) {
152
+ cumsum.push_back(0.0);
153
+ cumsum2.push_back(0.0);
154
+ for (idx_t i = 0; i < n; ++i) {
155
+ float x = vec[i];
156
+ cumsum.push_back(x + cumsum[i]);
157
+ cumsum2.push_back(x * x + cumsum2[i]);
158
+ }
159
+ }
160
+
161
+ float operator()(idx_t i, idx_t j) {
162
+ if (j < i) {
163
+ return 0.0f;
164
+ }
165
+ auto mu = (cumsum[j + 1] - cumsum[i]) / (j - i + 1);
166
+ auto result = cumsum2[j + 1] - cumsum2[i];
167
+ result += (j - i + 1) * (mu * mu);
168
+ result -= (2 * mu) * (cumsum[j + 1] - cumsum[i]);
169
+ return float(result);
170
+ }
171
+ };
172
+
173
+ template <class T>
174
+ class Matrix {
175
+ std::vector<T> data;
176
+ idx_t nrows;
177
+ idx_t ncols;
178
+
179
+ public:
180
+ Matrix(idx_t nrows, idx_t ncols) {
181
+ this->nrows = nrows;
182
+ this->ncols = ncols;
183
+ data.resize(nrows * ncols);
184
+ }
185
+
186
+ inline T& at(idx_t i, idx_t j) {
187
+ return data[i * ncols + j];
188
+ }
189
+ };
190
+
191
+ } // anonymous namespace
192
+
193
+ double kmeans1d(const float* x, size_t n, size_t nclusters, float* centroids) {
194
+ FAISS_THROW_IF_NOT(n >= nclusters);
195
+
196
+ // corner case
197
+ if (n == nclusters) {
198
+ memcpy(centroids, x, n * sizeof(*x));
199
+ return 0.0f;
200
+ }
201
+
202
+ /***************************************************
203
+ * sort in ascending order, O(NlogN) in time
204
+ ***************************************************/
205
+ std::vector<float> arr(x, x + n);
206
+ std::sort(arr.begin(), arr.end());
207
+
208
+ /***************************************************
209
+ dynamic programming algorithm
210
+
211
+ Reference: https://arxiv.org/abs/1701.07204
212
+ -------------------------------
213
+
214
+ Assume x is already sorted in ascending order.
215
+
216
+ N: number of points
217
+ K: number of clusters
218
+
219
+ CC(i, j): the cost of grouping xi,...,xj into one cluster
220
+ D[k][m]: the cost of optimally clustering x1,...,xm into k clusters
221
+ T[k][m]: the start index of the k-th cluster
222
+
223
+ The DP process is as follow:
224
+ D[k][m] = min_i D[k − 1][i − 1] + CC(i, m)
225
+ T[k][m] = argmin_i D[k − 1][i − 1] + CC(i, m)
226
+
227
+ This could be solved in O(KN^2) time and O(KN) space.
228
+
229
+ To further reduce the time complexity, we use SMAWK algo to
230
+ solve the argmin problem as follow:
231
+
232
+ For each k:
233
+ C[m][i] = D[k − 1][i − 1] + CC(i, m)
234
+
235
+ Here C is a n x n totally monotone matrix.
236
+ We could find the row minima by SMAWK in O(N) time.
237
+
238
+ Now the time complexity is reduced from O(kN^2) to O(KN).
239
+ ****************************************************/
240
+
241
+ CostCalculator CC(arr, n);
242
+ Matrix<float> D(nclusters, n);
243
+ Matrix<idx_t> T(nclusters, n);
244
+
245
+ for (idx_t m = 0; m < n; m++) {
246
+ D.at(0, m) = CC(0, m);
247
+ T.at(0, m) = 0;
248
+ }
249
+
250
+ std::vector<idx_t> indices(nclusters, 0);
251
+
252
+ for (idx_t k = 1; k < nclusters; ++k) {
253
+ // we define C here
254
+ auto C = [&D, &CC, &k](idx_t m, idx_t i) {
255
+ if (i == 0) {
256
+ return CC(i, m);
257
+ }
258
+ idx_t col = std::min(m, i - 1);
259
+ return D.at(k - 1, col) + CC(i, m);
260
+ };
261
+
262
+ std::vector<idx_t> argmins(n); // argmin of each row
263
+ smawk(n, n, C, argmins.data());
264
+ for (idx_t m = 0; m < argmins.size(); m++) {
265
+ idx_t idx = argmins[m];
266
+ D.at(k, m) = C(m, idx);
267
+ T.at(k, m) = idx;
268
+ }
269
+ }
270
+
271
+ /***************************************************
272
+ compute centroids by backtracking
273
+
274
+ T[K - 1][T[K][N] - 1] T[K][N] N
275
+ --------------|------------------------|-----------|
276
+ | cluster K - 1 | cluster K |
277
+
278
+ ****************************************************/
279
+
280
+ // for imbalance factor
281
+ double tot = 0.0;
282
+ double uf = 0.0;
283
+
284
+ idx_t end = n;
285
+ for (idx_t k = nclusters - 1; k >= 0; k--) {
286
+ const idx_t start = T.at(k, end - 1);
287
+ const float sum =
288
+ std::accumulate(arr.data() + start, arr.data() + end, 0.0f);
289
+ const idx_t size = end - start;
290
+ FAISS_THROW_IF_NOT_FMT(
291
+ size > 0, "Cluster %d: size %d", int(k), int(size));
292
+ centroids[k] = sum / size;
293
+ end = start;
294
+
295
+ tot += size;
296
+ uf += size * double(size);
297
+ }
298
+
299
+ uf = uf * nclusters / (tot * tot);
300
+ return uf;
301
+ }
302
+
303
+ } // namespace faiss
@@ -0,0 +1,48 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <faiss/Index.h>
11
+ #include <functional>
12
+
13
+ namespace faiss {
14
+
15
+ /** SMAWK algorithm. Find the row minima of a monotone matrix.
16
+ *
17
+ * Expose this for testing.
18
+ *
19
+ * @param nrows number of rows
20
+ * @param ncols number of columns
21
+ * @param x input matrix, size (nrows, ncols)
22
+ * @param argmins argmin of each row
23
+ */
24
+ void smawk(
25
+ const Index::idx_t nrows,
26
+ const Index::idx_t ncols,
27
+ const float* x,
28
+ Index::idx_t* argmins);
29
+
30
+ /** Exact 1D K-Means by dynamic programming
31
+ *
32
+ * From "Fast Exact k-Means, k-Medians and Bregman Divergence Clustering in 1D"
33
+ * Allan Grønlund, Kasper Green Larsen, Alexander Mathiasen, Jesper Sindahl
34
+ * Nielsen, Stefan Schneider, Mingzhou Song, ArXiV'17
35
+ *
36
+ * Section 2.2
37
+ *
38
+ * https://arxiv.org/abs/1701.07204
39
+ *
40
+ * @param x input 1D array
41
+ * @param n input array length
42
+ * @param nclusters number of clusters
43
+ * @param centroids output centroids, size nclusters
44
+ * @return imbalancce factor
45
+ */
46
+ double kmeans1d(const float* x, size_t n, size_t nclusters, float* centroids);
47
+
48
+ } // namespace faiss
@@ -122,30 +122,70 @@ void pq4_pack_codes_range(
122
122
  }
123
123
  }
124
124
 
125
+ namespace {
126
+
127
+ // get the specific address of the vector inside a block
128
+ // shift is used for determine the if the saved in bits 0..3 (false) or
129
+ // bits 4..7 (true)
130
+ uint8_t get_vector_specific_address(
131
+ size_t bbs,
132
+ size_t vector_id,
133
+ size_t sq,
134
+ bool& shift) {
135
+ // get the vector_id inside the block
136
+ vector_id = vector_id % bbs;
137
+ shift = vector_id > 15;
138
+ vector_id = vector_id & 15;
139
+
140
+ // get the address of the vector in sq
141
+ size_t address;
142
+ if (vector_id < 8) {
143
+ address = vector_id << 1;
144
+ } else {
145
+ address = ((vector_id - 8) << 1) + 1;
146
+ }
147
+ if (sq & 1) {
148
+ address += 16;
149
+ }
150
+ return (sq >> 1) * bbs + address;
151
+ }
152
+
153
+ } // anonymous namespace
154
+
125
155
  uint8_t pq4_get_packed_element(
126
156
  const uint8_t* data,
127
157
  size_t bbs,
128
158
  size_t nsq,
129
- size_t i,
159
+ size_t vector_id,
130
160
  size_t sq) {
131
161
  // move to correct bbs-sized block
132
- data += (i / bbs * (nsq / 2) + sq / 2) * bbs;
133
- sq = sq & 1;
134
- i = i % bbs;
135
-
136
- // another step
137
- data += (i / 32) * 32;
138
- i = i % 32;
139
-
140
- if (sq == 1) {
141
- data += 16;
162
+ // number of blocks * block size
163
+ data += (vector_id / bbs) * (((nsq + 1) / 2) * bbs);
164
+ bool shift;
165
+ size_t address = get_vector_specific_address(bbs, vector_id, sq, shift);
166
+ if (shift) {
167
+ return data[address] >> 4;
168
+ } else {
169
+ return data[address] & 15;
142
170
  }
143
- const uint8_t iperm0[16] = {
144
- 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
145
- if (i < 16) {
146
- return data[iperm0[i]] & 15;
171
+ }
172
+
173
+ void pq4_set_packed_element(
174
+ uint8_t* data,
175
+ uint8_t code,
176
+ size_t bbs,
177
+ size_t nsq,
178
+ size_t vector_id,
179
+ size_t sq) {
180
+ // move to correct bbs-sized block
181
+ // number of blocks * block size
182
+ data += (vector_id / bbs) * (((nsq + 1) / 2) * bbs);
183
+ bool shift;
184
+ size_t address = get_vector_specific_address(bbs, vector_id, sq, shift);
185
+ if (shift) {
186
+ data[address] = (code << 4) | (data[address] & 15);
147
187
  } else {
148
- return data[iperm0[i - 16]] >> 4;
188
+ data[address] = code | (data[address] & ~15);
149
189
  }
150
190
  }
151
191
 
@@ -26,7 +26,7 @@ namespace faiss {
26
26
  * The unused bytes are set to 0.
27
27
  *
28
28
  * @param codes input codes, size (ntotal, ceil(M / 2))
29
- * @param nototal number of input codes
29
+ * @param ntotal number of input codes
30
30
  * @param nb output number of codes (ntotal rounded up to a multiple of
31
31
  * bbs)
32
32
  * @param M2 number of sub-quantizers (=M rounded up to a muliple of 2)
@@ -61,14 +61,27 @@ void pq4_pack_codes_range(
61
61
 
62
62
  /** get a single element from a packed codes table
63
63
  *
64
- * @param i vector id
64
+ * @param vector_id vector id
65
65
  * @param sq subquantizer (< nsq)
66
66
  */
67
67
  uint8_t pq4_get_packed_element(
68
68
  const uint8_t* data,
69
69
  size_t bbs,
70
70
  size_t nsq,
71
- size_t i,
71
+ size_t vector_id,
72
+ size_t sq);
73
+
74
+ /** set a single element "code" into a packed codes table
75
+ *
76
+ * @param vector_id vector id
77
+ * @param sq subquantizer (< nsq)
78
+ */
79
+ void pq4_set_packed_element(
80
+ uint8_t* data,
81
+ uint8_t code,
82
+ size_t bbs,
83
+ size_t nsq,
84
+ size_t vector_id,
72
85
  size_t sq);
73
86
 
74
87
  /** Pack Look-up table for consumption by the kernel.
@@ -88,8 +101,9 @@ void pq4_pack_LUT(int nq, int nsq, const uint8_t* src, uint8_t* dest);
88
101
  * @param nsq number of sub-quantizers (muliple of 2)
89
102
  * @param codes packed codes array
90
103
  * @param LUT packed look-up table
104
+ * @param scaler scaler to scale the encoded norm
91
105
  */
92
- template <class ResultHandler>
106
+ template <class ResultHandler, class Scaler>
93
107
  void pq4_accumulate_loop(
94
108
  int nq,
95
109
  size_t nb,
@@ -97,7 +111,8 @@ void pq4_accumulate_loop(
97
111
  int nsq,
98
112
  const uint8_t* codes,
99
113
  const uint8_t* LUT,
100
- ResultHandler& res);
114
+ ResultHandler& res,
115
+ const Scaler& scaler);
101
116
 
102
117
  /* qbs versions, supported only for bbs=32.
103
118
  *
@@ -141,20 +156,22 @@ int pq4_pack_LUT_qbs_q_map(
141
156
 
142
157
  /** Run accumulation loop.
143
158
  *
144
- * @param qbs 4-bit encded number of queries
159
+ * @param qbs 4-bit encoded number of queries
145
160
  * @param nb number of database codes (mutliple of bbs)
146
161
  * @param nsq number of sub-quantizers
147
162
  * @param codes encoded database vectors (packed)
148
163
  * @param LUT look-up table (packed)
149
164
  * @param res call-back for the resutls
165
+ * @param scaler scaler to scale the encoded norm
150
166
  */
151
- template <class ResultHandler>
167
+ template <class ResultHandler, class Scaler>
152
168
  void pq4_accumulate_loop_qbs(
153
169
  int qbs,
154
170
  size_t nb,
155
171
  int nsq,
156
172
  const uint8_t* codes,
157
173
  const uint8_t* LUT,
158
- ResultHandler& res);
174
+ ResultHandler& res,
175
+ const Scaler& scaler);
159
176
 
160
177
  } // namespace faiss
@@ -8,6 +8,7 @@
8
8
  #include <faiss/impl/pq4_fast_scan.h>
9
9
 
10
10
  #include <faiss/impl/FaissAssert.h>
11
+ #include <faiss/impl/LookupTableScaler.h>
11
12
  #include <faiss/impl/simd_result_handlers.h>
12
13
 
13
14
  namespace faiss {
@@ -26,12 +27,13 @@ namespace {
26
27
  * writes results in a ResultHandler
27
28
  */
28
29
 
29
- template <int NQ, int BB, class ResultHandler>
30
+ template <int NQ, int BB, class ResultHandler, class Scaler>
30
31
  void kernel_accumulate_block(
31
32
  int nsq,
32
33
  const uint8_t* codes,
33
34
  const uint8_t* LUT,
34
- ResultHandler& res) {
35
+ ResultHandler& res,
36
+ const Scaler& scaler) {
35
37
  // distance accumulators
36
38
  simd16uint16 accu[NQ][BB][4];
37
39
 
@@ -44,7 +46,7 @@ void kernel_accumulate_block(
44
46
  }
45
47
  }
46
48
 
47
- for (int sq = 0; sq < nsq; sq += 2) {
49
+ for (int sq = 0; sq < nsq - scaler.nscale; sq += 2) {
48
50
  simd32uint8 lut_cache[NQ];
49
51
  for (int q = 0; q < NQ; q++) {
50
52
  lut_cache[q] = simd32uint8(LUT);
@@ -72,6 +74,35 @@ void kernel_accumulate_block(
72
74
  }
73
75
  }
74
76
 
77
+ for (int sq = 0; sq < scaler.nscale; sq += 2) {
78
+ simd32uint8 lut_cache[NQ];
79
+ for (int q = 0; q < NQ; q++) {
80
+ lut_cache[q] = simd32uint8(LUT);
81
+ LUT += 32;
82
+ }
83
+
84
+ for (int b = 0; b < BB; b++) {
85
+ simd32uint8 c = simd32uint8(codes);
86
+ codes += 32;
87
+ simd32uint8 mask(15);
88
+ simd32uint8 chi = simd32uint8(simd16uint16(c) >> 4) & mask;
89
+ simd32uint8 clo = c & mask;
90
+
91
+ for (int q = 0; q < NQ; q++) {
92
+ simd32uint8 lut = lut_cache[q];
93
+
94
+ simd32uint8 res0 = scaler.lookup(lut, clo);
95
+ accu[q][b][0] += scaler.scale_lo(res0); // handle vectors 0..7
96
+ accu[q][b][1] += scaler.scale_hi(res0); // handle vectors 8..15
97
+
98
+ simd32uint8 res1 = scaler.lookup(lut, chi);
99
+ accu[q][b][2] += scaler.scale_lo(res1); // handle vectors 16..23
100
+ accu[q][b][3] +=
101
+ scaler.scale_hi(res1); // handle vectors 24..31
102
+ }
103
+ }
104
+ }
105
+
75
106
  for (int q = 0; q < NQ; q++) {
76
107
  for (int b = 0; b < BB; b++) {
77
108
  accu[q][b][0] -= accu[q][b][1] << 8;
@@ -85,17 +116,18 @@ void kernel_accumulate_block(
85
116
  }
86
117
  }
87
118
 
88
- template <int NQ, int BB, class ResultHandler>
119
+ template <int NQ, int BB, class ResultHandler, class Scaler>
89
120
  void accumulate_fixed_blocks(
90
121
  size_t nb,
91
122
  int nsq,
92
123
  const uint8_t* codes,
93
124
  const uint8_t* LUT,
94
- ResultHandler& res) {
125
+ ResultHandler& res,
126
+ const Scaler& scaler) {
95
127
  constexpr int bbs = 32 * BB;
96
128
  for (int64_t j0 = 0; j0 < nb; j0 += bbs) {
97
129
  FixedStorageHandler<NQ, 2 * BB> res2;
98
- kernel_accumulate_block<NQ, BB>(nsq, codes, LUT, res2);
130
+ kernel_accumulate_block<NQ, BB>(nsq, codes, LUT, res2, scaler);
99
131
  res.set_block_origin(0, j0);
100
132
  res2.to_other_handler(res);
101
133
  codes += bbs * nsq / 2;
@@ -104,7 +136,7 @@ void accumulate_fixed_blocks(
104
136
 
105
137
  } // anonymous namespace
106
138
 
107
- template <class ResultHandler>
139
+ template <class ResultHandler, class Scaler>
108
140
  void pq4_accumulate_loop(
109
141
  int nq,
110
142
  size_t nb,
@@ -112,15 +144,16 @@ void pq4_accumulate_loop(
112
144
  int nsq,
113
145
  const uint8_t* codes,
114
146
  const uint8_t* LUT,
115
- ResultHandler& res) {
147
+ ResultHandler& res,
148
+ const Scaler& scaler) {
116
149
  FAISS_THROW_IF_NOT(is_aligned_pointer(codes));
117
150
  FAISS_THROW_IF_NOT(is_aligned_pointer(LUT));
118
151
  FAISS_THROW_IF_NOT(bbs % 32 == 0);
119
152
  FAISS_THROW_IF_NOT(nb % bbs == 0);
120
153
 
121
- #define DISPATCH(NQ, BB) \
122
- case NQ * 1000 + BB: \
123
- accumulate_fixed_blocks<NQ, BB>(nb, nsq, codes, LUT, res); \
154
+ #define DISPATCH(NQ, BB) \
155
+ case NQ * 1000 + BB: \
156
+ accumulate_fixed_blocks<NQ, BB>(nb, nsq, codes, LUT, res, scaler); \
124
157
  break
125
158
 
126
159
  switch (nq * 1000 + bbs / 32) {
@@ -141,20 +174,28 @@ void pq4_accumulate_loop(
141
174
 
142
175
  // explicit template instantiations
143
176
 
144
- #define INSTANTIATE_ACCUMULATE(TH, C, with_id_map) \
145
- template void pq4_accumulate_loop<TH<C, with_id_map>>( \
146
- int, \
147
- size_t, \
148
- int, \
149
- int, \
150
- const uint8_t*, \
151
- const uint8_t*, \
152
- TH<C, with_id_map>&);
153
-
154
- #define INSTANTIATE_3(C, with_id_map) \
155
- INSTANTIATE_ACCUMULATE(SingleResultHandler, C, with_id_map) \
156
- INSTANTIATE_ACCUMULATE(HeapHandler, C, with_id_map) \
157
- INSTANTIATE_ACCUMULATE(ReservoirHandler, C, with_id_map)
177
+ #define INSTANTIATE_ACCUMULATE(TH, C, with_id_map, S) \
178
+ template void pq4_accumulate_loop<TH<C, with_id_map>, S>( \
179
+ int, \
180
+ size_t, \
181
+ int, \
182
+ int, \
183
+ const uint8_t*, \
184
+ const uint8_t*, \
185
+ TH<C, with_id_map>&, \
186
+ const S&);
187
+
188
+ using DS = DummyScaler;
189
+ using NS = NormTableScaler;
190
+
191
+ #define INSTANTIATE_3(C, with_id_map) \
192
+ INSTANTIATE_ACCUMULATE(SingleResultHandler, C, with_id_map, DS) \
193
+ INSTANTIATE_ACCUMULATE(HeapHandler, C, with_id_map, DS) \
194
+ INSTANTIATE_ACCUMULATE(ReservoirHandler, C, with_id_map, DS) \
195
+ \
196
+ INSTANTIATE_ACCUMULATE(SingleResultHandler, C, with_id_map, NS) \
197
+ INSTANTIATE_ACCUMULATE(HeapHandler, C, with_id_map, NS) \
198
+ INSTANTIATE_ACCUMULATE(ReservoirHandler, C, with_id_map, NS)
158
199
 
159
200
  using Csi = CMax<uint16_t, int>;
160
201
  INSTANTIATE_3(Csi, false);