faiss 0.1.5 → 0.2.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +24 -0
- data/README.md +12 -0
- data/ext/faiss/ext.cpp +1 -1
- data/ext/faiss/extconf.rb +6 -2
- data/ext/faiss/index.cpp +114 -43
- data/ext/faiss/index_binary.cpp +24 -30
- data/ext/faiss/kmeans.cpp +20 -16
- data/ext/faiss/numo.hpp +867 -0
- data/ext/faiss/pca_matrix.cpp +13 -14
- data/ext/faiss/product_quantizer.cpp +23 -24
- data/ext/faiss/utils.cpp +10 -37
- data/ext/faiss/utils.h +2 -13
- data/lib/faiss.rb +0 -5
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +292 -291
- data/vendor/faiss/faiss/AutoTune.h +55 -56
- data/vendor/faiss/faiss/Clustering.cpp +334 -195
- data/vendor/faiss/faiss/Clustering.h +88 -35
- data/vendor/faiss/faiss/IVFlib.cpp +171 -195
- data/vendor/faiss/faiss/IVFlib.h +48 -51
- data/vendor/faiss/faiss/Index.cpp +85 -103
- data/vendor/faiss/faiss/Index.h +54 -48
- data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
- data/vendor/faiss/faiss/Index2Layer.h +22 -22
- data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
- data/vendor/faiss/faiss/IndexBinary.h +140 -132
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
- data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
- data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
- data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
- data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
- data/vendor/faiss/faiss/IndexFlat.h +35 -46
- data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
- data/vendor/faiss/faiss/IndexHNSW.h +57 -41
- data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
- data/vendor/faiss/faiss/IndexIVF.h +146 -113
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
- data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
- data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
- data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
- data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
- data/vendor/faiss/faiss/IndexLSH.h +21 -26
- data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
- data/vendor/faiss/faiss/IndexLattice.h +11 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
- data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
- data/vendor/faiss/faiss/IndexNSG.h +85 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
- data/vendor/faiss/faiss/IndexPQ.h +64 -67
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
- data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
- data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
- data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
- data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
- data/vendor/faiss/faiss/IndexRefine.h +22 -23
- data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
- data/vendor/faiss/faiss/IndexReplicas.h +62 -56
- data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
- data/vendor/faiss/faiss/IndexResidual.h +152 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
- data/vendor/faiss/faiss/IndexShards.cpp +256 -240
- data/vendor/faiss/faiss/IndexShards.h +85 -73
- data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
- data/vendor/faiss/faiss/MatrixStats.h +7 -10
- data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
- data/vendor/faiss/faiss/MetaIndexes.h +40 -34
- data/vendor/faiss/faiss/MetricType.h +7 -7
- data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
- data/vendor/faiss/faiss/VectorTransform.h +61 -89
- data/vendor/faiss/faiss/clone_index.cpp +77 -73
- data/vendor/faiss/faiss/clone_index.h +4 -9
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
- data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
- data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
- data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
- data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
- data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
- data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
- data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
- data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
- data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
- data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
- data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
- data/vendor/faiss/faiss/impl/FaissException.h +41 -29
- data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
- data/vendor/faiss/faiss/impl/HNSW.h +179 -200
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
- data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
- data/vendor/faiss/faiss/impl/NSG.h +199 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
- data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
- data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
- data/vendor/faiss/faiss/impl/io.cpp +75 -94
- data/vendor/faiss/faiss/impl/io.h +31 -41
- data/vendor/faiss/faiss/impl/io_macros.h +40 -29
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
- data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
- data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
- data/vendor/faiss/faiss/index_factory.cpp +269 -218
- data/vendor/faiss/faiss/index_factory.h +6 -7
- data/vendor/faiss/faiss/index_io.h +23 -26
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
- data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
- data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
- data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
- data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
- data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
- data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
- data/vendor/faiss/faiss/utils/Heap.h +186 -209
- data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
- data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
- data/vendor/faiss/faiss/utils/distances.cpp +301 -310
- data/vendor/faiss/faiss/utils/distances.h +133 -118
- data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
- data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
- data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
- data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
- data/vendor/faiss/faiss/utils/hamming.h +62 -85
- data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
- data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
- data/vendor/faiss/faiss/utils/partitioning.h +26 -21
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
- data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
- data/vendor/faiss/faiss/utils/random.cpp +39 -63
- data/vendor/faiss/faiss/utils/random.h +13 -16
- data/vendor/faiss/faiss/utils/simdlib.h +4 -2
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
- data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
- data/vendor/faiss/faiss/utils/utils.cpp +304 -287
- data/vendor/faiss/faiss/utils/utils.h +53 -48
- metadata +24 -10
- data/lib/faiss/index.rb +0 -20
- data/lib/faiss/index_binary.rb +0 -20
- data/lib/faiss/kmeans.rb +0 -15
- data/lib/faiss/pca_matrix.rb +0 -15
- data/lib/faiss/product_quantizer.rb +0 -22
|
@@ -10,38 +10,39 @@
|
|
|
10
10
|
#ifndef FAISS_AUTO_TUNE_H
|
|
11
11
|
#define FAISS_AUTO_TUNE_H
|
|
12
12
|
|
|
13
|
-
#include <vector>
|
|
14
|
-
#include <unordered_map>
|
|
15
13
|
#include <stdint.h>
|
|
14
|
+
#include <unordered_map>
|
|
15
|
+
#include <vector>
|
|
16
16
|
|
|
17
17
|
#include <faiss/Index.h>
|
|
18
18
|
#include <faiss/IndexBinary.h>
|
|
19
19
|
|
|
20
20
|
namespace faiss {
|
|
21
21
|
|
|
22
|
-
|
|
23
22
|
/**
|
|
24
23
|
* Evaluation criterion. Returns a performance measure in [0,1],
|
|
25
24
|
* higher is better.
|
|
26
25
|
*/
|
|
27
26
|
struct AutoTuneCriterion {
|
|
28
27
|
typedef Index::idx_t idx_t;
|
|
29
|
-
idx_t nq;
|
|
30
|
-
idx_t nnn;
|
|
31
|
-
idx_t gt_nnn; ///< nb of GT NNs required to evaluate
|
|
28
|
+
idx_t nq; ///< nb of queries this criterion is evaluated on
|
|
29
|
+
idx_t nnn; ///< nb of NNs that the query should request
|
|
30
|
+
idx_t gt_nnn; ///< nb of GT NNs required to evaluate criterion
|
|
32
31
|
|
|
33
|
-
std::vector<float> gt_D;
|
|
34
|
-
std::vector<idx_t> gt_I;
|
|
32
|
+
std::vector<float> gt_D; ///< Ground-truth distances (size nq * gt_nnn)
|
|
33
|
+
std::vector<idx_t> gt_I; ///< Ground-truth indexes (size nq * gt_nnn)
|
|
35
34
|
|
|
36
|
-
AutoTuneCriterion
|
|
35
|
+
AutoTuneCriterion(idx_t nq, idx_t nnn);
|
|
37
36
|
|
|
38
37
|
/** Intitializes the gt_D and gt_I vectors. Must be called before evaluating
|
|
39
38
|
*
|
|
40
39
|
* @param gt_D_in size nq * gt_nnn
|
|
41
40
|
* @param gt_I_in size nq * gt_nnn
|
|
42
41
|
*/
|
|
43
|
-
void set_groundtruth
|
|
44
|
-
|
|
42
|
+
void set_groundtruth(
|
|
43
|
+
int gt_nnn,
|
|
44
|
+
const float* gt_D_in,
|
|
45
|
+
const idx_t* gt_I_in);
|
|
45
46
|
|
|
46
47
|
/** Evaluate the criterion.
|
|
47
48
|
*
|
|
@@ -49,29 +50,25 @@ struct AutoTuneCriterion {
|
|
|
49
50
|
* @param I size nq * nnn
|
|
50
51
|
* @return the criterion, between 0 and 1. Larger is better.
|
|
51
52
|
*/
|
|
52
|
-
virtual double evaluate
|
|
53
|
-
|
|
54
|
-
virtual ~AutoTuneCriterion () {}
|
|
53
|
+
virtual double evaluate(const float* D, const idx_t* I) const = 0;
|
|
55
54
|
|
|
55
|
+
virtual ~AutoTuneCriterion() {}
|
|
56
56
|
};
|
|
57
57
|
|
|
58
|
-
struct OneRecallAtRCriterion: AutoTuneCriterion {
|
|
59
|
-
|
|
58
|
+
struct OneRecallAtRCriterion : AutoTuneCriterion {
|
|
60
59
|
idx_t R;
|
|
61
60
|
|
|
62
|
-
OneRecallAtRCriterion
|
|
61
|
+
OneRecallAtRCriterion(idx_t nq, idx_t R);
|
|
63
62
|
|
|
64
63
|
double evaluate(const float* D, const idx_t* I) const override;
|
|
65
64
|
|
|
66
65
|
~OneRecallAtRCriterion() override {}
|
|
67
66
|
};
|
|
68
67
|
|
|
69
|
-
|
|
70
|
-
struct IntersectionCriterion: AutoTuneCriterion {
|
|
71
|
-
|
|
68
|
+
struct IntersectionCriterion : AutoTuneCriterion {
|
|
72
69
|
idx_t R;
|
|
73
70
|
|
|
74
|
-
IntersectionCriterion
|
|
71
|
+
IntersectionCriterion(idx_t nq, idx_t R);
|
|
75
72
|
|
|
76
73
|
double evaluate(const float* D, const idx_t* I) const override;
|
|
77
74
|
|
|
@@ -91,7 +88,7 @@ struct OperatingPoint {
|
|
|
91
88
|
double perf; ///< performance measure (output of a Criterion)
|
|
92
89
|
double t; ///< corresponding execution time (ms)
|
|
93
90
|
std::string key; ///< key that identifies this op pt
|
|
94
|
-
int64_t cno;
|
|
91
|
+
int64_t cno; ///< integer identifer
|
|
95
92
|
};
|
|
96
93
|
|
|
97
94
|
struct OperatingPoints {
|
|
@@ -102,27 +99,27 @@ struct OperatingPoints {
|
|
|
102
99
|
std::vector<OperatingPoint> optimal_pts;
|
|
103
100
|
|
|
104
101
|
// begins with a single operating point: t=0, perf=0
|
|
105
|
-
OperatingPoints
|
|
102
|
+
OperatingPoints();
|
|
106
103
|
|
|
107
104
|
/// add operating points from other to this, with a prefix to the keys
|
|
108
|
-
int merge_with
|
|
109
|
-
|
|
105
|
+
int merge_with(
|
|
106
|
+
const OperatingPoints& other,
|
|
107
|
+
const std::string& prefix = "");
|
|
110
108
|
|
|
111
|
-
void clear
|
|
109
|
+
void clear();
|
|
112
110
|
|
|
113
111
|
/// add a performance measure. Return whether it is an optimal point
|
|
114
|
-
bool add
|
|
112
|
+
bool add(double perf, double t, const std::string& key, size_t cno = 0);
|
|
115
113
|
|
|
116
114
|
/// get time required to obtain a given performance measure
|
|
117
|
-
double t_for_perf
|
|
115
|
+
double t_for_perf(double perf) const;
|
|
118
116
|
|
|
119
117
|
/// easy-to-read output
|
|
120
|
-
void display
|
|
118
|
+
void display(bool only_optimal = true) const;
|
|
121
119
|
|
|
122
120
|
/// output to a format easy to digest by gnuplot
|
|
123
|
-
void all_to_gnuplot
|
|
124
|
-
void optimal_to_gnuplot
|
|
125
|
-
|
|
121
|
+
void all_to_gnuplot(const char* fname) const;
|
|
122
|
+
void optimal_to_gnuplot(const char* fname) const;
|
|
126
123
|
};
|
|
127
124
|
|
|
128
125
|
/// possible values of a parameter, sorted from least to most expensive/accurate
|
|
@@ -156,41 +153,45 @@ struct ParameterSpace {
|
|
|
156
153
|
/// duration (to avoid jittering in MT mode)
|
|
157
154
|
double min_test_duration;
|
|
158
155
|
|
|
159
|
-
ParameterSpace
|
|
156
|
+
ParameterSpace();
|
|
160
157
|
|
|
161
158
|
/// nb of combinations, = product of values sizes
|
|
162
|
-
size_t n_combinations
|
|
159
|
+
size_t n_combinations() const;
|
|
163
160
|
|
|
164
161
|
/// returns whether combinations c1 >= c2 in the tuple sense
|
|
165
|
-
bool combination_ge
|
|
162
|
+
bool combination_ge(size_t c1, size_t c2) const;
|
|
166
163
|
|
|
167
164
|
/// get string representation of the combination
|
|
168
|
-
std::string combination_name
|
|
165
|
+
std::string combination_name(size_t cno) const;
|
|
169
166
|
|
|
170
167
|
/// print a description on stdout
|
|
171
|
-
void display
|
|
168
|
+
void display() const;
|
|
172
169
|
|
|
173
170
|
/// add a new parameter (or return it if it exists)
|
|
174
|
-
ParameterRange
|
|
171
|
+
ParameterRange& add_range(const std::string& name);
|
|
175
172
|
|
|
176
173
|
/// initialize with reasonable parameters for the index
|
|
177
|
-
virtual void initialize
|
|
174
|
+
virtual void initialize(const Index* index);
|
|
178
175
|
|
|
179
176
|
/// set a combination of parameters on an index
|
|
180
|
-
void set_index_parameters
|
|
177
|
+
void set_index_parameters(Index* index, size_t cno) const;
|
|
181
178
|
|
|
182
179
|
/// set a combination of parameters described by a string
|
|
183
|
-
void set_index_parameters
|
|
180
|
+
void set_index_parameters(Index* index, const char* param_string) const;
|
|
184
181
|
|
|
185
182
|
/// set one of the parameters, returns whether setting was successful
|
|
186
|
-
virtual void set_index_parameter
|
|
187
|
-
|
|
183
|
+
virtual void set_index_parameter(
|
|
184
|
+
Index* index,
|
|
185
|
+
const std::string& name,
|
|
186
|
+
double val) const;
|
|
188
187
|
|
|
189
188
|
/** find an upper bound on the performance and a lower bound on t
|
|
190
189
|
* for configuration cno given another operating point op */
|
|
191
|
-
void update_bounds
|
|
192
|
-
|
|
193
|
-
|
|
190
|
+
void update_bounds(
|
|
191
|
+
size_t cno,
|
|
192
|
+
const OperatingPoint& op,
|
|
193
|
+
double* upper_bound_perf,
|
|
194
|
+
double* lower_bound_t) const;
|
|
194
195
|
|
|
195
196
|
/** explore operating points
|
|
196
197
|
* @param index index to run on
|
|
@@ -198,18 +199,16 @@ struct ParameterSpace {
|
|
|
198
199
|
* @param crit selection criterion
|
|
199
200
|
* @param ops resulting operating points
|
|
200
201
|
*/
|
|
201
|
-
void explore
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
202
|
+
void explore(
|
|
203
|
+
Index* index,
|
|
204
|
+
size_t nq,
|
|
205
|
+
const float* xq,
|
|
206
|
+
const AutoTuneCriterion& crit,
|
|
207
|
+
OperatingPoints* ops) const;
|
|
208
|
+
|
|
209
|
+
virtual ~ParameterSpace() {}
|
|
207
210
|
};
|
|
208
211
|
|
|
209
|
-
|
|
210
|
-
|
|
211
212
|
} // namespace faiss
|
|
212
213
|
|
|
213
|
-
|
|
214
|
-
|
|
215
214
|
#endif
|
|
@@ -8,6 +8,7 @@
|
|
|
8
8
|
// -*- c++ -*-
|
|
9
9
|
|
|
10
10
|
#include <faiss/Clustering.h>
|
|
11
|
+
#include <faiss/VectorTransform.h>
|
|
11
12
|
#include <faiss/impl/AuxIndexStructures.h>
|
|
12
13
|
|
|
13
14
|
#include <cinttypes>
|
|
@@ -17,100 +18,100 @@
|
|
|
17
18
|
|
|
18
19
|
#include <omp.h>
|
|
19
20
|
|
|
20
|
-
#include <faiss/utils/utils.h>
|
|
21
|
-
#include <faiss/utils/random.h>
|
|
22
|
-
#include <faiss/utils/distances.h>
|
|
23
|
-
#include <faiss/impl/FaissAssert.h>
|
|
24
21
|
#include <faiss/IndexFlat.h>
|
|
22
|
+
#include <faiss/impl/FaissAssert.h>
|
|
23
|
+
#include <faiss/utils/distances.h>
|
|
24
|
+
#include <faiss/utils/random.h>
|
|
25
|
+
#include <faiss/utils/utils.h>
|
|
25
26
|
|
|
26
27
|
namespace faiss {
|
|
27
28
|
|
|
28
|
-
ClusteringParameters::ClusteringParameters
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
{}
|
|
29
|
+
ClusteringParameters::ClusteringParameters()
|
|
30
|
+
: niter(25),
|
|
31
|
+
nredo(1),
|
|
32
|
+
verbose(false),
|
|
33
|
+
spherical(false),
|
|
34
|
+
int_centroids(false),
|
|
35
|
+
update_index(false),
|
|
36
|
+
frozen_centroids(false),
|
|
37
|
+
min_points_per_centroid(39),
|
|
38
|
+
max_points_per_centroid(256),
|
|
39
|
+
seed(1234),
|
|
40
|
+
decode_block_size(32768) {}
|
|
41
41
|
// 39 corresponds to 10000 / 256 -> to avoid warnings on PQ tests with randu10k
|
|
42
42
|
|
|
43
|
+
Clustering::Clustering(int d, int k) : d(d), k(k) {}
|
|
43
44
|
|
|
44
|
-
Clustering::Clustering
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
Clustering::Clustering (int d, int k, const ClusteringParameters &cp):
|
|
48
|
-
ClusteringParameters (cp), d(d), k(k) {}
|
|
45
|
+
Clustering::Clustering(int d, int k, const ClusteringParameters& cp)
|
|
46
|
+
: ClusteringParameters(cp), d(d), k(k) {}
|
|
49
47
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
static double imbalance_factor (int n, int k, int64_t *assign) {
|
|
48
|
+
static double imbalance_factor(int n, int k, int64_t* assign) {
|
|
53
49
|
std::vector<int> hist(k, 0);
|
|
54
50
|
for (int i = 0; i < n; i++)
|
|
55
51
|
hist[assign[i]]++;
|
|
56
52
|
|
|
57
53
|
double tot = 0, uf = 0;
|
|
58
54
|
|
|
59
|
-
for (int i = 0
|
|
55
|
+
for (int i = 0; i < k; i++) {
|
|
60
56
|
tot += hist[i];
|
|
61
|
-
uf += hist[i] * (double)
|
|
57
|
+
uf += hist[i] * (double)hist[i];
|
|
62
58
|
}
|
|
63
59
|
uf = uf * k / (tot * tot);
|
|
64
60
|
|
|
65
61
|
return uf;
|
|
66
62
|
}
|
|
67
63
|
|
|
68
|
-
void Clustering::post_process_centroids
|
|
69
|
-
{
|
|
70
|
-
|
|
64
|
+
void Clustering::post_process_centroids() {
|
|
71
65
|
if (spherical) {
|
|
72
|
-
fvec_renorm_L2
|
|
66
|
+
fvec_renorm_L2(d, k, centroids.data());
|
|
73
67
|
}
|
|
74
68
|
|
|
75
69
|
if (int_centroids) {
|
|
76
70
|
for (size_t i = 0; i < centroids.size(); i++)
|
|
77
|
-
centroids[i] = roundf
|
|
71
|
+
centroids[i] = roundf(centroids[i]);
|
|
78
72
|
}
|
|
79
73
|
}
|
|
80
74
|
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
75
|
+
void Clustering::train(
|
|
76
|
+
idx_t nx,
|
|
77
|
+
const float* x_in,
|
|
78
|
+
Index& index,
|
|
79
|
+
const float* weights) {
|
|
80
|
+
train_encoded(
|
|
81
|
+
nx,
|
|
82
|
+
reinterpret_cast<const uint8_t*>(x_in),
|
|
83
|
+
nullptr,
|
|
84
|
+
index,
|
|
85
|
+
weights);
|
|
86
86
|
}
|
|
87
87
|
|
|
88
|
-
|
|
89
88
|
namespace {
|
|
90
89
|
|
|
91
90
|
using idx_t = Clustering::idx_t;
|
|
92
91
|
|
|
93
92
|
idx_t subsample_training_set(
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
93
|
+
const Clustering& clus,
|
|
94
|
+
idx_t nx,
|
|
95
|
+
const uint8_t* x,
|
|
96
|
+
size_t line_size,
|
|
97
|
+
const float* weights,
|
|
98
|
+
uint8_t** x_out,
|
|
99
|
+
float** weights_out) {
|
|
100
100
|
if (clus.verbose) {
|
|
101
101
|
printf("Sampling a subset of %zd / %" PRId64 " for training\n",
|
|
102
|
-
clus.k * clus.max_points_per_centroid,
|
|
102
|
+
clus.k * clus.max_points_per_centroid,
|
|
103
|
+
nx);
|
|
103
104
|
}
|
|
104
|
-
std::vector<int> perm
|
|
105
|
-
rand_perm
|
|
105
|
+
std::vector<int> perm(nx);
|
|
106
|
+
rand_perm(perm.data(), nx, clus.seed);
|
|
106
107
|
nx = clus.k * clus.max_points_per_centroid;
|
|
107
|
-
uint8_t
|
|
108
|
+
uint8_t* x_new = new uint8_t[nx * line_size];
|
|
108
109
|
*x_out = x_new;
|
|
109
110
|
for (idx_t i = 0; i < nx; i++) {
|
|
110
|
-
memcpy
|
|
111
|
+
memcpy(x_new + i * line_size, x + perm[i] * line_size, line_size);
|
|
111
112
|
}
|
|
112
113
|
if (weights) {
|
|
113
|
-
float
|
|
114
|
+
float* weights_new = new float[nx];
|
|
114
115
|
for (idx_t i = 0; i < nx; i++) {
|
|
115
116
|
weights_new[i] = weights[perm[i]];
|
|
116
117
|
}
|
|
@@ -134,20 +135,23 @@ idx_t subsample_training_set(
|
|
|
134
135
|
*
|
|
135
136
|
*/
|
|
136
137
|
|
|
137
|
-
void compute_centroids
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
138
|
+
void compute_centroids(
|
|
139
|
+
size_t d,
|
|
140
|
+
size_t k,
|
|
141
|
+
size_t n,
|
|
142
|
+
size_t k_frozen,
|
|
143
|
+
const uint8_t* x,
|
|
144
|
+
const Index* codec,
|
|
145
|
+
const int64_t* assign,
|
|
146
|
+
const float* weights,
|
|
147
|
+
float* hassign,
|
|
148
|
+
float* centroids) {
|
|
145
149
|
k -= k_frozen;
|
|
146
150
|
centroids += k_frozen * d;
|
|
147
151
|
|
|
148
|
-
memset
|
|
152
|
+
memset(centroids, 0, sizeof(*centroids) * d * k);
|
|
149
153
|
|
|
150
|
-
size_t line_size = codec ? codec->sa_code_size() : d * sizeof
|
|
154
|
+
size_t line_size = codec ? codec->sa_code_size() : d * sizeof(float);
|
|
151
155
|
|
|
152
156
|
#pragma omp parallel
|
|
153
157
|
{
|
|
@@ -157,20 +161,20 @@ void compute_centroids (size_t d, size_t k, size_t n,
|
|
|
157
161
|
// this thread is taking care of centroids c0:c1
|
|
158
162
|
size_t c0 = (k * rank) / nt;
|
|
159
163
|
size_t c1 = (k * (rank + 1)) / nt;
|
|
160
|
-
std::vector<float> decode_buffer
|
|
164
|
+
std::vector<float> decode_buffer(d);
|
|
161
165
|
|
|
162
166
|
for (size_t i = 0; i < n; i++) {
|
|
163
167
|
int64_t ci = assign[i];
|
|
164
|
-
assert
|
|
168
|
+
assert(ci >= 0 && ci < k + k_frozen);
|
|
165
169
|
ci -= k_frozen;
|
|
166
|
-
if (ci >= c0 && ci < c1)
|
|
167
|
-
float
|
|
168
|
-
const float
|
|
170
|
+
if (ci >= c0 && ci < c1) {
|
|
171
|
+
float* c = centroids + ci * d;
|
|
172
|
+
const float* xi;
|
|
169
173
|
if (!codec) {
|
|
170
174
|
xi = reinterpret_cast<const float*>(x + i * line_size);
|
|
171
175
|
} else {
|
|
172
|
-
float
|
|
173
|
-
codec->sa_decode
|
|
176
|
+
float* xif = decode_buffer.data();
|
|
177
|
+
codec->sa_decode(1, x + i * line_size, xif);
|
|
174
178
|
xi = xif;
|
|
175
179
|
}
|
|
176
180
|
if (weights) {
|
|
@@ -187,7 +191,6 @@ void compute_centroids (size_t d, size_t k, size_t n,
|
|
|
187
191
|
}
|
|
188
192
|
}
|
|
189
193
|
}
|
|
190
|
-
|
|
191
194
|
}
|
|
192
195
|
|
|
193
196
|
#pragma omp parallel for
|
|
@@ -196,12 +199,11 @@ void compute_centroids (size_t d, size_t k, size_t n,
|
|
|
196
199
|
continue;
|
|
197
200
|
}
|
|
198
201
|
float norm = 1 / hassign[ci];
|
|
199
|
-
float
|
|
202
|
+
float* c = centroids + ci * d;
|
|
200
203
|
for (size_t j = 0; j < d; j++) {
|
|
201
204
|
c[j] *= norm;
|
|
202
205
|
}
|
|
203
206
|
}
|
|
204
|
-
|
|
205
207
|
}
|
|
206
208
|
|
|
207
209
|
// a bit above machine epsilon for float16
|
|
@@ -214,29 +216,33 @@ void compute_centroids (size_t d, size_t k, size_t n,
|
|
|
214
216
|
*
|
|
215
217
|
* @return nb of spliting operations (larger is worse)
|
|
216
218
|
*/
|
|
217
|
-
int split_clusters
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
219
|
+
int split_clusters(
|
|
220
|
+
size_t d,
|
|
221
|
+
size_t k,
|
|
222
|
+
size_t n,
|
|
223
|
+
size_t k_frozen,
|
|
224
|
+
float* hassign,
|
|
225
|
+
float* centroids) {
|
|
222
226
|
k -= k_frozen;
|
|
223
227
|
centroids += k_frozen * d;
|
|
224
228
|
|
|
225
229
|
/* Take care of void clusters */
|
|
226
230
|
size_t nsplit = 0;
|
|
227
|
-
RandomGenerator rng
|
|
231
|
+
RandomGenerator rng(1234);
|
|
228
232
|
for (size_t ci = 0; ci < k; ci++) {
|
|
229
233
|
if (hassign[ci] == 0) { /* need to redefine a centroid */
|
|
230
234
|
size_t cj;
|
|
231
235
|
for (cj = 0; 1; cj = (cj + 1) % k) {
|
|
232
236
|
/* probability to pick this cluster for split */
|
|
233
|
-
float p = (hassign[cj] - 1.0) / (float)
|
|
234
|
-
float r = rng.rand_float
|
|
237
|
+
float p = (hassign[cj] - 1.0) / (float)(n - k);
|
|
238
|
+
float r = rng.rand_float();
|
|
235
239
|
if (r < p) {
|
|
236
240
|
break; /* found our cluster to be split */
|
|
237
241
|
}
|
|
238
242
|
}
|
|
239
|
-
memcpy
|
|
243
|
+
memcpy(centroids + ci * d,
|
|
244
|
+
centroids + cj * d,
|
|
245
|
+
sizeof(*centroids) * d);
|
|
240
246
|
|
|
241
247
|
/* small symmetric pertubation */
|
|
242
248
|
for (size_t j = 0; j < d; j++) {
|
|
@@ -257,30 +263,35 @@ int split_clusters (size_t d, size_t k, size_t n,
|
|
|
257
263
|
}
|
|
258
264
|
|
|
259
265
|
return nsplit;
|
|
260
|
-
|
|
261
266
|
}
|
|
262
267
|
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
268
|
+
}; // namespace
|
|
269
|
+
|
|
270
|
+
void Clustering::train_encoded(
|
|
271
|
+
idx_t nx,
|
|
272
|
+
const uint8_t* x_in,
|
|
273
|
+
const Index* codec,
|
|
274
|
+
Index& index,
|
|
275
|
+
const float* weights) {
|
|
276
|
+
FAISS_THROW_IF_NOT_FMT(
|
|
277
|
+
nx >= k,
|
|
278
|
+
"Number of training points (%" PRId64
|
|
279
|
+
") should be at least "
|
|
280
|
+
"as large as number of clusters (%zd)",
|
|
281
|
+
nx,
|
|
282
|
+
k);
|
|
283
|
+
|
|
284
|
+
FAISS_THROW_IF_NOT_FMT(
|
|
285
|
+
(!codec || codec->d == d),
|
|
286
|
+
"Codec dimension %d not the same as data dimension %d",
|
|
287
|
+
int(codec->d),
|
|
288
|
+
int(d));
|
|
289
|
+
|
|
290
|
+
FAISS_THROW_IF_NOT_FMT(
|
|
291
|
+
index.d == d,
|
|
282
292
|
"Index dimension %d not the same as data dimension %d",
|
|
283
|
-
int(index.d),
|
|
293
|
+
int(index.d),
|
|
294
|
+
int(d));
|
|
284
295
|
|
|
285
296
|
double t0 = getmillisecs();
|
|
286
297
|
|
|
@@ -288,67 +299,78 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
|
|
|
288
299
|
// Check for NaNs in input data. Normally it is the user's
|
|
289
300
|
// responsibility, but it may spare us some hard-to-debug
|
|
290
301
|
// reports.
|
|
291
|
-
const float
|
|
302
|
+
const float* x = reinterpret_cast<const float*>(x_in);
|
|
292
303
|
for (size_t i = 0; i < nx * d; i++) {
|
|
293
|
-
FAISS_THROW_IF_NOT_MSG
|
|
294
|
-
|
|
304
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
305
|
+
std::isfinite(x[i]), "input contains NaN's or Inf's");
|
|
295
306
|
}
|
|
296
307
|
}
|
|
297
308
|
|
|
298
|
-
const uint8_t
|
|
299
|
-
std::unique_ptr<uint8_t
|
|
300
|
-
std::unique_ptr<float
|
|
309
|
+
const uint8_t* x = x_in;
|
|
310
|
+
std::unique_ptr<uint8_t[]> del1;
|
|
311
|
+
std::unique_ptr<float[]> del3;
|
|
301
312
|
size_t line_size = codec ? codec->sa_code_size() : sizeof(float) * d;
|
|
302
313
|
|
|
303
314
|
if (nx > k * max_points_per_centroid) {
|
|
304
|
-
uint8_t
|
|
305
|
-
float
|
|
306
|
-
nx = subsample_training_set
|
|
307
|
-
|
|
308
|
-
del1.reset
|
|
309
|
-
|
|
315
|
+
uint8_t* x_new;
|
|
316
|
+
float* weights_new;
|
|
317
|
+
nx = subsample_training_set(
|
|
318
|
+
*this, nx, x, line_size, weights, &x_new, &weights_new);
|
|
319
|
+
del1.reset(x_new);
|
|
320
|
+
x = x_new;
|
|
321
|
+
del3.reset(weights_new);
|
|
322
|
+
weights = weights_new;
|
|
310
323
|
} else if (nx < k * min_points_per_centroid) {
|
|
311
|
-
fprintf
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
324
|
+
fprintf(stderr,
|
|
325
|
+
"WARNING clustering %" PRId64
|
|
326
|
+
" points to %zd centroids: "
|
|
327
|
+
"please provide at least %" PRId64 " training points\n",
|
|
328
|
+
nx,
|
|
329
|
+
k,
|
|
330
|
+
idx_t(k) * min_points_per_centroid);
|
|
315
331
|
}
|
|
316
332
|
|
|
317
333
|
if (nx == k) {
|
|
318
334
|
// this is a corner case, just copy training set to clusters
|
|
319
335
|
if (verbose) {
|
|
320
|
-
printf("Number of training points (%" PRId64
|
|
321
|
-
"
|
|
336
|
+
printf("Number of training points (%" PRId64
|
|
337
|
+
") same as number of "
|
|
338
|
+
"clusters, just copying\n",
|
|
339
|
+
nx);
|
|
322
340
|
}
|
|
323
|
-
centroids.resize
|
|
341
|
+
centroids.resize(d * k);
|
|
324
342
|
if (!codec) {
|
|
325
|
-
memcpy
|
|
343
|
+
memcpy(centroids.data(), x_in, sizeof(float) * d * k);
|
|
326
344
|
} else {
|
|
327
|
-
codec->sa_decode
|
|
345
|
+
codec->sa_decode(nx, x_in, centroids.data());
|
|
328
346
|
}
|
|
329
347
|
|
|
330
348
|
// one fake iteration...
|
|
331
|
-
ClusteringIterationStats stats = {
|
|
332
|
-
iteration_stats.push_back
|
|
349
|
+
ClusteringIterationStats stats = {0.0, 0.0, 0.0, 1.0, 0};
|
|
350
|
+
iteration_stats.push_back(stats);
|
|
333
351
|
|
|
334
352
|
index.reset();
|
|
335
353
|
index.add(k, centroids.data());
|
|
336
354
|
return;
|
|
337
355
|
}
|
|
338
356
|
|
|
339
|
-
|
|
340
357
|
if (verbose) {
|
|
341
|
-
printf("Clustering %" PRId64
|
|
358
|
+
printf("Clustering %" PRId64
|
|
359
|
+
" points in %zdD to %zd clusters, "
|
|
342
360
|
"redo %d times, %d iterations\n",
|
|
343
|
-
nx,
|
|
361
|
+
nx,
|
|
362
|
+
d,
|
|
363
|
+
k,
|
|
364
|
+
nredo,
|
|
365
|
+
niter);
|
|
344
366
|
if (codec) {
|
|
345
367
|
printf("Input data encoded in %zd bytes per vector\n",
|
|
346
|
-
codec->sa_code_size
|
|
368
|
+
codec->sa_code_size());
|
|
347
369
|
}
|
|
348
370
|
}
|
|
349
371
|
|
|
350
|
-
std::unique_ptr<idx_t
|
|
351
|
-
std::unique_ptr<float
|
|
372
|
+
std::unique_ptr<idx_t[]> assign(new idx_t[nx]);
|
|
373
|
+
std::unique_ptr<float[]> dis(new float[nx]);
|
|
352
374
|
|
|
353
375
|
// remember best iteration for redo
|
|
354
376
|
bool lower_is_better = index.metric_type != METRIC_INNER_PRODUCT;
|
|
@@ -358,52 +380,49 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
|
|
|
358
380
|
|
|
359
381
|
// support input centroids
|
|
360
382
|
|
|
361
|
-
FAISS_THROW_IF_NOT_MSG
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
);
|
|
383
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
384
|
+
centroids.size() % d == 0,
|
|
385
|
+
"size of provided input centroids not a multiple of dimension");
|
|
365
386
|
|
|
366
387
|
size_t n_input_centroids = centroids.size() / d;
|
|
367
388
|
|
|
368
389
|
if (verbose && n_input_centroids > 0) {
|
|
369
|
-
printf
|
|
370
|
-
|
|
390
|
+
printf(" Using %zd centroids provided as input (%sfrozen)\n",
|
|
391
|
+
n_input_centroids,
|
|
392
|
+
frozen_centroids ? "" : "not ");
|
|
371
393
|
}
|
|
372
394
|
|
|
373
395
|
double t_search_tot = 0;
|
|
374
396
|
if (verbose) {
|
|
375
|
-
printf(" Preprocessing in %.2f s\n",
|
|
376
|
-
(getmillisecs() - t0) / 1000.);
|
|
397
|
+
printf(" Preprocessing in %.2f s\n", (getmillisecs() - t0) / 1000.);
|
|
377
398
|
}
|
|
378
399
|
t0 = getmillisecs();
|
|
379
400
|
|
|
380
401
|
// temporary buffer to decode vectors during the optimization
|
|
381
|
-
std::vector<float> decode_buffer
|
|
382
|
-
(codec ? d * decode_block_size : 0);
|
|
402
|
+
std::vector<float> decode_buffer(codec ? d * decode_block_size : 0);
|
|
383
403
|
|
|
384
404
|
for (int redo = 0; redo < nredo; redo++) {
|
|
385
|
-
|
|
386
405
|
if (verbose && nredo > 1) {
|
|
387
406
|
printf("Outer iteration %d / %d\n", redo, nredo);
|
|
388
407
|
}
|
|
389
408
|
|
|
390
409
|
// initialize (remaining) centroids with random points from the dataset
|
|
391
|
-
centroids.resize
|
|
392
|
-
std::vector<int> perm
|
|
410
|
+
centroids.resize(d * k);
|
|
411
|
+
std::vector<int> perm(nx);
|
|
393
412
|
|
|
394
|
-
rand_perm
|
|
413
|
+
rand_perm(perm.data(), nx, seed + 1 + redo * 15486557L);
|
|
395
414
|
|
|
396
415
|
if (!codec) {
|
|
397
|
-
for (int i = n_input_centroids; i < k
|
|
398
|
-
memcpy
|
|
416
|
+
for (int i = n_input_centroids; i < k; i++) {
|
|
417
|
+
memcpy(¢roids[i * d], x + perm[i] * line_size, line_size);
|
|
399
418
|
}
|
|
400
419
|
} else {
|
|
401
|
-
for (int i = n_input_centroids; i < k
|
|
402
|
-
codec->sa_decode
|
|
420
|
+
for (int i = n_input_centroids; i < k; i++) {
|
|
421
|
+
codec->sa_decode(1, x + perm[i] * line_size, ¢roids[i * d]);
|
|
403
422
|
}
|
|
404
423
|
}
|
|
405
424
|
|
|
406
|
-
post_process_centroids
|
|
425
|
+
post_process_centroids();
|
|
407
426
|
|
|
408
427
|
// prepare the index
|
|
409
428
|
|
|
@@ -412,10 +431,10 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
|
|
|
412
431
|
}
|
|
413
432
|
|
|
414
433
|
if (!index.is_trained) {
|
|
415
|
-
index.train
|
|
434
|
+
index.train(k, centroids.data());
|
|
416
435
|
}
|
|
417
436
|
|
|
418
|
-
index.add
|
|
437
|
+
index.add(k, centroids.data());
|
|
419
438
|
|
|
420
439
|
// k-means iterations
|
|
421
440
|
|
|
@@ -424,18 +443,28 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
|
|
|
424
443
|
double t0s = getmillisecs();
|
|
425
444
|
|
|
426
445
|
if (!codec) {
|
|
427
|
-
index.search
|
|
428
|
-
|
|
446
|
+
index.search(
|
|
447
|
+
nx,
|
|
448
|
+
reinterpret_cast<const float*>(x),
|
|
449
|
+
1,
|
|
450
|
+
dis.get(),
|
|
451
|
+
assign.get());
|
|
429
452
|
} else {
|
|
430
453
|
// search by blocks of decode_block_size vectors
|
|
431
|
-
size_t code_size = codec->sa_code_size
|
|
454
|
+
size_t code_size = codec->sa_code_size();
|
|
432
455
|
for (size_t i0 = 0; i0 < nx; i0 += decode_block_size) {
|
|
433
456
|
size_t i1 = i0 + decode_block_size;
|
|
434
|
-
if (i1 > nx) {
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
457
|
+
if (i1 > nx) {
|
|
458
|
+
i1 = nx;
|
|
459
|
+
}
|
|
460
|
+
codec->sa_decode(
|
|
461
|
+
i1 - i0, x + code_size * i0, decode_buffer.data());
|
|
462
|
+
index.search(
|
|
463
|
+
i1 - i0,
|
|
464
|
+
decode_buffer.data(),
|
|
465
|
+
1,
|
|
466
|
+
dis.get() + i0,
|
|
467
|
+
assign.get() + i0);
|
|
439
468
|
}
|
|
440
469
|
}
|
|
441
470
|
|
|
@@ -449,61 +478,71 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
|
|
|
449
478
|
}
|
|
450
479
|
|
|
451
480
|
// update the centroids
|
|
452
|
-
std::vector<float> hassign
|
|
481
|
+
std::vector<float> hassign(k);
|
|
453
482
|
|
|
454
483
|
size_t k_frozen = frozen_centroids ? n_input_centroids : 0;
|
|
455
|
-
compute_centroids
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
484
|
+
compute_centroids(
|
|
485
|
+
d,
|
|
486
|
+
k,
|
|
487
|
+
nx,
|
|
488
|
+
k_frozen,
|
|
489
|
+
x,
|
|
490
|
+
codec,
|
|
491
|
+
assign.get(),
|
|
492
|
+
weights,
|
|
493
|
+
hassign.data(),
|
|
494
|
+
centroids.data());
|
|
495
|
+
|
|
496
|
+
int nsplit = split_clusters(
|
|
497
|
+
d, k, nx, k_frozen, hassign.data(), centroids.data());
|
|
465
498
|
|
|
466
499
|
// collect statistics
|
|
467
|
-
ClusteringIterationStats stats =
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
500
|
+
ClusteringIterationStats stats = {
|
|
501
|
+
obj,
|
|
502
|
+
(getmillisecs() - t0) / 1000.0,
|
|
503
|
+
t_search_tot / 1000,
|
|
504
|
+
imbalance_factor(nx, k, assign.get()),
|
|
505
|
+
nsplit};
|
|
472
506
|
iteration_stats.push_back(stats);
|
|
473
507
|
|
|
474
508
|
if (verbose) {
|
|
475
|
-
printf
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
509
|
+
printf(" Iteration %d (%.2f s, search %.2f s): "
|
|
510
|
+
"objective=%g imbalance=%.3f nsplit=%d \r",
|
|
511
|
+
i,
|
|
512
|
+
stats.time,
|
|
513
|
+
stats.time_search,
|
|
514
|
+
stats.obj,
|
|
515
|
+
stats.imbalance_factor,
|
|
516
|
+
nsplit);
|
|
517
|
+
fflush(stdout);
|
|
480
518
|
}
|
|
481
519
|
|
|
482
|
-
post_process_centroids
|
|
520
|
+
post_process_centroids();
|
|
483
521
|
|
|
484
522
|
// add centroids to index for the next iteration (or for output)
|
|
485
523
|
|
|
486
|
-
index.reset
|
|
524
|
+
index.reset();
|
|
487
525
|
if (update_index) {
|
|
488
|
-
index.train
|
|
526
|
+
index.train(k, centroids.data());
|
|
489
527
|
}
|
|
490
528
|
|
|
491
|
-
index.add
|
|
492
|
-
InterruptCallback::check
|
|
529
|
+
index.add(k, centroids.data());
|
|
530
|
+
InterruptCallback::check();
|
|
493
531
|
}
|
|
494
532
|
|
|
495
|
-
if (verbose)
|
|
533
|
+
if (verbose)
|
|
534
|
+
printf("\n");
|
|
496
535
|
if (nredo > 1) {
|
|
497
536
|
if ((lower_is_better && obj < best_obj) ||
|
|
498
537
|
(!lower_is_better && obj > best_obj)) {
|
|
499
538
|
if (verbose) {
|
|
500
|
-
printf
|
|
539
|
+
printf("Objective improved: keep new clusters\n");
|
|
501
540
|
}
|
|
502
541
|
best_centroids = centroids;
|
|
503
542
|
best_iteration_stats = iteration_stats;
|
|
504
543
|
best_obj = obj;
|
|
505
544
|
}
|
|
506
|
-
index.reset
|
|
545
|
+
index.reset();
|
|
507
546
|
}
|
|
508
547
|
}
|
|
509
548
|
if (nredo > 1) {
|
|
@@ -512,20 +551,120 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
|
|
|
512
551
|
index.reset();
|
|
513
552
|
index.add(k, best_centroids.data());
|
|
514
553
|
}
|
|
515
|
-
|
|
516
554
|
}
|
|
517
555
|
|
|
518
|
-
float kmeans_clustering
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
556
|
+
float kmeans_clustering(
|
|
557
|
+
size_t d,
|
|
558
|
+
size_t n,
|
|
559
|
+
size_t k,
|
|
560
|
+
const float* x,
|
|
561
|
+
float* centroids) {
|
|
562
|
+
Clustering clus(d, k);
|
|
523
563
|
clus.verbose = d * n * k > (1L << 30);
|
|
524
564
|
// display logs if > 1Gflop per iteration
|
|
525
|
-
IndexFlatL2 index
|
|
526
|
-
clus.train
|
|
565
|
+
IndexFlatL2 index(d);
|
|
566
|
+
clus.train(n, x, index);
|
|
527
567
|
memcpy(centroids, clus.centroids.data(), sizeof(*centroids) * d * k);
|
|
528
568
|
return clus.iteration_stats.back().obj;
|
|
529
569
|
}
|
|
530
570
|
|
|
571
|
+
/******************************************************************************
|
|
572
|
+
* ProgressiveDimClustering implementation
|
|
573
|
+
******************************************************************************/
|
|
574
|
+
|
|
575
|
+
ProgressiveDimClusteringParameters::ProgressiveDimClusteringParameters() {
|
|
576
|
+
progressive_dim_steps = 10;
|
|
577
|
+
apply_pca = true; // seems a good idea to do this by default
|
|
578
|
+
niter = 10; // reduce nb of iterations per step
|
|
579
|
+
}
|
|
580
|
+
|
|
581
|
+
Index* ProgressiveDimIndexFactory::operator()(int dim) {
|
|
582
|
+
return new IndexFlatL2(dim);
|
|
583
|
+
}
|
|
584
|
+
|
|
585
|
+
ProgressiveDimClustering::ProgressiveDimClustering(int d, int k) : d(d), k(k) {}
|
|
586
|
+
|
|
587
|
+
ProgressiveDimClustering::ProgressiveDimClustering(
|
|
588
|
+
int d,
|
|
589
|
+
int k,
|
|
590
|
+
const ProgressiveDimClusteringParameters& cp)
|
|
591
|
+
: ProgressiveDimClusteringParameters(cp), d(d), k(k) {}
|
|
592
|
+
|
|
593
|
+
namespace {
|
|
594
|
+
|
|
595
|
+
using idx_t = Index::idx_t;
|
|
596
|
+
|
|
597
|
+
void copy_columns(idx_t n, idx_t d1, const float* src, idx_t d2, float* dest) {
|
|
598
|
+
idx_t d = std::min(d1, d2);
|
|
599
|
+
for (idx_t i = 0; i < n; i++) {
|
|
600
|
+
memcpy(dest, src, sizeof(float) * d);
|
|
601
|
+
src += d1;
|
|
602
|
+
dest += d2;
|
|
603
|
+
}
|
|
604
|
+
}
|
|
605
|
+
|
|
606
|
+
}; // namespace
|
|
607
|
+
|
|
608
|
+
void ProgressiveDimClustering::train(
|
|
609
|
+
idx_t n,
|
|
610
|
+
const float* x,
|
|
611
|
+
ProgressiveDimIndexFactory& factory) {
|
|
612
|
+
int d_prev = 0;
|
|
613
|
+
|
|
614
|
+
PCAMatrix pca(d, d);
|
|
615
|
+
|
|
616
|
+
std::vector<float> xbuf;
|
|
617
|
+
if (apply_pca) {
|
|
618
|
+
if (verbose) {
|
|
619
|
+
printf("Training PCA transform\n");
|
|
620
|
+
}
|
|
621
|
+
pca.train(n, x);
|
|
622
|
+
if (verbose) {
|
|
623
|
+
printf("Apply PCA\n");
|
|
624
|
+
}
|
|
625
|
+
xbuf.resize(n * d);
|
|
626
|
+
pca.apply_noalloc(n, x, xbuf.data());
|
|
627
|
+
x = xbuf.data();
|
|
628
|
+
}
|
|
629
|
+
|
|
630
|
+
for (int iter = 0; iter < progressive_dim_steps; iter++) {
|
|
631
|
+
int di = int(pow(d, (1. + iter) / progressive_dim_steps));
|
|
632
|
+
if (verbose) {
|
|
633
|
+
printf("Progressive dim step %d: cluster in dimension %d\n",
|
|
634
|
+
iter,
|
|
635
|
+
di);
|
|
636
|
+
}
|
|
637
|
+
std::unique_ptr<Index> clustering_index(factory(di));
|
|
638
|
+
|
|
639
|
+
Clustering clus(di, k, *this);
|
|
640
|
+
if (d_prev > 0) {
|
|
641
|
+
// copy warm-start centroids (padded with 0s)
|
|
642
|
+
clus.centroids.resize(k * di);
|
|
643
|
+
copy_columns(
|
|
644
|
+
k, d_prev, centroids.data(), di, clus.centroids.data());
|
|
645
|
+
}
|
|
646
|
+
std::vector<float> xsub(n * di);
|
|
647
|
+
copy_columns(n, d, x, di, xsub.data());
|
|
648
|
+
|
|
649
|
+
clus.train(n, xsub.data(), *clustering_index.get());
|
|
650
|
+
|
|
651
|
+
centroids = clus.centroids;
|
|
652
|
+
iteration_stats.insert(
|
|
653
|
+
iteration_stats.end(),
|
|
654
|
+
clus.iteration_stats.begin(),
|
|
655
|
+
clus.iteration_stats.end());
|
|
656
|
+
|
|
657
|
+
d_prev = di;
|
|
658
|
+
}
|
|
659
|
+
|
|
660
|
+
if (apply_pca) {
|
|
661
|
+
if (verbose) {
|
|
662
|
+
printf("Revert PCA transform on centroids\n");
|
|
663
|
+
}
|
|
664
|
+
std::vector<float> cent_transformed(d * k);
|
|
665
|
+
pca.reverse_transform(k, centroids.data(), cent_transformed.data());
|
|
666
|
+
cent_transformed.swap(centroids);
|
|
667
|
+
}
|
|
668
|
+
}
|
|
669
|
+
|
|
531
670
|
} // namespace faiss
|