faiss 0.2.0 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -0,0 +1,301 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <algorithm>
9
+ #include <cstdint>
10
+ #include <cstring>
11
+ #include <functional>
12
+ #include <numeric>
13
+ #include <string>
14
+ #include <unordered_map>
15
+ #include <vector>
16
+
17
+ #include <faiss/Index.h>
18
+ #include <faiss/impl/FaissAssert.h>
19
+ #include <faiss/impl/kmeans1d.h>
20
+
21
+ namespace faiss {
22
+
23
+ using idx_t = Index::idx_t;
24
+ using LookUpFunc = std::function<float(idx_t, idx_t)>;
25
+
26
+ void reduce(
27
+ const std::vector<idx_t>& rows,
28
+ const std::vector<idx_t>& input_cols,
29
+ const LookUpFunc& lookup,
30
+ std::vector<idx_t>& output_cols) {
31
+ for (idx_t col : input_cols) {
32
+ while (!output_cols.empty()) {
33
+ idx_t row = rows[output_cols.size() - 1];
34
+ float a = lookup(row, col);
35
+ float b = lookup(row, output_cols.back());
36
+ if (a >= b) { // defeated
37
+ break;
38
+ }
39
+ output_cols.pop_back();
40
+ }
41
+ if (output_cols.size() < rows.size()) {
42
+ output_cols.push_back(col);
43
+ }
44
+ }
45
+ }
46
+
47
+ void interpolate(
48
+ const std::vector<idx_t>& rows,
49
+ const std::vector<idx_t>& cols,
50
+ const LookUpFunc& lookup,
51
+ idx_t* argmins) {
52
+ std::unordered_map<idx_t, idx_t> idx_to_col;
53
+ for (idx_t idx = 0; idx < cols.size(); ++idx) {
54
+ idx_to_col[cols[idx]] = idx;
55
+ }
56
+
57
+ idx_t start = 0;
58
+ for (idx_t r = 0; r < rows.size(); r += 2) {
59
+ idx_t row = rows[r];
60
+ idx_t end = cols.size() - 1;
61
+ if (r < rows.size() - 1) {
62
+ idx_t idx = argmins[rows[r + 1]];
63
+ end = idx_to_col[idx];
64
+ }
65
+ idx_t argmin = cols[start];
66
+ float min = lookup(row, argmin);
67
+ for (idx_t c = start + 1; c <= end; c++) {
68
+ float value = lookup(row, cols[c]);
69
+ if (value < min) {
70
+ argmin = cols[c];
71
+ min = value;
72
+ }
73
+ }
74
+ argmins[row] = argmin;
75
+ start = end;
76
+ }
77
+ }
78
+
79
+ /** SMAWK algo. Find the row minima of a monotone matrix.
80
+ *
81
+ * References:
82
+ * 1. http://web.cs.unlv.edu/larmore/Courses/CSC477/monge.pdf
83
+ * 2. https://gist.github.com/dstein64/8e94a6a25efc1335657e910ff525f405
84
+ * 3. https://github.com/dstein64/kmeans1d
85
+ */
86
+ void smawk_impl(
87
+ const std::vector<idx_t>& rows,
88
+ const std::vector<idx_t>& input_cols,
89
+ const LookUpFunc& lookup,
90
+ idx_t* argmins) {
91
+ if (rows.size() == 0) {
92
+ return;
93
+ }
94
+
95
+ /**********************************
96
+ * REDUCE
97
+ **********************************/
98
+ auto ptr = &input_cols;
99
+ std::vector<idx_t> survived_cols; // survived columns
100
+ if (rows.size() < input_cols.size()) {
101
+ reduce(rows, input_cols, lookup, survived_cols);
102
+ ptr = &survived_cols;
103
+ }
104
+ auto& cols = *ptr; // avoid memory copy
105
+
106
+ /**********************************
107
+ * INTERPOLATE
108
+ **********************************/
109
+
110
+ // call recursively on odd-indexed rows
111
+ std::vector<idx_t> odd_rows;
112
+ for (idx_t i = 1; i < rows.size(); i += 2) {
113
+ odd_rows.push_back(rows[i]);
114
+ }
115
+ smawk_impl(odd_rows, cols, lookup, argmins);
116
+
117
+ // interpolate the even-indexed rows
118
+ interpolate(rows, cols, lookup, argmins);
119
+ }
120
+
121
+ void smawk(
122
+ const idx_t nrows,
123
+ const idx_t ncols,
124
+ const LookUpFunc& lookup,
125
+ idx_t* argmins) {
126
+ std::vector<idx_t> rows(nrows);
127
+ std::vector<idx_t> cols(ncols);
128
+ std::iota(std::begin(rows), std::end(rows), 0);
129
+ std::iota(std::begin(cols), std::end(cols), 0);
130
+
131
+ smawk_impl(rows, cols, lookup, argmins);
132
+ }
133
+
134
+ void smawk(
135
+ const idx_t nrows,
136
+ const idx_t ncols,
137
+ const float* x,
138
+ idx_t* argmins) {
139
+ auto lookup = [&x, &ncols](idx_t i, idx_t j) { return x[i * ncols + j]; };
140
+ smawk(nrows, ncols, lookup, argmins);
141
+ }
142
+
143
+ namespace {
144
+
145
+ class CostCalculator {
146
+ // The reuslt would be inaccurate if we use float
147
+ std::vector<double> cumsum;
148
+ std::vector<double> cumsum2;
149
+
150
+ public:
151
+ CostCalculator(const std::vector<float>& vec, idx_t n) {
152
+ cumsum.push_back(0.0);
153
+ cumsum2.push_back(0.0);
154
+ for (idx_t i = 0; i < n; ++i) {
155
+ float x = vec[i];
156
+ cumsum.push_back(x + cumsum[i]);
157
+ cumsum2.push_back(x * x + cumsum2[i]);
158
+ }
159
+ }
160
+
161
+ float operator()(idx_t i, idx_t j) {
162
+ if (j < i) {
163
+ return 0.0f;
164
+ }
165
+ auto mu = (cumsum[j + 1] - cumsum[i]) / (j - i + 1);
166
+ auto result = cumsum2[j + 1] - cumsum2[i];
167
+ result += (j - i + 1) * (mu * mu);
168
+ result -= (2 * mu) * (cumsum[j + 1] - cumsum[i]);
169
+ return float(result);
170
+ }
171
+ };
172
+
173
+ template <class T>
174
+ class Matrix {
175
+ std::vector<T> data;
176
+ idx_t nrows;
177
+ idx_t ncols;
178
+
179
+ public:
180
+ Matrix(idx_t nrows, idx_t ncols) {
181
+ this->nrows = nrows;
182
+ this->ncols = ncols;
183
+ data.resize(nrows * ncols);
184
+ }
185
+
186
+ inline T& at(idx_t i, idx_t j) {
187
+ return data[i * ncols + j];
188
+ }
189
+ };
190
+
191
+ } // anonymous namespace
192
+
193
+ double kmeans1d(const float* x, size_t n, size_t nclusters, float* centroids) {
194
+ FAISS_THROW_IF_NOT(n >= nclusters);
195
+
196
+ // corner case
197
+ if (n == nclusters) {
198
+ memcpy(centroids, x, n * sizeof(*x));
199
+ return 0.0f;
200
+ }
201
+
202
+ /***************************************************
203
+ * sort in ascending order, O(NlogN) in time
204
+ ***************************************************/
205
+ std::vector<float> arr(x, x + n);
206
+ std::sort(arr.begin(), arr.end());
207
+
208
+ /***************************************************
209
+ dynamic programming algorithm
210
+
211
+ Reference: https://arxiv.org/abs/1701.07204
212
+ -------------------------------
213
+
214
+ Assume x is already sorted in ascending order.
215
+
216
+ N: number of points
217
+ K: number of clusters
218
+
219
+ CC(i, j): the cost of grouping xi,...,xj into one cluster
220
+ D[k][m]: the cost of optimally clustering x1,...,xm into k clusters
221
+ T[k][m]: the start index of the k-th cluster
222
+
223
+ The DP process is as follow:
224
+ D[k][m] = min_i D[k − 1][i − 1] + CC(i, m)
225
+ T[k][m] = argmin_i D[k − 1][i − 1] + CC(i, m)
226
+
227
+ This could be solved in O(KN^2) time and O(KN) space.
228
+
229
+ To further reduce the time complexity, we use SMAWK algo to
230
+ solve the argmin problem as follow:
231
+
232
+ For each k:
233
+ C[m][i] = D[k − 1][i − 1] + CC(i, m)
234
+
235
+ Here C is a n x n totally monotone matrix.
236
+ We could find the row minima by SMAWK in O(N) time.
237
+
238
+ Now the time complexity is reduced from O(kN^2) to O(KN).
239
+ ****************************************************/
240
+
241
+ CostCalculator CC(arr, n);
242
+ Matrix<float> D(nclusters, n);
243
+ Matrix<idx_t> T(nclusters, n);
244
+
245
+ for (idx_t m = 0; m < n; m++) {
246
+ D.at(0, m) = CC(0, m);
247
+ T.at(0, m) = 0;
248
+ }
249
+
250
+ std::vector<idx_t> indices(nclusters, 0);
251
+
252
+ for (idx_t k = 1; k < nclusters; ++k) {
253
+ // we define C here
254
+ auto C = [&D, &CC, &k](idx_t m, idx_t i) {
255
+ if (i == 0) {
256
+ return CC(i, m);
257
+ }
258
+ idx_t col = std::min(m, i - 1);
259
+ return D.at(k - 1, col) + CC(i, m);
260
+ };
261
+
262
+ std::vector<idx_t> argmins(n); // argmin of each row
263
+ smawk(n, n, C, argmins.data());
264
+ for (idx_t m = 0; m < argmins.size(); m++) {
265
+ idx_t idx = argmins[m];
266
+ D.at(k, m) = C(m, idx);
267
+ T.at(k, m) = idx;
268
+ }
269
+ }
270
+
271
+ /***************************************************
272
+ compute centroids by backtracking
273
+
274
+ T[K - 1][T[K][N] - 1] T[K][N] N
275
+ --------------|------------------------|-----------|
276
+ | cluster K - 1 | cluster K |
277
+
278
+ ****************************************************/
279
+
280
+ // for imbalance factor
281
+ double tot = 0.0, uf = 0.0;
282
+
283
+ idx_t end = n;
284
+ for (idx_t k = nclusters - 1; k >= 0; k--) {
285
+ idx_t start = T.at(k, end - 1);
286
+ float sum = std::accumulate(&arr[start], &arr[end], 0.0f);
287
+ idx_t size = end - start;
288
+ FAISS_THROW_IF_NOT_FMT(
289
+ size > 0, "Cluster %d: size %d", int(k), int(size));
290
+ centroids[k] = sum / size;
291
+ end = start;
292
+
293
+ tot += size;
294
+ uf += size * double(size);
295
+ }
296
+
297
+ uf = uf * nclusters / (tot * tot);
298
+ return uf;
299
+ }
300
+
301
+ } // namespace faiss
@@ -0,0 +1,48 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <faiss/Index.h>
11
+ #include <functional>
12
+
13
+ namespace faiss {
14
+
15
+ /** SMAWK algorithm. Find the row minima of a monotone matrix.
16
+ *
17
+ * Expose this for testing.
18
+ *
19
+ * @param nrows number of rows
20
+ * @param ncols number of columns
21
+ * @param x input matrix, size (nrows, ncols)
22
+ * @param argmins argmin of each row
23
+ */
24
+ void smawk(
25
+ const Index::idx_t nrows,
26
+ const Index::idx_t ncols,
27
+ const float* x,
28
+ Index::idx_t* argmins);
29
+
30
+ /** Exact 1D K-Means by dynamic programming
31
+ *
32
+ * From "Fast Exact k-Means, k-Medians and Bregman Divergence Clustering in 1D"
33
+ * Allan Grønlund, Kasper Green Larsen, Alexander Mathiasen, Jesper Sindahl
34
+ * Nielsen, Stefan Schneider, Mingzhou Song, ArXiV'17
35
+ *
36
+ * Section 2.2
37
+ *
38
+ * https://arxiv.org/abs/1701.07204
39
+ *
40
+ * @param x input 1D array
41
+ * @param n input array length
42
+ * @param nclusters number of clusters
43
+ * @param centroids output centroids, size nclusters
44
+ * @return imbalancce factor
45
+ */
46
+ double kmeans1d(const float* x, size_t n, size_t nclusters, float* centroids);
47
+
48
+ } // namespace faiss