faiss 0.2.0 → 0.2.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/LICENSE.txt +1 -1
- data/README.md +7 -7
- data/ext/faiss/extconf.rb +6 -3
- data/ext/faiss/numo.hpp +4 -4
- data/ext/faiss/utils.cpp +1 -1
- data/ext/faiss/utils.h +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +292 -291
- data/vendor/faiss/faiss/AutoTune.h +55 -56
- data/vendor/faiss/faiss/Clustering.cpp +365 -194
- data/vendor/faiss/faiss/Clustering.h +102 -35
- data/vendor/faiss/faiss/IVFlib.cpp +171 -195
- data/vendor/faiss/faiss/IVFlib.h +48 -51
- data/vendor/faiss/faiss/Index.cpp +85 -103
- data/vendor/faiss/faiss/Index.h +54 -48
- data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
- data/vendor/faiss/faiss/Index2Layer.h +22 -36
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
- data/vendor/faiss/faiss/IndexBinary.h +140 -132
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
- data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
- data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
- data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
- data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
- data/vendor/faiss/faiss/IndexFlat.h +42 -59
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
- data/vendor/faiss/faiss/IndexHNSW.h +57 -41
- data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
- data/vendor/faiss/faiss/IndexIVF.h +169 -118
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
- data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
- data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
- data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
- data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
- data/vendor/faiss/faiss/IndexLSH.h +20 -38
- data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
- data/vendor/faiss/faiss/IndexLattice.h +11 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
- data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
- data/vendor/faiss/faiss/IndexNSG.h +85 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
- data/vendor/faiss/faiss/IndexPQ.h +64 -82
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
- data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
- data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
- data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
- data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
- data/vendor/faiss/faiss/IndexRefine.h +32 -23
- data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
- data/vendor/faiss/faiss/IndexReplicas.h +62 -56
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
- data/vendor/faiss/faiss/IndexShards.cpp +256 -240
- data/vendor/faiss/faiss/IndexShards.h +85 -73
- data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
- data/vendor/faiss/faiss/MatrixStats.h +7 -10
- data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
- data/vendor/faiss/faiss/MetaIndexes.h +40 -34
- data/vendor/faiss/faiss/MetricType.h +7 -7
- data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
- data/vendor/faiss/faiss/VectorTransform.h +64 -89
- data/vendor/faiss/faiss/clone_index.cpp +78 -73
- data/vendor/faiss/faiss/clone_index.h +4 -9
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
- data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
- data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
- data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
- data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
- data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
- data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
- data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
- data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
- data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
- data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
- data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
- data/vendor/faiss/faiss/impl/FaissException.h +41 -29
- data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
- data/vendor/faiss/faiss/impl/HNSW.h +179 -200
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
- data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
- data/vendor/faiss/faiss/impl/NSG.h +199 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
- data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
- data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
- data/vendor/faiss/faiss/impl/io.cpp +76 -95
- data/vendor/faiss/faiss/impl/io.h +31 -41
- data/vendor/faiss/faiss/impl/io_macros.h +60 -29
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
- data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
- data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
- data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
- data/vendor/faiss/faiss/index_factory.cpp +619 -397
- data/vendor/faiss/faiss/index_factory.h +8 -6
- data/vendor/faiss/faiss/index_io.h +23 -26
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
- data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
- data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
- data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
- data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
- data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
- data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
- data/vendor/faiss/faiss/utils/Heap.h +186 -209
- data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
- data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
- data/vendor/faiss/faiss/utils/distances.cpp +305 -312
- data/vendor/faiss/faiss/utils/distances.h +170 -122
- data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
- data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
- data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
- data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
- data/vendor/faiss/faiss/utils/hamming.h +62 -85
- data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
- data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
- data/vendor/faiss/faiss/utils/partitioning.h +26 -21
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
- data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
- data/vendor/faiss/faiss/utils/random.cpp +39 -63
- data/vendor/faiss/faiss/utils/random.h +13 -16
- data/vendor/faiss/faiss/utils/simdlib.h +4 -2
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
- data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
- data/vendor/faiss/faiss/utils/utils.cpp +304 -287
- data/vendor/faiss/faiss/utils/utils.h +54 -49
- metadata +29 -4
|
@@ -8,6 +8,7 @@
|
|
|
8
8
|
// -*- c++ -*-
|
|
9
9
|
|
|
10
10
|
#include <faiss/Clustering.h>
|
|
11
|
+
#include <faiss/VectorTransform.h>
|
|
11
12
|
#include <faiss/impl/AuxIndexStructures.h>
|
|
12
13
|
|
|
13
14
|
#include <cinttypes>
|
|
@@ -17,100 +18,101 @@
|
|
|
17
18
|
|
|
18
19
|
#include <omp.h>
|
|
19
20
|
|
|
20
|
-
#include <faiss/utils/utils.h>
|
|
21
|
-
#include <faiss/utils/random.h>
|
|
22
|
-
#include <faiss/utils/distances.h>
|
|
23
|
-
#include <faiss/impl/FaissAssert.h>
|
|
24
21
|
#include <faiss/IndexFlat.h>
|
|
22
|
+
#include <faiss/impl/FaissAssert.h>
|
|
23
|
+
#include <faiss/impl/kmeans1d.h>
|
|
24
|
+
#include <faiss/utils/distances.h>
|
|
25
|
+
#include <faiss/utils/random.h>
|
|
26
|
+
#include <faiss/utils/utils.h>
|
|
25
27
|
|
|
26
28
|
namespace faiss {
|
|
27
29
|
|
|
28
|
-
ClusteringParameters::ClusteringParameters
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
{}
|
|
30
|
+
ClusteringParameters::ClusteringParameters()
|
|
31
|
+
: niter(25),
|
|
32
|
+
nredo(1),
|
|
33
|
+
verbose(false),
|
|
34
|
+
spherical(false),
|
|
35
|
+
int_centroids(false),
|
|
36
|
+
update_index(false),
|
|
37
|
+
frozen_centroids(false),
|
|
38
|
+
min_points_per_centroid(39),
|
|
39
|
+
max_points_per_centroid(256),
|
|
40
|
+
seed(1234),
|
|
41
|
+
decode_block_size(32768) {}
|
|
41
42
|
// 39 corresponds to 10000 / 256 -> to avoid warnings on PQ tests with randu10k
|
|
42
43
|
|
|
44
|
+
Clustering::Clustering(int d, int k) : d(d), k(k) {}
|
|
43
45
|
|
|
44
|
-
Clustering::Clustering
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
Clustering::Clustering (int d, int k, const ClusteringParameters &cp):
|
|
48
|
-
ClusteringParameters (cp), d(d), k(k) {}
|
|
46
|
+
Clustering::Clustering(int d, int k, const ClusteringParameters& cp)
|
|
47
|
+
: ClusteringParameters(cp), d(d), k(k) {}
|
|
49
48
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
static double imbalance_factor (int n, int k, int64_t *assign) {
|
|
49
|
+
static double imbalance_factor(int n, int k, int64_t* assign) {
|
|
53
50
|
std::vector<int> hist(k, 0);
|
|
54
51
|
for (int i = 0; i < n; i++)
|
|
55
52
|
hist[assign[i]]++;
|
|
56
53
|
|
|
57
54
|
double tot = 0, uf = 0;
|
|
58
55
|
|
|
59
|
-
for (int i = 0
|
|
56
|
+
for (int i = 0; i < k; i++) {
|
|
60
57
|
tot += hist[i];
|
|
61
|
-
uf += hist[i] * (double)
|
|
58
|
+
uf += hist[i] * (double)hist[i];
|
|
62
59
|
}
|
|
63
60
|
uf = uf * k / (tot * tot);
|
|
64
61
|
|
|
65
62
|
return uf;
|
|
66
63
|
}
|
|
67
64
|
|
|
68
|
-
void Clustering::post_process_centroids
|
|
69
|
-
{
|
|
70
|
-
|
|
65
|
+
void Clustering::post_process_centroids() {
|
|
71
66
|
if (spherical) {
|
|
72
|
-
fvec_renorm_L2
|
|
67
|
+
fvec_renorm_L2(d, k, centroids.data());
|
|
73
68
|
}
|
|
74
69
|
|
|
75
70
|
if (int_centroids) {
|
|
76
71
|
for (size_t i = 0; i < centroids.size(); i++)
|
|
77
|
-
centroids[i] = roundf
|
|
72
|
+
centroids[i] = roundf(centroids[i]);
|
|
78
73
|
}
|
|
79
74
|
}
|
|
80
75
|
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
76
|
+
void Clustering::train(
|
|
77
|
+
idx_t nx,
|
|
78
|
+
const float* x_in,
|
|
79
|
+
Index& index,
|
|
80
|
+
const float* weights) {
|
|
81
|
+
train_encoded(
|
|
82
|
+
nx,
|
|
83
|
+
reinterpret_cast<const uint8_t*>(x_in),
|
|
84
|
+
nullptr,
|
|
85
|
+
index,
|
|
86
|
+
weights);
|
|
86
87
|
}
|
|
87
88
|
|
|
88
|
-
|
|
89
89
|
namespace {
|
|
90
90
|
|
|
91
91
|
using idx_t = Clustering::idx_t;
|
|
92
92
|
|
|
93
93
|
idx_t subsample_training_set(
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
94
|
+
const Clustering& clus,
|
|
95
|
+
idx_t nx,
|
|
96
|
+
const uint8_t* x,
|
|
97
|
+
size_t line_size,
|
|
98
|
+
const float* weights,
|
|
99
|
+
uint8_t** x_out,
|
|
100
|
+
float** weights_out) {
|
|
100
101
|
if (clus.verbose) {
|
|
101
102
|
printf("Sampling a subset of %zd / %" PRId64 " for training\n",
|
|
102
|
-
clus.k * clus.max_points_per_centroid,
|
|
103
|
+
clus.k * clus.max_points_per_centroid,
|
|
104
|
+
nx);
|
|
103
105
|
}
|
|
104
|
-
std::vector<int> perm
|
|
105
|
-
rand_perm
|
|
106
|
+
std::vector<int> perm(nx);
|
|
107
|
+
rand_perm(perm.data(), nx, clus.seed);
|
|
106
108
|
nx = clus.k * clus.max_points_per_centroid;
|
|
107
|
-
uint8_t
|
|
109
|
+
uint8_t* x_new = new uint8_t[nx * line_size];
|
|
108
110
|
*x_out = x_new;
|
|
109
111
|
for (idx_t i = 0; i < nx; i++) {
|
|
110
|
-
memcpy
|
|
112
|
+
memcpy(x_new + i * line_size, x + perm[i] * line_size, line_size);
|
|
111
113
|
}
|
|
112
114
|
if (weights) {
|
|
113
|
-
float
|
|
115
|
+
float* weights_new = new float[nx];
|
|
114
116
|
for (idx_t i = 0; i < nx; i++) {
|
|
115
117
|
weights_new[i] = weights[perm[i]];
|
|
116
118
|
}
|
|
@@ -134,20 +136,23 @@ idx_t subsample_training_set(
|
|
|
134
136
|
*
|
|
135
137
|
*/
|
|
136
138
|
|
|
137
|
-
void compute_centroids
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
139
|
+
void compute_centroids(
|
|
140
|
+
size_t d,
|
|
141
|
+
size_t k,
|
|
142
|
+
size_t n,
|
|
143
|
+
size_t k_frozen,
|
|
144
|
+
const uint8_t* x,
|
|
145
|
+
const Index* codec,
|
|
146
|
+
const int64_t* assign,
|
|
147
|
+
const float* weights,
|
|
148
|
+
float* hassign,
|
|
149
|
+
float* centroids) {
|
|
145
150
|
k -= k_frozen;
|
|
146
151
|
centroids += k_frozen * d;
|
|
147
152
|
|
|
148
|
-
memset
|
|
153
|
+
memset(centroids, 0, sizeof(*centroids) * d * k);
|
|
149
154
|
|
|
150
|
-
size_t line_size = codec ? codec->sa_code_size() : d * sizeof
|
|
155
|
+
size_t line_size = codec ? codec->sa_code_size() : d * sizeof(float);
|
|
151
156
|
|
|
152
157
|
#pragma omp parallel
|
|
153
158
|
{
|
|
@@ -157,20 +162,20 @@ void compute_centroids (size_t d, size_t k, size_t n,
|
|
|
157
162
|
// this thread is taking care of centroids c0:c1
|
|
158
163
|
size_t c0 = (k * rank) / nt;
|
|
159
164
|
size_t c1 = (k * (rank + 1)) / nt;
|
|
160
|
-
std::vector<float> decode_buffer
|
|
165
|
+
std::vector<float> decode_buffer(d);
|
|
161
166
|
|
|
162
167
|
for (size_t i = 0; i < n; i++) {
|
|
163
168
|
int64_t ci = assign[i];
|
|
164
|
-
assert
|
|
169
|
+
assert(ci >= 0 && ci < k + k_frozen);
|
|
165
170
|
ci -= k_frozen;
|
|
166
|
-
if (ci >= c0 && ci < c1)
|
|
167
|
-
float
|
|
168
|
-
const float
|
|
171
|
+
if (ci >= c0 && ci < c1) {
|
|
172
|
+
float* c = centroids + ci * d;
|
|
173
|
+
const float* xi;
|
|
169
174
|
if (!codec) {
|
|
170
175
|
xi = reinterpret_cast<const float*>(x + i * line_size);
|
|
171
176
|
} else {
|
|
172
|
-
float
|
|
173
|
-
codec->sa_decode
|
|
177
|
+
float* xif = decode_buffer.data();
|
|
178
|
+
codec->sa_decode(1, x + i * line_size, xif);
|
|
174
179
|
xi = xif;
|
|
175
180
|
}
|
|
176
181
|
if (weights) {
|
|
@@ -187,7 +192,6 @@ void compute_centroids (size_t d, size_t k, size_t n,
|
|
|
187
192
|
}
|
|
188
193
|
}
|
|
189
194
|
}
|
|
190
|
-
|
|
191
195
|
}
|
|
192
196
|
|
|
193
197
|
#pragma omp parallel for
|
|
@@ -196,12 +200,11 @@ void compute_centroids (size_t d, size_t k, size_t n,
|
|
|
196
200
|
continue;
|
|
197
201
|
}
|
|
198
202
|
float norm = 1 / hassign[ci];
|
|
199
|
-
float
|
|
203
|
+
float* c = centroids + ci * d;
|
|
200
204
|
for (size_t j = 0; j < d; j++) {
|
|
201
205
|
c[j] *= norm;
|
|
202
206
|
}
|
|
203
207
|
}
|
|
204
|
-
|
|
205
208
|
}
|
|
206
209
|
|
|
207
210
|
// a bit above machine epsilon for float16
|
|
@@ -214,29 +217,33 @@ void compute_centroids (size_t d, size_t k, size_t n,
|
|
|
214
217
|
*
|
|
215
218
|
* @return nb of spliting operations (larger is worse)
|
|
216
219
|
*/
|
|
217
|
-
int split_clusters
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
220
|
+
int split_clusters(
|
|
221
|
+
size_t d,
|
|
222
|
+
size_t k,
|
|
223
|
+
size_t n,
|
|
224
|
+
size_t k_frozen,
|
|
225
|
+
float* hassign,
|
|
226
|
+
float* centroids) {
|
|
222
227
|
k -= k_frozen;
|
|
223
228
|
centroids += k_frozen * d;
|
|
224
229
|
|
|
225
230
|
/* Take care of void clusters */
|
|
226
231
|
size_t nsplit = 0;
|
|
227
|
-
RandomGenerator rng
|
|
232
|
+
RandomGenerator rng(1234);
|
|
228
233
|
for (size_t ci = 0; ci < k; ci++) {
|
|
229
234
|
if (hassign[ci] == 0) { /* need to redefine a centroid */
|
|
230
235
|
size_t cj;
|
|
231
236
|
for (cj = 0; 1; cj = (cj + 1) % k) {
|
|
232
237
|
/* probability to pick this cluster for split */
|
|
233
|
-
float p = (hassign[cj] - 1.0) / (float)
|
|
234
|
-
float r = rng.rand_float
|
|
238
|
+
float p = (hassign[cj] - 1.0) / (float)(n - k);
|
|
239
|
+
float r = rng.rand_float();
|
|
235
240
|
if (r < p) {
|
|
236
241
|
break; /* found our cluster to be split */
|
|
237
242
|
}
|
|
238
243
|
}
|
|
239
|
-
memcpy
|
|
244
|
+
memcpy(centroids + ci * d,
|
|
245
|
+
centroids + cj * d,
|
|
246
|
+
sizeof(*centroids) * d);
|
|
240
247
|
|
|
241
248
|
/* small symmetric pertubation */
|
|
242
249
|
for (size_t j = 0; j < d; j++) {
|
|
@@ -257,30 +264,35 @@ int split_clusters (size_t d, size_t k, size_t n,
|
|
|
257
264
|
}
|
|
258
265
|
|
|
259
266
|
return nsplit;
|
|
260
|
-
|
|
261
267
|
}
|
|
262
268
|
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
269
|
+
}; // namespace
|
|
270
|
+
|
|
271
|
+
void Clustering::train_encoded(
|
|
272
|
+
idx_t nx,
|
|
273
|
+
const uint8_t* x_in,
|
|
274
|
+
const Index* codec,
|
|
275
|
+
Index& index,
|
|
276
|
+
const float* weights) {
|
|
277
|
+
FAISS_THROW_IF_NOT_FMT(
|
|
278
|
+
nx >= k,
|
|
279
|
+
"Number of training points (%" PRId64
|
|
280
|
+
") should be at least "
|
|
281
|
+
"as large as number of clusters (%zd)",
|
|
282
|
+
nx,
|
|
283
|
+
k);
|
|
284
|
+
|
|
285
|
+
FAISS_THROW_IF_NOT_FMT(
|
|
286
|
+
(!codec || codec->d == d),
|
|
287
|
+
"Codec dimension %d not the same as data dimension %d",
|
|
288
|
+
int(codec->d),
|
|
289
|
+
int(d));
|
|
290
|
+
|
|
291
|
+
FAISS_THROW_IF_NOT_FMT(
|
|
292
|
+
index.d == d,
|
|
282
293
|
"Index dimension %d not the same as data dimension %d",
|
|
283
|
-
int(index.d),
|
|
294
|
+
int(index.d),
|
|
295
|
+
int(d));
|
|
284
296
|
|
|
285
297
|
double t0 = getmillisecs();
|
|
286
298
|
|
|
@@ -288,67 +300,78 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
|
|
|
288
300
|
// Check for NaNs in input data. Normally it is the user's
|
|
289
301
|
// responsibility, but it may spare us some hard-to-debug
|
|
290
302
|
// reports.
|
|
291
|
-
const float
|
|
303
|
+
const float* x = reinterpret_cast<const float*>(x_in);
|
|
292
304
|
for (size_t i = 0; i < nx * d; i++) {
|
|
293
|
-
FAISS_THROW_IF_NOT_MSG
|
|
294
|
-
|
|
305
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
306
|
+
std::isfinite(x[i]), "input contains NaN's or Inf's");
|
|
295
307
|
}
|
|
296
308
|
}
|
|
297
309
|
|
|
298
|
-
const uint8_t
|
|
299
|
-
std::unique_ptr<uint8_t
|
|
300
|
-
std::unique_ptr<float
|
|
310
|
+
const uint8_t* x = x_in;
|
|
311
|
+
std::unique_ptr<uint8_t[]> del1;
|
|
312
|
+
std::unique_ptr<float[]> del3;
|
|
301
313
|
size_t line_size = codec ? codec->sa_code_size() : sizeof(float) * d;
|
|
302
314
|
|
|
303
315
|
if (nx > k * max_points_per_centroid) {
|
|
304
|
-
uint8_t
|
|
305
|
-
float
|
|
306
|
-
nx = subsample_training_set
|
|
307
|
-
|
|
308
|
-
del1.reset
|
|
309
|
-
|
|
316
|
+
uint8_t* x_new;
|
|
317
|
+
float* weights_new;
|
|
318
|
+
nx = subsample_training_set(
|
|
319
|
+
*this, nx, x, line_size, weights, &x_new, &weights_new);
|
|
320
|
+
del1.reset(x_new);
|
|
321
|
+
x = x_new;
|
|
322
|
+
del3.reset(weights_new);
|
|
323
|
+
weights = weights_new;
|
|
310
324
|
} else if (nx < k * min_points_per_centroid) {
|
|
311
|
-
fprintf
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
325
|
+
fprintf(stderr,
|
|
326
|
+
"WARNING clustering %" PRId64
|
|
327
|
+
" points to %zd centroids: "
|
|
328
|
+
"please provide at least %" PRId64 " training points\n",
|
|
329
|
+
nx,
|
|
330
|
+
k,
|
|
331
|
+
idx_t(k) * min_points_per_centroid);
|
|
315
332
|
}
|
|
316
333
|
|
|
317
334
|
if (nx == k) {
|
|
318
335
|
// this is a corner case, just copy training set to clusters
|
|
319
336
|
if (verbose) {
|
|
320
|
-
printf("Number of training points (%" PRId64
|
|
321
|
-
"
|
|
337
|
+
printf("Number of training points (%" PRId64
|
|
338
|
+
") same as number of "
|
|
339
|
+
"clusters, just copying\n",
|
|
340
|
+
nx);
|
|
322
341
|
}
|
|
323
|
-
centroids.resize
|
|
342
|
+
centroids.resize(d * k);
|
|
324
343
|
if (!codec) {
|
|
325
|
-
memcpy
|
|
344
|
+
memcpy(centroids.data(), x_in, sizeof(float) * d * k);
|
|
326
345
|
} else {
|
|
327
|
-
codec->sa_decode
|
|
346
|
+
codec->sa_decode(nx, x_in, centroids.data());
|
|
328
347
|
}
|
|
329
348
|
|
|
330
349
|
// one fake iteration...
|
|
331
|
-
ClusteringIterationStats stats = {
|
|
332
|
-
iteration_stats.push_back
|
|
350
|
+
ClusteringIterationStats stats = {0.0, 0.0, 0.0, 1.0, 0};
|
|
351
|
+
iteration_stats.push_back(stats);
|
|
333
352
|
|
|
334
353
|
index.reset();
|
|
335
354
|
index.add(k, centroids.data());
|
|
336
355
|
return;
|
|
337
356
|
}
|
|
338
357
|
|
|
339
|
-
|
|
340
358
|
if (verbose) {
|
|
341
|
-
printf("Clustering %" PRId64
|
|
359
|
+
printf("Clustering %" PRId64
|
|
360
|
+
" points in %zdD to %zd clusters, "
|
|
342
361
|
"redo %d times, %d iterations\n",
|
|
343
|
-
nx,
|
|
362
|
+
nx,
|
|
363
|
+
d,
|
|
364
|
+
k,
|
|
365
|
+
nredo,
|
|
366
|
+
niter);
|
|
344
367
|
if (codec) {
|
|
345
368
|
printf("Input data encoded in %zd bytes per vector\n",
|
|
346
|
-
codec->sa_code_size
|
|
369
|
+
codec->sa_code_size());
|
|
347
370
|
}
|
|
348
371
|
}
|
|
349
372
|
|
|
350
|
-
std::unique_ptr<idx_t
|
|
351
|
-
std::unique_ptr<float
|
|
373
|
+
std::unique_ptr<idx_t[]> assign(new idx_t[nx]);
|
|
374
|
+
std::unique_ptr<float[]> dis(new float[nx]);
|
|
352
375
|
|
|
353
376
|
// remember best iteration for redo
|
|
354
377
|
bool lower_is_better = index.metric_type != METRIC_INNER_PRODUCT;
|
|
@@ -358,52 +381,49 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
|
|
|
358
381
|
|
|
359
382
|
// support input centroids
|
|
360
383
|
|
|
361
|
-
FAISS_THROW_IF_NOT_MSG
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
);
|
|
384
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
385
|
+
centroids.size() % d == 0,
|
|
386
|
+
"size of provided input centroids not a multiple of dimension");
|
|
365
387
|
|
|
366
388
|
size_t n_input_centroids = centroids.size() / d;
|
|
367
389
|
|
|
368
390
|
if (verbose && n_input_centroids > 0) {
|
|
369
|
-
printf
|
|
370
|
-
|
|
391
|
+
printf(" Using %zd centroids provided as input (%sfrozen)\n",
|
|
392
|
+
n_input_centroids,
|
|
393
|
+
frozen_centroids ? "" : "not ");
|
|
371
394
|
}
|
|
372
395
|
|
|
373
396
|
double t_search_tot = 0;
|
|
374
397
|
if (verbose) {
|
|
375
|
-
printf(" Preprocessing in %.2f s\n",
|
|
376
|
-
(getmillisecs() - t0) / 1000.);
|
|
398
|
+
printf(" Preprocessing in %.2f s\n", (getmillisecs() - t0) / 1000.);
|
|
377
399
|
}
|
|
378
400
|
t0 = getmillisecs();
|
|
379
401
|
|
|
380
402
|
// temporary buffer to decode vectors during the optimization
|
|
381
|
-
std::vector<float> decode_buffer
|
|
382
|
-
(codec ? d * decode_block_size : 0);
|
|
403
|
+
std::vector<float> decode_buffer(codec ? d * decode_block_size : 0);
|
|
383
404
|
|
|
384
405
|
for (int redo = 0; redo < nredo; redo++) {
|
|
385
|
-
|
|
386
406
|
if (verbose && nredo > 1) {
|
|
387
407
|
printf("Outer iteration %d / %d\n", redo, nredo);
|
|
388
408
|
}
|
|
389
409
|
|
|
390
410
|
// initialize (remaining) centroids with random points from the dataset
|
|
391
|
-
centroids.resize
|
|
392
|
-
std::vector<int> perm
|
|
411
|
+
centroids.resize(d * k);
|
|
412
|
+
std::vector<int> perm(nx);
|
|
393
413
|
|
|
394
|
-
rand_perm
|
|
414
|
+
rand_perm(perm.data(), nx, seed + 1 + redo * 15486557L);
|
|
395
415
|
|
|
396
416
|
if (!codec) {
|
|
397
|
-
for (int i = n_input_centroids; i < k
|
|
398
|
-
memcpy
|
|
417
|
+
for (int i = n_input_centroids; i < k; i++) {
|
|
418
|
+
memcpy(¢roids[i * d], x + perm[i] * line_size, line_size);
|
|
399
419
|
}
|
|
400
420
|
} else {
|
|
401
|
-
for (int i = n_input_centroids; i < k
|
|
402
|
-
codec->sa_decode
|
|
421
|
+
for (int i = n_input_centroids; i < k; i++) {
|
|
422
|
+
codec->sa_decode(1, x + perm[i] * line_size, ¢roids[i * d]);
|
|
403
423
|
}
|
|
404
424
|
}
|
|
405
425
|
|
|
406
|
-
post_process_centroids
|
|
426
|
+
post_process_centroids();
|
|
407
427
|
|
|
408
428
|
// prepare the index
|
|
409
429
|
|
|
@@ -412,10 +432,10 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
|
|
|
412
432
|
}
|
|
413
433
|
|
|
414
434
|
if (!index.is_trained) {
|
|
415
|
-
index.train
|
|
435
|
+
index.train(k, centroids.data());
|
|
416
436
|
}
|
|
417
437
|
|
|
418
|
-
index.add
|
|
438
|
+
index.add(k, centroids.data());
|
|
419
439
|
|
|
420
440
|
// k-means iterations
|
|
421
441
|
|
|
@@ -424,18 +444,28 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
|
|
|
424
444
|
double t0s = getmillisecs();
|
|
425
445
|
|
|
426
446
|
if (!codec) {
|
|
427
|
-
index.search
|
|
428
|
-
|
|
447
|
+
index.search(
|
|
448
|
+
nx,
|
|
449
|
+
reinterpret_cast<const float*>(x),
|
|
450
|
+
1,
|
|
451
|
+
dis.get(),
|
|
452
|
+
assign.get());
|
|
429
453
|
} else {
|
|
430
454
|
// search by blocks of decode_block_size vectors
|
|
431
|
-
size_t code_size = codec->sa_code_size
|
|
455
|
+
size_t code_size = codec->sa_code_size();
|
|
432
456
|
for (size_t i0 = 0; i0 < nx; i0 += decode_block_size) {
|
|
433
457
|
size_t i1 = i0 + decode_block_size;
|
|
434
|
-
if (i1 > nx) {
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
458
|
+
if (i1 > nx) {
|
|
459
|
+
i1 = nx;
|
|
460
|
+
}
|
|
461
|
+
codec->sa_decode(
|
|
462
|
+
i1 - i0, x + code_size * i0, decode_buffer.data());
|
|
463
|
+
index.search(
|
|
464
|
+
i1 - i0,
|
|
465
|
+
decode_buffer.data(),
|
|
466
|
+
1,
|
|
467
|
+
dis.get() + i0,
|
|
468
|
+
assign.get() + i0);
|
|
439
469
|
}
|
|
440
470
|
}
|
|
441
471
|
|
|
@@ -449,61 +479,71 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
|
|
|
449
479
|
}
|
|
450
480
|
|
|
451
481
|
// update the centroids
|
|
452
|
-
std::vector<float> hassign
|
|
482
|
+
std::vector<float> hassign(k);
|
|
453
483
|
|
|
454
484
|
size_t k_frozen = frozen_centroids ? n_input_centroids : 0;
|
|
455
|
-
compute_centroids
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
485
|
+
compute_centroids(
|
|
486
|
+
d,
|
|
487
|
+
k,
|
|
488
|
+
nx,
|
|
489
|
+
k_frozen,
|
|
490
|
+
x,
|
|
491
|
+
codec,
|
|
492
|
+
assign.get(),
|
|
493
|
+
weights,
|
|
494
|
+
hassign.data(),
|
|
495
|
+
centroids.data());
|
|
496
|
+
|
|
497
|
+
int nsplit = split_clusters(
|
|
498
|
+
d, k, nx, k_frozen, hassign.data(), centroids.data());
|
|
465
499
|
|
|
466
500
|
// collect statistics
|
|
467
|
-
ClusteringIterationStats stats =
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
501
|
+
ClusteringIterationStats stats = {
|
|
502
|
+
obj,
|
|
503
|
+
(getmillisecs() - t0) / 1000.0,
|
|
504
|
+
t_search_tot / 1000,
|
|
505
|
+
imbalance_factor(nx, k, assign.get()),
|
|
506
|
+
nsplit};
|
|
472
507
|
iteration_stats.push_back(stats);
|
|
473
508
|
|
|
474
509
|
if (verbose) {
|
|
475
|
-
printf
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
510
|
+
printf(" Iteration %d (%.2f s, search %.2f s): "
|
|
511
|
+
"objective=%g imbalance=%.3f nsplit=%d \r",
|
|
512
|
+
i,
|
|
513
|
+
stats.time,
|
|
514
|
+
stats.time_search,
|
|
515
|
+
stats.obj,
|
|
516
|
+
stats.imbalance_factor,
|
|
517
|
+
nsplit);
|
|
518
|
+
fflush(stdout);
|
|
480
519
|
}
|
|
481
520
|
|
|
482
|
-
post_process_centroids
|
|
521
|
+
post_process_centroids();
|
|
483
522
|
|
|
484
523
|
// add centroids to index for the next iteration (or for output)
|
|
485
524
|
|
|
486
|
-
index.reset
|
|
525
|
+
index.reset();
|
|
487
526
|
if (update_index) {
|
|
488
|
-
index.train
|
|
527
|
+
index.train(k, centroids.data());
|
|
489
528
|
}
|
|
490
529
|
|
|
491
|
-
index.add
|
|
492
|
-
InterruptCallback::check
|
|
530
|
+
index.add(k, centroids.data());
|
|
531
|
+
InterruptCallback::check();
|
|
493
532
|
}
|
|
494
533
|
|
|
495
|
-
if (verbose)
|
|
534
|
+
if (verbose)
|
|
535
|
+
printf("\n");
|
|
496
536
|
if (nredo > 1) {
|
|
497
537
|
if ((lower_is_better && obj < best_obj) ||
|
|
498
538
|
(!lower_is_better && obj > best_obj)) {
|
|
499
539
|
if (verbose) {
|
|
500
|
-
printf
|
|
540
|
+
printf("Objective improved: keep new clusters\n");
|
|
501
541
|
}
|
|
502
542
|
best_centroids = centroids;
|
|
503
543
|
best_iteration_stats = iteration_stats;
|
|
504
544
|
best_obj = obj;
|
|
505
545
|
}
|
|
506
|
-
index.reset
|
|
546
|
+
index.reset();
|
|
507
547
|
}
|
|
508
548
|
}
|
|
509
549
|
if (nredo > 1) {
|
|
@@ -512,20 +552,151 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
|
|
|
512
552
|
index.reset();
|
|
513
553
|
index.add(k, best_centroids.data());
|
|
514
554
|
}
|
|
555
|
+
}
|
|
556
|
+
|
|
557
|
+
Clustering1D::Clustering1D(int k) : Clustering(1, k) {}
|
|
558
|
+
|
|
559
|
+
Clustering1D::Clustering1D(int k, const ClusteringParameters& cp)
|
|
560
|
+
: Clustering(1, k, cp) {}
|
|
561
|
+
|
|
562
|
+
void Clustering1D::train_exact(idx_t n, const float* x) {
|
|
563
|
+
const float* xt = x;
|
|
564
|
+
|
|
565
|
+
std::unique_ptr<uint8_t[]> del;
|
|
566
|
+
if (n > k * max_points_per_centroid) {
|
|
567
|
+
uint8_t* x_new;
|
|
568
|
+
float* weights_new;
|
|
569
|
+
n = subsample_training_set(
|
|
570
|
+
*this,
|
|
571
|
+
n,
|
|
572
|
+
(uint8_t*)x,
|
|
573
|
+
sizeof(float) * d,
|
|
574
|
+
nullptr,
|
|
575
|
+
&x_new,
|
|
576
|
+
&weights_new);
|
|
577
|
+
del.reset(x_new);
|
|
578
|
+
xt = (float*)x_new;
|
|
579
|
+
}
|
|
515
580
|
|
|
581
|
+
centroids.resize(k);
|
|
582
|
+
double uf = kmeans1d(xt, n, k, centroids.data());
|
|
583
|
+
|
|
584
|
+
ClusteringIterationStats stats = {0.0, 0.0, 0.0, uf, 0};
|
|
585
|
+
iteration_stats.push_back(stats);
|
|
516
586
|
}
|
|
517
587
|
|
|
518
|
-
float kmeans_clustering
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
588
|
+
float kmeans_clustering(
|
|
589
|
+
size_t d,
|
|
590
|
+
size_t n,
|
|
591
|
+
size_t k,
|
|
592
|
+
const float* x,
|
|
593
|
+
float* centroids) {
|
|
594
|
+
Clustering clus(d, k);
|
|
523
595
|
clus.verbose = d * n * k > (1L << 30);
|
|
524
596
|
// display logs if > 1Gflop per iteration
|
|
525
|
-
IndexFlatL2 index
|
|
526
|
-
clus.train
|
|
597
|
+
IndexFlatL2 index(d);
|
|
598
|
+
clus.train(n, x, index);
|
|
527
599
|
memcpy(centroids, clus.centroids.data(), sizeof(*centroids) * d * k);
|
|
528
600
|
return clus.iteration_stats.back().obj;
|
|
529
601
|
}
|
|
530
602
|
|
|
603
|
+
/******************************************************************************
|
|
604
|
+
* ProgressiveDimClustering implementation
|
|
605
|
+
******************************************************************************/
|
|
606
|
+
|
|
607
|
+
ProgressiveDimClusteringParameters::ProgressiveDimClusteringParameters() {
|
|
608
|
+
progressive_dim_steps = 10;
|
|
609
|
+
apply_pca = true; // seems a good idea to do this by default
|
|
610
|
+
niter = 10; // reduce nb of iterations per step
|
|
611
|
+
}
|
|
612
|
+
|
|
613
|
+
Index* ProgressiveDimIndexFactory::operator()(int dim) {
|
|
614
|
+
return new IndexFlatL2(dim);
|
|
615
|
+
}
|
|
616
|
+
|
|
617
|
+
ProgressiveDimClustering::ProgressiveDimClustering(int d, int k) : d(d), k(k) {}
|
|
618
|
+
|
|
619
|
+
ProgressiveDimClustering::ProgressiveDimClustering(
|
|
620
|
+
int d,
|
|
621
|
+
int k,
|
|
622
|
+
const ProgressiveDimClusteringParameters& cp)
|
|
623
|
+
: ProgressiveDimClusteringParameters(cp), d(d), k(k) {}
|
|
624
|
+
|
|
625
|
+
namespace {
|
|
626
|
+
|
|
627
|
+
using idx_t = Index::idx_t;
|
|
628
|
+
|
|
629
|
+
void copy_columns(idx_t n, idx_t d1, const float* src, idx_t d2, float* dest) {
|
|
630
|
+
idx_t d = std::min(d1, d2);
|
|
631
|
+
for (idx_t i = 0; i < n; i++) {
|
|
632
|
+
memcpy(dest, src, sizeof(float) * d);
|
|
633
|
+
src += d1;
|
|
634
|
+
dest += d2;
|
|
635
|
+
}
|
|
636
|
+
}
|
|
637
|
+
|
|
638
|
+
}; // namespace
|
|
639
|
+
|
|
640
|
+
void ProgressiveDimClustering::train(
|
|
641
|
+
idx_t n,
|
|
642
|
+
const float* x,
|
|
643
|
+
ProgressiveDimIndexFactory& factory) {
|
|
644
|
+
int d_prev = 0;
|
|
645
|
+
|
|
646
|
+
PCAMatrix pca(d, d);
|
|
647
|
+
|
|
648
|
+
std::vector<float> xbuf;
|
|
649
|
+
if (apply_pca) {
|
|
650
|
+
if (verbose) {
|
|
651
|
+
printf("Training PCA transform\n");
|
|
652
|
+
}
|
|
653
|
+
pca.train(n, x);
|
|
654
|
+
if (verbose) {
|
|
655
|
+
printf("Apply PCA\n");
|
|
656
|
+
}
|
|
657
|
+
xbuf.resize(n * d);
|
|
658
|
+
pca.apply_noalloc(n, x, xbuf.data());
|
|
659
|
+
x = xbuf.data();
|
|
660
|
+
}
|
|
661
|
+
|
|
662
|
+
for (int iter = 0; iter < progressive_dim_steps; iter++) {
|
|
663
|
+
int di = int(pow(d, (1. + iter) / progressive_dim_steps));
|
|
664
|
+
if (verbose) {
|
|
665
|
+
printf("Progressive dim step %d: cluster in dimension %d\n",
|
|
666
|
+
iter,
|
|
667
|
+
di);
|
|
668
|
+
}
|
|
669
|
+
std::unique_ptr<Index> clustering_index(factory(di));
|
|
670
|
+
|
|
671
|
+
Clustering clus(di, k, *this);
|
|
672
|
+
if (d_prev > 0) {
|
|
673
|
+
// copy warm-start centroids (padded with 0s)
|
|
674
|
+
clus.centroids.resize(k * di);
|
|
675
|
+
copy_columns(
|
|
676
|
+
k, d_prev, centroids.data(), di, clus.centroids.data());
|
|
677
|
+
}
|
|
678
|
+
std::vector<float> xsub(n * di);
|
|
679
|
+
copy_columns(n, d, x, di, xsub.data());
|
|
680
|
+
|
|
681
|
+
clus.train(n, xsub.data(), *clustering_index.get());
|
|
682
|
+
|
|
683
|
+
centroids = clus.centroids;
|
|
684
|
+
iteration_stats.insert(
|
|
685
|
+
iteration_stats.end(),
|
|
686
|
+
clus.iteration_stats.begin(),
|
|
687
|
+
clus.iteration_stats.end());
|
|
688
|
+
|
|
689
|
+
d_prev = di;
|
|
690
|
+
}
|
|
691
|
+
|
|
692
|
+
if (apply_pca) {
|
|
693
|
+
if (verbose) {
|
|
694
|
+
printf("Revert PCA transform on centroids\n");
|
|
695
|
+
}
|
|
696
|
+
std::vector<float> cent_transformed(d * k);
|
|
697
|
+
pca.reverse_transform(k, centroids.data(), cent_transformed.data());
|
|
698
|
+
cent_transformed.swap(centroids);
|
|
699
|
+
}
|
|
700
|
+
}
|
|
701
|
+
|
|
531
702
|
} // namespace faiss
|