faiss 0.2.3 → 0.2.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/LICENSE.txt +1 -1
- data/README.md +23 -21
- data/ext/faiss/extconf.rb +11 -0
- data/ext/faiss/index.cpp +4 -4
- data/ext/faiss/index_binary.cpp +6 -6
- data/ext/faiss/product_quantizer.cpp +4 -4
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +13 -0
- data/vendor/faiss/faiss/Clustering.cpp +32 -0
- data/vendor/faiss/faiss/Clustering.h +14 -0
- data/vendor/faiss/faiss/IVFlib.cpp +101 -2
- data/vendor/faiss/faiss/IVFlib.h +26 -2
- data/vendor/faiss/faiss/Index.cpp +36 -3
- data/vendor/faiss/faiss/Index.h +43 -6
- data/vendor/faiss/faiss/Index2Layer.cpp +24 -93
- data/vendor/faiss/faiss/Index2Layer.h +8 -17
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +610 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +253 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
- data/vendor/faiss/faiss/IndexBinary.h +18 -3
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
- data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
- data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
- data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
- data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
- data/vendor/faiss/faiss/IndexFastScan.h +145 -0
- data/vendor/faiss/faiss/IndexFlat.cpp +52 -69
- data/vendor/faiss/faiss/IndexFlat.h +16 -19
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +101 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +59 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
- data/vendor/faiss/faiss/IndexHNSW.h +4 -2
- data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
- data/vendor/faiss/faiss/IndexIDMap.h +107 -0
- data/vendor/faiss/faiss/IndexIVF.cpp +200 -40
- data/vendor/faiss/faiss/IndexIVF.h +59 -22
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +393 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +183 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +43 -26
- data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +238 -53
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -2
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +63 -40
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +23 -7
- data/vendor/faiss/faiss/IndexLSH.cpp +8 -32
- data/vendor/faiss/faiss/IndexLSH.h +4 -16
- data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
- data/vendor/faiss/faiss/IndexLattice.h +3 -1
- data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -5
- data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +37 -5
- data/vendor/faiss/faiss/IndexNSG.h +25 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +108 -120
- data/vendor/faiss/faiss/IndexPQ.h +21 -22
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
- data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
- data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
- data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
- data/vendor/faiss/faiss/IndexRefine.cpp +36 -4
- data/vendor/faiss/faiss/IndexRefine.h +14 -2
- data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
- data/vendor/faiss/faiss/IndexReplicas.h +2 -1
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +28 -43
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +8 -23
- data/vendor/faiss/faiss/IndexShards.cpp +4 -1
- data/vendor/faiss/faiss/IndexShards.h +2 -1
- data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
- data/vendor/faiss/faiss/MetaIndexes.h +3 -81
- data/vendor/faiss/faiss/VectorTransform.cpp +45 -1
- data/vendor/faiss/faiss/VectorTransform.h +25 -4
- data/vendor/faiss/faiss/clone_index.cpp +26 -3
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
- data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -6
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
- data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
- data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +331 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +110 -19
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
- data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +133 -32
- data/vendor/faiss/faiss/impl/HNSW.h +19 -16
- data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
- data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +378 -217
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +106 -29
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -4
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
- data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +521 -55
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +94 -16
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +108 -191
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
- data/vendor/faiss/faiss/impl/index_read.cpp +338 -24
- data/vendor/faiss/faiss/impl/index_write.cpp +300 -18
- data/vendor/faiss/faiss/impl/io.cpp +1 -1
- data/vendor/faiss/faiss/impl/io_macros.h +20 -0
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +303 -0
- data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
- data/vendor/faiss/faiss/index_factory.cpp +772 -412
- data/vendor/faiss/faiss/index_factory.h +3 -0
- data/vendor/faiss/faiss/index_io.h +5 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
- data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
- data/vendor/faiss/faiss/utils/Heap.h +31 -15
- data/vendor/faiss/faiss/utils/distances.cpp +384 -58
- data/vendor/faiss/faiss/utils/distances.h +149 -18
- data/vendor/faiss/faiss/utils/distances_simd.cpp +776 -6
- data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
- data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
- data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
- data/vendor/faiss/faiss/utils/fp16.h +11 -0
- data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
- data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
- data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
- data/vendor/faiss/faiss/utils/random.cpp +53 -0
- data/vendor/faiss/faiss/utils/random.h +5 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
- data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
- data/vendor/faiss/faiss/utils/utils.h +1 -1
- metadata +46 -5
- data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
- data/vendor/faiss/faiss/IndexResidual.h +0 -152
|
@@ -5,9 +5,6 @@
|
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
|
6
6
|
*/
|
|
7
7
|
|
|
8
|
-
// -*- c++ -*-
|
|
9
|
-
|
|
10
|
-
#include <faiss/impl/FaissAssert.h>
|
|
11
8
|
#include <faiss/impl/LocalSearchQuantizer.h>
|
|
12
9
|
|
|
13
10
|
#include <cstddef>
|
|
@@ -18,6 +15,8 @@
|
|
|
18
15
|
|
|
19
16
|
#include <algorithm>
|
|
20
17
|
|
|
18
|
+
#include <faiss/impl/AuxIndexStructures.h>
|
|
19
|
+
#include <faiss/impl/FaissAssert.h>
|
|
21
20
|
#include <faiss/utils/distances.h>
|
|
22
21
|
#include <faiss/utils/hamming.h> // BitstringWriter
|
|
23
22
|
#include <faiss/utils/utils.h>
|
|
@@ -42,18 +41,6 @@ void sgetri_(
|
|
|
42
41
|
FINTEGER* lwork,
|
|
43
42
|
FINTEGER* info);
|
|
44
43
|
|
|
45
|
-
// solves a system of linear equations
|
|
46
|
-
void sgetrs_(
|
|
47
|
-
const char* trans,
|
|
48
|
-
FINTEGER* n,
|
|
49
|
-
FINTEGER* nrhs,
|
|
50
|
-
float* A,
|
|
51
|
-
FINTEGER* lda,
|
|
52
|
-
FINTEGER* ipiv,
|
|
53
|
-
float* b,
|
|
54
|
-
FINTEGER* ldb,
|
|
55
|
-
FINTEGER* info);
|
|
56
|
-
|
|
57
44
|
// general matrix multiplication
|
|
58
45
|
int sgemm_(
|
|
59
46
|
const char* transa,
|
|
@@ -69,26 +56,73 @@ int sgemm_(
|
|
|
69
56
|
float* beta,
|
|
70
57
|
float* c,
|
|
71
58
|
FINTEGER* ldc);
|
|
59
|
+
|
|
60
|
+
// LU decomoposition of a general matrix
|
|
61
|
+
void dgetrf_(
|
|
62
|
+
FINTEGER* m,
|
|
63
|
+
FINTEGER* n,
|
|
64
|
+
double* a,
|
|
65
|
+
FINTEGER* lda,
|
|
66
|
+
FINTEGER* ipiv,
|
|
67
|
+
FINTEGER* info);
|
|
68
|
+
|
|
69
|
+
// generate inverse of a matrix given its LU decomposition
|
|
70
|
+
void dgetri_(
|
|
71
|
+
FINTEGER* n,
|
|
72
|
+
double* a,
|
|
73
|
+
FINTEGER* lda,
|
|
74
|
+
FINTEGER* ipiv,
|
|
75
|
+
double* work,
|
|
76
|
+
FINTEGER* lwork,
|
|
77
|
+
FINTEGER* info);
|
|
78
|
+
|
|
79
|
+
// general matrix multiplication
|
|
80
|
+
int dgemm_(
|
|
81
|
+
const char* transa,
|
|
82
|
+
const char* transb,
|
|
83
|
+
FINTEGER* m,
|
|
84
|
+
FINTEGER* n,
|
|
85
|
+
FINTEGER* k,
|
|
86
|
+
const double* alpha,
|
|
87
|
+
const double* a,
|
|
88
|
+
FINTEGER* lda,
|
|
89
|
+
const double* b,
|
|
90
|
+
FINTEGER* ldb,
|
|
91
|
+
double* beta,
|
|
92
|
+
double* c,
|
|
93
|
+
FINTEGER* ldc);
|
|
72
94
|
}
|
|
73
95
|
|
|
74
96
|
namespace {
|
|
75
97
|
|
|
98
|
+
void fmat_inverse(float* a, int n) {
|
|
99
|
+
int info;
|
|
100
|
+
int lwork = n * n;
|
|
101
|
+
std::vector<int> ipiv(n);
|
|
102
|
+
std::vector<float> workspace(lwork);
|
|
103
|
+
|
|
104
|
+
sgetrf_(&n, &n, a, &n, ipiv.data(), &info);
|
|
105
|
+
FAISS_THROW_IF_NOT(info == 0);
|
|
106
|
+
sgetri_(&n, a, &n, ipiv.data(), workspace.data(), &lwork, &info);
|
|
107
|
+
FAISS_THROW_IF_NOT(info == 0);
|
|
108
|
+
}
|
|
109
|
+
|
|
76
110
|
// c and a and b can overlap
|
|
77
|
-
void
|
|
111
|
+
void dfvec_add(size_t d, const double* a, const float* b, double* c) {
|
|
78
112
|
for (size_t i = 0; i < d; i++) {
|
|
79
113
|
c[i] = a[i] + b[i];
|
|
80
114
|
}
|
|
81
115
|
}
|
|
82
116
|
|
|
83
|
-
void
|
|
117
|
+
void dmat_inverse(double* a, int n) {
|
|
84
118
|
int info;
|
|
85
119
|
int lwork = n * n;
|
|
86
120
|
std::vector<int> ipiv(n);
|
|
87
|
-
std::vector<
|
|
121
|
+
std::vector<double> workspace(lwork);
|
|
88
122
|
|
|
89
|
-
|
|
123
|
+
dgetrf_(&n, &n, a, &n, ipiv.data(), &info);
|
|
90
124
|
FAISS_THROW_IF_NOT(info == 0);
|
|
91
|
-
|
|
125
|
+
dgetri_(&n, a, &n, ipiv.data(), workspace.data(), &lwork, &info);
|
|
92
126
|
FAISS_THROW_IF_NOT(info == 0);
|
|
93
127
|
}
|
|
94
128
|
|
|
@@ -107,21 +141,15 @@ void random_int32(
|
|
|
107
141
|
|
|
108
142
|
namespace faiss {
|
|
109
143
|
|
|
110
|
-
LSQTimer lsq_timer;
|
|
111
|
-
|
|
112
|
-
LocalSearchQuantizer::LocalSearchQuantizer(size_t d, size_t M, size_t nbits) {
|
|
113
|
-
FAISS_THROW_IF_NOT((M * nbits) % 8 == 0);
|
|
114
|
-
|
|
115
|
-
this->d = d;
|
|
116
|
-
this->M = M;
|
|
117
|
-
this->nbits = std::vector<size_t>(M, nbits);
|
|
118
|
-
|
|
119
|
-
// set derived values
|
|
120
|
-
set_derived_values();
|
|
121
|
-
|
|
122
|
-
is_trained = false;
|
|
123
|
-
verbose = false;
|
|
144
|
+
lsq::LSQTimer lsq_timer;
|
|
145
|
+
using lsq::LSQTimerScope;
|
|
124
146
|
|
|
147
|
+
LocalSearchQuantizer::LocalSearchQuantizer(
|
|
148
|
+
size_t d,
|
|
149
|
+
size_t M,
|
|
150
|
+
size_t nbits,
|
|
151
|
+
Search_type_t search_type)
|
|
152
|
+
: AdditiveQuantizer(d, std::vector<size_t>(M, nbits), search_type) {
|
|
125
153
|
K = (1 << nbits);
|
|
126
154
|
|
|
127
155
|
train_iters = 25;
|
|
@@ -138,15 +166,23 @@ LocalSearchQuantizer::LocalSearchQuantizer(size_t d, size_t M, size_t nbits) {
|
|
|
138
166
|
|
|
139
167
|
random_seed = 0x12345;
|
|
140
168
|
std::srand(random_seed);
|
|
169
|
+
|
|
170
|
+
icm_encoder_factory = nullptr;
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
LocalSearchQuantizer::~LocalSearchQuantizer() {
|
|
174
|
+
delete icm_encoder_factory;
|
|
141
175
|
}
|
|
142
176
|
|
|
177
|
+
LocalSearchQuantizer::LocalSearchQuantizer() : LocalSearchQuantizer(0, 0, 0) {}
|
|
178
|
+
|
|
143
179
|
void LocalSearchQuantizer::train(size_t n, const float* x) {
|
|
144
180
|
FAISS_THROW_IF_NOT(K == (1 << nbits[0]));
|
|
145
|
-
|
|
181
|
+
nperts = std::min(nperts, M);
|
|
146
182
|
|
|
147
183
|
lsq_timer.reset();
|
|
184
|
+
LSQTimerScope scope(&lsq_timer, "train");
|
|
148
185
|
if (verbose) {
|
|
149
|
-
lsq_timer.start("train");
|
|
150
186
|
printf("Training LSQ, with %zd subcodes on %zd %zdD vectors\n",
|
|
151
187
|
M,
|
|
152
188
|
n,
|
|
@@ -209,7 +245,7 @@ void LocalSearchQuantizer::train(size_t n, const float* x) {
|
|
|
209
245
|
}
|
|
210
246
|
|
|
211
247
|
// refine codes
|
|
212
|
-
icm_encode(
|
|
248
|
+
icm_encode(codes.data(), x, n, train_ils_iters, gen);
|
|
213
249
|
|
|
214
250
|
if (verbose) {
|
|
215
251
|
float obj = evaluate(codes.data(), x, n);
|
|
@@ -217,25 +253,33 @@ void LocalSearchQuantizer::train(size_t n, const float* x) {
|
|
|
217
253
|
}
|
|
218
254
|
}
|
|
219
255
|
|
|
256
|
+
is_trained = true;
|
|
257
|
+
{
|
|
258
|
+
std::vector<float> x_recons(n * d);
|
|
259
|
+
std::vector<float> norms(n);
|
|
260
|
+
decode_unpacked(codes.data(), x_recons.data(), n);
|
|
261
|
+
fvec_norms_L2sqr(norms.data(), x_recons.data(), d, n);
|
|
262
|
+
|
|
263
|
+
train_norm(n, norms.data());
|
|
264
|
+
}
|
|
265
|
+
|
|
220
266
|
if (verbose) {
|
|
221
|
-
lsq_timer.end("train");
|
|
222
267
|
float obj = evaluate(codes.data(), x, n);
|
|
268
|
+
scope.finish();
|
|
223
269
|
printf("After training: obj = %lf\n", obj);
|
|
224
270
|
|
|
225
271
|
printf("Time statistic:\n");
|
|
226
|
-
for (const auto& it : lsq_timer.
|
|
227
|
-
printf("\t%s time: %lf s\n", it.first.data(), it.second);
|
|
272
|
+
for (const auto& it : lsq_timer.t) {
|
|
273
|
+
printf("\t%s time: %lf s\n", it.first.data(), it.second / 1000);
|
|
228
274
|
}
|
|
229
275
|
}
|
|
230
|
-
|
|
231
|
-
is_trained = true;
|
|
232
276
|
}
|
|
233
277
|
|
|
234
278
|
void LocalSearchQuantizer::perturb_codebooks(
|
|
235
279
|
float T,
|
|
236
280
|
const std::vector<float>& stddev,
|
|
237
281
|
std::mt19937& gen) {
|
|
238
|
-
lsq_timer
|
|
282
|
+
LSQTimerScope scope(&lsq_timer, "perturb_codebooks");
|
|
239
283
|
|
|
240
284
|
std::vector<std::normal_distribution<float>> distribs;
|
|
241
285
|
for (size_t i = 0; i < d; i++) {
|
|
@@ -249,32 +293,34 @@ void LocalSearchQuantizer::perturb_codebooks(
|
|
|
249
293
|
}
|
|
250
294
|
}
|
|
251
295
|
}
|
|
252
|
-
|
|
253
|
-
lsq_timer.end("perturb_codebooks");
|
|
254
296
|
}
|
|
255
297
|
|
|
256
|
-
void LocalSearchQuantizer::
|
|
298
|
+
void LocalSearchQuantizer::compute_codes_add_centroids(
|
|
257
299
|
const float* x,
|
|
258
300
|
uint8_t* codes_out,
|
|
259
|
-
size_t n
|
|
301
|
+
size_t n,
|
|
302
|
+
const float* centroids) const {
|
|
260
303
|
FAISS_THROW_IF_NOT_MSG(is_trained, "LSQ is not trained yet.");
|
|
304
|
+
|
|
305
|
+
lsq_timer.reset();
|
|
306
|
+
LSQTimerScope scope(&lsq_timer, "encode");
|
|
261
307
|
if (verbose) {
|
|
262
|
-
lsq_timer.reset();
|
|
263
308
|
printf("Encoding %zd vectors...\n", n);
|
|
264
|
-
lsq_timer.start("encode");
|
|
265
309
|
}
|
|
266
310
|
|
|
267
311
|
std::vector<int32_t> codes(n * M);
|
|
268
312
|
std::mt19937 gen(random_seed);
|
|
269
313
|
random_int32(codes, 0, K - 1, gen);
|
|
270
314
|
|
|
271
|
-
icm_encode(
|
|
272
|
-
pack_codes(n, codes.data(), codes_out);
|
|
315
|
+
icm_encode(codes.data(), x, n, encode_ils_iters, gen);
|
|
316
|
+
pack_codes(n, codes.data(), codes_out, -1, nullptr, centroids);
|
|
273
317
|
|
|
274
318
|
if (verbose) {
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
319
|
+
scope.finish();
|
|
320
|
+
printf("Time statistic:\n");
|
|
321
|
+
for (const auto& it : lsq_timer.t) {
|
|
322
|
+
printf("\t%s time: %lf s\n", it.first.data(), it.second / 1000);
|
|
323
|
+
}
|
|
278
324
|
}
|
|
279
325
|
}
|
|
280
326
|
|
|
@@ -298,73 +344,144 @@ void LocalSearchQuantizer::update_codebooks(
|
|
|
298
344
|
const float* x,
|
|
299
345
|
const int32_t* codes,
|
|
300
346
|
size_t n) {
|
|
301
|
-
lsq_timer
|
|
347
|
+
LSQTimerScope scope(&lsq_timer, "update_codebooks");
|
|
348
|
+
|
|
349
|
+
if (!update_codebooks_with_double) {
|
|
350
|
+
// allocate memory
|
|
351
|
+
// bb = B'B, bx = BX
|
|
352
|
+
std::vector<float> bb(M * K * M * K, 0.0f); // [M * K, M * K]
|
|
353
|
+
std::vector<float> bx(M * K * d, 0.0f); // [M * K, d]
|
|
354
|
+
|
|
355
|
+
// compute B'B
|
|
356
|
+
for (size_t i = 0; i < n; i++) {
|
|
357
|
+
for (size_t m = 0; m < M; m++) {
|
|
358
|
+
int32_t code1 = codes[i * M + m];
|
|
359
|
+
int32_t idx1 = m * K + code1;
|
|
360
|
+
bb[idx1 * M * K + idx1] += 1;
|
|
361
|
+
|
|
362
|
+
for (size_t m2 = m + 1; m2 < M; m2++) {
|
|
363
|
+
int32_t code2 = codes[i * M + m2];
|
|
364
|
+
int32_t idx2 = m2 * K + code2;
|
|
365
|
+
bb[idx1 * M * K + idx2] += 1;
|
|
366
|
+
bb[idx2 * M * K + idx1] += 1;
|
|
367
|
+
}
|
|
368
|
+
}
|
|
369
|
+
}
|
|
302
370
|
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
371
|
+
// add a regularization term to B'B
|
|
372
|
+
for (int64_t i = 0; i < M * K; i++) {
|
|
373
|
+
bb[i * (M * K) + i] += lambd;
|
|
374
|
+
}
|
|
307
375
|
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
int32_t idx2 = m2 * K + code2;
|
|
318
|
-
bb[idx1 * M * K + idx2] += 1;
|
|
319
|
-
bb[idx2 * M * K + idx1] += 1;
|
|
376
|
+
// compute (B'B)^(-1)
|
|
377
|
+
fmat_inverse(bb.data(), M * K); // [M*K, M*K]
|
|
378
|
+
|
|
379
|
+
// compute BX
|
|
380
|
+
for (size_t i = 0; i < n; i++) {
|
|
381
|
+
for (size_t m = 0; m < M; m++) {
|
|
382
|
+
int32_t code = codes[i * M + m];
|
|
383
|
+
float* data = bx.data() + (m * K + code) * d;
|
|
384
|
+
fvec_add(d, data, x + i * d, data);
|
|
320
385
|
}
|
|
321
386
|
}
|
|
322
|
-
}
|
|
323
387
|
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
388
|
+
// compute C = (B'B)^(-1) @ BX
|
|
389
|
+
//
|
|
390
|
+
// NOTE: LAPACK use column major order
|
|
391
|
+
// out = alpha * op(A) * op(B) + beta * C
|
|
392
|
+
FINTEGER nrows_A = d;
|
|
393
|
+
FINTEGER ncols_A = M * K;
|
|
394
|
+
|
|
395
|
+
FINTEGER nrows_B = M * K;
|
|
396
|
+
FINTEGER ncols_B = M * K;
|
|
397
|
+
|
|
398
|
+
float alpha = 1.0f;
|
|
399
|
+
float beta = 0.0f;
|
|
400
|
+
sgemm_("Not Transposed",
|
|
401
|
+
"Not Transposed",
|
|
402
|
+
&nrows_A, // nrows of op(A)
|
|
403
|
+
&ncols_B, // ncols of op(B)
|
|
404
|
+
&ncols_A, // ncols of op(A)
|
|
405
|
+
&alpha,
|
|
406
|
+
bx.data(),
|
|
407
|
+
&nrows_A, // nrows of A
|
|
408
|
+
bb.data(),
|
|
409
|
+
&nrows_B, // nrows of B
|
|
410
|
+
&beta,
|
|
411
|
+
codebooks.data(),
|
|
412
|
+
&nrows_A); // nrows of output
|
|
413
|
+
|
|
414
|
+
} else {
|
|
415
|
+
// allocate memory
|
|
416
|
+
// bb = B'B, bx = BX
|
|
417
|
+
std::vector<double> bb(M * K * M * K, 0.0f); // [M * K, M * K]
|
|
418
|
+
std::vector<double> bx(M * K * d, 0.0f); // [M * K, d]
|
|
419
|
+
|
|
420
|
+
// compute B'B
|
|
421
|
+
for (size_t i = 0; i < n; i++) {
|
|
422
|
+
for (size_t m = 0; m < M; m++) {
|
|
423
|
+
int32_t code1 = codes[i * M + m];
|
|
424
|
+
int32_t idx1 = m * K + code1;
|
|
425
|
+
bb[idx1 * M * K + idx1] += 1;
|
|
426
|
+
|
|
427
|
+
for (size_t m2 = m + 1; m2 < M; m2++) {
|
|
428
|
+
int32_t code2 = codes[i * M + m2];
|
|
429
|
+
int32_t idx2 = m2 * K + code2;
|
|
430
|
+
bb[idx1 * M * K + idx2] += 1;
|
|
431
|
+
bb[idx2 * M * K + idx1] += 1;
|
|
432
|
+
}
|
|
433
|
+
}
|
|
434
|
+
}
|
|
328
435
|
|
|
329
|
-
|
|
330
|
-
|
|
436
|
+
// add a regularization term to B'B
|
|
437
|
+
for (int64_t i = 0; i < M * K; i++) {
|
|
438
|
+
bb[i * (M * K) + i] += lambd;
|
|
439
|
+
}
|
|
331
440
|
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
441
|
+
// compute (B'B)^(-1)
|
|
442
|
+
dmat_inverse(bb.data(), M * K); // [M*K, M*K]
|
|
443
|
+
|
|
444
|
+
// compute BX
|
|
445
|
+
for (size_t i = 0; i < n; i++) {
|
|
446
|
+
for (size_t m = 0; m < M; m++) {
|
|
447
|
+
int32_t code = codes[i * M + m];
|
|
448
|
+
double* data = bx.data() + (m * K + code) * d;
|
|
449
|
+
dfvec_add(d, data, x + i * d, data);
|
|
450
|
+
}
|
|
338
451
|
}
|
|
339
|
-
}
|
|
340
452
|
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
453
|
+
// compute C = (B'B)^(-1) @ BX
|
|
454
|
+
//
|
|
455
|
+
// NOTE: LAPACK use column major order
|
|
456
|
+
// out = alpha * op(A) * op(B) + beta * C
|
|
457
|
+
FINTEGER nrows_A = d;
|
|
458
|
+
FINTEGER ncols_A = M * K;
|
|
459
|
+
|
|
460
|
+
FINTEGER nrows_B = M * K;
|
|
461
|
+
FINTEGER ncols_B = M * K;
|
|
462
|
+
|
|
463
|
+
std::vector<double> d_codebooks(M * K * d);
|
|
464
|
+
|
|
465
|
+
double alpha = 1.0f;
|
|
466
|
+
double beta = 0.0f;
|
|
467
|
+
dgemm_("Not Transposed",
|
|
468
|
+
"Not Transposed",
|
|
469
|
+
&nrows_A, // nrows of op(A)
|
|
470
|
+
&ncols_B, // ncols of op(B)
|
|
471
|
+
&ncols_A, // ncols of op(A)
|
|
472
|
+
&alpha,
|
|
473
|
+
bx.data(),
|
|
474
|
+
&nrows_A, // nrows of A
|
|
475
|
+
bb.data(),
|
|
476
|
+
&nrows_B, // nrows of B
|
|
477
|
+
&beta,
|
|
478
|
+
d_codebooks.data(),
|
|
479
|
+
&nrows_A); // nrows of output
|
|
480
|
+
|
|
481
|
+
for (size_t i = 0; i < M * K * d; i++) {
|
|
482
|
+
codebooks[i] = (float)d_codebooks[i];
|
|
483
|
+
}
|
|
484
|
+
}
|
|
368
485
|
}
|
|
369
486
|
|
|
370
487
|
/** encode using iterative conditional mode
|
|
@@ -386,15 +503,23 @@ void LocalSearchQuantizer::update_codebooks(
|
|
|
386
503
|
* These two terms can be precomputed and store in a look up table.
|
|
387
504
|
*/
|
|
388
505
|
void LocalSearchQuantizer::icm_encode(
|
|
389
|
-
const float* x,
|
|
390
506
|
int32_t* codes,
|
|
507
|
+
const float* x,
|
|
391
508
|
size_t n,
|
|
392
509
|
size_t ils_iters,
|
|
393
510
|
std::mt19937& gen) const {
|
|
394
|
-
lsq_timer
|
|
511
|
+
LSQTimerScope scope(&lsq_timer, "icm_encode");
|
|
512
|
+
|
|
513
|
+
auto factory = icm_encoder_factory;
|
|
514
|
+
std::unique_ptr<lsq::IcmEncoder> icm_encoder;
|
|
515
|
+
if (factory == nullptr) {
|
|
516
|
+
icm_encoder.reset(lsq::IcmEncoderFactory().get(this));
|
|
517
|
+
} else {
|
|
518
|
+
icm_encoder.reset(factory->get(this));
|
|
519
|
+
}
|
|
395
520
|
|
|
396
|
-
|
|
397
|
-
|
|
521
|
+
// precompute binary terms for all chunks
|
|
522
|
+
icm_encoder->set_binary_term();
|
|
398
523
|
|
|
399
524
|
const size_t n_chunks = (n + chunk_size - 1) / chunk_size;
|
|
400
525
|
for (size_t i = 0; i < n_chunks; i++) {
|
|
@@ -410,21 +535,20 @@ void LocalSearchQuantizer::icm_encode(
|
|
|
410
535
|
|
|
411
536
|
const float* xi = x + i * chunk_size * d;
|
|
412
537
|
int32_t* codesi = codes + i * chunk_size * M;
|
|
413
|
-
|
|
538
|
+
icm_encoder->verbose = (verbose && i == 0);
|
|
539
|
+
icm_encoder->encode(codesi, xi, gen, ni, ils_iters);
|
|
414
540
|
}
|
|
415
|
-
|
|
416
|
-
lsq_timer.end("icm_encode");
|
|
417
541
|
}
|
|
418
542
|
|
|
419
|
-
void LocalSearchQuantizer::
|
|
420
|
-
size_t index,
|
|
421
|
-
const float* x,
|
|
543
|
+
void LocalSearchQuantizer::icm_encode_impl(
|
|
422
544
|
int32_t* codes,
|
|
423
|
-
|
|
545
|
+
const float* x,
|
|
424
546
|
const float* binaries,
|
|
547
|
+
std::mt19937& gen,
|
|
548
|
+
size_t n,
|
|
425
549
|
size_t ils_iters,
|
|
426
|
-
|
|
427
|
-
std::vector<float> unaries(n * M * K); // [
|
|
550
|
+
bool verbose) const {
|
|
551
|
+
std::vector<float> unaries(n * M * K); // [M, n, K]
|
|
428
552
|
compute_unary_terms(x, unaries.data(), n);
|
|
429
553
|
|
|
430
554
|
std::vector<int32_t> best_codes;
|
|
@@ -438,9 +562,7 @@ void LocalSearchQuantizer::icm_encode_partial(
|
|
|
438
562
|
// add perturbation to codes
|
|
439
563
|
perturb_codes(codes, n, gen);
|
|
440
564
|
|
|
441
|
-
|
|
442
|
-
icm_encode_step(unaries.data(), binaries, codes, n);
|
|
443
|
-
}
|
|
565
|
+
icm_encode_step(codes, unaries.data(), binaries, n, icm_iters);
|
|
444
566
|
|
|
445
567
|
std::vector<float> icm_objs(n, 0.0f);
|
|
446
568
|
evaluate(codes, x, n, icm_objs.data());
|
|
@@ -463,7 +585,7 @@ void LocalSearchQuantizer::icm_encode_partial(
|
|
|
463
585
|
|
|
464
586
|
memcpy(codes, best_codes.data(), sizeof(int32_t) * n * M);
|
|
465
587
|
|
|
466
|
-
if (verbose
|
|
588
|
+
if (verbose) {
|
|
467
589
|
printf("\tils_iter %zd: obj = %lf, n_betters/n = %zd/%zd\n",
|
|
468
590
|
iter1,
|
|
469
591
|
mean_obj,
|
|
@@ -474,61 +596,67 @@ void LocalSearchQuantizer::icm_encode_partial(
|
|
|
474
596
|
}
|
|
475
597
|
|
|
476
598
|
void LocalSearchQuantizer::icm_encode_step(
|
|
599
|
+
int32_t* codes,
|
|
477
600
|
const float* unaries,
|
|
478
601
|
const float* binaries,
|
|
479
|
-
|
|
480
|
-
size_t
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
std::vector<float> objs(n * K);
|
|
484
|
-
#pragma omp parallel for
|
|
485
|
-
for (int64_t i = 0; i < n; i++) {
|
|
486
|
-
auto u = unaries + i * (M * K) + m * K;
|
|
487
|
-
memcpy(objs.data() + i * K, u, sizeof(float) * K);
|
|
488
|
-
}
|
|
602
|
+
size_t n,
|
|
603
|
+
size_t n_iters) const {
|
|
604
|
+
FAISS_THROW_IF_NOT(M != 0 && K != 0);
|
|
605
|
+
FAISS_THROW_IF_NOT(binaries != nullptr);
|
|
489
606
|
|
|
490
|
-
|
|
491
|
-
//
|
|
492
|
-
for (size_t
|
|
493
|
-
|
|
494
|
-
|
|
607
|
+
for (size_t iter = 0; iter < n_iters; iter++) {
|
|
608
|
+
// condition on the m-th subcode
|
|
609
|
+
for (size_t m = 0; m < M; m++) {
|
|
610
|
+
std::vector<float> objs(n * K);
|
|
611
|
+
#pragma omp parallel for
|
|
612
|
+
for (int64_t i = 0; i < n; i++) {
|
|
613
|
+
auto u = unaries + m * n * K + i * K;
|
|
614
|
+
memcpy(objs.data() + i * K, u, sizeof(float) * K);
|
|
495
615
|
}
|
|
496
616
|
|
|
617
|
+
// compute objective function by adding unary
|
|
618
|
+
// and binary terms together
|
|
619
|
+
for (size_t other_m = 0; other_m < M; other_m++) {
|
|
620
|
+
if (other_m == m) {
|
|
621
|
+
continue;
|
|
622
|
+
}
|
|
623
|
+
|
|
497
624
|
#pragma omp parallel for
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
625
|
+
for (int64_t i = 0; i < n; i++) {
|
|
626
|
+
for (int32_t code = 0; code < K; code++) {
|
|
627
|
+
int32_t code2 = codes[i * M + other_m];
|
|
628
|
+
size_t binary_idx = m * M * K * K + other_m * K * K +
|
|
629
|
+
code * K + code2;
|
|
630
|
+
// binaries[m, other_m, code, code2]
|
|
631
|
+
objs[i * K + code] += binaries[binary_idx];
|
|
632
|
+
}
|
|
505
633
|
}
|
|
506
634
|
}
|
|
507
|
-
}
|
|
508
635
|
|
|
509
|
-
|
|
636
|
+
// find the optimal value of the m-th subcode
|
|
510
637
|
#pragma omp parallel for
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
638
|
+
for (int64_t i = 0; i < n; i++) {
|
|
639
|
+
float best_obj = HUGE_VALF;
|
|
640
|
+
int32_t best_code = 0;
|
|
641
|
+
for (size_t code = 0; code < K; code++) {
|
|
642
|
+
float obj = objs[i * K + code];
|
|
643
|
+
if (obj < best_obj) {
|
|
644
|
+
best_obj = obj;
|
|
645
|
+
best_code = code;
|
|
646
|
+
}
|
|
519
647
|
}
|
|
648
|
+
codes[i * M + m] = best_code;
|
|
520
649
|
}
|
|
521
|
-
codes[i * M + m] = best_code;
|
|
522
|
-
}
|
|
523
650
|
|
|
524
|
-
|
|
651
|
+
} // loop M
|
|
652
|
+
}
|
|
525
653
|
}
|
|
526
654
|
|
|
527
655
|
void LocalSearchQuantizer::perturb_codes(
|
|
528
656
|
int32_t* codes,
|
|
529
657
|
size_t n,
|
|
530
658
|
std::mt19937& gen) const {
|
|
531
|
-
lsq_timer
|
|
659
|
+
LSQTimerScope scope(&lsq_timer, "perturb_codes");
|
|
532
660
|
|
|
533
661
|
std::uniform_int_distribution<size_t> m_distrib(0, M - 1);
|
|
534
662
|
std::uniform_int_distribution<int32_t> k_distrib(0, K - 1);
|
|
@@ -539,12 +667,10 @@ void LocalSearchQuantizer::perturb_codes(
|
|
|
539
667
|
codes[i * M + m] = k_distrib(gen);
|
|
540
668
|
}
|
|
541
669
|
}
|
|
542
|
-
|
|
543
|
-
lsq_timer.end("perturb_codes");
|
|
544
670
|
}
|
|
545
671
|
|
|
546
672
|
void LocalSearchQuantizer::compute_binary_terms(float* binaries) const {
|
|
547
|
-
lsq_timer
|
|
673
|
+
LSQTimerScope scope(&lsq_timer, "compute_binary_terms");
|
|
548
674
|
|
|
549
675
|
#pragma omp parallel for
|
|
550
676
|
for (int64_t m12 = 0; m12 < M * M; m12++) {
|
|
@@ -562,52 +688,53 @@ void LocalSearchQuantizer::compute_binary_terms(float* binaries) const {
|
|
|
562
688
|
}
|
|
563
689
|
}
|
|
564
690
|
}
|
|
565
|
-
|
|
566
|
-
lsq_timer.end("compute_binary_terms");
|
|
567
691
|
}
|
|
568
692
|
|
|
569
693
|
void LocalSearchQuantizer::compute_unary_terms(
|
|
570
694
|
const float* x,
|
|
571
|
-
float* unaries,
|
|
695
|
+
float* unaries, // [M, n, K]
|
|
572
696
|
size_t n) const {
|
|
573
|
-
lsq_timer
|
|
697
|
+
LSQTimerScope scope(&lsq_timer, "compute_unary_terms");
|
|
574
698
|
|
|
575
|
-
// compute x *
|
|
699
|
+
// compute x * codebook^T for each codebook
|
|
576
700
|
//
|
|
577
701
|
// NOTE: LAPACK use column major order
|
|
578
702
|
// out = alpha * op(A) * op(B) + beta * C
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
703
|
+
|
|
704
|
+
for (size_t m = 0; m < M; m++) {
|
|
705
|
+
FINTEGER nrows_A = K;
|
|
706
|
+
FINTEGER ncols_A = d;
|
|
707
|
+
|
|
708
|
+
FINTEGER nrows_B = d;
|
|
709
|
+
FINTEGER ncols_B = n;
|
|
710
|
+
|
|
711
|
+
float alpha = -2.0f;
|
|
712
|
+
float beta = 0.0f;
|
|
713
|
+
sgemm_("Transposed",
|
|
714
|
+
"Not Transposed",
|
|
715
|
+
&nrows_A, // nrows of op(A)
|
|
716
|
+
&ncols_B, // ncols of op(B)
|
|
717
|
+
&ncols_A, // ncols of op(A)
|
|
718
|
+
&alpha,
|
|
719
|
+
codebooks.data() + m * K * d,
|
|
720
|
+
&ncols_A, // nrows of A
|
|
721
|
+
x,
|
|
722
|
+
&nrows_B, // nrows of B
|
|
723
|
+
&beta,
|
|
724
|
+
unaries + m * n * K,
|
|
725
|
+
&nrows_A); // nrows of output
|
|
726
|
+
}
|
|
600
727
|
|
|
601
728
|
std::vector<float> norms(M * K);
|
|
602
729
|
fvec_norms_L2sqr(norms.data(), codebooks.data(), d, M * K);
|
|
603
730
|
|
|
604
731
|
#pragma omp parallel for
|
|
605
732
|
for (int64_t i = 0; i < n; i++) {
|
|
606
|
-
|
|
607
|
-
|
|
733
|
+
for (size_t m = 0; m < M; m++) {
|
|
734
|
+
float* u = unaries + m * n * K + i * K;
|
|
735
|
+
fvec_add(K, u, norms.data() + m * K, u);
|
|
736
|
+
}
|
|
608
737
|
}
|
|
609
|
-
|
|
610
|
-
lsq_timer.end("compute_unary_terms");
|
|
611
738
|
}
|
|
612
739
|
|
|
613
740
|
float LocalSearchQuantizer::evaluate(
|
|
@@ -615,7 +742,7 @@ float LocalSearchQuantizer::evaluate(
|
|
|
615
742
|
const float* x,
|
|
616
743
|
size_t n,
|
|
617
744
|
float* objs) const {
|
|
618
|
-
lsq_timer
|
|
745
|
+
LSQTimerScope scope(&lsq_timer, "evaluate");
|
|
619
746
|
|
|
620
747
|
// decode
|
|
621
748
|
std::vector<float> decoded_x(n * d, 0.0f);
|
|
@@ -631,7 +758,7 @@ float LocalSearchQuantizer::evaluate(
|
|
|
631
758
|
fvec_add(d, decoded_i, c, decoded_i);
|
|
632
759
|
}
|
|
633
760
|
|
|
634
|
-
float err = fvec_L2sqr(x + i * d, decoded_i, d);
|
|
761
|
+
float err = faiss::fvec_L2sqr(x + i * d, decoded_i, d);
|
|
635
762
|
obj += err;
|
|
636
763
|
|
|
637
764
|
if (objs) {
|
|
@@ -639,34 +766,68 @@ float LocalSearchQuantizer::evaluate(
|
|
|
639
766
|
}
|
|
640
767
|
}
|
|
641
768
|
|
|
642
|
-
lsq_timer.end("evaluate");
|
|
643
|
-
|
|
644
769
|
obj = obj / n;
|
|
645
770
|
return obj;
|
|
646
771
|
}
|
|
647
772
|
|
|
648
|
-
|
|
649
|
-
|
|
773
|
+
namespace lsq {
|
|
774
|
+
|
|
775
|
+
IcmEncoder::IcmEncoder(const LocalSearchQuantizer* lsq)
|
|
776
|
+
: verbose(false), lsq(lsq) {}
|
|
777
|
+
|
|
778
|
+
void IcmEncoder::set_binary_term() {
|
|
779
|
+
auto M = lsq->M;
|
|
780
|
+
auto K = lsq->K;
|
|
781
|
+
binaries.resize(M * M * K * K);
|
|
782
|
+
lsq->compute_binary_terms(binaries.data());
|
|
650
783
|
}
|
|
651
784
|
|
|
652
|
-
void
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
785
|
+
void IcmEncoder::encode(
|
|
786
|
+
int32_t* codes,
|
|
787
|
+
const float* x,
|
|
788
|
+
std::mt19937& gen,
|
|
789
|
+
size_t n,
|
|
790
|
+
size_t ils_iters) const {
|
|
791
|
+
lsq->icm_encode_impl(codes, x, binaries.data(), gen, n, ils_iters, verbose);
|
|
656
792
|
}
|
|
657
793
|
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
794
|
+
double LSQTimer::get(const std::string& name) {
|
|
795
|
+
if (t.count(name) == 0) {
|
|
796
|
+
return 0.0;
|
|
797
|
+
} else {
|
|
798
|
+
return t[name];
|
|
799
|
+
}
|
|
800
|
+
}
|
|
801
|
+
|
|
802
|
+
void LSQTimer::add(const std::string& name, double delta) {
|
|
803
|
+
if (t.count(name) == 0) {
|
|
804
|
+
t[name] = delta;
|
|
805
|
+
} else {
|
|
806
|
+
t[name] += delta;
|
|
807
|
+
}
|
|
664
808
|
}
|
|
665
809
|
|
|
666
810
|
void LSQTimer::reset() {
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
811
|
+
t.clear();
|
|
812
|
+
}
|
|
813
|
+
|
|
814
|
+
LSQTimerScope::LSQTimerScope(LSQTimer* timer, std::string name)
|
|
815
|
+
: timer(timer), name(name), finished(false) {
|
|
816
|
+
t0 = getmillisecs();
|
|
670
817
|
}
|
|
671
818
|
|
|
819
|
+
void LSQTimerScope::finish() {
|
|
820
|
+
if (!finished) {
|
|
821
|
+
auto delta = getmillisecs() - t0;
|
|
822
|
+
timer->add(name, delta);
|
|
823
|
+
finished = true;
|
|
824
|
+
}
|
|
825
|
+
}
|
|
826
|
+
|
|
827
|
+
LSQTimerScope::~LSQTimerScope() {
|
|
828
|
+
finish();
|
|
829
|
+
}
|
|
830
|
+
|
|
831
|
+
} // namespace lsq
|
|
832
|
+
|
|
672
833
|
} // namespace faiss
|