faiss 0.2.3 → 0.2.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/LICENSE.txt +1 -1
- data/README.md +23 -21
- data/ext/faiss/extconf.rb +11 -0
- data/ext/faiss/index.cpp +4 -4
- data/ext/faiss/index_binary.cpp +6 -6
- data/ext/faiss/product_quantizer.cpp +4 -4
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +13 -0
- data/vendor/faiss/faiss/Clustering.cpp +32 -0
- data/vendor/faiss/faiss/Clustering.h +14 -0
- data/vendor/faiss/faiss/IVFlib.cpp +101 -2
- data/vendor/faiss/faiss/IVFlib.h +26 -2
- data/vendor/faiss/faiss/Index.cpp +36 -3
- data/vendor/faiss/faiss/Index.h +43 -6
- data/vendor/faiss/faiss/Index2Layer.cpp +24 -93
- data/vendor/faiss/faiss/Index2Layer.h +8 -17
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +610 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +253 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
- data/vendor/faiss/faiss/IndexBinary.h +18 -3
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
- data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
- data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
- data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
- data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
- data/vendor/faiss/faiss/IndexFastScan.h +145 -0
- data/vendor/faiss/faiss/IndexFlat.cpp +52 -69
- data/vendor/faiss/faiss/IndexFlat.h +16 -19
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +101 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +59 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
- data/vendor/faiss/faiss/IndexHNSW.h +4 -2
- data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
- data/vendor/faiss/faiss/IndexIDMap.h +107 -0
- data/vendor/faiss/faiss/IndexIVF.cpp +200 -40
- data/vendor/faiss/faiss/IndexIVF.h +59 -22
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +393 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +183 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +43 -26
- data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +238 -53
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -2
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +63 -40
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +23 -7
- data/vendor/faiss/faiss/IndexLSH.cpp +8 -32
- data/vendor/faiss/faiss/IndexLSH.h +4 -16
- data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
- data/vendor/faiss/faiss/IndexLattice.h +3 -1
- data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -5
- data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +37 -5
- data/vendor/faiss/faiss/IndexNSG.h +25 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +108 -120
- data/vendor/faiss/faiss/IndexPQ.h +21 -22
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
- data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
- data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
- data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
- data/vendor/faiss/faiss/IndexRefine.cpp +36 -4
- data/vendor/faiss/faiss/IndexRefine.h +14 -2
- data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
- data/vendor/faiss/faiss/IndexReplicas.h +2 -1
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +28 -43
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +8 -23
- data/vendor/faiss/faiss/IndexShards.cpp +4 -1
- data/vendor/faiss/faiss/IndexShards.h +2 -1
- data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
- data/vendor/faiss/faiss/MetaIndexes.h +3 -81
- data/vendor/faiss/faiss/VectorTransform.cpp +45 -1
- data/vendor/faiss/faiss/VectorTransform.h +25 -4
- data/vendor/faiss/faiss/clone_index.cpp +26 -3
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
- data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -6
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
- data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
- data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +331 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +110 -19
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
- data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +133 -32
- data/vendor/faiss/faiss/impl/HNSW.h +19 -16
- data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
- data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +378 -217
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +106 -29
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -4
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
- data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +521 -55
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +94 -16
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +108 -191
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
- data/vendor/faiss/faiss/impl/index_read.cpp +338 -24
- data/vendor/faiss/faiss/impl/index_write.cpp +300 -18
- data/vendor/faiss/faiss/impl/io.cpp +1 -1
- data/vendor/faiss/faiss/impl/io_macros.h +20 -0
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +303 -0
- data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
- data/vendor/faiss/faiss/index_factory.cpp +772 -412
- data/vendor/faiss/faiss/index_factory.h +3 -0
- data/vendor/faiss/faiss/index_io.h +5 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
- data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
- data/vendor/faiss/faiss/utils/Heap.h +31 -15
- data/vendor/faiss/faiss/utils/distances.cpp +384 -58
- data/vendor/faiss/faiss/utils/distances.h +149 -18
- data/vendor/faiss/faiss/utils/distances_simd.cpp +776 -6
- data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
- data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
- data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
- data/vendor/faiss/faiss/utils/fp16.h +11 -0
- data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
- data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
- data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
- data/vendor/faiss/faiss/utils/random.cpp +53 -0
- data/vendor/faiss/faiss/utils/random.h +5 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
- data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
- data/vendor/faiss/faiss/utils/utils.h +1 -1
- metadata +46 -5
- data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
- data/vendor/faiss/faiss/IndexResidual.h +0 -152
|
@@ -5,20 +5,15 @@
|
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
|
6
6
|
*/
|
|
7
7
|
|
|
8
|
-
// -*- c++ -*-
|
|
9
|
-
|
|
10
|
-
#include "faiss/impl/ResidualQuantizer.h"
|
|
11
|
-
#include <faiss/impl/FaissAssert.h>
|
|
12
8
|
#include <faiss/impl/ResidualQuantizer.h>
|
|
13
|
-
#include "faiss/utils/utils.h"
|
|
14
9
|
|
|
10
|
+
#include <algorithm>
|
|
11
|
+
#include <cmath>
|
|
15
12
|
#include <cstddef>
|
|
16
13
|
#include <cstdio>
|
|
17
14
|
#include <cstring>
|
|
18
15
|
#include <memory>
|
|
19
16
|
|
|
20
|
-
#include <algorithm>
|
|
21
|
-
|
|
22
17
|
#include <faiss/IndexFlat.h>
|
|
23
18
|
#include <faiss/VectorTransform.h>
|
|
24
19
|
#include <faiss/impl/AuxIndexStructures.h>
|
|
@@ -28,39 +23,109 @@
|
|
|
28
23
|
#include <faiss/utils/hamming.h>
|
|
29
24
|
#include <faiss/utils/utils.h>
|
|
30
25
|
|
|
26
|
+
extern "C" {
|
|
27
|
+
|
|
28
|
+
// general matrix multiplication
|
|
29
|
+
int sgemm_(
|
|
30
|
+
const char* transa,
|
|
31
|
+
const char* transb,
|
|
32
|
+
FINTEGER* m,
|
|
33
|
+
FINTEGER* n,
|
|
34
|
+
FINTEGER* k,
|
|
35
|
+
const float* alpha,
|
|
36
|
+
const float* a,
|
|
37
|
+
FINTEGER* lda,
|
|
38
|
+
const float* b,
|
|
39
|
+
FINTEGER* ldb,
|
|
40
|
+
float* beta,
|
|
41
|
+
float* c,
|
|
42
|
+
FINTEGER* ldc);
|
|
43
|
+
|
|
44
|
+
// http://www.netlib.org/clapack/old/single/sgels.c
|
|
45
|
+
// solve least squares
|
|
46
|
+
|
|
47
|
+
int sgelsd_(
|
|
48
|
+
FINTEGER* m,
|
|
49
|
+
FINTEGER* n,
|
|
50
|
+
FINTEGER* nrhs,
|
|
51
|
+
float* a,
|
|
52
|
+
FINTEGER* lda,
|
|
53
|
+
float* b,
|
|
54
|
+
FINTEGER* ldb,
|
|
55
|
+
float* s,
|
|
56
|
+
float* rcond,
|
|
57
|
+
FINTEGER* rank,
|
|
58
|
+
float* work,
|
|
59
|
+
FINTEGER* lwork,
|
|
60
|
+
FINTEGER* iwork,
|
|
61
|
+
FINTEGER* info);
|
|
62
|
+
}
|
|
63
|
+
|
|
31
64
|
namespace faiss {
|
|
32
65
|
|
|
33
66
|
ResidualQuantizer::ResidualQuantizer()
|
|
34
67
|
: train_type(Train_progressive_dim),
|
|
35
|
-
|
|
36
|
-
|
|
68
|
+
niter_codebook_refine(5),
|
|
69
|
+
max_beam_size(5),
|
|
70
|
+
use_beam_LUT(0),
|
|
37
71
|
assign_index_factory(nullptr) {
|
|
38
72
|
d = 0;
|
|
39
73
|
M = 0;
|
|
40
74
|
verbose = false;
|
|
41
75
|
}
|
|
42
76
|
|
|
43
|
-
ResidualQuantizer::ResidualQuantizer(
|
|
77
|
+
ResidualQuantizer::ResidualQuantizer(
|
|
78
|
+
size_t d,
|
|
79
|
+
const std::vector<size_t>& nbits,
|
|
80
|
+
Search_type_t search_type)
|
|
44
81
|
: ResidualQuantizer() {
|
|
82
|
+
this->search_type = search_type;
|
|
45
83
|
this->d = d;
|
|
46
84
|
M = nbits.size();
|
|
47
85
|
this->nbits = nbits;
|
|
48
86
|
set_derived_values();
|
|
49
87
|
}
|
|
50
88
|
|
|
51
|
-
ResidualQuantizer::ResidualQuantizer(
|
|
52
|
-
|
|
89
|
+
ResidualQuantizer::ResidualQuantizer(
|
|
90
|
+
size_t d,
|
|
91
|
+
size_t M,
|
|
92
|
+
size_t nbits,
|
|
93
|
+
Search_type_t search_type)
|
|
94
|
+
: ResidualQuantizer(d, std::vector<size_t>(M, nbits), search_type) {}
|
|
95
|
+
|
|
96
|
+
void ResidualQuantizer::initialize_from(
|
|
97
|
+
const ResidualQuantizer& other,
|
|
98
|
+
int skip_M) {
|
|
99
|
+
FAISS_THROW_IF_NOT(M + skip_M <= other.M);
|
|
100
|
+
FAISS_THROW_IF_NOT(skip_M >= 0);
|
|
101
|
+
|
|
102
|
+
Search_type_t this_search_type = search_type;
|
|
103
|
+
int this_M = M;
|
|
104
|
+
|
|
105
|
+
// a first good approximation: override everything
|
|
106
|
+
*this = other;
|
|
107
|
+
|
|
108
|
+
// adjust derived values
|
|
109
|
+
M = this_M;
|
|
110
|
+
search_type = this_search_type;
|
|
111
|
+
nbits.resize(M);
|
|
112
|
+
memcpy(nbits.data(),
|
|
113
|
+
other.nbits.data() + skip_M,
|
|
114
|
+
nbits.size() * sizeof(nbits[0]));
|
|
53
115
|
|
|
54
|
-
|
|
116
|
+
set_derived_values();
|
|
55
117
|
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
118
|
+
// resize codebooks if trained
|
|
119
|
+
if (codebooks.size() > 0) {
|
|
120
|
+
FAISS_THROW_IF_NOT(codebooks.size() == other.total_codebook_size * d);
|
|
121
|
+
codebooks.resize(total_codebook_size * d);
|
|
122
|
+
memcpy(codebooks.data(),
|
|
123
|
+
other.codebooks.data() + other.codebook_offsets[skip_M] * d,
|
|
124
|
+
codebooks.size() * sizeof(codebooks[0]));
|
|
125
|
+
// TODO: norm_tabs?
|
|
59
126
|
}
|
|
60
127
|
}
|
|
61
128
|
|
|
62
|
-
} // anonymous namespace
|
|
63
|
-
|
|
64
129
|
void beam_search_encode_step(
|
|
65
130
|
size_t d,
|
|
66
131
|
size_t K,
|
|
@@ -90,7 +155,7 @@ void beam_search_encode_step(
|
|
|
90
155
|
cent_ids.resize(n * beam_size * new_beam_size);
|
|
91
156
|
if (assign_index->ntotal != 0) {
|
|
92
157
|
// then we assume the codebooks are already added to the index
|
|
93
|
-
FAISS_THROW_IF_NOT(assign_index->ntotal
|
|
158
|
+
FAISS_THROW_IF_NOT(assign_index->ntotal == K);
|
|
94
159
|
} else {
|
|
95
160
|
assign_index->add(K, cent);
|
|
96
161
|
}
|
|
@@ -208,6 +273,7 @@ void ResidualQuantizer::train(size_t n, const float* x) {
|
|
|
208
273
|
std::vector<int32_t> codes;
|
|
209
274
|
std::vector<float> distances;
|
|
210
275
|
double t0 = getmillisecs();
|
|
276
|
+
double clustering_time = 0;
|
|
211
277
|
|
|
212
278
|
for (int m = 0; m < M; m++) {
|
|
213
279
|
int K = 1 << nbits[m];
|
|
@@ -224,8 +290,6 @@ void ResidualQuantizer::train(size_t n, const float* x) {
|
|
|
224
290
|
}
|
|
225
291
|
train_residuals = residuals1;
|
|
226
292
|
}
|
|
227
|
-
train_type_t tt = train_type_t(train_type & ~Train_top_beam);
|
|
228
|
-
|
|
229
293
|
std::vector<float> codebooks;
|
|
230
294
|
float obj = 0;
|
|
231
295
|
|
|
@@ -235,7 +299,10 @@ void ResidualQuantizer::train(size_t n, const float* x) {
|
|
|
235
299
|
} else {
|
|
236
300
|
assign_index.reset(new IndexFlatL2(d));
|
|
237
301
|
}
|
|
238
|
-
|
|
302
|
+
|
|
303
|
+
double t1 = getmillisecs();
|
|
304
|
+
|
|
305
|
+
if (!(train_type & Train_progressive_dim)) { // regular kmeans
|
|
239
306
|
Clustering clus(d, K, cp);
|
|
240
307
|
clus.train(
|
|
241
308
|
train_residuals.size() / d,
|
|
@@ -244,7 +311,7 @@ void ResidualQuantizer::train(size_t n, const float* x) {
|
|
|
244
311
|
codebooks.swap(clus.centroids);
|
|
245
312
|
assign_index->reset();
|
|
246
313
|
obj = clus.iteration_stats.back().obj;
|
|
247
|
-
} else
|
|
314
|
+
} else { // progressive dim clustering
|
|
248
315
|
ProgressiveDimClustering clus(d, K, cp);
|
|
249
316
|
ProgressiveDimIndexFactory default_fac;
|
|
250
317
|
clus.train(
|
|
@@ -253,9 +320,8 @@ void ResidualQuantizer::train(size_t n, const float* x) {
|
|
|
253
320
|
assign_index_factory ? *assign_index_factory : default_fac);
|
|
254
321
|
codebooks.swap(clus.centroids);
|
|
255
322
|
obj = clus.iteration_stats.back().obj;
|
|
256
|
-
} else {
|
|
257
|
-
FAISS_THROW_MSG("train type not supported");
|
|
258
323
|
}
|
|
324
|
+
clustering_time += (getmillisecs() - t1) / 1000;
|
|
259
325
|
|
|
260
326
|
memcpy(this->codebooks.data() + codebook_offsets[m] * d,
|
|
261
327
|
codebooks.data(),
|
|
@@ -268,21 +334,38 @@ void ResidualQuantizer::train(size_t n, const float* x) {
|
|
|
268
334
|
std::vector<float> new_residuals(n * new_beam_size * d);
|
|
269
335
|
std::vector<float> new_distances(n * new_beam_size);
|
|
270
336
|
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
new_codes.data(),
|
|
282
|
-
new_residuals.data(),
|
|
283
|
-
new_distances.data(),
|
|
284
|
-
assign_index.get());
|
|
337
|
+
size_t bs;
|
|
338
|
+
{ // determine batch size
|
|
339
|
+
size_t mem = memory_per_point();
|
|
340
|
+
if (n > 1 && mem * n > max_mem_distances) {
|
|
341
|
+
// then split queries to reduce temp memory
|
|
342
|
+
bs = std::max(max_mem_distances / mem, size_t(1));
|
|
343
|
+
} else {
|
|
344
|
+
bs = n;
|
|
345
|
+
}
|
|
346
|
+
}
|
|
285
347
|
|
|
348
|
+
for (size_t i0 = 0; i0 < n; i0 += bs) {
|
|
349
|
+
size_t i1 = std::min(i0 + bs, n);
|
|
350
|
+
|
|
351
|
+
/* printf("i0: %ld i1: %ld K %d ntotal assign index %ld\n",
|
|
352
|
+
i0, i1, K, assign_index->ntotal); */
|
|
353
|
+
|
|
354
|
+
beam_search_encode_step(
|
|
355
|
+
d,
|
|
356
|
+
K,
|
|
357
|
+
codebooks.data(),
|
|
358
|
+
i1 - i0,
|
|
359
|
+
cur_beam_size,
|
|
360
|
+
residuals.data() + i0 * cur_beam_size * d,
|
|
361
|
+
m,
|
|
362
|
+
codes.data() + i0 * cur_beam_size * m,
|
|
363
|
+
new_beam_size,
|
|
364
|
+
new_codes.data() + i0 * new_beam_size * (m + 1),
|
|
365
|
+
new_residuals.data() + i0 * new_beam_size * d,
|
|
366
|
+
new_distances.data() + i0 * new_beam_size,
|
|
367
|
+
assign_index.get());
|
|
368
|
+
}
|
|
286
369
|
codes.swap(new_codes);
|
|
287
370
|
residuals.swap(new_residuals);
|
|
288
371
|
distances.swap(new_distances);
|
|
@@ -293,20 +376,165 @@ void ResidualQuantizer::train(size_t n, const float* x) {
|
|
|
293
376
|
}
|
|
294
377
|
|
|
295
378
|
if (verbose) {
|
|
296
|
-
printf("[%.3f s] train stage %d, %d bits, kmeans objective %g, "
|
|
297
|
-
"total distance %g, beam_size %d->%d\n",
|
|
379
|
+
printf("[%.3f s, %.3f s clustering] train stage %d, %d bits, kmeans objective %g, "
|
|
380
|
+
"total distance %g, beam_size %d->%d (batch size %zd)\n",
|
|
298
381
|
(getmillisecs() - t0) / 1000,
|
|
382
|
+
clustering_time,
|
|
299
383
|
m,
|
|
300
384
|
int(nbits[m]),
|
|
301
385
|
obj,
|
|
302
386
|
sum_distances,
|
|
303
387
|
cur_beam_size,
|
|
304
|
-
new_beam_size
|
|
388
|
+
new_beam_size,
|
|
389
|
+
bs);
|
|
305
390
|
}
|
|
306
391
|
cur_beam_size = new_beam_size;
|
|
307
392
|
}
|
|
308
393
|
|
|
309
394
|
is_trained = true;
|
|
395
|
+
|
|
396
|
+
if (train_type & Train_refine_codebook) {
|
|
397
|
+
for (int iter = 0; iter < niter_codebook_refine; iter++) {
|
|
398
|
+
if (verbose) {
|
|
399
|
+
printf("re-estimating the codebooks to minimize "
|
|
400
|
+
"quantization errors (iter %d).\n",
|
|
401
|
+
iter);
|
|
402
|
+
}
|
|
403
|
+
retrain_AQ_codebook(n, x);
|
|
404
|
+
}
|
|
405
|
+
}
|
|
406
|
+
|
|
407
|
+
// find min and max norms
|
|
408
|
+
std::vector<float> norms(n);
|
|
409
|
+
|
|
410
|
+
for (size_t i = 0; i < n; i++) {
|
|
411
|
+
norms[i] = fvec_L2sqr(
|
|
412
|
+
x + i * d, residuals.data() + i * cur_beam_size * d, d);
|
|
413
|
+
}
|
|
414
|
+
|
|
415
|
+
// fvec_norms_L2sqr(norms.data(), x, d, n);
|
|
416
|
+
train_norm(n, norms.data());
|
|
417
|
+
|
|
418
|
+
if (!(train_type & Skip_codebook_tables)) {
|
|
419
|
+
compute_codebook_tables();
|
|
420
|
+
}
|
|
421
|
+
}
|
|
422
|
+
|
|
423
|
+
float ResidualQuantizer::retrain_AQ_codebook(size_t n, const float* x) {
|
|
424
|
+
FAISS_THROW_IF_NOT_MSG(n >= total_codebook_size, "too few training points");
|
|
425
|
+
|
|
426
|
+
if (verbose) {
|
|
427
|
+
printf(" encoding %zd training vectors\n", n);
|
|
428
|
+
}
|
|
429
|
+
std::vector<uint8_t> codes(n * code_size);
|
|
430
|
+
compute_codes(x, codes.data(), n);
|
|
431
|
+
|
|
432
|
+
// compute reconstruction error
|
|
433
|
+
float input_recons_error;
|
|
434
|
+
{
|
|
435
|
+
std::vector<float> x_recons(n * d);
|
|
436
|
+
decode(codes.data(), x_recons.data(), n);
|
|
437
|
+
input_recons_error = fvec_L2sqr(x, x_recons.data(), n * d);
|
|
438
|
+
if (verbose) {
|
|
439
|
+
printf(" input quantization error %g\n", input_recons_error);
|
|
440
|
+
}
|
|
441
|
+
}
|
|
442
|
+
|
|
443
|
+
// build matrix of the linear system
|
|
444
|
+
std::vector<float> C(n * total_codebook_size);
|
|
445
|
+
for (size_t i = 0; i < n; i++) {
|
|
446
|
+
BitstringReader bsr(codes.data() + i * code_size, code_size);
|
|
447
|
+
for (int m = 0; m < M; m++) {
|
|
448
|
+
int idx = bsr.read(nbits[m]);
|
|
449
|
+
C[i + (codebook_offsets[m] + idx) * n] = 1;
|
|
450
|
+
}
|
|
451
|
+
}
|
|
452
|
+
|
|
453
|
+
// transpose training vectors
|
|
454
|
+
std::vector<float> xt(n * d);
|
|
455
|
+
|
|
456
|
+
for (size_t i = 0; i < n; i++) {
|
|
457
|
+
for (size_t j = 0; j < d; j++) {
|
|
458
|
+
xt[j * n + i] = x[i * d + j];
|
|
459
|
+
}
|
|
460
|
+
}
|
|
461
|
+
|
|
462
|
+
{ // solve least squares
|
|
463
|
+
FINTEGER lwork = -1;
|
|
464
|
+
FINTEGER di = d, ni = n, tcsi = total_codebook_size;
|
|
465
|
+
FINTEGER info = -1, rank = -1;
|
|
466
|
+
|
|
467
|
+
float rcond = 1e-4; // this is an important parameter because the code
|
|
468
|
+
// matrix can be rank deficient for small problems,
|
|
469
|
+
// the default rcond=-1 does not work
|
|
470
|
+
float worksize;
|
|
471
|
+
std::vector<float> sing_vals(total_codebook_size);
|
|
472
|
+
FINTEGER nlvl = 1000; // formula is a bit convoluted so let's take an
|
|
473
|
+
// upper bound
|
|
474
|
+
std::vector<FINTEGER> iwork(
|
|
475
|
+
3 * total_codebook_size * nlvl + 11 * total_codebook_size);
|
|
476
|
+
|
|
477
|
+
// worksize query
|
|
478
|
+
sgelsd_(&ni,
|
|
479
|
+
&tcsi,
|
|
480
|
+
&di,
|
|
481
|
+
C.data(),
|
|
482
|
+
&ni,
|
|
483
|
+
xt.data(),
|
|
484
|
+
&ni,
|
|
485
|
+
sing_vals.data(),
|
|
486
|
+
&rcond,
|
|
487
|
+
&rank,
|
|
488
|
+
&worksize,
|
|
489
|
+
&lwork,
|
|
490
|
+
iwork.data(),
|
|
491
|
+
&info);
|
|
492
|
+
FAISS_THROW_IF_NOT(info == 0);
|
|
493
|
+
|
|
494
|
+
lwork = worksize;
|
|
495
|
+
std::vector<float> work(lwork);
|
|
496
|
+
// actual call
|
|
497
|
+
sgelsd_(&ni,
|
|
498
|
+
&tcsi,
|
|
499
|
+
&di,
|
|
500
|
+
C.data(),
|
|
501
|
+
&ni,
|
|
502
|
+
xt.data(),
|
|
503
|
+
&ni,
|
|
504
|
+
sing_vals.data(),
|
|
505
|
+
&rcond,
|
|
506
|
+
&rank,
|
|
507
|
+
work.data(),
|
|
508
|
+
&lwork,
|
|
509
|
+
iwork.data(),
|
|
510
|
+
&info);
|
|
511
|
+
FAISS_THROW_IF_NOT_FMT(info == 0, "SGELS returned info=%d", int(info));
|
|
512
|
+
if (verbose) {
|
|
513
|
+
printf(" sgelsd rank=%d/%d\n",
|
|
514
|
+
int(rank),
|
|
515
|
+
int(total_codebook_size));
|
|
516
|
+
}
|
|
517
|
+
}
|
|
518
|
+
|
|
519
|
+
// result is in xt, re-transpose to codebook
|
|
520
|
+
|
|
521
|
+
for (size_t i = 0; i < total_codebook_size; i++) {
|
|
522
|
+
for (size_t j = 0; j < d; j++) {
|
|
523
|
+
codebooks[i * d + j] = xt[j * n + i];
|
|
524
|
+
FAISS_THROW_IF_NOT(std::isfinite(codebooks[i * d + j]));
|
|
525
|
+
}
|
|
526
|
+
}
|
|
527
|
+
|
|
528
|
+
float output_recons_error = 0;
|
|
529
|
+
for (size_t j = 0; j < d; j++) {
|
|
530
|
+
output_recons_error += fvec_norm_L2sqr(
|
|
531
|
+
xt.data() + total_codebook_size + n * j,
|
|
532
|
+
n - total_codebook_size);
|
|
533
|
+
}
|
|
534
|
+
if (verbose) {
|
|
535
|
+
printf(" output quantization error %g\n", output_recons_error);
|
|
536
|
+
}
|
|
537
|
+
return output_recons_error;
|
|
310
538
|
}
|
|
311
539
|
|
|
312
540
|
size_t ResidualQuantizer::memory_per_point(int beam_size) const {
|
|
@@ -321,10 +549,11 @@ size_t ResidualQuantizer::memory_per_point(int beam_size) const {
|
|
|
321
549
|
return mem;
|
|
322
550
|
}
|
|
323
551
|
|
|
324
|
-
void ResidualQuantizer::
|
|
552
|
+
void ResidualQuantizer::compute_codes_add_centroids(
|
|
325
553
|
const float* x,
|
|
326
554
|
uint8_t* codes_out,
|
|
327
|
-
size_t n
|
|
555
|
+
size_t n,
|
|
556
|
+
const float* centroids) const {
|
|
328
557
|
FAISS_THROW_IF_NOT_MSG(is_trained, "RQ is not trained yet.");
|
|
329
558
|
|
|
330
559
|
size_t mem = memory_per_point();
|
|
@@ -336,27 +565,87 @@ void ResidualQuantizer::compute_codes(
|
|
|
336
565
|
}
|
|
337
566
|
for (size_t i0 = 0; i0 < n; i0 += bs) {
|
|
338
567
|
size_t i1 = std::min(n, i0 + bs);
|
|
339
|
-
|
|
568
|
+
const float* cent = nullptr;
|
|
569
|
+
if (centroids != nullptr) {
|
|
570
|
+
cent = centroids + i0 * d;
|
|
571
|
+
}
|
|
572
|
+
compute_codes_add_centroids(
|
|
573
|
+
x + i0 * d, codes_out + i0 * code_size, i1 - i0, cent);
|
|
340
574
|
}
|
|
341
575
|
return;
|
|
342
576
|
}
|
|
343
577
|
|
|
344
|
-
std::vector<float> residuals(max_beam_size * n * d);
|
|
345
578
|
std::vector<int32_t> codes(max_beam_size * M * n);
|
|
579
|
+
std::vector<float> norms;
|
|
346
580
|
std::vector<float> distances(max_beam_size * n);
|
|
347
581
|
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
1,
|
|
351
|
-
x,
|
|
352
|
-
max_beam_size,
|
|
353
|
-
codes.data(),
|
|
354
|
-
residuals.data(),
|
|
355
|
-
distances.data());
|
|
582
|
+
if (use_beam_LUT == 0) {
|
|
583
|
+
std::vector<float> residuals(max_beam_size * n * d);
|
|
356
584
|
|
|
585
|
+
refine_beam(
|
|
586
|
+
n,
|
|
587
|
+
1,
|
|
588
|
+
x,
|
|
589
|
+
max_beam_size,
|
|
590
|
+
codes.data(),
|
|
591
|
+
residuals.data(),
|
|
592
|
+
distances.data());
|
|
593
|
+
|
|
594
|
+
if (search_type == ST_norm_float || search_type == ST_norm_qint8 ||
|
|
595
|
+
search_type == ST_norm_qint4) {
|
|
596
|
+
norms.resize(n);
|
|
597
|
+
// recover the norms of reconstruction as
|
|
598
|
+
// || original_vector - residual ||^2
|
|
599
|
+
for (size_t i = 0; i < n; i++) {
|
|
600
|
+
norms[i] = fvec_L2sqr(
|
|
601
|
+
x + i * d, residuals.data() + i * max_beam_size * d, d);
|
|
602
|
+
}
|
|
603
|
+
}
|
|
604
|
+
} else if (use_beam_LUT == 1) {
|
|
605
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
606
|
+
codebook_cross_products.size() ==
|
|
607
|
+
total_codebook_size * total_codebook_size,
|
|
608
|
+
"call compute_codebook_tables first");
|
|
609
|
+
|
|
610
|
+
std::vector<float> query_norms(n);
|
|
611
|
+
fvec_norms_L2sqr(query_norms.data(), x, d, n);
|
|
612
|
+
|
|
613
|
+
std::vector<float> query_cp(n * total_codebook_size);
|
|
614
|
+
{
|
|
615
|
+
FINTEGER ti = total_codebook_size, di = d, ni = n;
|
|
616
|
+
float zero = 0, one = 1;
|
|
617
|
+
sgemm_("Transposed",
|
|
618
|
+
"Not transposed",
|
|
619
|
+
&ti,
|
|
620
|
+
&ni,
|
|
621
|
+
&di,
|
|
622
|
+
&one,
|
|
623
|
+
codebooks.data(),
|
|
624
|
+
&di,
|
|
625
|
+
x,
|
|
626
|
+
&di,
|
|
627
|
+
&zero,
|
|
628
|
+
query_cp.data(),
|
|
629
|
+
&ti);
|
|
630
|
+
}
|
|
631
|
+
|
|
632
|
+
refine_beam_LUT(
|
|
633
|
+
n,
|
|
634
|
+
query_norms.data(),
|
|
635
|
+
query_cp.data(),
|
|
636
|
+
max_beam_size,
|
|
637
|
+
codes.data(),
|
|
638
|
+
distances.data());
|
|
639
|
+
}
|
|
357
640
|
// pack only the first code of the beam (hence the ld_codes=M *
|
|
358
641
|
// max_beam_size)
|
|
359
|
-
pack_codes(
|
|
642
|
+
pack_codes(
|
|
643
|
+
n,
|
|
644
|
+
codes.data(),
|
|
645
|
+
codes_out,
|
|
646
|
+
M * max_beam_size,
|
|
647
|
+
norms.size() > 0 ? norms.data() : nullptr,
|
|
648
|
+
centroids);
|
|
360
649
|
}
|
|
361
650
|
|
|
362
651
|
void ResidualQuantizer::refine_beam(
|
|
@@ -445,4 +734,181 @@ void ResidualQuantizer::refine_beam(
|
|
|
445
734
|
}
|
|
446
735
|
}
|
|
447
736
|
|
|
737
|
+
/*******************************************************************
|
|
738
|
+
* Functions using the dot products between codebook entries
|
|
739
|
+
*******************************************************************/
|
|
740
|
+
|
|
741
|
+
void ResidualQuantizer::compute_codebook_tables() {
|
|
742
|
+
codebook_cross_products.resize(total_codebook_size * total_codebook_size);
|
|
743
|
+
cent_norms.resize(total_codebook_size);
|
|
744
|
+
// stricly speaking we could use ssyrk
|
|
745
|
+
{
|
|
746
|
+
FINTEGER ni = total_codebook_size;
|
|
747
|
+
FINTEGER di = d;
|
|
748
|
+
float zero = 0, one = 1;
|
|
749
|
+
sgemm_("Transposed",
|
|
750
|
+
"Not transposed",
|
|
751
|
+
&ni,
|
|
752
|
+
&ni,
|
|
753
|
+
&di,
|
|
754
|
+
&one,
|
|
755
|
+
codebooks.data(),
|
|
756
|
+
&di,
|
|
757
|
+
codebooks.data(),
|
|
758
|
+
&di,
|
|
759
|
+
&zero,
|
|
760
|
+
codebook_cross_products.data(),
|
|
761
|
+
&ni);
|
|
762
|
+
}
|
|
763
|
+
for (size_t i = 0; i < total_codebook_size; i++) {
|
|
764
|
+
cent_norms[i] = codebook_cross_products[i + i * total_codebook_size];
|
|
765
|
+
}
|
|
766
|
+
}
|
|
767
|
+
|
|
768
|
+
void beam_search_encode_step_tab(
|
|
769
|
+
size_t K,
|
|
770
|
+
size_t n,
|
|
771
|
+
size_t beam_size, // input sizes
|
|
772
|
+
const float* codebook_cross_norms, // size K * ldc
|
|
773
|
+
size_t ldc, // >= K
|
|
774
|
+
const uint64_t* codebook_offsets, // m
|
|
775
|
+
const float* query_cp, // size n * ldqc
|
|
776
|
+
size_t ldqc, // >= K
|
|
777
|
+
const float* cent_norms_i, // size K
|
|
778
|
+
size_t m,
|
|
779
|
+
const int32_t* codes, // n * beam_size * m
|
|
780
|
+
const float* distances, // n * beam_size
|
|
781
|
+
size_t new_beam_size,
|
|
782
|
+
int32_t* new_codes, // n * new_beam_size * (m + 1)
|
|
783
|
+
float* new_distances) // n * new_beam_size
|
|
784
|
+
{
|
|
785
|
+
FAISS_THROW_IF_NOT(ldc >= K);
|
|
786
|
+
|
|
787
|
+
#pragma omp parallel for if (n > 100)
|
|
788
|
+
for (int64_t i = 0; i < n; i++) {
|
|
789
|
+
std::vector<float> cent_distances(beam_size * K);
|
|
790
|
+
std::vector<float> cd_common(K);
|
|
791
|
+
|
|
792
|
+
const int32_t* codes_i = codes + i * m * beam_size;
|
|
793
|
+
const float* query_cp_i = query_cp + i * ldqc;
|
|
794
|
+
const float* distances_i = distances + i * beam_size;
|
|
795
|
+
|
|
796
|
+
for (size_t k = 0; k < K; k++) {
|
|
797
|
+
cd_common[k] = cent_norms_i[k] - 2 * query_cp_i[k];
|
|
798
|
+
}
|
|
799
|
+
|
|
800
|
+
for (size_t b = 0; b < beam_size; b++) {
|
|
801
|
+
std::vector<float> dp(K);
|
|
802
|
+
|
|
803
|
+
for (size_t m1 = 0; m1 < m; m1++) {
|
|
804
|
+
size_t c = codes_i[b * m + m1];
|
|
805
|
+
const float* cb =
|
|
806
|
+
&codebook_cross_norms[(codebook_offsets[m1] + c) * ldc];
|
|
807
|
+
fvec_add(K, cb, dp.data(), dp.data());
|
|
808
|
+
}
|
|
809
|
+
|
|
810
|
+
for (size_t k = 0; k < K; k++) {
|
|
811
|
+
cent_distances[b * K + k] =
|
|
812
|
+
distances_i[b] + cd_common[k] + 2 * dp[k];
|
|
813
|
+
}
|
|
814
|
+
}
|
|
815
|
+
|
|
816
|
+
using C = CMax<float, int>;
|
|
817
|
+
int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size;
|
|
818
|
+
float* new_distances_i = new_distances + i * new_beam_size;
|
|
819
|
+
|
|
820
|
+
const float* cent_distances_i = cent_distances.data();
|
|
821
|
+
|
|
822
|
+
// then we have to select the best results
|
|
823
|
+
for (int i = 0; i < new_beam_size; i++) {
|
|
824
|
+
new_distances_i[i] = C::neutral();
|
|
825
|
+
}
|
|
826
|
+
std::vector<int> perm(new_beam_size, -1);
|
|
827
|
+
heap_addn<C>(
|
|
828
|
+
new_beam_size,
|
|
829
|
+
new_distances_i,
|
|
830
|
+
perm.data(),
|
|
831
|
+
cent_distances_i,
|
|
832
|
+
nullptr,
|
|
833
|
+
beam_size * K);
|
|
834
|
+
heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
|
|
835
|
+
|
|
836
|
+
for (int j = 0; j < new_beam_size; j++) {
|
|
837
|
+
int js = perm[j] / K;
|
|
838
|
+
int ls = perm[j] % K;
|
|
839
|
+
if (m > 0) {
|
|
840
|
+
memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m);
|
|
841
|
+
}
|
|
842
|
+
new_codes_i[m] = ls;
|
|
843
|
+
new_codes_i += m + 1;
|
|
844
|
+
}
|
|
845
|
+
}
|
|
846
|
+
}
|
|
847
|
+
|
|
848
|
+
void ResidualQuantizer::refine_beam_LUT(
|
|
849
|
+
size_t n,
|
|
850
|
+
const float* query_norms, // size n
|
|
851
|
+
const float* query_cp, //
|
|
852
|
+
int out_beam_size,
|
|
853
|
+
int32_t* out_codes,
|
|
854
|
+
float* out_distances) const {
|
|
855
|
+
int beam_size = 1;
|
|
856
|
+
|
|
857
|
+
std::vector<int32_t> codes;
|
|
858
|
+
std::vector<float> distances(query_norms, query_norms + n);
|
|
859
|
+
double t0 = getmillisecs();
|
|
860
|
+
|
|
861
|
+
for (int m = 0; m < M; m++) {
|
|
862
|
+
int K = 1 << nbits[m];
|
|
863
|
+
|
|
864
|
+
int new_beam_size = std::min(beam_size * K, out_beam_size);
|
|
865
|
+
std::vector<int32_t> new_codes(n * new_beam_size * (m + 1));
|
|
866
|
+
std::vector<float> new_distances(n * new_beam_size);
|
|
867
|
+
|
|
868
|
+
beam_search_encode_step_tab(
|
|
869
|
+
K,
|
|
870
|
+
n,
|
|
871
|
+
beam_size,
|
|
872
|
+
codebook_cross_products.data() + codebook_offsets[m],
|
|
873
|
+
total_codebook_size,
|
|
874
|
+
codebook_offsets.data(),
|
|
875
|
+
query_cp + codebook_offsets[m],
|
|
876
|
+
total_codebook_size,
|
|
877
|
+
cent_norms.data() + codebook_offsets[m],
|
|
878
|
+
m,
|
|
879
|
+
codes.data(),
|
|
880
|
+
distances.data(),
|
|
881
|
+
new_beam_size,
|
|
882
|
+
new_codes.data(),
|
|
883
|
+
new_distances.data());
|
|
884
|
+
|
|
885
|
+
codes.swap(new_codes);
|
|
886
|
+
distances.swap(new_distances);
|
|
887
|
+
beam_size = new_beam_size;
|
|
888
|
+
|
|
889
|
+
if (verbose) {
|
|
890
|
+
float sum_distances = 0;
|
|
891
|
+
for (int j = 0; j < distances.size(); j++) {
|
|
892
|
+
sum_distances += distances[j];
|
|
893
|
+
}
|
|
894
|
+
printf("[%.3f s] encode stage %d, %d bits, "
|
|
895
|
+
"total error %g, beam_size %d\n",
|
|
896
|
+
(getmillisecs() - t0) / 1000,
|
|
897
|
+
m,
|
|
898
|
+
int(nbits[m]),
|
|
899
|
+
sum_distances,
|
|
900
|
+
beam_size);
|
|
901
|
+
}
|
|
902
|
+
}
|
|
903
|
+
|
|
904
|
+
if (out_codes) {
|
|
905
|
+
memcpy(out_codes, codes.data(), codes.size() * sizeof(codes[0]));
|
|
906
|
+
}
|
|
907
|
+
if (out_distances) {
|
|
908
|
+
memcpy(out_distances,
|
|
909
|
+
distances.data(),
|
|
910
|
+
distances.size() * sizeof(distances[0]));
|
|
911
|
+
}
|
|
912
|
+
}
|
|
913
|
+
|
|
448
914
|
} // namespace faiss
|