faiss 0.2.4 → 0.2.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +23 -21
- data/ext/faiss/extconf.rb +11 -0
- data/ext/faiss/index.cpp +4 -4
- data/ext/faiss/index_binary.cpp +6 -6
- data/ext/faiss/product_quantizer.cpp +4 -4
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +13 -0
- data/vendor/faiss/faiss/IVFlib.cpp +101 -2
- data/vendor/faiss/faiss/IVFlib.h +26 -2
- data/vendor/faiss/faiss/Index.cpp +36 -3
- data/vendor/faiss/faiss/Index.h +43 -6
- data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
- data/vendor/faiss/faiss/Index2Layer.h +6 -1
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
- data/vendor/faiss/faiss/IndexBinary.h +18 -3
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
- data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
- data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
- data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
- data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
- data/vendor/faiss/faiss/IndexFastScan.h +145 -0
- data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
- data/vendor/faiss/faiss/IndexFlat.h +7 -4
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
- data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
- data/vendor/faiss/faiss/IndexHNSW.h +4 -2
- data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
- data/vendor/faiss/faiss/IndexIDMap.h +107 -0
- data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
- data/vendor/faiss/faiss/IndexIVF.h +35 -16
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
- data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
- data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
- data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
- data/vendor/faiss/faiss/IndexLSH.h +2 -1
- data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
- data/vendor/faiss/faiss/IndexLattice.h +3 -1
- data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
- data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
- data/vendor/faiss/faiss/IndexNSG.h +25 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
- data/vendor/faiss/faiss/IndexPQ.h +19 -5
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
- data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
- data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
- data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
- data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
- data/vendor/faiss/faiss/IndexRefine.h +4 -2
- data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
- data/vendor/faiss/faiss/IndexReplicas.h +2 -1
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
- data/vendor/faiss/faiss/IndexShards.cpp +4 -1
- data/vendor/faiss/faiss/IndexShards.h +2 -1
- data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
- data/vendor/faiss/faiss/MetaIndexes.h +3 -81
- data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
- data/vendor/faiss/faiss/VectorTransform.h +22 -4
- data/vendor/faiss/faiss/clone_index.cpp +23 -1
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
- data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
- data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
- data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
- data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
- data/vendor/faiss/faiss/impl/HNSW.h +19 -16
- data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
- data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
- data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
- data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
- data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
- data/vendor/faiss/faiss/index_factory.cpp +196 -7
- data/vendor/faiss/faiss/index_io.h +5 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
- data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
- data/vendor/faiss/faiss/utils/Heap.h +31 -15
- data/vendor/faiss/faiss/utils/distances.cpp +380 -56
- data/vendor/faiss/faiss/utils/distances.h +113 -15
- data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
- data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
- data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
- data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
- data/vendor/faiss/faiss/utils/fp16.h +11 -0
- data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
- data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
- data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
- data/vendor/faiss/faiss/utils/random.cpp +53 -0
- data/vendor/faiss/faiss/utils/random.h +5 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
- data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
- metadata +37 -3
|
@@ -5,21 +5,15 @@
|
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
|
6
6
|
*/
|
|
7
7
|
|
|
8
|
-
// -*- c++ -*-
|
|
9
|
-
|
|
10
8
|
#include <faiss/impl/ResidualQuantizer.h>
|
|
11
9
|
|
|
12
10
|
#include <algorithm>
|
|
11
|
+
#include <cmath>
|
|
13
12
|
#include <cstddef>
|
|
14
13
|
#include <cstdio>
|
|
15
14
|
#include <cstring>
|
|
16
15
|
#include <memory>
|
|
17
16
|
|
|
18
|
-
#include <faiss/impl/FaissAssert.h>
|
|
19
|
-
#include <faiss/impl/ResidualQuantizer.h>
|
|
20
|
-
#include <faiss/utils/utils.h>
|
|
21
|
-
|
|
22
|
-
#include <faiss/Clustering.h>
|
|
23
17
|
#include <faiss/IndexFlat.h>
|
|
24
18
|
#include <faiss/VectorTransform.h>
|
|
25
19
|
#include <faiss/impl/AuxIndexStructures.h>
|
|
@@ -27,7 +21,6 @@
|
|
|
27
21
|
#include <faiss/utils/Heap.h>
|
|
28
22
|
#include <faiss/utils/distances.h>
|
|
29
23
|
#include <faiss/utils/hamming.h>
|
|
30
|
-
#include <faiss/utils/simdlib.h>
|
|
31
24
|
#include <faiss/utils/utils.h>
|
|
32
25
|
|
|
33
26
|
extern "C" {
|
|
@@ -47,15 +40,34 @@ int sgemm_(
|
|
|
47
40
|
float* beta,
|
|
48
41
|
float* c,
|
|
49
42
|
FINTEGER* ldc);
|
|
43
|
+
|
|
44
|
+
// http://www.netlib.org/clapack/old/single/sgels.c
|
|
45
|
+
// solve least squares
|
|
46
|
+
|
|
47
|
+
int sgelsd_(
|
|
48
|
+
FINTEGER* m,
|
|
49
|
+
FINTEGER* n,
|
|
50
|
+
FINTEGER* nrhs,
|
|
51
|
+
float* a,
|
|
52
|
+
FINTEGER* lda,
|
|
53
|
+
float* b,
|
|
54
|
+
FINTEGER* ldb,
|
|
55
|
+
float* s,
|
|
56
|
+
float* rcond,
|
|
57
|
+
FINTEGER* rank,
|
|
58
|
+
float* work,
|
|
59
|
+
FINTEGER* lwork,
|
|
60
|
+
FINTEGER* iwork,
|
|
61
|
+
FINTEGER* info);
|
|
50
62
|
}
|
|
51
63
|
|
|
52
64
|
namespace faiss {
|
|
53
65
|
|
|
54
66
|
ResidualQuantizer::ResidualQuantizer()
|
|
55
67
|
: train_type(Train_progressive_dim),
|
|
68
|
+
niter_codebook_refine(5),
|
|
56
69
|
max_beam_size(5),
|
|
57
70
|
use_beam_LUT(0),
|
|
58
|
-
max_mem_distances(5 * (size_t(1) << 30)), // 5 GiB
|
|
59
71
|
assign_index_factory(nullptr) {
|
|
60
72
|
d = 0;
|
|
61
73
|
M = 0;
|
|
@@ -81,6 +93,39 @@ ResidualQuantizer::ResidualQuantizer(
|
|
|
81
93
|
Search_type_t search_type)
|
|
82
94
|
: ResidualQuantizer(d, std::vector<size_t>(M, nbits), search_type) {}
|
|
83
95
|
|
|
96
|
+
void ResidualQuantizer::initialize_from(
|
|
97
|
+
const ResidualQuantizer& other,
|
|
98
|
+
int skip_M) {
|
|
99
|
+
FAISS_THROW_IF_NOT(M + skip_M <= other.M);
|
|
100
|
+
FAISS_THROW_IF_NOT(skip_M >= 0);
|
|
101
|
+
|
|
102
|
+
Search_type_t this_search_type = search_type;
|
|
103
|
+
int this_M = M;
|
|
104
|
+
|
|
105
|
+
// a first good approximation: override everything
|
|
106
|
+
*this = other;
|
|
107
|
+
|
|
108
|
+
// adjust derived values
|
|
109
|
+
M = this_M;
|
|
110
|
+
search_type = this_search_type;
|
|
111
|
+
nbits.resize(M);
|
|
112
|
+
memcpy(nbits.data(),
|
|
113
|
+
other.nbits.data() + skip_M,
|
|
114
|
+
nbits.size() * sizeof(nbits[0]));
|
|
115
|
+
|
|
116
|
+
set_derived_values();
|
|
117
|
+
|
|
118
|
+
// resize codebooks if trained
|
|
119
|
+
if (codebooks.size() > 0) {
|
|
120
|
+
FAISS_THROW_IF_NOT(codebooks.size() == other.total_codebook_size * d);
|
|
121
|
+
codebooks.resize(total_codebook_size * d);
|
|
122
|
+
memcpy(codebooks.data(),
|
|
123
|
+
other.codebooks.data() + other.codebook_offsets[skip_M] * d,
|
|
124
|
+
codebooks.size() * sizeof(codebooks[0]));
|
|
125
|
+
// TODO: norm_tabs?
|
|
126
|
+
}
|
|
127
|
+
}
|
|
128
|
+
|
|
84
129
|
void beam_search_encode_step(
|
|
85
130
|
size_t d,
|
|
86
131
|
size_t K,
|
|
@@ -245,8 +290,6 @@ void ResidualQuantizer::train(size_t n, const float* x) {
|
|
|
245
290
|
}
|
|
246
291
|
train_residuals = residuals1;
|
|
247
292
|
}
|
|
248
|
-
train_type_t tt = train_type_t(train_type & 1023);
|
|
249
|
-
|
|
250
293
|
std::vector<float> codebooks;
|
|
251
294
|
float obj = 0;
|
|
252
295
|
|
|
@@ -259,7 +302,7 @@ void ResidualQuantizer::train(size_t n, const float* x) {
|
|
|
259
302
|
|
|
260
303
|
double t1 = getmillisecs();
|
|
261
304
|
|
|
262
|
-
if (
|
|
305
|
+
if (!(train_type & Train_progressive_dim)) { // regular kmeans
|
|
263
306
|
Clustering clus(d, K, cp);
|
|
264
307
|
clus.train(
|
|
265
308
|
train_residuals.size() / d,
|
|
@@ -268,7 +311,7 @@ void ResidualQuantizer::train(size_t n, const float* x) {
|
|
|
268
311
|
codebooks.swap(clus.centroids);
|
|
269
312
|
assign_index->reset();
|
|
270
313
|
obj = clus.iteration_stats.back().obj;
|
|
271
|
-
} else
|
|
314
|
+
} else { // progressive dim clustering
|
|
272
315
|
ProgressiveDimClustering clus(d, K, cp);
|
|
273
316
|
ProgressiveDimIndexFactory default_fac;
|
|
274
317
|
clus.train(
|
|
@@ -277,8 +320,6 @@ void ResidualQuantizer::train(size_t n, const float* x) {
|
|
|
277
320
|
assign_index_factory ? *assign_index_factory : default_fac);
|
|
278
321
|
codebooks.swap(clus.centroids);
|
|
279
322
|
obj = clus.iteration_stats.back().obj;
|
|
280
|
-
} else {
|
|
281
|
-
FAISS_THROW_MSG("train type not supported");
|
|
282
323
|
}
|
|
283
324
|
clustering_time += (getmillisecs() - t1) / 1000;
|
|
284
325
|
|
|
@@ -350,6 +391,19 @@ void ResidualQuantizer::train(size_t n, const float* x) {
|
|
|
350
391
|
cur_beam_size = new_beam_size;
|
|
351
392
|
}
|
|
352
393
|
|
|
394
|
+
is_trained = true;
|
|
395
|
+
|
|
396
|
+
if (train_type & Train_refine_codebook) {
|
|
397
|
+
for (int iter = 0; iter < niter_codebook_refine; iter++) {
|
|
398
|
+
if (verbose) {
|
|
399
|
+
printf("re-estimating the codebooks to minimize "
|
|
400
|
+
"quantization errors (iter %d).\n",
|
|
401
|
+
iter);
|
|
402
|
+
}
|
|
403
|
+
retrain_AQ_codebook(n, x);
|
|
404
|
+
}
|
|
405
|
+
}
|
|
406
|
+
|
|
353
407
|
// find min and max norms
|
|
354
408
|
std::vector<float> norms(n);
|
|
355
409
|
|
|
@@ -359,33 +413,128 @@ void ResidualQuantizer::train(size_t n, const float* x) {
|
|
|
359
413
|
}
|
|
360
414
|
|
|
361
415
|
// fvec_norms_L2sqr(norms.data(), x, d, n);
|
|
416
|
+
train_norm(n, norms.data());
|
|
417
|
+
|
|
418
|
+
if (!(train_type & Skip_codebook_tables)) {
|
|
419
|
+
compute_codebook_tables();
|
|
420
|
+
}
|
|
421
|
+
}
|
|
422
|
+
|
|
423
|
+
float ResidualQuantizer::retrain_AQ_codebook(size_t n, const float* x) {
|
|
424
|
+
FAISS_THROW_IF_NOT_MSG(n >= total_codebook_size, "too few training points");
|
|
425
|
+
|
|
426
|
+
if (verbose) {
|
|
427
|
+
printf(" encoding %zd training vectors\n", n);
|
|
428
|
+
}
|
|
429
|
+
std::vector<uint8_t> codes(n * code_size);
|
|
430
|
+
compute_codes(x, codes.data(), n);
|
|
362
431
|
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
432
|
+
// compute reconstruction error
|
|
433
|
+
float input_recons_error;
|
|
434
|
+
{
|
|
435
|
+
std::vector<float> x_recons(n * d);
|
|
436
|
+
decode(codes.data(), x_recons.data(), n);
|
|
437
|
+
input_recons_error = fvec_L2sqr(x, x_recons.data(), n * d);
|
|
438
|
+
if (verbose) {
|
|
439
|
+
printf(" input quantization error %g\n", input_recons_error);
|
|
368
440
|
}
|
|
369
|
-
|
|
370
|
-
|
|
441
|
+
}
|
|
442
|
+
|
|
443
|
+
// build matrix of the linear system
|
|
444
|
+
std::vector<float> C(n * total_codebook_size);
|
|
445
|
+
for (size_t i = 0; i < n; i++) {
|
|
446
|
+
BitstringReader bsr(codes.data() + i * code_size, code_size);
|
|
447
|
+
for (int m = 0; m < M; m++) {
|
|
448
|
+
int idx = bsr.read(nbits[m]);
|
|
449
|
+
C[i + (codebook_offsets[m] + idx) * n] = 1;
|
|
450
|
+
}
|
|
451
|
+
}
|
|
452
|
+
|
|
453
|
+
// transpose training vectors
|
|
454
|
+
std::vector<float> xt(n * d);
|
|
455
|
+
|
|
456
|
+
for (size_t i = 0; i < n; i++) {
|
|
457
|
+
for (size_t j = 0; j < d; j++) {
|
|
458
|
+
xt[j * n + i] = x[i * d + j];
|
|
371
459
|
}
|
|
372
460
|
}
|
|
373
461
|
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
462
|
+
{ // solve least squares
|
|
463
|
+
FINTEGER lwork = -1;
|
|
464
|
+
FINTEGER di = d, ni = n, tcsi = total_codebook_size;
|
|
465
|
+
FINTEGER info = -1, rank = -1;
|
|
466
|
+
|
|
467
|
+
float rcond = 1e-4; // this is an important parameter because the code
|
|
468
|
+
// matrix can be rank deficient for small problems,
|
|
469
|
+
// the default rcond=-1 does not work
|
|
470
|
+
float worksize;
|
|
471
|
+
std::vector<float> sing_vals(total_codebook_size);
|
|
472
|
+
FINTEGER nlvl = 1000; // formula is a bit convoluted so let's take an
|
|
473
|
+
// upper bound
|
|
474
|
+
std::vector<FINTEGER> iwork(
|
|
475
|
+
3 * total_codebook_size * nlvl + 11 * total_codebook_size);
|
|
476
|
+
|
|
477
|
+
// worksize query
|
|
478
|
+
sgelsd_(&ni,
|
|
479
|
+
&tcsi,
|
|
480
|
+
&di,
|
|
481
|
+
C.data(),
|
|
482
|
+
&ni,
|
|
483
|
+
xt.data(),
|
|
484
|
+
&ni,
|
|
485
|
+
sing_vals.data(),
|
|
486
|
+
&rcond,
|
|
487
|
+
&rank,
|
|
488
|
+
&worksize,
|
|
489
|
+
&lwork,
|
|
490
|
+
iwork.data(),
|
|
491
|
+
&info);
|
|
492
|
+
FAISS_THROW_IF_NOT(info == 0);
|
|
493
|
+
|
|
494
|
+
lwork = worksize;
|
|
495
|
+
std::vector<float> work(lwork);
|
|
496
|
+
// actual call
|
|
497
|
+
sgelsd_(&ni,
|
|
498
|
+
&tcsi,
|
|
499
|
+
&di,
|
|
500
|
+
C.data(),
|
|
501
|
+
&ni,
|
|
502
|
+
xt.data(),
|
|
503
|
+
&ni,
|
|
504
|
+
sing_vals.data(),
|
|
505
|
+
&rcond,
|
|
506
|
+
&rank,
|
|
507
|
+
work.data(),
|
|
508
|
+
&lwork,
|
|
509
|
+
iwork.data(),
|
|
510
|
+
&info);
|
|
511
|
+
FAISS_THROW_IF_NOT_FMT(info == 0, "SGELS returned info=%d", int(info));
|
|
512
|
+
if (verbose) {
|
|
513
|
+
printf(" sgelsd rank=%d/%d\n",
|
|
514
|
+
int(rank),
|
|
515
|
+
int(total_codebook_size));
|
|
378
516
|
}
|
|
379
|
-
Clustering1D clus(k);
|
|
380
|
-
clus.train_exact(n, norms.data());
|
|
381
|
-
qnorm.add(clus.k, clus.centroids.data());
|
|
382
517
|
}
|
|
383
518
|
|
|
384
|
-
|
|
519
|
+
// result is in xt, re-transpose to codebook
|
|
385
520
|
|
|
386
|
-
|
|
387
|
-
|
|
521
|
+
for (size_t i = 0; i < total_codebook_size; i++) {
|
|
522
|
+
for (size_t j = 0; j < d; j++) {
|
|
523
|
+
codebooks[i * d + j] = xt[j * n + i];
|
|
524
|
+
FAISS_THROW_IF_NOT(std::isfinite(codebooks[i * d + j]));
|
|
525
|
+
}
|
|
526
|
+
}
|
|
527
|
+
|
|
528
|
+
float output_recons_error = 0;
|
|
529
|
+
for (size_t j = 0; j < d; j++) {
|
|
530
|
+
output_recons_error += fvec_norm_L2sqr(
|
|
531
|
+
xt.data() + total_codebook_size + n * j,
|
|
532
|
+
n - total_codebook_size);
|
|
388
533
|
}
|
|
534
|
+
if (verbose) {
|
|
535
|
+
printf(" output quantization error %g\n", output_recons_error);
|
|
536
|
+
}
|
|
537
|
+
return output_recons_error;
|
|
389
538
|
}
|
|
390
539
|
|
|
391
540
|
size_t ResidualQuantizer::memory_per_point(int beam_size) const {
|
|
@@ -400,10 +549,11 @@ size_t ResidualQuantizer::memory_per_point(int beam_size) const {
|
|
|
400
549
|
return mem;
|
|
401
550
|
}
|
|
402
551
|
|
|
403
|
-
void ResidualQuantizer::
|
|
552
|
+
void ResidualQuantizer::compute_codes_add_centroids(
|
|
404
553
|
const float* x,
|
|
405
554
|
uint8_t* codes_out,
|
|
406
|
-
size_t n
|
|
555
|
+
size_t n,
|
|
556
|
+
const float* centroids) const {
|
|
407
557
|
FAISS_THROW_IF_NOT_MSG(is_trained, "RQ is not trained yet.");
|
|
408
558
|
|
|
409
559
|
size_t mem = memory_per_point();
|
|
@@ -415,7 +565,12 @@ void ResidualQuantizer::compute_codes(
|
|
|
415
565
|
}
|
|
416
566
|
for (size_t i0 = 0; i0 < n; i0 += bs) {
|
|
417
567
|
size_t i1 = std::min(n, i0 + bs);
|
|
418
|
-
|
|
568
|
+
const float* cent = nullptr;
|
|
569
|
+
if (centroids != nullptr) {
|
|
570
|
+
cent = centroids + i0 * d;
|
|
571
|
+
}
|
|
572
|
+
compute_codes_add_centroids(
|
|
573
|
+
x + i0 * d, codes_out + i0 * code_size, i1 - i0, cent);
|
|
419
574
|
}
|
|
420
575
|
return;
|
|
421
576
|
}
|
|
@@ -489,7 +644,8 @@ void ResidualQuantizer::compute_codes(
|
|
|
489
644
|
codes.data(),
|
|
490
645
|
codes_out,
|
|
491
646
|
M * max_beam_size,
|
|
492
|
-
norms.size() > 0 ? norms.data() : nullptr
|
|
647
|
+
norms.size() > 0 ? norms.data() : nullptr,
|
|
648
|
+
centroids);
|
|
493
649
|
}
|
|
494
650
|
|
|
495
651
|
void ResidualQuantizer::refine_beam(
|
|
@@ -24,25 +24,31 @@ namespace faiss {
|
|
|
24
24
|
|
|
25
25
|
struct ResidualQuantizer : AdditiveQuantizer {
|
|
26
26
|
/// initialization
|
|
27
|
-
enum train_type_t {
|
|
28
|
-
Train_default = 0, ///< regular k-means
|
|
29
|
-
Train_progressive_dim = 1, ///< progressive dim clustering
|
|
30
|
-
Train_default_Train_top_beam = 1024,
|
|
31
|
-
Train_progressive_dim_Train_top_beam = 1025,
|
|
32
|
-
Train_default_Skip_codebook_tables = 2048,
|
|
33
|
-
Train_progressive_dim_Skip_codebook_tables = 2049,
|
|
34
|
-
Train_default_Train_top_beam_Skip_codebook_tables = 3072,
|
|
35
|
-
Train_progressive_dim_Train_top_beam_Skip_codebook_tables = 3073,
|
|
36
|
-
};
|
|
37
27
|
|
|
28
|
+
// Was enum but that does not work so well with bitmasks
|
|
29
|
+
using train_type_t = int;
|
|
30
|
+
|
|
31
|
+
/// Binary or of the Train_* flags below
|
|
38
32
|
train_type_t train_type;
|
|
39
33
|
|
|
40
|
-
|
|
41
|
-
|
|
34
|
+
/// regular k-means (minimal amount of computation)
|
|
35
|
+
static const int Train_default = 0;
|
|
36
|
+
|
|
37
|
+
/// progressive dim clustering (set by default)
|
|
38
|
+
static const int Train_progressive_dim = 1;
|
|
39
|
+
|
|
40
|
+
/// do a few iterations of codebook refinement after first level estimation
|
|
41
|
+
static const int Train_refine_codebook = 2;
|
|
42
|
+
|
|
43
|
+
/// number of iterations for codebook refinement.
|
|
44
|
+
int niter_codebook_refine;
|
|
45
|
+
|
|
46
|
+
/** set this bit on train_type if beam is to be trained only on the
|
|
47
|
+
* first element of the beam (faster but less accurate) */
|
|
42
48
|
static const int Train_top_beam = 1024;
|
|
43
49
|
|
|
44
|
-
|
|
45
|
-
|
|
50
|
+
/** set this bit to *not* autmatically compute the codebook tables
|
|
51
|
+
* after training */
|
|
46
52
|
static const int Skip_codebook_tables = 2048;
|
|
47
53
|
|
|
48
54
|
/// beam size used for training and for encoding
|
|
@@ -51,10 +57,6 @@ struct ResidualQuantizer : AdditiveQuantizer {
|
|
|
51
57
|
/// use LUT for beam search
|
|
52
58
|
int use_beam_LUT;
|
|
53
59
|
|
|
54
|
-
/// distance matrixes with beam search can get large, so use this
|
|
55
|
-
/// to batch computations at encoding time.
|
|
56
|
-
size_t max_mem_distances;
|
|
57
|
-
|
|
58
60
|
/// clustering parameters
|
|
59
61
|
ProgressiveDimClusteringParameters cp;
|
|
60
62
|
|
|
@@ -74,15 +76,33 @@ struct ResidualQuantizer : AdditiveQuantizer {
|
|
|
74
76
|
|
|
75
77
|
ResidualQuantizer();
|
|
76
78
|
|
|
77
|
-
|
|
79
|
+
/// Train the residual quantizer
|
|
78
80
|
void train(size_t n, const float* x) override;
|
|
79
81
|
|
|
82
|
+
/// Copy the M codebook levels from other, starting from skip_M
|
|
83
|
+
void initialize_from(const ResidualQuantizer& other, int skip_M = 0);
|
|
84
|
+
|
|
85
|
+
/** Encode the vectors and compute codebook that minimizes the quantization
|
|
86
|
+
* error on these codes
|
|
87
|
+
*
|
|
88
|
+
* @param x training vectors, size n * d
|
|
89
|
+
* @param n nb of training vectors, n >= total_codebook_size
|
|
90
|
+
* @return returns quantization error for the new codebook with old
|
|
91
|
+
* codes
|
|
92
|
+
*/
|
|
93
|
+
float retrain_AQ_codebook(size_t n, const float* x);
|
|
94
|
+
|
|
80
95
|
/** Encode a set of vectors
|
|
81
96
|
*
|
|
82
97
|
* @param x vectors to encode, size n * d
|
|
83
98
|
* @param codes output codes, size n * code_size
|
|
99
|
+
* @param centroids centroids to be added to x, size n * d
|
|
84
100
|
*/
|
|
85
|
-
void
|
|
101
|
+
void compute_codes_add_centroids(
|
|
102
|
+
const float* x,
|
|
103
|
+
uint8_t* codes,
|
|
104
|
+
size_t n,
|
|
105
|
+
const float* centroids = nullptr) const override;
|
|
86
106
|
|
|
87
107
|
/** lower-level encode function
|
|
88
108
|
*
|
|
@@ -413,4 +413,100 @@ struct RangeSearchResultHandler {
|
|
|
413
413
|
}
|
|
414
414
|
};
|
|
415
415
|
|
|
416
|
+
/*****************************************************************
|
|
417
|
+
* Single best result handler.
|
|
418
|
+
* Tracks the only best result, thus avoiding storing
|
|
419
|
+
* some temporary data in memory.
|
|
420
|
+
*****************************************************************/
|
|
421
|
+
|
|
422
|
+
template <class C>
|
|
423
|
+
struct SingleBestResultHandler {
|
|
424
|
+
using T = typename C::T;
|
|
425
|
+
using TI = typename C::TI;
|
|
426
|
+
|
|
427
|
+
int nq;
|
|
428
|
+
// contains exactly nq elements
|
|
429
|
+
T* dis_tab;
|
|
430
|
+
// contains exactly nq elements
|
|
431
|
+
TI* ids_tab;
|
|
432
|
+
|
|
433
|
+
SingleBestResultHandler(size_t nq, T* dis_tab, TI* ids_tab)
|
|
434
|
+
: nq(nq), dis_tab(dis_tab), ids_tab(ids_tab) {}
|
|
435
|
+
|
|
436
|
+
struct SingleResultHandler {
|
|
437
|
+
SingleBestResultHandler& hr;
|
|
438
|
+
|
|
439
|
+
T min_dis;
|
|
440
|
+
TI min_idx;
|
|
441
|
+
size_t current_idx = 0;
|
|
442
|
+
|
|
443
|
+
SingleResultHandler(SingleBestResultHandler& hr) : hr(hr) {}
|
|
444
|
+
|
|
445
|
+
/// begin results for query # i
|
|
446
|
+
void begin(const size_t current_idx) {
|
|
447
|
+
this->current_idx = current_idx;
|
|
448
|
+
min_dis = HUGE_VALF;
|
|
449
|
+
min_idx = 0;
|
|
450
|
+
}
|
|
451
|
+
|
|
452
|
+
/// add one result for query i
|
|
453
|
+
void add_result(T dis, TI idx) {
|
|
454
|
+
if (C::cmp(min_dis, dis)) {
|
|
455
|
+
min_dis = dis;
|
|
456
|
+
min_idx = idx;
|
|
457
|
+
}
|
|
458
|
+
}
|
|
459
|
+
|
|
460
|
+
/// series of results for query i is done
|
|
461
|
+
void end() {
|
|
462
|
+
hr.dis_tab[current_idx] = min_dis;
|
|
463
|
+
hr.ids_tab[current_idx] = min_idx;
|
|
464
|
+
}
|
|
465
|
+
};
|
|
466
|
+
|
|
467
|
+
size_t i0, i1;
|
|
468
|
+
|
|
469
|
+
/// begin
|
|
470
|
+
void begin_multiple(size_t i0, size_t i1) {
|
|
471
|
+
this->i0 = i0;
|
|
472
|
+
this->i1 = i1;
|
|
473
|
+
|
|
474
|
+
for (size_t i = i0; i < i1; i++) {
|
|
475
|
+
this->dis_tab[i] = HUGE_VALF;
|
|
476
|
+
}
|
|
477
|
+
}
|
|
478
|
+
|
|
479
|
+
/// add results for query i0..i1 and j0..j1
|
|
480
|
+
void add_results(size_t j0, size_t j1, const T* dis_tab) {
|
|
481
|
+
for (int64_t i = i0; i < i1; i++) {
|
|
482
|
+
const T* dis_tab_i = dis_tab + (j1 - j0) * (i - i0) - j0;
|
|
483
|
+
|
|
484
|
+
auto& min_distance = this->dis_tab[i];
|
|
485
|
+
auto& min_index = this->ids_tab[i];
|
|
486
|
+
|
|
487
|
+
for (size_t j = j0; j < j1; j++) {
|
|
488
|
+
const T distance = dis_tab_i[j];
|
|
489
|
+
|
|
490
|
+
if (C::cmp(min_distance, distance)) {
|
|
491
|
+
min_distance = distance;
|
|
492
|
+
min_index = j;
|
|
493
|
+
}
|
|
494
|
+
}
|
|
495
|
+
}
|
|
496
|
+
}
|
|
497
|
+
|
|
498
|
+
void add_result(const size_t i, const T dis, const TI idx) {
|
|
499
|
+
auto& min_distance = this->dis_tab[i];
|
|
500
|
+
auto& min_index = this->ids_tab[i];
|
|
501
|
+
|
|
502
|
+
if (C::cmp(min_distance, dis)) {
|
|
503
|
+
min_distance = dis;
|
|
504
|
+
min_index = idx;
|
|
505
|
+
}
|
|
506
|
+
}
|
|
507
|
+
|
|
508
|
+
/// series of results for queries i0..i1 is done
|
|
509
|
+
void end_multiple() {}
|
|
510
|
+
};
|
|
511
|
+
|
|
416
512
|
} // namespace faiss
|