faiss 0.2.0 → 0.2.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/LICENSE.txt +1 -1
- data/README.md +7 -7
- data/ext/faiss/extconf.rb +6 -3
- data/ext/faiss/numo.hpp +4 -4
- data/ext/faiss/utils.cpp +1 -1
- data/ext/faiss/utils.h +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +292 -291
- data/vendor/faiss/faiss/AutoTune.h +55 -56
- data/vendor/faiss/faiss/Clustering.cpp +365 -194
- data/vendor/faiss/faiss/Clustering.h +102 -35
- data/vendor/faiss/faiss/IVFlib.cpp +171 -195
- data/vendor/faiss/faiss/IVFlib.h +48 -51
- data/vendor/faiss/faiss/Index.cpp +85 -103
- data/vendor/faiss/faiss/Index.h +54 -48
- data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
- data/vendor/faiss/faiss/Index2Layer.h +22 -36
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
- data/vendor/faiss/faiss/IndexBinary.h +140 -132
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
- data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
- data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
- data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
- data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
- data/vendor/faiss/faiss/IndexFlat.h +42 -59
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
- data/vendor/faiss/faiss/IndexHNSW.h +57 -41
- data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
- data/vendor/faiss/faiss/IndexIVF.h +169 -118
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
- data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
- data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
- data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
- data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
- data/vendor/faiss/faiss/IndexLSH.h +20 -38
- data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
- data/vendor/faiss/faiss/IndexLattice.h +11 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
- data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
- data/vendor/faiss/faiss/IndexNSG.h +85 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
- data/vendor/faiss/faiss/IndexPQ.h +64 -82
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
- data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
- data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
- data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
- data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
- data/vendor/faiss/faiss/IndexRefine.h +32 -23
- data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
- data/vendor/faiss/faiss/IndexReplicas.h +62 -56
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
- data/vendor/faiss/faiss/IndexShards.cpp +256 -240
- data/vendor/faiss/faiss/IndexShards.h +85 -73
- data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
- data/vendor/faiss/faiss/MatrixStats.h +7 -10
- data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
- data/vendor/faiss/faiss/MetaIndexes.h +40 -34
- data/vendor/faiss/faiss/MetricType.h +7 -7
- data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
- data/vendor/faiss/faiss/VectorTransform.h +64 -89
- data/vendor/faiss/faiss/clone_index.cpp +78 -73
- data/vendor/faiss/faiss/clone_index.h +4 -9
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
- data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
- data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
- data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
- data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
- data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
- data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
- data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
- data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
- data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
- data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
- data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
- data/vendor/faiss/faiss/impl/FaissException.h +41 -29
- data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
- data/vendor/faiss/faiss/impl/HNSW.h +179 -200
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
- data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
- data/vendor/faiss/faiss/impl/NSG.h +199 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
- data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
- data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
- data/vendor/faiss/faiss/impl/io.cpp +76 -95
- data/vendor/faiss/faiss/impl/io.h +31 -41
- data/vendor/faiss/faiss/impl/io_macros.h +60 -29
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
- data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
- data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
- data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
- data/vendor/faiss/faiss/index_factory.cpp +619 -397
- data/vendor/faiss/faiss/index_factory.h +8 -6
- data/vendor/faiss/faiss/index_io.h +23 -26
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
- data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
- data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
- data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
- data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
- data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
- data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
- data/vendor/faiss/faiss/utils/Heap.h +186 -209
- data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
- data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
- data/vendor/faiss/faiss/utils/distances.cpp +305 -312
- data/vendor/faiss/faiss/utils/distances.h +170 -122
- data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
- data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
- data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
- data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
- data/vendor/faiss/faiss/utils/hamming.h +62 -85
- data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
- data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
- data/vendor/faiss/faiss/utils/partitioning.h +26 -21
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
- data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
- data/vendor/faiss/faiss/utils/random.cpp +39 -63
- data/vendor/faiss/faiss/utils/random.h +13 -16
- data/vendor/faiss/faiss/utils/simdlib.h +4 -2
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
- data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
- data/vendor/faiss/faiss/utils/utils.cpp +304 -287
- data/vendor/faiss/faiss/utils/utils.h +54 -49
- metadata +29 -4
|
@@ -0,0 +1,301 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
#include <algorithm>
|
|
9
|
+
#include <cstdint>
|
|
10
|
+
#include <cstring>
|
|
11
|
+
#include <functional>
|
|
12
|
+
#include <numeric>
|
|
13
|
+
#include <string>
|
|
14
|
+
#include <unordered_map>
|
|
15
|
+
#include <vector>
|
|
16
|
+
|
|
17
|
+
#include <faiss/Index.h>
|
|
18
|
+
#include <faiss/impl/FaissAssert.h>
|
|
19
|
+
#include <faiss/impl/kmeans1d.h>
|
|
20
|
+
|
|
21
|
+
namespace faiss {
|
|
22
|
+
|
|
23
|
+
using idx_t = Index::idx_t;
|
|
24
|
+
using LookUpFunc = std::function<float(idx_t, idx_t)>;
|
|
25
|
+
|
|
26
|
+
void reduce(
|
|
27
|
+
const std::vector<idx_t>& rows,
|
|
28
|
+
const std::vector<idx_t>& input_cols,
|
|
29
|
+
const LookUpFunc& lookup,
|
|
30
|
+
std::vector<idx_t>& output_cols) {
|
|
31
|
+
for (idx_t col : input_cols) {
|
|
32
|
+
while (!output_cols.empty()) {
|
|
33
|
+
idx_t row = rows[output_cols.size() - 1];
|
|
34
|
+
float a = lookup(row, col);
|
|
35
|
+
float b = lookup(row, output_cols.back());
|
|
36
|
+
if (a >= b) { // defeated
|
|
37
|
+
break;
|
|
38
|
+
}
|
|
39
|
+
output_cols.pop_back();
|
|
40
|
+
}
|
|
41
|
+
if (output_cols.size() < rows.size()) {
|
|
42
|
+
output_cols.push_back(col);
|
|
43
|
+
}
|
|
44
|
+
}
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
void interpolate(
|
|
48
|
+
const std::vector<idx_t>& rows,
|
|
49
|
+
const std::vector<idx_t>& cols,
|
|
50
|
+
const LookUpFunc& lookup,
|
|
51
|
+
idx_t* argmins) {
|
|
52
|
+
std::unordered_map<idx_t, idx_t> idx_to_col;
|
|
53
|
+
for (idx_t idx = 0; idx < cols.size(); ++idx) {
|
|
54
|
+
idx_to_col[cols[idx]] = idx;
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
idx_t start = 0;
|
|
58
|
+
for (idx_t r = 0; r < rows.size(); r += 2) {
|
|
59
|
+
idx_t row = rows[r];
|
|
60
|
+
idx_t end = cols.size() - 1;
|
|
61
|
+
if (r < rows.size() - 1) {
|
|
62
|
+
idx_t idx = argmins[rows[r + 1]];
|
|
63
|
+
end = idx_to_col[idx];
|
|
64
|
+
}
|
|
65
|
+
idx_t argmin = cols[start];
|
|
66
|
+
float min = lookup(row, argmin);
|
|
67
|
+
for (idx_t c = start + 1; c <= end; c++) {
|
|
68
|
+
float value = lookup(row, cols[c]);
|
|
69
|
+
if (value < min) {
|
|
70
|
+
argmin = cols[c];
|
|
71
|
+
min = value;
|
|
72
|
+
}
|
|
73
|
+
}
|
|
74
|
+
argmins[row] = argmin;
|
|
75
|
+
start = end;
|
|
76
|
+
}
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
/** SMAWK algo. Find the row minima of a monotone matrix.
|
|
80
|
+
*
|
|
81
|
+
* References:
|
|
82
|
+
* 1. http://web.cs.unlv.edu/larmore/Courses/CSC477/monge.pdf
|
|
83
|
+
* 2. https://gist.github.com/dstein64/8e94a6a25efc1335657e910ff525f405
|
|
84
|
+
* 3. https://github.com/dstein64/kmeans1d
|
|
85
|
+
*/
|
|
86
|
+
void smawk_impl(
|
|
87
|
+
const std::vector<idx_t>& rows,
|
|
88
|
+
const std::vector<idx_t>& input_cols,
|
|
89
|
+
const LookUpFunc& lookup,
|
|
90
|
+
idx_t* argmins) {
|
|
91
|
+
if (rows.size() == 0) {
|
|
92
|
+
return;
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
/**********************************
|
|
96
|
+
* REDUCE
|
|
97
|
+
**********************************/
|
|
98
|
+
auto ptr = &input_cols;
|
|
99
|
+
std::vector<idx_t> survived_cols; // survived columns
|
|
100
|
+
if (rows.size() < input_cols.size()) {
|
|
101
|
+
reduce(rows, input_cols, lookup, survived_cols);
|
|
102
|
+
ptr = &survived_cols;
|
|
103
|
+
}
|
|
104
|
+
auto& cols = *ptr; // avoid memory copy
|
|
105
|
+
|
|
106
|
+
/**********************************
|
|
107
|
+
* INTERPOLATE
|
|
108
|
+
**********************************/
|
|
109
|
+
|
|
110
|
+
// call recursively on odd-indexed rows
|
|
111
|
+
std::vector<idx_t> odd_rows;
|
|
112
|
+
for (idx_t i = 1; i < rows.size(); i += 2) {
|
|
113
|
+
odd_rows.push_back(rows[i]);
|
|
114
|
+
}
|
|
115
|
+
smawk_impl(odd_rows, cols, lookup, argmins);
|
|
116
|
+
|
|
117
|
+
// interpolate the even-indexed rows
|
|
118
|
+
interpolate(rows, cols, lookup, argmins);
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
void smawk(
|
|
122
|
+
const idx_t nrows,
|
|
123
|
+
const idx_t ncols,
|
|
124
|
+
const LookUpFunc& lookup,
|
|
125
|
+
idx_t* argmins) {
|
|
126
|
+
std::vector<idx_t> rows(nrows);
|
|
127
|
+
std::vector<idx_t> cols(ncols);
|
|
128
|
+
std::iota(std::begin(rows), std::end(rows), 0);
|
|
129
|
+
std::iota(std::begin(cols), std::end(cols), 0);
|
|
130
|
+
|
|
131
|
+
smawk_impl(rows, cols, lookup, argmins);
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
void smawk(
|
|
135
|
+
const idx_t nrows,
|
|
136
|
+
const idx_t ncols,
|
|
137
|
+
const float* x,
|
|
138
|
+
idx_t* argmins) {
|
|
139
|
+
auto lookup = [&x, &ncols](idx_t i, idx_t j) { return x[i * ncols + j]; };
|
|
140
|
+
smawk(nrows, ncols, lookup, argmins);
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
namespace {
|
|
144
|
+
|
|
145
|
+
class CostCalculator {
|
|
146
|
+
// The reuslt would be inaccurate if we use float
|
|
147
|
+
std::vector<double> cumsum;
|
|
148
|
+
std::vector<double> cumsum2;
|
|
149
|
+
|
|
150
|
+
public:
|
|
151
|
+
CostCalculator(const std::vector<float>& vec, idx_t n) {
|
|
152
|
+
cumsum.push_back(0.0);
|
|
153
|
+
cumsum2.push_back(0.0);
|
|
154
|
+
for (idx_t i = 0; i < n; ++i) {
|
|
155
|
+
float x = vec[i];
|
|
156
|
+
cumsum.push_back(x + cumsum[i]);
|
|
157
|
+
cumsum2.push_back(x * x + cumsum2[i]);
|
|
158
|
+
}
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
float operator()(idx_t i, idx_t j) {
|
|
162
|
+
if (j < i) {
|
|
163
|
+
return 0.0f;
|
|
164
|
+
}
|
|
165
|
+
auto mu = (cumsum[j + 1] - cumsum[i]) / (j - i + 1);
|
|
166
|
+
auto result = cumsum2[j + 1] - cumsum2[i];
|
|
167
|
+
result += (j - i + 1) * (mu * mu);
|
|
168
|
+
result -= (2 * mu) * (cumsum[j + 1] - cumsum[i]);
|
|
169
|
+
return float(result);
|
|
170
|
+
}
|
|
171
|
+
};
|
|
172
|
+
|
|
173
|
+
template <class T>
|
|
174
|
+
class Matrix {
|
|
175
|
+
std::vector<T> data;
|
|
176
|
+
idx_t nrows;
|
|
177
|
+
idx_t ncols;
|
|
178
|
+
|
|
179
|
+
public:
|
|
180
|
+
Matrix(idx_t nrows, idx_t ncols) {
|
|
181
|
+
this->nrows = nrows;
|
|
182
|
+
this->ncols = ncols;
|
|
183
|
+
data.resize(nrows * ncols);
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
inline T& at(idx_t i, idx_t j) {
|
|
187
|
+
return data[i * ncols + j];
|
|
188
|
+
}
|
|
189
|
+
};
|
|
190
|
+
|
|
191
|
+
} // anonymous namespace
|
|
192
|
+
|
|
193
|
+
double kmeans1d(const float* x, size_t n, size_t nclusters, float* centroids) {
|
|
194
|
+
FAISS_THROW_IF_NOT(n >= nclusters);
|
|
195
|
+
|
|
196
|
+
// corner case
|
|
197
|
+
if (n == nclusters) {
|
|
198
|
+
memcpy(centroids, x, n * sizeof(*x));
|
|
199
|
+
return 0.0f;
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
/***************************************************
|
|
203
|
+
* sort in ascending order, O(NlogN) in time
|
|
204
|
+
***************************************************/
|
|
205
|
+
std::vector<float> arr(x, x + n);
|
|
206
|
+
std::sort(arr.begin(), arr.end());
|
|
207
|
+
|
|
208
|
+
/***************************************************
|
|
209
|
+
dynamic programming algorithm
|
|
210
|
+
|
|
211
|
+
Reference: https://arxiv.org/abs/1701.07204
|
|
212
|
+
-------------------------------
|
|
213
|
+
|
|
214
|
+
Assume x is already sorted in ascending order.
|
|
215
|
+
|
|
216
|
+
N: number of points
|
|
217
|
+
K: number of clusters
|
|
218
|
+
|
|
219
|
+
CC(i, j): the cost of grouping xi,...,xj into one cluster
|
|
220
|
+
D[k][m]: the cost of optimally clustering x1,...,xm into k clusters
|
|
221
|
+
T[k][m]: the start index of the k-th cluster
|
|
222
|
+
|
|
223
|
+
The DP process is as follow:
|
|
224
|
+
D[k][m] = min_i D[k − 1][i − 1] + CC(i, m)
|
|
225
|
+
T[k][m] = argmin_i D[k − 1][i − 1] + CC(i, m)
|
|
226
|
+
|
|
227
|
+
This could be solved in O(KN^2) time and O(KN) space.
|
|
228
|
+
|
|
229
|
+
To further reduce the time complexity, we use SMAWK algo to
|
|
230
|
+
solve the argmin problem as follow:
|
|
231
|
+
|
|
232
|
+
For each k:
|
|
233
|
+
C[m][i] = D[k − 1][i − 1] + CC(i, m)
|
|
234
|
+
|
|
235
|
+
Here C is a n x n totally monotone matrix.
|
|
236
|
+
We could find the row minima by SMAWK in O(N) time.
|
|
237
|
+
|
|
238
|
+
Now the time complexity is reduced from O(kN^2) to O(KN).
|
|
239
|
+
****************************************************/
|
|
240
|
+
|
|
241
|
+
CostCalculator CC(arr, n);
|
|
242
|
+
Matrix<float> D(nclusters, n);
|
|
243
|
+
Matrix<idx_t> T(nclusters, n);
|
|
244
|
+
|
|
245
|
+
for (idx_t m = 0; m < n; m++) {
|
|
246
|
+
D.at(0, m) = CC(0, m);
|
|
247
|
+
T.at(0, m) = 0;
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
std::vector<idx_t> indices(nclusters, 0);
|
|
251
|
+
|
|
252
|
+
for (idx_t k = 1; k < nclusters; ++k) {
|
|
253
|
+
// we define C here
|
|
254
|
+
auto C = [&D, &CC, &k](idx_t m, idx_t i) {
|
|
255
|
+
if (i == 0) {
|
|
256
|
+
return CC(i, m);
|
|
257
|
+
}
|
|
258
|
+
idx_t col = std::min(m, i - 1);
|
|
259
|
+
return D.at(k - 1, col) + CC(i, m);
|
|
260
|
+
};
|
|
261
|
+
|
|
262
|
+
std::vector<idx_t> argmins(n); // argmin of each row
|
|
263
|
+
smawk(n, n, C, argmins.data());
|
|
264
|
+
for (idx_t m = 0; m < argmins.size(); m++) {
|
|
265
|
+
idx_t idx = argmins[m];
|
|
266
|
+
D.at(k, m) = C(m, idx);
|
|
267
|
+
T.at(k, m) = idx;
|
|
268
|
+
}
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
/***************************************************
|
|
272
|
+
compute centroids by backtracking
|
|
273
|
+
|
|
274
|
+
T[K - 1][T[K][N] - 1] T[K][N] N
|
|
275
|
+
--------------|------------------------|-----------|
|
|
276
|
+
| cluster K - 1 | cluster K |
|
|
277
|
+
|
|
278
|
+
****************************************************/
|
|
279
|
+
|
|
280
|
+
// for imbalance factor
|
|
281
|
+
double tot = 0.0, uf = 0.0;
|
|
282
|
+
|
|
283
|
+
idx_t end = n;
|
|
284
|
+
for (idx_t k = nclusters - 1; k >= 0; k--) {
|
|
285
|
+
idx_t start = T.at(k, end - 1);
|
|
286
|
+
float sum = std::accumulate(&arr[start], &arr[end], 0.0f);
|
|
287
|
+
idx_t size = end - start;
|
|
288
|
+
FAISS_THROW_IF_NOT_FMT(
|
|
289
|
+
size > 0, "Cluster %d: size %d", int(k), int(size));
|
|
290
|
+
centroids[k] = sum / size;
|
|
291
|
+
end = start;
|
|
292
|
+
|
|
293
|
+
tot += size;
|
|
294
|
+
uf += size * double(size);
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
uf = uf * nclusters / (tot * tot);
|
|
298
|
+
return uf;
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
} // namespace faiss
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
#pragma once
|
|
9
|
+
|
|
10
|
+
#include <faiss/Index.h>
|
|
11
|
+
#include <functional>
|
|
12
|
+
|
|
13
|
+
namespace faiss {
|
|
14
|
+
|
|
15
|
+
/** SMAWK algorithm. Find the row minima of a monotone matrix.
|
|
16
|
+
*
|
|
17
|
+
* Expose this for testing.
|
|
18
|
+
*
|
|
19
|
+
* @param nrows number of rows
|
|
20
|
+
* @param ncols number of columns
|
|
21
|
+
* @param x input matrix, size (nrows, ncols)
|
|
22
|
+
* @param argmins argmin of each row
|
|
23
|
+
*/
|
|
24
|
+
void smawk(
|
|
25
|
+
const Index::idx_t nrows,
|
|
26
|
+
const Index::idx_t ncols,
|
|
27
|
+
const float* x,
|
|
28
|
+
Index::idx_t* argmins);
|
|
29
|
+
|
|
30
|
+
/** Exact 1D K-Means by dynamic programming
|
|
31
|
+
*
|
|
32
|
+
* From "Fast Exact k-Means, k-Medians and Bregman Divergence Clustering in 1D"
|
|
33
|
+
* Allan Grønlund, Kasper Green Larsen, Alexander Mathiasen, Jesper Sindahl
|
|
34
|
+
* Nielsen, Stefan Schneider, Mingzhou Song, ArXiV'17
|
|
35
|
+
*
|
|
36
|
+
* Section 2.2
|
|
37
|
+
*
|
|
38
|
+
* https://arxiv.org/abs/1701.07204
|
|
39
|
+
*
|
|
40
|
+
* @param x input 1D array
|
|
41
|
+
* @param n input array length
|
|
42
|
+
* @param nclusters number of clusters
|
|
43
|
+
* @param centroids output centroids, size nclusters
|
|
44
|
+
* @return imbalancce factor
|
|
45
|
+
*/
|
|
46
|
+
double kmeans1d(const float* x, size_t n, size_t nclusters, float* centroids);
|
|
47
|
+
|
|
48
|
+
} // namespace faiss
|