faiss 0.1.0 → 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +103 -3
- data/ext/faiss/ext.cpp +99 -32
- data/ext/faiss/extconf.rb +12 -2
- data/lib/faiss/ext.bundle +0 -0
- data/lib/faiss/index.rb +3 -3
- data/lib/faiss/index_binary.rb +3 -3
- data/lib/faiss/kmeans.rb +1 -1
- data/lib/faiss/pca_matrix.rb +2 -2
- data/lib/faiss/product_quantizer.rb +3 -3
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/AutoTune.cpp +719 -0
- data/vendor/faiss/AutoTune.h +212 -0
- data/vendor/faiss/Clustering.cpp +261 -0
- data/vendor/faiss/Clustering.h +101 -0
- data/vendor/faiss/IVFlib.cpp +339 -0
- data/vendor/faiss/IVFlib.h +132 -0
- data/vendor/faiss/Index.cpp +171 -0
- data/vendor/faiss/Index.h +261 -0
- data/vendor/faiss/Index2Layer.cpp +437 -0
- data/vendor/faiss/Index2Layer.h +85 -0
- data/vendor/faiss/IndexBinary.cpp +77 -0
- data/vendor/faiss/IndexBinary.h +163 -0
- data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
- data/vendor/faiss/IndexBinaryFlat.h +54 -0
- data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
- data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
- data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
- data/vendor/faiss/IndexBinaryHNSW.h +56 -0
- data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
- data/vendor/faiss/IndexBinaryIVF.h +211 -0
- data/vendor/faiss/IndexFlat.cpp +508 -0
- data/vendor/faiss/IndexFlat.h +175 -0
- data/vendor/faiss/IndexHNSW.cpp +1090 -0
- data/vendor/faiss/IndexHNSW.h +170 -0
- data/vendor/faiss/IndexIVF.cpp +909 -0
- data/vendor/faiss/IndexIVF.h +353 -0
- data/vendor/faiss/IndexIVFFlat.cpp +502 -0
- data/vendor/faiss/IndexIVFFlat.h +118 -0
- data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
- data/vendor/faiss/IndexIVFPQ.h +161 -0
- data/vendor/faiss/IndexIVFPQR.cpp +219 -0
- data/vendor/faiss/IndexIVFPQR.h +65 -0
- data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
- data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
- data/vendor/faiss/IndexLSH.cpp +225 -0
- data/vendor/faiss/IndexLSH.h +87 -0
- data/vendor/faiss/IndexLattice.cpp +143 -0
- data/vendor/faiss/IndexLattice.h +68 -0
- data/vendor/faiss/IndexPQ.cpp +1188 -0
- data/vendor/faiss/IndexPQ.h +199 -0
- data/vendor/faiss/IndexPreTransform.cpp +288 -0
- data/vendor/faiss/IndexPreTransform.h +91 -0
- data/vendor/faiss/IndexReplicas.cpp +123 -0
- data/vendor/faiss/IndexReplicas.h +76 -0
- data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
- data/vendor/faiss/IndexScalarQuantizer.h +127 -0
- data/vendor/faiss/IndexShards.cpp +317 -0
- data/vendor/faiss/IndexShards.h +100 -0
- data/vendor/faiss/InvertedLists.cpp +623 -0
- data/vendor/faiss/InvertedLists.h +334 -0
- data/vendor/faiss/LICENSE +21 -0
- data/vendor/faiss/MatrixStats.cpp +252 -0
- data/vendor/faiss/MatrixStats.h +62 -0
- data/vendor/faiss/MetaIndexes.cpp +351 -0
- data/vendor/faiss/MetaIndexes.h +126 -0
- data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
- data/vendor/faiss/OnDiskInvertedLists.h +127 -0
- data/vendor/faiss/VectorTransform.cpp +1157 -0
- data/vendor/faiss/VectorTransform.h +322 -0
- data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
- data/vendor/faiss/c_api/AutoTune_c.h +64 -0
- data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
- data/vendor/faiss/c_api/Clustering_c.h +117 -0
- data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
- data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
- data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
- data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
- data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
- data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
- data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
- data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
- data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
- data/vendor/faiss/c_api/IndexShards_c.h +42 -0
- data/vendor/faiss/c_api/Index_c.cpp +105 -0
- data/vendor/faiss/c_api/Index_c.h +183 -0
- data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
- data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
- data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
- data/vendor/faiss/c_api/clone_index_c.h +32 -0
- data/vendor/faiss/c_api/error_c.h +42 -0
- data/vendor/faiss/c_api/error_impl.cpp +27 -0
- data/vendor/faiss/c_api/error_impl.h +16 -0
- data/vendor/faiss/c_api/faiss_c.h +58 -0
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
- data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
- data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
- data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
- data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
- data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
- data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
- data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
- data/vendor/faiss/c_api/index_factory_c.h +30 -0
- data/vendor/faiss/c_api/index_io_c.cpp +42 -0
- data/vendor/faiss/c_api/index_io_c.h +50 -0
- data/vendor/faiss/c_api/macros_impl.h +110 -0
- data/vendor/faiss/clone_index.cpp +147 -0
- data/vendor/faiss/clone_index.h +38 -0
- data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
- data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
- data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
- data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
- data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
- data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
- data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
- data/vendor/faiss/gpu/GpuCloner.h +82 -0
- data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
- data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
- data/vendor/faiss/gpu/GpuDistance.h +52 -0
- data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
- data/vendor/faiss/gpu/GpuIndex.h +148 -0
- data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
- data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
- data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
- data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
- data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
- data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
- data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
- data/vendor/faiss/gpu/GpuResources.cpp +52 -0
- data/vendor/faiss/gpu/GpuResources.h +73 -0
- data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
- data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
- data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
- data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
- data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
- data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
- data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
- data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
- data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
- data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
- data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
- data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
- data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
- data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
- data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
- data/vendor/faiss/gpu/test/TestUtils.h +93 -0
- data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
- data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
- data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
- data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
- data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
- data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
- data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
- data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
- data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
- data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
- data/vendor/faiss/gpu/utils/Timer.h +52 -0
- data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
- data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
- data/vendor/faiss/impl/FaissAssert.h +95 -0
- data/vendor/faiss/impl/FaissException.cpp +66 -0
- data/vendor/faiss/impl/FaissException.h +71 -0
- data/vendor/faiss/impl/HNSW.cpp +818 -0
- data/vendor/faiss/impl/HNSW.h +275 -0
- data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
- data/vendor/faiss/impl/PolysemousTraining.h +158 -0
- data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
- data/vendor/faiss/impl/ProductQuantizer.h +242 -0
- data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
- data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
- data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
- data/vendor/faiss/impl/ThreadedIndex.h +80 -0
- data/vendor/faiss/impl/index_read.cpp +793 -0
- data/vendor/faiss/impl/index_write.cpp +558 -0
- data/vendor/faiss/impl/io.cpp +142 -0
- data/vendor/faiss/impl/io.h +98 -0
- data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
- data/vendor/faiss/impl/lattice_Zn.h +199 -0
- data/vendor/faiss/index_factory.cpp +392 -0
- data/vendor/faiss/index_factory.h +25 -0
- data/vendor/faiss/index_io.h +75 -0
- data/vendor/faiss/misc/test_blas.cpp +84 -0
- data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
- data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
- data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
- data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
- data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
- data/vendor/faiss/tests/test_merge.cpp +258 -0
- data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
- data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
- data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
- data/vendor/faiss/tests/test_params_override.cpp +231 -0
- data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
- data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
- data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
- data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
- data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
- data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
- data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
- data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
- data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
- data/vendor/faiss/utils/Heap.cpp +122 -0
- data/vendor/faiss/utils/Heap.h +495 -0
- data/vendor/faiss/utils/WorkerThread.cpp +126 -0
- data/vendor/faiss/utils/WorkerThread.h +61 -0
- data/vendor/faiss/utils/distances.cpp +765 -0
- data/vendor/faiss/utils/distances.h +243 -0
- data/vendor/faiss/utils/distances_simd.cpp +809 -0
- data/vendor/faiss/utils/extra_distances.cpp +336 -0
- data/vendor/faiss/utils/extra_distances.h +54 -0
- data/vendor/faiss/utils/hamming-inl.h +472 -0
- data/vendor/faiss/utils/hamming.cpp +792 -0
- data/vendor/faiss/utils/hamming.h +220 -0
- data/vendor/faiss/utils/random.cpp +192 -0
- data/vendor/faiss/utils/random.h +60 -0
- data/vendor/faiss/utils/utils.cpp +783 -0
- data/vendor/faiss/utils/utils.h +181 -0
- metadata +216 -2
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
// -*- c++ -*-
|
|
9
|
+
|
|
10
|
+
#ifndef FAISS_AUTO_TUNE_H
|
|
11
|
+
#define FAISS_AUTO_TUNE_H
|
|
12
|
+
|
|
13
|
+
#include <vector>
|
|
14
|
+
#include <unordered_map>
|
|
15
|
+
#include <stdint.h>
|
|
16
|
+
|
|
17
|
+
#include <faiss/Index.h>
|
|
18
|
+
#include <faiss/IndexBinary.h>
|
|
19
|
+
|
|
20
|
+
namespace faiss {
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
/**
|
|
24
|
+
* Evaluation criterion. Returns a performance measure in [0,1],
|
|
25
|
+
* higher is better.
|
|
26
|
+
*/
|
|
27
|
+
struct AutoTuneCriterion {
|
|
28
|
+
typedef Index::idx_t idx_t;
|
|
29
|
+
idx_t nq; ///< nb of queries this criterion is evaluated on
|
|
30
|
+
idx_t nnn; ///< nb of NNs that the query should request
|
|
31
|
+
idx_t gt_nnn; ///< nb of GT NNs required to evaluate crterion
|
|
32
|
+
|
|
33
|
+
std::vector<float> gt_D; ///< Ground-truth distances (size nq * gt_nnn)
|
|
34
|
+
std::vector<idx_t> gt_I; ///< Ground-truth indexes (size nq * gt_nnn)
|
|
35
|
+
|
|
36
|
+
AutoTuneCriterion (idx_t nq, idx_t nnn);
|
|
37
|
+
|
|
38
|
+
/** Intitializes the gt_D and gt_I vectors. Must be called before evaluating
|
|
39
|
+
*
|
|
40
|
+
* @param gt_D_in size nq * gt_nnn
|
|
41
|
+
* @param gt_I_in size nq * gt_nnn
|
|
42
|
+
*/
|
|
43
|
+
void set_groundtruth (int gt_nnn, const float *gt_D_in,
|
|
44
|
+
const idx_t *gt_I_in);
|
|
45
|
+
|
|
46
|
+
/** Evaluate the criterion.
|
|
47
|
+
*
|
|
48
|
+
* @param D size nq * nnn
|
|
49
|
+
* @param I size nq * nnn
|
|
50
|
+
* @return the criterion, between 0 and 1. Larger is better.
|
|
51
|
+
*/
|
|
52
|
+
virtual double evaluate (const float *D, const idx_t *I) const = 0;
|
|
53
|
+
|
|
54
|
+
virtual ~AutoTuneCriterion () {}
|
|
55
|
+
|
|
56
|
+
};
|
|
57
|
+
|
|
58
|
+
struct OneRecallAtRCriterion: AutoTuneCriterion {
|
|
59
|
+
|
|
60
|
+
idx_t R;
|
|
61
|
+
|
|
62
|
+
OneRecallAtRCriterion (idx_t nq, idx_t R);
|
|
63
|
+
|
|
64
|
+
double evaluate(const float* D, const idx_t* I) const override;
|
|
65
|
+
|
|
66
|
+
~OneRecallAtRCriterion() override {}
|
|
67
|
+
};
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
struct IntersectionCriterion: AutoTuneCriterion {
|
|
71
|
+
|
|
72
|
+
idx_t R;
|
|
73
|
+
|
|
74
|
+
IntersectionCriterion (idx_t nq, idx_t R);
|
|
75
|
+
|
|
76
|
+
double evaluate(const float* D, const idx_t* I) const override;
|
|
77
|
+
|
|
78
|
+
~IntersectionCriterion() override {}
|
|
79
|
+
};
|
|
80
|
+
|
|
81
|
+
/**
|
|
82
|
+
* Maintains a list of experimental results. Each operating point is a
|
|
83
|
+
* (perf, t, key) triplet, where higher perf and lower t is
|
|
84
|
+
* better. The key field is an arbitrary identifier for the operating point
|
|
85
|
+
*/
|
|
86
|
+
|
|
87
|
+
struct OperatingPoint {
|
|
88
|
+
double perf; ///< performance measure (output of a Criterion)
|
|
89
|
+
double t; ///< corresponding execution time (ms)
|
|
90
|
+
std::string key; ///< key that identifies this op pt
|
|
91
|
+
int64_t cno; ///< integer identifer
|
|
92
|
+
};
|
|
93
|
+
|
|
94
|
+
struct OperatingPoints {
|
|
95
|
+
/// all operating points
|
|
96
|
+
std::vector<OperatingPoint> all_pts;
|
|
97
|
+
|
|
98
|
+
/// optimal operating points, sorted by perf
|
|
99
|
+
std::vector<OperatingPoint> optimal_pts;
|
|
100
|
+
|
|
101
|
+
// begins with a single operating point: t=0, perf=0
|
|
102
|
+
OperatingPoints ();
|
|
103
|
+
|
|
104
|
+
/// add operating points from other to this, with a prefix to the keys
|
|
105
|
+
int merge_with (const OperatingPoints &other,
|
|
106
|
+
const std::string & prefix = "");
|
|
107
|
+
|
|
108
|
+
void clear ();
|
|
109
|
+
|
|
110
|
+
/// add a performance measure. Return whether it is an optimal point
|
|
111
|
+
bool add (double perf, double t, const std::string & key, size_t cno = 0);
|
|
112
|
+
|
|
113
|
+
/// get time required to obtain a given performance measure
|
|
114
|
+
double t_for_perf (double perf) const;
|
|
115
|
+
|
|
116
|
+
/// easy-to-read output
|
|
117
|
+
void display (bool only_optimal = true) const;
|
|
118
|
+
|
|
119
|
+
/// output to a format easy to digest by gnuplot
|
|
120
|
+
void all_to_gnuplot (const char *fname) const;
|
|
121
|
+
void optimal_to_gnuplot (const char *fname) const;
|
|
122
|
+
|
|
123
|
+
};
|
|
124
|
+
|
|
125
|
+
/// possible values of a parameter, sorted from least to most expensive/accurate
|
|
126
|
+
struct ParameterRange {
|
|
127
|
+
std::string name;
|
|
128
|
+
std::vector<double> values;
|
|
129
|
+
};
|
|
130
|
+
|
|
131
|
+
/** Uses a-priori knowledge on the Faiss indexes to extract tunable parameters.
|
|
132
|
+
*/
|
|
133
|
+
struct ParameterSpace {
|
|
134
|
+
/// all tunable parameters
|
|
135
|
+
std::vector<ParameterRange> parameter_ranges;
|
|
136
|
+
|
|
137
|
+
// exploration parameters
|
|
138
|
+
|
|
139
|
+
/// verbosity during exploration
|
|
140
|
+
int verbose;
|
|
141
|
+
|
|
142
|
+
/// nb of experiments during optimization (0 = try all combinations)
|
|
143
|
+
int n_experiments;
|
|
144
|
+
|
|
145
|
+
/// maximum number of queries to submit at a time.
|
|
146
|
+
size_t batchsize;
|
|
147
|
+
|
|
148
|
+
/// use multithreading over batches (useful to benchmark
|
|
149
|
+
/// independent single-searches)
|
|
150
|
+
bool thread_over_batches;
|
|
151
|
+
|
|
152
|
+
/// run tests several times until they reach at least this
|
|
153
|
+
/// duration (to avoid jittering in MT mode)
|
|
154
|
+
double min_test_duration;
|
|
155
|
+
|
|
156
|
+
ParameterSpace ();
|
|
157
|
+
|
|
158
|
+
/// nb of combinations, = product of values sizes
|
|
159
|
+
size_t n_combinations () const;
|
|
160
|
+
|
|
161
|
+
/// returns whether combinations c1 >= c2 in the tuple sense
|
|
162
|
+
bool combination_ge (size_t c1, size_t c2) const;
|
|
163
|
+
|
|
164
|
+
/// get string representation of the combination
|
|
165
|
+
std::string combination_name (size_t cno) const;
|
|
166
|
+
|
|
167
|
+
/// print a description on stdout
|
|
168
|
+
void display () const;
|
|
169
|
+
|
|
170
|
+
/// add a new parameter (or return it if it exists)
|
|
171
|
+
ParameterRange &add_range(const char * name);
|
|
172
|
+
|
|
173
|
+
/// initialize with reasonable parameters for the index
|
|
174
|
+
virtual void initialize (const Index * index);
|
|
175
|
+
|
|
176
|
+
/// set a combination of parameters on an index
|
|
177
|
+
void set_index_parameters (Index *index, size_t cno) const;
|
|
178
|
+
|
|
179
|
+
/// set a combination of parameters described by a string
|
|
180
|
+
void set_index_parameters (Index *index, const char *param_string) const;
|
|
181
|
+
|
|
182
|
+
/// set one of the parameters
|
|
183
|
+
virtual void set_index_parameter (
|
|
184
|
+
Index * index, const std::string & name, double val) const;
|
|
185
|
+
|
|
186
|
+
/** find an upper bound on the performance and a lower bound on t
|
|
187
|
+
* for configuration cno given another operating point op */
|
|
188
|
+
void update_bounds (size_t cno, const OperatingPoint & op,
|
|
189
|
+
double *upper_bound_perf,
|
|
190
|
+
double *lower_bound_t) const;
|
|
191
|
+
|
|
192
|
+
/** explore operating points
|
|
193
|
+
* @param index index to run on
|
|
194
|
+
* @param xq query vectors (size nq * index.d)
|
|
195
|
+
* @param crit selection criterion
|
|
196
|
+
* @param ops resulting operating points
|
|
197
|
+
*/
|
|
198
|
+
void explore (Index *index,
|
|
199
|
+
size_t nq, const float *xq,
|
|
200
|
+
const AutoTuneCriterion & crit,
|
|
201
|
+
OperatingPoints * ops) const;
|
|
202
|
+
|
|
203
|
+
virtual ~ParameterSpace () {}
|
|
204
|
+
};
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
} // namespace faiss
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
#endif
|
|
@@ -0,0 +1,261 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
// -*- c++ -*-
|
|
9
|
+
|
|
10
|
+
#include <faiss/Clustering.h>
|
|
11
|
+
#include <faiss/impl/AuxIndexStructures.h>
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
#include <cmath>
|
|
15
|
+
#include <cstdio>
|
|
16
|
+
#include <cstring>
|
|
17
|
+
|
|
18
|
+
#include <faiss/utils/utils.h>
|
|
19
|
+
#include <faiss/utils/random.h>
|
|
20
|
+
#include <faiss/utils/distances.h>
|
|
21
|
+
#include <faiss/impl/FaissAssert.h>
|
|
22
|
+
#include <faiss/IndexFlat.h>
|
|
23
|
+
|
|
24
|
+
namespace faiss {
|
|
25
|
+
|
|
26
|
+
ClusteringParameters::ClusteringParameters ():
|
|
27
|
+
niter(25),
|
|
28
|
+
nredo(1),
|
|
29
|
+
verbose(false),
|
|
30
|
+
spherical(false),
|
|
31
|
+
int_centroids(false),
|
|
32
|
+
update_index(false),
|
|
33
|
+
frozen_centroids(false),
|
|
34
|
+
min_points_per_centroid(39),
|
|
35
|
+
max_points_per_centroid(256),
|
|
36
|
+
seed(1234)
|
|
37
|
+
{}
|
|
38
|
+
// 39 corresponds to 10000 / 256 -> to avoid warnings on PQ tests with randu10k
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
Clustering::Clustering (int d, int k):
|
|
42
|
+
d(d), k(k) {}
|
|
43
|
+
|
|
44
|
+
Clustering::Clustering (int d, int k, const ClusteringParameters &cp):
|
|
45
|
+
ClusteringParameters (cp), d(d), k(k) {}
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
static double imbalance_factor (int n, int k, int64_t *assign) {
|
|
50
|
+
std::vector<int> hist(k, 0);
|
|
51
|
+
for (int i = 0; i < n; i++)
|
|
52
|
+
hist[assign[i]]++;
|
|
53
|
+
|
|
54
|
+
double tot = 0, uf = 0;
|
|
55
|
+
|
|
56
|
+
for (int i = 0 ; i < k ; i++) {
|
|
57
|
+
tot += hist[i];
|
|
58
|
+
uf += hist[i] * (double) hist[i];
|
|
59
|
+
}
|
|
60
|
+
uf = uf * k / (tot * tot);
|
|
61
|
+
|
|
62
|
+
return uf;
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
void Clustering::post_process_centroids ()
|
|
66
|
+
{
|
|
67
|
+
|
|
68
|
+
if (spherical) {
|
|
69
|
+
fvec_renorm_L2 (d, k, centroids.data());
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
if (int_centroids) {
|
|
73
|
+
for (size_t i = 0; i < centroids.size(); i++)
|
|
74
|
+
centroids[i] = roundf (centroids[i]);
|
|
75
|
+
}
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
void Clustering::train (idx_t nx, const float *x_in, Index & index) {
|
|
80
|
+
FAISS_THROW_IF_NOT_FMT (nx >= k,
|
|
81
|
+
"Number of training points (%ld) should be at least "
|
|
82
|
+
"as large as number of clusters (%ld)", nx, k);
|
|
83
|
+
|
|
84
|
+
double t0 = getmillisecs();
|
|
85
|
+
|
|
86
|
+
// yes it is the user's responsibility, but it may spare us some
|
|
87
|
+
// hard-to-debug reports.
|
|
88
|
+
for (size_t i = 0; i < nx * d; i++) {
|
|
89
|
+
FAISS_THROW_IF_NOT_MSG (finite (x_in[i]),
|
|
90
|
+
"input contains NaN's or Inf's");
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
const float *x = x_in;
|
|
94
|
+
ScopeDeleter<float> del1;
|
|
95
|
+
|
|
96
|
+
if (nx > k * max_points_per_centroid) {
|
|
97
|
+
if (verbose)
|
|
98
|
+
printf("Sampling a subset of %ld / %ld for training\n",
|
|
99
|
+
k * max_points_per_centroid, nx);
|
|
100
|
+
std::vector<int> perm (nx);
|
|
101
|
+
rand_perm (perm.data (), nx, seed);
|
|
102
|
+
nx = k * max_points_per_centroid;
|
|
103
|
+
float * x_new = new float [nx * d];
|
|
104
|
+
for (idx_t i = 0; i < nx; i++)
|
|
105
|
+
memcpy (x_new + i * d, x + perm[i] * d, sizeof(x_new[0]) * d);
|
|
106
|
+
x = x_new;
|
|
107
|
+
del1.set (x);
|
|
108
|
+
} else if (nx < k * min_points_per_centroid) {
|
|
109
|
+
fprintf (stderr,
|
|
110
|
+
"WARNING clustering %ld points to %ld centroids: "
|
|
111
|
+
"please provide at least %ld training points\n",
|
|
112
|
+
nx, k, idx_t(k) * min_points_per_centroid);
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
if (nx == k) {
|
|
117
|
+
if (verbose) {
|
|
118
|
+
printf("Number of training points (%ld) same as number of "
|
|
119
|
+
"clusters, just copying\n", nx);
|
|
120
|
+
}
|
|
121
|
+
// this is a corner case, just copy training set to clusters
|
|
122
|
+
centroids.resize (d * k);
|
|
123
|
+
memcpy (centroids.data(), x_in, sizeof (*x_in) * d * k);
|
|
124
|
+
index.reset();
|
|
125
|
+
index.add(k, x_in);
|
|
126
|
+
return;
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
if (verbose)
|
|
131
|
+
printf("Clustering %d points in %ldD to %ld clusters, "
|
|
132
|
+
"redo %d times, %d iterations\n",
|
|
133
|
+
int(nx), d, k, nredo, niter);
|
|
134
|
+
|
|
135
|
+
idx_t * assign = new idx_t[nx];
|
|
136
|
+
ScopeDeleter<idx_t> del (assign);
|
|
137
|
+
float * dis = new float[nx];
|
|
138
|
+
ScopeDeleter<float> del2(dis);
|
|
139
|
+
|
|
140
|
+
// for redo
|
|
141
|
+
float best_err = HUGE_VALF;
|
|
142
|
+
std::vector<float> best_obj;
|
|
143
|
+
std::vector<float> best_centroids;
|
|
144
|
+
|
|
145
|
+
// support input centroids
|
|
146
|
+
|
|
147
|
+
FAISS_THROW_IF_NOT_MSG (
|
|
148
|
+
centroids.size() % d == 0,
|
|
149
|
+
"size of provided input centroids not a multiple of dimension");
|
|
150
|
+
|
|
151
|
+
size_t n_input_centroids = centroids.size() / d;
|
|
152
|
+
|
|
153
|
+
if (verbose && n_input_centroids > 0) {
|
|
154
|
+
printf (" Using %zd centroids provided as input (%sfrozen)\n",
|
|
155
|
+
n_input_centroids, frozen_centroids ? "" : "not ");
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
double t_search_tot = 0;
|
|
159
|
+
if (verbose) {
|
|
160
|
+
printf(" Preprocessing in %.2f s\n",
|
|
161
|
+
(getmillisecs() - t0) / 1000.);
|
|
162
|
+
}
|
|
163
|
+
t0 = getmillisecs();
|
|
164
|
+
|
|
165
|
+
for (int redo = 0; redo < nredo; redo++) {
|
|
166
|
+
|
|
167
|
+
if (verbose && nredo > 1) {
|
|
168
|
+
printf("Outer iteration %d / %d\n", redo, nredo);
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
// initialize remaining centroids with random points from the dataset
|
|
172
|
+
centroids.resize (d * k);
|
|
173
|
+
std::vector<int> perm (nx);
|
|
174
|
+
|
|
175
|
+
rand_perm (perm.data(), nx, seed + 1 + redo * 15486557L);
|
|
176
|
+
for (int i = n_input_centroids; i < k ; i++)
|
|
177
|
+
memcpy (¢roids[i * d], x + perm[i] * d,
|
|
178
|
+
d * sizeof (float));
|
|
179
|
+
|
|
180
|
+
post_process_centroids ();
|
|
181
|
+
|
|
182
|
+
if (index.ntotal != 0) {
|
|
183
|
+
index.reset();
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
if (!index.is_trained) {
|
|
187
|
+
index.train (k, centroids.data());
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
index.add (k, centroids.data());
|
|
191
|
+
float err = 0;
|
|
192
|
+
for (int i = 0; i < niter; i++) {
|
|
193
|
+
double t0s = getmillisecs();
|
|
194
|
+
index.search (nx, x, 1, dis, assign);
|
|
195
|
+
InterruptCallback::check();
|
|
196
|
+
t_search_tot += getmillisecs() - t0s;
|
|
197
|
+
|
|
198
|
+
err = 0;
|
|
199
|
+
for (int j = 0; j < nx; j++)
|
|
200
|
+
err += dis[j];
|
|
201
|
+
obj.push_back (err);
|
|
202
|
+
|
|
203
|
+
int nsplit = km_update_centroids (
|
|
204
|
+
x, centroids.data(),
|
|
205
|
+
assign, d, k, nx, frozen_centroids ? n_input_centroids : 0);
|
|
206
|
+
|
|
207
|
+
if (verbose) {
|
|
208
|
+
printf (" Iteration %d (%.2f s, search %.2f s): "
|
|
209
|
+
"objective=%g imbalance=%.3f nsplit=%d \r",
|
|
210
|
+
i, (getmillisecs() - t0) / 1000.0,
|
|
211
|
+
t_search_tot / 1000,
|
|
212
|
+
err, imbalance_factor (nx, k, assign),
|
|
213
|
+
nsplit);
|
|
214
|
+
fflush (stdout);
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
post_process_centroids ();
|
|
218
|
+
|
|
219
|
+
index.reset ();
|
|
220
|
+
if (update_index)
|
|
221
|
+
index.train (k, centroids.data());
|
|
222
|
+
|
|
223
|
+
assert (index.ntotal == 0);
|
|
224
|
+
index.add (k, centroids.data());
|
|
225
|
+
InterruptCallback::check ();
|
|
226
|
+
}
|
|
227
|
+
if (verbose) printf("\n");
|
|
228
|
+
if (nredo > 1) {
|
|
229
|
+
if (err < best_err) {
|
|
230
|
+
if (verbose)
|
|
231
|
+
printf ("Objective improved: keep new clusters\n");
|
|
232
|
+
best_centroids = centroids;
|
|
233
|
+
best_obj = obj;
|
|
234
|
+
best_err = err;
|
|
235
|
+
}
|
|
236
|
+
index.reset ();
|
|
237
|
+
}
|
|
238
|
+
}
|
|
239
|
+
if (nredo > 1) {
|
|
240
|
+
centroids = best_centroids;
|
|
241
|
+
obj = best_obj;
|
|
242
|
+
index.reset();
|
|
243
|
+
index.add(k, best_centroids.data());
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
float kmeans_clustering (size_t d, size_t n, size_t k,
|
|
249
|
+
const float *x,
|
|
250
|
+
float *centroids)
|
|
251
|
+
{
|
|
252
|
+
Clustering clus (d, k);
|
|
253
|
+
clus.verbose = d * n * k > (1L << 30);
|
|
254
|
+
// display logs if > 1Gflop per iteration
|
|
255
|
+
IndexFlatL2 index (d);
|
|
256
|
+
clus.train (n, x, index);
|
|
257
|
+
memcpy(centroids, clus.centroids.data(), sizeof(*centroids) * d * k);
|
|
258
|
+
return clus.obj.back();
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
} // namespace faiss
|