faiss 0.2.3 → 0.2.5
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/LICENSE.txt +1 -1
- data/README.md +23 -21
- data/ext/faiss/extconf.rb +11 -0
- data/ext/faiss/index.cpp +4 -4
- data/ext/faiss/index_binary.cpp +6 -6
- data/ext/faiss/product_quantizer.cpp +4 -4
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +13 -0
- data/vendor/faiss/faiss/Clustering.cpp +32 -0
- data/vendor/faiss/faiss/Clustering.h +14 -0
- data/vendor/faiss/faiss/IVFlib.cpp +101 -2
- data/vendor/faiss/faiss/IVFlib.h +26 -2
- data/vendor/faiss/faiss/Index.cpp +36 -3
- data/vendor/faiss/faiss/Index.h +43 -6
- data/vendor/faiss/faiss/Index2Layer.cpp +24 -93
- data/vendor/faiss/faiss/Index2Layer.h +8 -17
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +610 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +253 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
- data/vendor/faiss/faiss/IndexBinary.h +18 -3
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
- data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
- data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
- data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
- data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
- data/vendor/faiss/faiss/IndexFastScan.h +145 -0
- data/vendor/faiss/faiss/IndexFlat.cpp +52 -69
- data/vendor/faiss/faiss/IndexFlat.h +16 -19
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +101 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +59 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
- data/vendor/faiss/faiss/IndexHNSW.h +4 -2
- data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
- data/vendor/faiss/faiss/IndexIDMap.h +107 -0
- data/vendor/faiss/faiss/IndexIVF.cpp +200 -40
- data/vendor/faiss/faiss/IndexIVF.h +59 -22
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +393 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +183 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +43 -26
- data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +238 -53
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -2
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +63 -40
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +23 -7
- data/vendor/faiss/faiss/IndexLSH.cpp +8 -32
- data/vendor/faiss/faiss/IndexLSH.h +4 -16
- data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
- data/vendor/faiss/faiss/IndexLattice.h +3 -1
- data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -5
- data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +37 -5
- data/vendor/faiss/faiss/IndexNSG.h +25 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +108 -120
- data/vendor/faiss/faiss/IndexPQ.h +21 -22
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
- data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
- data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
- data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
- data/vendor/faiss/faiss/IndexRefine.cpp +36 -4
- data/vendor/faiss/faiss/IndexRefine.h +14 -2
- data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
- data/vendor/faiss/faiss/IndexReplicas.h +2 -1
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +28 -43
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +8 -23
- data/vendor/faiss/faiss/IndexShards.cpp +4 -1
- data/vendor/faiss/faiss/IndexShards.h +2 -1
- data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
- data/vendor/faiss/faiss/MetaIndexes.h +3 -81
- data/vendor/faiss/faiss/VectorTransform.cpp +45 -1
- data/vendor/faiss/faiss/VectorTransform.h +25 -4
- data/vendor/faiss/faiss/clone_index.cpp +26 -3
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
- data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -6
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
- data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
- data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +331 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +110 -19
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
- data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +133 -32
- data/vendor/faiss/faiss/impl/HNSW.h +19 -16
- data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
- data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +378 -217
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +106 -29
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -4
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
- data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +521 -55
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +94 -16
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +108 -191
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
- data/vendor/faiss/faiss/impl/index_read.cpp +338 -24
- data/vendor/faiss/faiss/impl/index_write.cpp +300 -18
- data/vendor/faiss/faiss/impl/io.cpp +1 -1
- data/vendor/faiss/faiss/impl/io_macros.h +20 -0
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +303 -0
- data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
- data/vendor/faiss/faiss/index_factory.cpp +772 -412
- data/vendor/faiss/faiss/index_factory.h +3 -0
- data/vendor/faiss/faiss/index_io.h +5 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
- data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
- data/vendor/faiss/faiss/utils/Heap.h +31 -15
- data/vendor/faiss/faiss/utils/distances.cpp +384 -58
- data/vendor/faiss/faiss/utils/distances.h +149 -18
- data/vendor/faiss/faiss/utils/distances_simd.cpp +776 -6
- data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
- data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
- data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
- data/vendor/faiss/faiss/utils/fp16.h +11 -0
- data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
- data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
- data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
- data/vendor/faiss/faiss/utils/random.cpp +53 -0
- data/vendor/faiss/faiss/utils/random.h +5 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
- data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
- data/vendor/faiss/faiss/utils/utils.h +1 -1
- metadata +46 -5
- data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
- data/vendor/faiss/faiss/IndexResidual.h +0 -152
@@ -5,20 +5,15 @@
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
6
6
|
*/
|
7
7
|
|
8
|
-
// -*- c++ -*-
|
9
|
-
|
10
|
-
#include "faiss/impl/ResidualQuantizer.h"
|
11
|
-
#include <faiss/impl/FaissAssert.h>
|
12
8
|
#include <faiss/impl/ResidualQuantizer.h>
|
13
|
-
#include "faiss/utils/utils.h"
|
14
9
|
|
10
|
+
#include <algorithm>
|
11
|
+
#include <cmath>
|
15
12
|
#include <cstddef>
|
16
13
|
#include <cstdio>
|
17
14
|
#include <cstring>
|
18
15
|
#include <memory>
|
19
16
|
|
20
|
-
#include <algorithm>
|
21
|
-
|
22
17
|
#include <faiss/IndexFlat.h>
|
23
18
|
#include <faiss/VectorTransform.h>
|
24
19
|
#include <faiss/impl/AuxIndexStructures.h>
|
@@ -28,39 +23,109 @@
|
|
28
23
|
#include <faiss/utils/hamming.h>
|
29
24
|
#include <faiss/utils/utils.h>
|
30
25
|
|
26
|
+
extern "C" {
|
27
|
+
|
28
|
+
// general matrix multiplication
|
29
|
+
int sgemm_(
|
30
|
+
const char* transa,
|
31
|
+
const char* transb,
|
32
|
+
FINTEGER* m,
|
33
|
+
FINTEGER* n,
|
34
|
+
FINTEGER* k,
|
35
|
+
const float* alpha,
|
36
|
+
const float* a,
|
37
|
+
FINTEGER* lda,
|
38
|
+
const float* b,
|
39
|
+
FINTEGER* ldb,
|
40
|
+
float* beta,
|
41
|
+
float* c,
|
42
|
+
FINTEGER* ldc);
|
43
|
+
|
44
|
+
// http://www.netlib.org/clapack/old/single/sgels.c
|
45
|
+
// solve least squares
|
46
|
+
|
47
|
+
int sgelsd_(
|
48
|
+
FINTEGER* m,
|
49
|
+
FINTEGER* n,
|
50
|
+
FINTEGER* nrhs,
|
51
|
+
float* a,
|
52
|
+
FINTEGER* lda,
|
53
|
+
float* b,
|
54
|
+
FINTEGER* ldb,
|
55
|
+
float* s,
|
56
|
+
float* rcond,
|
57
|
+
FINTEGER* rank,
|
58
|
+
float* work,
|
59
|
+
FINTEGER* lwork,
|
60
|
+
FINTEGER* iwork,
|
61
|
+
FINTEGER* info);
|
62
|
+
}
|
63
|
+
|
31
64
|
namespace faiss {
|
32
65
|
|
33
66
|
ResidualQuantizer::ResidualQuantizer()
|
34
67
|
: train_type(Train_progressive_dim),
|
35
|
-
|
36
|
-
|
68
|
+
niter_codebook_refine(5),
|
69
|
+
max_beam_size(5),
|
70
|
+
use_beam_LUT(0),
|
37
71
|
assign_index_factory(nullptr) {
|
38
72
|
d = 0;
|
39
73
|
M = 0;
|
40
74
|
verbose = false;
|
41
75
|
}
|
42
76
|
|
43
|
-
ResidualQuantizer::ResidualQuantizer(
|
77
|
+
ResidualQuantizer::ResidualQuantizer(
|
78
|
+
size_t d,
|
79
|
+
const std::vector<size_t>& nbits,
|
80
|
+
Search_type_t search_type)
|
44
81
|
: ResidualQuantizer() {
|
82
|
+
this->search_type = search_type;
|
45
83
|
this->d = d;
|
46
84
|
M = nbits.size();
|
47
85
|
this->nbits = nbits;
|
48
86
|
set_derived_values();
|
49
87
|
}
|
50
88
|
|
51
|
-
ResidualQuantizer::ResidualQuantizer(
|
52
|
-
|
89
|
+
ResidualQuantizer::ResidualQuantizer(
|
90
|
+
size_t d,
|
91
|
+
size_t M,
|
92
|
+
size_t nbits,
|
93
|
+
Search_type_t search_type)
|
94
|
+
: ResidualQuantizer(d, std::vector<size_t>(M, nbits), search_type) {}
|
95
|
+
|
96
|
+
void ResidualQuantizer::initialize_from(
|
97
|
+
const ResidualQuantizer& other,
|
98
|
+
int skip_M) {
|
99
|
+
FAISS_THROW_IF_NOT(M + skip_M <= other.M);
|
100
|
+
FAISS_THROW_IF_NOT(skip_M >= 0);
|
101
|
+
|
102
|
+
Search_type_t this_search_type = search_type;
|
103
|
+
int this_M = M;
|
104
|
+
|
105
|
+
// a first good approximation: override everything
|
106
|
+
*this = other;
|
107
|
+
|
108
|
+
// adjust derived values
|
109
|
+
M = this_M;
|
110
|
+
search_type = this_search_type;
|
111
|
+
nbits.resize(M);
|
112
|
+
memcpy(nbits.data(),
|
113
|
+
other.nbits.data() + skip_M,
|
114
|
+
nbits.size() * sizeof(nbits[0]));
|
53
115
|
|
54
|
-
|
116
|
+
set_derived_values();
|
55
117
|
|
56
|
-
|
57
|
-
|
58
|
-
|
118
|
+
// resize codebooks if trained
|
119
|
+
if (codebooks.size() > 0) {
|
120
|
+
FAISS_THROW_IF_NOT(codebooks.size() == other.total_codebook_size * d);
|
121
|
+
codebooks.resize(total_codebook_size * d);
|
122
|
+
memcpy(codebooks.data(),
|
123
|
+
other.codebooks.data() + other.codebook_offsets[skip_M] * d,
|
124
|
+
codebooks.size() * sizeof(codebooks[0]));
|
125
|
+
// TODO: norm_tabs?
|
59
126
|
}
|
60
127
|
}
|
61
128
|
|
62
|
-
} // anonymous namespace
|
63
|
-
|
64
129
|
void beam_search_encode_step(
|
65
130
|
size_t d,
|
66
131
|
size_t K,
|
@@ -90,7 +155,7 @@ void beam_search_encode_step(
|
|
90
155
|
cent_ids.resize(n * beam_size * new_beam_size);
|
91
156
|
if (assign_index->ntotal != 0) {
|
92
157
|
// then we assume the codebooks are already added to the index
|
93
|
-
FAISS_THROW_IF_NOT(assign_index->ntotal
|
158
|
+
FAISS_THROW_IF_NOT(assign_index->ntotal == K);
|
94
159
|
} else {
|
95
160
|
assign_index->add(K, cent);
|
96
161
|
}
|
@@ -208,6 +273,7 @@ void ResidualQuantizer::train(size_t n, const float* x) {
|
|
208
273
|
std::vector<int32_t> codes;
|
209
274
|
std::vector<float> distances;
|
210
275
|
double t0 = getmillisecs();
|
276
|
+
double clustering_time = 0;
|
211
277
|
|
212
278
|
for (int m = 0; m < M; m++) {
|
213
279
|
int K = 1 << nbits[m];
|
@@ -224,8 +290,6 @@ void ResidualQuantizer::train(size_t n, const float* x) {
|
|
224
290
|
}
|
225
291
|
train_residuals = residuals1;
|
226
292
|
}
|
227
|
-
train_type_t tt = train_type_t(train_type & ~Train_top_beam);
|
228
|
-
|
229
293
|
std::vector<float> codebooks;
|
230
294
|
float obj = 0;
|
231
295
|
|
@@ -235,7 +299,10 @@ void ResidualQuantizer::train(size_t n, const float* x) {
|
|
235
299
|
} else {
|
236
300
|
assign_index.reset(new IndexFlatL2(d));
|
237
301
|
}
|
238
|
-
|
302
|
+
|
303
|
+
double t1 = getmillisecs();
|
304
|
+
|
305
|
+
if (!(train_type & Train_progressive_dim)) { // regular kmeans
|
239
306
|
Clustering clus(d, K, cp);
|
240
307
|
clus.train(
|
241
308
|
train_residuals.size() / d,
|
@@ -244,7 +311,7 @@ void ResidualQuantizer::train(size_t n, const float* x) {
|
|
244
311
|
codebooks.swap(clus.centroids);
|
245
312
|
assign_index->reset();
|
246
313
|
obj = clus.iteration_stats.back().obj;
|
247
|
-
} else
|
314
|
+
} else { // progressive dim clustering
|
248
315
|
ProgressiveDimClustering clus(d, K, cp);
|
249
316
|
ProgressiveDimIndexFactory default_fac;
|
250
317
|
clus.train(
|
@@ -253,9 +320,8 @@ void ResidualQuantizer::train(size_t n, const float* x) {
|
|
253
320
|
assign_index_factory ? *assign_index_factory : default_fac);
|
254
321
|
codebooks.swap(clus.centroids);
|
255
322
|
obj = clus.iteration_stats.back().obj;
|
256
|
-
} else {
|
257
|
-
FAISS_THROW_MSG("train type not supported");
|
258
323
|
}
|
324
|
+
clustering_time += (getmillisecs() - t1) / 1000;
|
259
325
|
|
260
326
|
memcpy(this->codebooks.data() + codebook_offsets[m] * d,
|
261
327
|
codebooks.data(),
|
@@ -268,21 +334,38 @@ void ResidualQuantizer::train(size_t n, const float* x) {
|
|
268
334
|
std::vector<float> new_residuals(n * new_beam_size * d);
|
269
335
|
std::vector<float> new_distances(n * new_beam_size);
|
270
336
|
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
new_codes.data(),
|
282
|
-
new_residuals.data(),
|
283
|
-
new_distances.data(),
|
284
|
-
assign_index.get());
|
337
|
+
size_t bs;
|
338
|
+
{ // determine batch size
|
339
|
+
size_t mem = memory_per_point();
|
340
|
+
if (n > 1 && mem * n > max_mem_distances) {
|
341
|
+
// then split queries to reduce temp memory
|
342
|
+
bs = std::max(max_mem_distances / mem, size_t(1));
|
343
|
+
} else {
|
344
|
+
bs = n;
|
345
|
+
}
|
346
|
+
}
|
285
347
|
|
348
|
+
for (size_t i0 = 0; i0 < n; i0 += bs) {
|
349
|
+
size_t i1 = std::min(i0 + bs, n);
|
350
|
+
|
351
|
+
/* printf("i0: %ld i1: %ld K %d ntotal assign index %ld\n",
|
352
|
+
i0, i1, K, assign_index->ntotal); */
|
353
|
+
|
354
|
+
beam_search_encode_step(
|
355
|
+
d,
|
356
|
+
K,
|
357
|
+
codebooks.data(),
|
358
|
+
i1 - i0,
|
359
|
+
cur_beam_size,
|
360
|
+
residuals.data() + i0 * cur_beam_size * d,
|
361
|
+
m,
|
362
|
+
codes.data() + i0 * cur_beam_size * m,
|
363
|
+
new_beam_size,
|
364
|
+
new_codes.data() + i0 * new_beam_size * (m + 1),
|
365
|
+
new_residuals.data() + i0 * new_beam_size * d,
|
366
|
+
new_distances.data() + i0 * new_beam_size,
|
367
|
+
assign_index.get());
|
368
|
+
}
|
286
369
|
codes.swap(new_codes);
|
287
370
|
residuals.swap(new_residuals);
|
288
371
|
distances.swap(new_distances);
|
@@ -293,20 +376,165 @@ void ResidualQuantizer::train(size_t n, const float* x) {
|
|
293
376
|
}
|
294
377
|
|
295
378
|
if (verbose) {
|
296
|
-
printf("[%.3f s] train stage %d, %d bits, kmeans objective %g, "
|
297
|
-
"total distance %g, beam_size %d->%d\n",
|
379
|
+
printf("[%.3f s, %.3f s clustering] train stage %d, %d bits, kmeans objective %g, "
|
380
|
+
"total distance %g, beam_size %d->%d (batch size %zd)\n",
|
298
381
|
(getmillisecs() - t0) / 1000,
|
382
|
+
clustering_time,
|
299
383
|
m,
|
300
384
|
int(nbits[m]),
|
301
385
|
obj,
|
302
386
|
sum_distances,
|
303
387
|
cur_beam_size,
|
304
|
-
new_beam_size
|
388
|
+
new_beam_size,
|
389
|
+
bs);
|
305
390
|
}
|
306
391
|
cur_beam_size = new_beam_size;
|
307
392
|
}
|
308
393
|
|
309
394
|
is_trained = true;
|
395
|
+
|
396
|
+
if (train_type & Train_refine_codebook) {
|
397
|
+
for (int iter = 0; iter < niter_codebook_refine; iter++) {
|
398
|
+
if (verbose) {
|
399
|
+
printf("re-estimating the codebooks to minimize "
|
400
|
+
"quantization errors (iter %d).\n",
|
401
|
+
iter);
|
402
|
+
}
|
403
|
+
retrain_AQ_codebook(n, x);
|
404
|
+
}
|
405
|
+
}
|
406
|
+
|
407
|
+
// find min and max norms
|
408
|
+
std::vector<float> norms(n);
|
409
|
+
|
410
|
+
for (size_t i = 0; i < n; i++) {
|
411
|
+
norms[i] = fvec_L2sqr(
|
412
|
+
x + i * d, residuals.data() + i * cur_beam_size * d, d);
|
413
|
+
}
|
414
|
+
|
415
|
+
// fvec_norms_L2sqr(norms.data(), x, d, n);
|
416
|
+
train_norm(n, norms.data());
|
417
|
+
|
418
|
+
if (!(train_type & Skip_codebook_tables)) {
|
419
|
+
compute_codebook_tables();
|
420
|
+
}
|
421
|
+
}
|
422
|
+
|
423
|
+
float ResidualQuantizer::retrain_AQ_codebook(size_t n, const float* x) {
|
424
|
+
FAISS_THROW_IF_NOT_MSG(n >= total_codebook_size, "too few training points");
|
425
|
+
|
426
|
+
if (verbose) {
|
427
|
+
printf(" encoding %zd training vectors\n", n);
|
428
|
+
}
|
429
|
+
std::vector<uint8_t> codes(n * code_size);
|
430
|
+
compute_codes(x, codes.data(), n);
|
431
|
+
|
432
|
+
// compute reconstruction error
|
433
|
+
float input_recons_error;
|
434
|
+
{
|
435
|
+
std::vector<float> x_recons(n * d);
|
436
|
+
decode(codes.data(), x_recons.data(), n);
|
437
|
+
input_recons_error = fvec_L2sqr(x, x_recons.data(), n * d);
|
438
|
+
if (verbose) {
|
439
|
+
printf(" input quantization error %g\n", input_recons_error);
|
440
|
+
}
|
441
|
+
}
|
442
|
+
|
443
|
+
// build matrix of the linear system
|
444
|
+
std::vector<float> C(n * total_codebook_size);
|
445
|
+
for (size_t i = 0; i < n; i++) {
|
446
|
+
BitstringReader bsr(codes.data() + i * code_size, code_size);
|
447
|
+
for (int m = 0; m < M; m++) {
|
448
|
+
int idx = bsr.read(nbits[m]);
|
449
|
+
C[i + (codebook_offsets[m] + idx) * n] = 1;
|
450
|
+
}
|
451
|
+
}
|
452
|
+
|
453
|
+
// transpose training vectors
|
454
|
+
std::vector<float> xt(n * d);
|
455
|
+
|
456
|
+
for (size_t i = 0; i < n; i++) {
|
457
|
+
for (size_t j = 0; j < d; j++) {
|
458
|
+
xt[j * n + i] = x[i * d + j];
|
459
|
+
}
|
460
|
+
}
|
461
|
+
|
462
|
+
{ // solve least squares
|
463
|
+
FINTEGER lwork = -1;
|
464
|
+
FINTEGER di = d, ni = n, tcsi = total_codebook_size;
|
465
|
+
FINTEGER info = -1, rank = -1;
|
466
|
+
|
467
|
+
float rcond = 1e-4; // this is an important parameter because the code
|
468
|
+
// matrix can be rank deficient for small problems,
|
469
|
+
// the default rcond=-1 does not work
|
470
|
+
float worksize;
|
471
|
+
std::vector<float> sing_vals(total_codebook_size);
|
472
|
+
FINTEGER nlvl = 1000; // formula is a bit convoluted so let's take an
|
473
|
+
// upper bound
|
474
|
+
std::vector<FINTEGER> iwork(
|
475
|
+
3 * total_codebook_size * nlvl + 11 * total_codebook_size);
|
476
|
+
|
477
|
+
// worksize query
|
478
|
+
sgelsd_(&ni,
|
479
|
+
&tcsi,
|
480
|
+
&di,
|
481
|
+
C.data(),
|
482
|
+
&ni,
|
483
|
+
xt.data(),
|
484
|
+
&ni,
|
485
|
+
sing_vals.data(),
|
486
|
+
&rcond,
|
487
|
+
&rank,
|
488
|
+
&worksize,
|
489
|
+
&lwork,
|
490
|
+
iwork.data(),
|
491
|
+
&info);
|
492
|
+
FAISS_THROW_IF_NOT(info == 0);
|
493
|
+
|
494
|
+
lwork = worksize;
|
495
|
+
std::vector<float> work(lwork);
|
496
|
+
// actual call
|
497
|
+
sgelsd_(&ni,
|
498
|
+
&tcsi,
|
499
|
+
&di,
|
500
|
+
C.data(),
|
501
|
+
&ni,
|
502
|
+
xt.data(),
|
503
|
+
&ni,
|
504
|
+
sing_vals.data(),
|
505
|
+
&rcond,
|
506
|
+
&rank,
|
507
|
+
work.data(),
|
508
|
+
&lwork,
|
509
|
+
iwork.data(),
|
510
|
+
&info);
|
511
|
+
FAISS_THROW_IF_NOT_FMT(info == 0, "SGELS returned info=%d", int(info));
|
512
|
+
if (verbose) {
|
513
|
+
printf(" sgelsd rank=%d/%d\n",
|
514
|
+
int(rank),
|
515
|
+
int(total_codebook_size));
|
516
|
+
}
|
517
|
+
}
|
518
|
+
|
519
|
+
// result is in xt, re-transpose to codebook
|
520
|
+
|
521
|
+
for (size_t i = 0; i < total_codebook_size; i++) {
|
522
|
+
for (size_t j = 0; j < d; j++) {
|
523
|
+
codebooks[i * d + j] = xt[j * n + i];
|
524
|
+
FAISS_THROW_IF_NOT(std::isfinite(codebooks[i * d + j]));
|
525
|
+
}
|
526
|
+
}
|
527
|
+
|
528
|
+
float output_recons_error = 0;
|
529
|
+
for (size_t j = 0; j < d; j++) {
|
530
|
+
output_recons_error += fvec_norm_L2sqr(
|
531
|
+
xt.data() + total_codebook_size + n * j,
|
532
|
+
n - total_codebook_size);
|
533
|
+
}
|
534
|
+
if (verbose) {
|
535
|
+
printf(" output quantization error %g\n", output_recons_error);
|
536
|
+
}
|
537
|
+
return output_recons_error;
|
310
538
|
}
|
311
539
|
|
312
540
|
size_t ResidualQuantizer::memory_per_point(int beam_size) const {
|
@@ -321,10 +549,11 @@ size_t ResidualQuantizer::memory_per_point(int beam_size) const {
|
|
321
549
|
return mem;
|
322
550
|
}
|
323
551
|
|
324
|
-
void ResidualQuantizer::
|
552
|
+
void ResidualQuantizer::compute_codes_add_centroids(
|
325
553
|
const float* x,
|
326
554
|
uint8_t* codes_out,
|
327
|
-
size_t n
|
555
|
+
size_t n,
|
556
|
+
const float* centroids) const {
|
328
557
|
FAISS_THROW_IF_NOT_MSG(is_trained, "RQ is not trained yet.");
|
329
558
|
|
330
559
|
size_t mem = memory_per_point();
|
@@ -336,27 +565,87 @@ void ResidualQuantizer::compute_codes(
|
|
336
565
|
}
|
337
566
|
for (size_t i0 = 0; i0 < n; i0 += bs) {
|
338
567
|
size_t i1 = std::min(n, i0 + bs);
|
339
|
-
|
568
|
+
const float* cent = nullptr;
|
569
|
+
if (centroids != nullptr) {
|
570
|
+
cent = centroids + i0 * d;
|
571
|
+
}
|
572
|
+
compute_codes_add_centroids(
|
573
|
+
x + i0 * d, codes_out + i0 * code_size, i1 - i0, cent);
|
340
574
|
}
|
341
575
|
return;
|
342
576
|
}
|
343
577
|
|
344
|
-
std::vector<float> residuals(max_beam_size * n * d);
|
345
578
|
std::vector<int32_t> codes(max_beam_size * M * n);
|
579
|
+
std::vector<float> norms;
|
346
580
|
std::vector<float> distances(max_beam_size * n);
|
347
581
|
|
348
|
-
|
349
|
-
|
350
|
-
1,
|
351
|
-
x,
|
352
|
-
max_beam_size,
|
353
|
-
codes.data(),
|
354
|
-
residuals.data(),
|
355
|
-
distances.data());
|
582
|
+
if (use_beam_LUT == 0) {
|
583
|
+
std::vector<float> residuals(max_beam_size * n * d);
|
356
584
|
|
585
|
+
refine_beam(
|
586
|
+
n,
|
587
|
+
1,
|
588
|
+
x,
|
589
|
+
max_beam_size,
|
590
|
+
codes.data(),
|
591
|
+
residuals.data(),
|
592
|
+
distances.data());
|
593
|
+
|
594
|
+
if (search_type == ST_norm_float || search_type == ST_norm_qint8 ||
|
595
|
+
search_type == ST_norm_qint4) {
|
596
|
+
norms.resize(n);
|
597
|
+
// recover the norms of reconstruction as
|
598
|
+
// || original_vector - residual ||^2
|
599
|
+
for (size_t i = 0; i < n; i++) {
|
600
|
+
norms[i] = fvec_L2sqr(
|
601
|
+
x + i * d, residuals.data() + i * max_beam_size * d, d);
|
602
|
+
}
|
603
|
+
}
|
604
|
+
} else if (use_beam_LUT == 1) {
|
605
|
+
FAISS_THROW_IF_NOT_MSG(
|
606
|
+
codebook_cross_products.size() ==
|
607
|
+
total_codebook_size * total_codebook_size,
|
608
|
+
"call compute_codebook_tables first");
|
609
|
+
|
610
|
+
std::vector<float> query_norms(n);
|
611
|
+
fvec_norms_L2sqr(query_norms.data(), x, d, n);
|
612
|
+
|
613
|
+
std::vector<float> query_cp(n * total_codebook_size);
|
614
|
+
{
|
615
|
+
FINTEGER ti = total_codebook_size, di = d, ni = n;
|
616
|
+
float zero = 0, one = 1;
|
617
|
+
sgemm_("Transposed",
|
618
|
+
"Not transposed",
|
619
|
+
&ti,
|
620
|
+
&ni,
|
621
|
+
&di,
|
622
|
+
&one,
|
623
|
+
codebooks.data(),
|
624
|
+
&di,
|
625
|
+
x,
|
626
|
+
&di,
|
627
|
+
&zero,
|
628
|
+
query_cp.data(),
|
629
|
+
&ti);
|
630
|
+
}
|
631
|
+
|
632
|
+
refine_beam_LUT(
|
633
|
+
n,
|
634
|
+
query_norms.data(),
|
635
|
+
query_cp.data(),
|
636
|
+
max_beam_size,
|
637
|
+
codes.data(),
|
638
|
+
distances.data());
|
639
|
+
}
|
357
640
|
// pack only the first code of the beam (hence the ld_codes=M *
|
358
641
|
// max_beam_size)
|
359
|
-
pack_codes(
|
642
|
+
pack_codes(
|
643
|
+
n,
|
644
|
+
codes.data(),
|
645
|
+
codes_out,
|
646
|
+
M * max_beam_size,
|
647
|
+
norms.size() > 0 ? norms.data() : nullptr,
|
648
|
+
centroids);
|
360
649
|
}
|
361
650
|
|
362
651
|
void ResidualQuantizer::refine_beam(
|
@@ -445,4 +734,181 @@ void ResidualQuantizer::refine_beam(
|
|
445
734
|
}
|
446
735
|
}
|
447
736
|
|
737
|
+
/*******************************************************************
|
738
|
+
* Functions using the dot products between codebook entries
|
739
|
+
*******************************************************************/
|
740
|
+
|
741
|
+
void ResidualQuantizer::compute_codebook_tables() {
|
742
|
+
codebook_cross_products.resize(total_codebook_size * total_codebook_size);
|
743
|
+
cent_norms.resize(total_codebook_size);
|
744
|
+
// stricly speaking we could use ssyrk
|
745
|
+
{
|
746
|
+
FINTEGER ni = total_codebook_size;
|
747
|
+
FINTEGER di = d;
|
748
|
+
float zero = 0, one = 1;
|
749
|
+
sgemm_("Transposed",
|
750
|
+
"Not transposed",
|
751
|
+
&ni,
|
752
|
+
&ni,
|
753
|
+
&di,
|
754
|
+
&one,
|
755
|
+
codebooks.data(),
|
756
|
+
&di,
|
757
|
+
codebooks.data(),
|
758
|
+
&di,
|
759
|
+
&zero,
|
760
|
+
codebook_cross_products.data(),
|
761
|
+
&ni);
|
762
|
+
}
|
763
|
+
for (size_t i = 0; i < total_codebook_size; i++) {
|
764
|
+
cent_norms[i] = codebook_cross_products[i + i * total_codebook_size];
|
765
|
+
}
|
766
|
+
}
|
767
|
+
|
768
|
+
void beam_search_encode_step_tab(
|
769
|
+
size_t K,
|
770
|
+
size_t n,
|
771
|
+
size_t beam_size, // input sizes
|
772
|
+
const float* codebook_cross_norms, // size K * ldc
|
773
|
+
size_t ldc, // >= K
|
774
|
+
const uint64_t* codebook_offsets, // m
|
775
|
+
const float* query_cp, // size n * ldqc
|
776
|
+
size_t ldqc, // >= K
|
777
|
+
const float* cent_norms_i, // size K
|
778
|
+
size_t m,
|
779
|
+
const int32_t* codes, // n * beam_size * m
|
780
|
+
const float* distances, // n * beam_size
|
781
|
+
size_t new_beam_size,
|
782
|
+
int32_t* new_codes, // n * new_beam_size * (m + 1)
|
783
|
+
float* new_distances) // n * new_beam_size
|
784
|
+
{
|
785
|
+
FAISS_THROW_IF_NOT(ldc >= K);
|
786
|
+
|
787
|
+
#pragma omp parallel for if (n > 100)
|
788
|
+
for (int64_t i = 0; i < n; i++) {
|
789
|
+
std::vector<float> cent_distances(beam_size * K);
|
790
|
+
std::vector<float> cd_common(K);
|
791
|
+
|
792
|
+
const int32_t* codes_i = codes + i * m * beam_size;
|
793
|
+
const float* query_cp_i = query_cp + i * ldqc;
|
794
|
+
const float* distances_i = distances + i * beam_size;
|
795
|
+
|
796
|
+
for (size_t k = 0; k < K; k++) {
|
797
|
+
cd_common[k] = cent_norms_i[k] - 2 * query_cp_i[k];
|
798
|
+
}
|
799
|
+
|
800
|
+
for (size_t b = 0; b < beam_size; b++) {
|
801
|
+
std::vector<float> dp(K);
|
802
|
+
|
803
|
+
for (size_t m1 = 0; m1 < m; m1++) {
|
804
|
+
size_t c = codes_i[b * m + m1];
|
805
|
+
const float* cb =
|
806
|
+
&codebook_cross_norms[(codebook_offsets[m1] + c) * ldc];
|
807
|
+
fvec_add(K, cb, dp.data(), dp.data());
|
808
|
+
}
|
809
|
+
|
810
|
+
for (size_t k = 0; k < K; k++) {
|
811
|
+
cent_distances[b * K + k] =
|
812
|
+
distances_i[b] + cd_common[k] + 2 * dp[k];
|
813
|
+
}
|
814
|
+
}
|
815
|
+
|
816
|
+
using C = CMax<float, int>;
|
817
|
+
int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size;
|
818
|
+
float* new_distances_i = new_distances + i * new_beam_size;
|
819
|
+
|
820
|
+
const float* cent_distances_i = cent_distances.data();
|
821
|
+
|
822
|
+
// then we have to select the best results
|
823
|
+
for (int i = 0; i < new_beam_size; i++) {
|
824
|
+
new_distances_i[i] = C::neutral();
|
825
|
+
}
|
826
|
+
std::vector<int> perm(new_beam_size, -1);
|
827
|
+
heap_addn<C>(
|
828
|
+
new_beam_size,
|
829
|
+
new_distances_i,
|
830
|
+
perm.data(),
|
831
|
+
cent_distances_i,
|
832
|
+
nullptr,
|
833
|
+
beam_size * K);
|
834
|
+
heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
|
835
|
+
|
836
|
+
for (int j = 0; j < new_beam_size; j++) {
|
837
|
+
int js = perm[j] / K;
|
838
|
+
int ls = perm[j] % K;
|
839
|
+
if (m > 0) {
|
840
|
+
memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m);
|
841
|
+
}
|
842
|
+
new_codes_i[m] = ls;
|
843
|
+
new_codes_i += m + 1;
|
844
|
+
}
|
845
|
+
}
|
846
|
+
}
|
847
|
+
|
848
|
+
void ResidualQuantizer::refine_beam_LUT(
|
849
|
+
size_t n,
|
850
|
+
const float* query_norms, // size n
|
851
|
+
const float* query_cp, //
|
852
|
+
int out_beam_size,
|
853
|
+
int32_t* out_codes,
|
854
|
+
float* out_distances) const {
|
855
|
+
int beam_size = 1;
|
856
|
+
|
857
|
+
std::vector<int32_t> codes;
|
858
|
+
std::vector<float> distances(query_norms, query_norms + n);
|
859
|
+
double t0 = getmillisecs();
|
860
|
+
|
861
|
+
for (int m = 0; m < M; m++) {
|
862
|
+
int K = 1 << nbits[m];
|
863
|
+
|
864
|
+
int new_beam_size = std::min(beam_size * K, out_beam_size);
|
865
|
+
std::vector<int32_t> new_codes(n * new_beam_size * (m + 1));
|
866
|
+
std::vector<float> new_distances(n * new_beam_size);
|
867
|
+
|
868
|
+
beam_search_encode_step_tab(
|
869
|
+
K,
|
870
|
+
n,
|
871
|
+
beam_size,
|
872
|
+
codebook_cross_products.data() + codebook_offsets[m],
|
873
|
+
total_codebook_size,
|
874
|
+
codebook_offsets.data(),
|
875
|
+
query_cp + codebook_offsets[m],
|
876
|
+
total_codebook_size,
|
877
|
+
cent_norms.data() + codebook_offsets[m],
|
878
|
+
m,
|
879
|
+
codes.data(),
|
880
|
+
distances.data(),
|
881
|
+
new_beam_size,
|
882
|
+
new_codes.data(),
|
883
|
+
new_distances.data());
|
884
|
+
|
885
|
+
codes.swap(new_codes);
|
886
|
+
distances.swap(new_distances);
|
887
|
+
beam_size = new_beam_size;
|
888
|
+
|
889
|
+
if (verbose) {
|
890
|
+
float sum_distances = 0;
|
891
|
+
for (int j = 0; j < distances.size(); j++) {
|
892
|
+
sum_distances += distances[j];
|
893
|
+
}
|
894
|
+
printf("[%.3f s] encode stage %d, %d bits, "
|
895
|
+
"total error %g, beam_size %d\n",
|
896
|
+
(getmillisecs() - t0) / 1000,
|
897
|
+
m,
|
898
|
+
int(nbits[m]),
|
899
|
+
sum_distances,
|
900
|
+
beam_size);
|
901
|
+
}
|
902
|
+
}
|
903
|
+
|
904
|
+
if (out_codes) {
|
905
|
+
memcpy(out_codes, codes.data(), codes.size() * sizeof(codes[0]));
|
906
|
+
}
|
907
|
+
if (out_distances) {
|
908
|
+
memcpy(out_distances,
|
909
|
+
distances.data(),
|
910
|
+
distances.size() * sizeof(distances[0]));
|
911
|
+
}
|
912
|
+
}
|
913
|
+
|
448
914
|
} // namespace faiss
|