faiss 0.2.3 → 0.2.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/LICENSE.txt +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/Clustering.cpp +32 -0
- data/vendor/faiss/faiss/Clustering.h +14 -0
- data/vendor/faiss/faiss/Index.h +1 -1
- data/vendor/faiss/faiss/Index2Layer.cpp +19 -92
- data/vendor/faiss/faiss/Index2Layer.h +2 -16
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
- data/vendor/faiss/faiss/{IndexResidual.h → IndexAdditiveQuantizer.h} +101 -58
- data/vendor/faiss/faiss/IndexFlat.cpp +22 -52
- data/vendor/faiss/faiss/IndexFlat.h +9 -15
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
- data/vendor/faiss/faiss/IndexIVF.cpp +79 -7
- data/vendor/faiss/faiss/IndexIVF.h +25 -7
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +9 -12
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +5 -4
- data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +60 -39
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +21 -6
- data/vendor/faiss/faiss/IndexLSH.cpp +4 -30
- data/vendor/faiss/faiss/IndexLSH.h +2 -15
- data/vendor/faiss/faiss/IndexNNDescent.cpp +0 -2
- data/vendor/faiss/faiss/IndexNSG.cpp +0 -2
- data/vendor/faiss/faiss/IndexPQ.cpp +2 -51
- data/vendor/faiss/faiss/IndexPQ.h +2 -17
- data/vendor/faiss/faiss/IndexRefine.cpp +28 -0
- data/vendor/faiss/faiss/IndexRefine.h +10 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -28
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -16
- data/vendor/faiss/faiss/VectorTransform.cpp +2 -1
- data/vendor/faiss/faiss/VectorTransform.h +3 -0
- data/vendor/faiss/faiss/clone_index.cpp +3 -2
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -2
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +257 -24
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +69 -9
- data/vendor/faiss/faiss/impl/HNSW.cpp +10 -5
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +393 -210
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +100 -28
- data/vendor/faiss/faiss/impl/NSG.cpp +0 -3
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +357 -47
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +65 -7
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +12 -19
- data/vendor/faiss/faiss/impl/index_read.cpp +102 -19
- data/vendor/faiss/faiss/impl/index_write.cpp +66 -16
- data/vendor/faiss/faiss/impl/io.cpp +1 -1
- data/vendor/faiss/faiss/impl/io_macros.h +20 -0
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
- data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
- data/vendor/faiss/faiss/index_factory.cpp +585 -414
- data/vendor/faiss/faiss/index_factory.h +3 -0
- data/vendor/faiss/faiss/utils/distances.cpp +4 -2
- data/vendor/faiss/faiss/utils/distances.h +36 -3
- data/vendor/faiss/faiss/utils/distances_simd.cpp +50 -0
- data/vendor/faiss/faiss/utils/utils.h +1 -1
- metadata +12 -5
- data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
@@ -5,9 +5,6 @@
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
6
6
|
*/
|
7
7
|
|
8
|
-
// -*- c++ -*-
|
9
|
-
|
10
|
-
#include <faiss/impl/FaissAssert.h>
|
11
8
|
#include <faiss/impl/LocalSearchQuantizer.h>
|
12
9
|
|
13
10
|
#include <cstddef>
|
@@ -18,6 +15,9 @@
|
|
18
15
|
|
19
16
|
#include <algorithm>
|
20
17
|
|
18
|
+
#include <faiss/Clustering.h>
|
19
|
+
#include <faiss/impl/AuxIndexStructures.h>
|
20
|
+
#include <faiss/impl/FaissAssert.h>
|
21
21
|
#include <faiss/utils/distances.h>
|
22
22
|
#include <faiss/utils/hamming.h> // BitstringWriter
|
23
23
|
#include <faiss/utils/utils.h>
|
@@ -42,18 +42,6 @@ void sgetri_(
|
|
42
42
|
FINTEGER* lwork,
|
43
43
|
FINTEGER* info);
|
44
44
|
|
45
|
-
// solves a system of linear equations
|
46
|
-
void sgetrs_(
|
47
|
-
const char* trans,
|
48
|
-
FINTEGER* n,
|
49
|
-
FINTEGER* nrhs,
|
50
|
-
float* A,
|
51
|
-
FINTEGER* lda,
|
52
|
-
FINTEGER* ipiv,
|
53
|
-
float* b,
|
54
|
-
FINTEGER* ldb,
|
55
|
-
FINTEGER* info);
|
56
|
-
|
57
45
|
// general matrix multiplication
|
58
46
|
int sgemm_(
|
59
47
|
const char* transa,
|
@@ -69,26 +57,73 @@ int sgemm_(
|
|
69
57
|
float* beta,
|
70
58
|
float* c,
|
71
59
|
FINTEGER* ldc);
|
60
|
+
|
61
|
+
// LU decomoposition of a general matrix
|
62
|
+
void dgetrf_(
|
63
|
+
FINTEGER* m,
|
64
|
+
FINTEGER* n,
|
65
|
+
double* a,
|
66
|
+
FINTEGER* lda,
|
67
|
+
FINTEGER* ipiv,
|
68
|
+
FINTEGER* info);
|
69
|
+
|
70
|
+
// generate inverse of a matrix given its LU decomposition
|
71
|
+
void dgetri_(
|
72
|
+
FINTEGER* n,
|
73
|
+
double* a,
|
74
|
+
FINTEGER* lda,
|
75
|
+
FINTEGER* ipiv,
|
76
|
+
double* work,
|
77
|
+
FINTEGER* lwork,
|
78
|
+
FINTEGER* info);
|
79
|
+
|
80
|
+
// general matrix multiplication
|
81
|
+
int dgemm_(
|
82
|
+
const char* transa,
|
83
|
+
const char* transb,
|
84
|
+
FINTEGER* m,
|
85
|
+
FINTEGER* n,
|
86
|
+
FINTEGER* k,
|
87
|
+
const double* alpha,
|
88
|
+
const double* a,
|
89
|
+
FINTEGER* lda,
|
90
|
+
const double* b,
|
91
|
+
FINTEGER* ldb,
|
92
|
+
double* beta,
|
93
|
+
double* c,
|
94
|
+
FINTEGER* ldc);
|
72
95
|
}
|
73
96
|
|
74
97
|
namespace {
|
75
98
|
|
99
|
+
void fmat_inverse(float* a, int n) {
|
100
|
+
int info;
|
101
|
+
int lwork = n * n;
|
102
|
+
std::vector<int> ipiv(n);
|
103
|
+
std::vector<float> workspace(lwork);
|
104
|
+
|
105
|
+
sgetrf_(&n, &n, a, &n, ipiv.data(), &info);
|
106
|
+
FAISS_THROW_IF_NOT(info == 0);
|
107
|
+
sgetri_(&n, a, &n, ipiv.data(), workspace.data(), &lwork, &info);
|
108
|
+
FAISS_THROW_IF_NOT(info == 0);
|
109
|
+
}
|
110
|
+
|
76
111
|
// c and a and b can overlap
|
77
|
-
void
|
112
|
+
void dfvec_add(size_t d, const double* a, const float* b, double* c) {
|
78
113
|
for (size_t i = 0; i < d; i++) {
|
79
114
|
c[i] = a[i] + b[i];
|
80
115
|
}
|
81
116
|
}
|
82
117
|
|
83
|
-
void
|
118
|
+
void dmat_inverse(double* a, int n) {
|
84
119
|
int info;
|
85
120
|
int lwork = n * n;
|
86
121
|
std::vector<int> ipiv(n);
|
87
|
-
std::vector<
|
122
|
+
std::vector<double> workspace(lwork);
|
88
123
|
|
89
|
-
|
124
|
+
dgetrf_(&n, &n, a, &n, ipiv.data(), &info);
|
90
125
|
FAISS_THROW_IF_NOT(info == 0);
|
91
|
-
|
126
|
+
dgetri_(&n, a, &n, ipiv.data(), workspace.data(), &lwork, &info);
|
92
127
|
FAISS_THROW_IF_NOT(info == 0);
|
93
128
|
}
|
94
129
|
|
@@ -107,18 +142,15 @@ void random_int32(
|
|
107
142
|
|
108
143
|
namespace faiss {
|
109
144
|
|
110
|
-
LSQTimer lsq_timer;
|
111
|
-
|
112
|
-
LocalSearchQuantizer::LocalSearchQuantizer(size_t d, size_t M, size_t nbits) {
|
113
|
-
FAISS_THROW_IF_NOT((M * nbits) % 8 == 0);
|
114
|
-
|
115
|
-
this->d = d;
|
116
|
-
this->M = M;
|
117
|
-
this->nbits = std::vector<size_t>(M, nbits);
|
118
|
-
|
119
|
-
// set derived values
|
120
|
-
set_derived_values();
|
145
|
+
lsq::LSQTimer lsq_timer;
|
146
|
+
using lsq::LSQTimerScope;
|
121
147
|
|
148
|
+
LocalSearchQuantizer::LocalSearchQuantizer(
|
149
|
+
size_t d,
|
150
|
+
size_t M,
|
151
|
+
size_t nbits,
|
152
|
+
Search_type_t search_type)
|
153
|
+
: AdditiveQuantizer(d, std::vector<size_t>(M, nbits), search_type) {
|
122
154
|
is_trained = false;
|
123
155
|
verbose = false;
|
124
156
|
|
@@ -138,15 +170,23 @@ LocalSearchQuantizer::LocalSearchQuantizer(size_t d, size_t M, size_t nbits) {
|
|
138
170
|
|
139
171
|
random_seed = 0x12345;
|
140
172
|
std::srand(random_seed);
|
173
|
+
|
174
|
+
icm_encoder_factory = nullptr;
|
141
175
|
}
|
142
176
|
|
177
|
+
LocalSearchQuantizer::~LocalSearchQuantizer() {
|
178
|
+
delete icm_encoder_factory;
|
179
|
+
}
|
180
|
+
|
181
|
+
LocalSearchQuantizer::LocalSearchQuantizer() : LocalSearchQuantizer(0, 0, 0) {}
|
182
|
+
|
143
183
|
void LocalSearchQuantizer::train(size_t n, const float* x) {
|
144
184
|
FAISS_THROW_IF_NOT(K == (1 << nbits[0]));
|
145
185
|
FAISS_THROW_IF_NOT(nperts <= M);
|
146
186
|
|
147
187
|
lsq_timer.reset();
|
188
|
+
LSQTimerScope scope(&lsq_timer, "train");
|
148
189
|
if (verbose) {
|
149
|
-
lsq_timer.start("train");
|
150
190
|
printf("Training LSQ, with %zd subcodes on %zd %zdD vectors\n",
|
151
191
|
M,
|
152
192
|
n,
|
@@ -209,7 +249,7 @@ void LocalSearchQuantizer::train(size_t n, const float* x) {
|
|
209
249
|
}
|
210
250
|
|
211
251
|
// refine codes
|
212
|
-
icm_encode(
|
252
|
+
icm_encode(codes.data(), x, n, train_ils_iters, gen);
|
213
253
|
|
214
254
|
if (verbose) {
|
215
255
|
float obj = evaluate(codes.data(), x, n);
|
@@ -217,25 +257,52 @@ void LocalSearchQuantizer::train(size_t n, const float* x) {
|
|
217
257
|
}
|
218
258
|
}
|
219
259
|
|
260
|
+
is_trained = true;
|
261
|
+
{
|
262
|
+
std::vector<float> x_recons(n * d);
|
263
|
+
std::vector<float> norms(n);
|
264
|
+
decode_unpacked(codes.data(), x_recons.data(), n);
|
265
|
+
fvec_norms_L2sqr(norms.data(), x_recons.data(), d, n);
|
266
|
+
|
267
|
+
norm_min = HUGE_VALF;
|
268
|
+
norm_max = -HUGE_VALF;
|
269
|
+
for (idx_t i = 0; i < n; i++) {
|
270
|
+
if (norms[i] < norm_min) {
|
271
|
+
norm_min = norms[i];
|
272
|
+
}
|
273
|
+
if (norms[i] > norm_max) {
|
274
|
+
norm_max = norms[i];
|
275
|
+
}
|
276
|
+
}
|
277
|
+
|
278
|
+
if (search_type == ST_norm_cqint8 || search_type == ST_norm_cqint4) {
|
279
|
+
size_t k = (1 << 8);
|
280
|
+
if (search_type == ST_norm_cqint4) {
|
281
|
+
k = (1 << 4);
|
282
|
+
}
|
283
|
+
Clustering1D clus(k);
|
284
|
+
clus.train_exact(n, norms.data());
|
285
|
+
qnorm.add(clus.k, clus.centroids.data());
|
286
|
+
}
|
287
|
+
}
|
288
|
+
|
220
289
|
if (verbose) {
|
221
|
-
lsq_timer.end("train");
|
222
290
|
float obj = evaluate(codes.data(), x, n);
|
291
|
+
scope.finish();
|
223
292
|
printf("After training: obj = %lf\n", obj);
|
224
293
|
|
225
294
|
printf("Time statistic:\n");
|
226
|
-
for (const auto& it : lsq_timer.
|
227
|
-
printf("\t%s time: %lf s\n", it.first.data(), it.second);
|
295
|
+
for (const auto& it : lsq_timer.t) {
|
296
|
+
printf("\t%s time: %lf s\n", it.first.data(), it.second / 1000);
|
228
297
|
}
|
229
298
|
}
|
230
|
-
|
231
|
-
is_trained = true;
|
232
299
|
}
|
233
300
|
|
234
301
|
void LocalSearchQuantizer::perturb_codebooks(
|
235
302
|
float T,
|
236
303
|
const std::vector<float>& stddev,
|
237
304
|
std::mt19937& gen) {
|
238
|
-
lsq_timer
|
305
|
+
LSQTimerScope scope(&lsq_timer, "perturb_codebooks");
|
239
306
|
|
240
307
|
std::vector<std::normal_distribution<float>> distribs;
|
241
308
|
for (size_t i = 0; i < d; i++) {
|
@@ -249,8 +316,6 @@ void LocalSearchQuantizer::perturb_codebooks(
|
|
249
316
|
}
|
250
317
|
}
|
251
318
|
}
|
252
|
-
|
253
|
-
lsq_timer.end("perturb_codebooks");
|
254
319
|
}
|
255
320
|
|
256
321
|
void LocalSearchQuantizer::compute_codes(
|
@@ -258,23 +323,26 @@ void LocalSearchQuantizer::compute_codes(
|
|
258
323
|
uint8_t* codes_out,
|
259
324
|
size_t n) const {
|
260
325
|
FAISS_THROW_IF_NOT_MSG(is_trained, "LSQ is not trained yet.");
|
326
|
+
|
327
|
+
lsq_timer.reset();
|
328
|
+
LSQTimerScope scope(&lsq_timer, "encode");
|
261
329
|
if (verbose) {
|
262
|
-
lsq_timer.reset();
|
263
330
|
printf("Encoding %zd vectors...\n", n);
|
264
|
-
lsq_timer.start("encode");
|
265
331
|
}
|
266
332
|
|
267
333
|
std::vector<int32_t> codes(n * M);
|
268
334
|
std::mt19937 gen(random_seed);
|
269
335
|
random_int32(codes, 0, K - 1, gen);
|
270
336
|
|
271
|
-
icm_encode(
|
337
|
+
icm_encode(codes.data(), x, n, encode_ils_iters, gen);
|
272
338
|
pack_codes(n, codes.data(), codes_out);
|
273
339
|
|
274
340
|
if (verbose) {
|
275
|
-
|
276
|
-
|
277
|
-
|
341
|
+
scope.finish();
|
342
|
+
printf("Time statistic:\n");
|
343
|
+
for (const auto& it : lsq_timer.t) {
|
344
|
+
printf("\t%s time: %lf s\n", it.first.data(), it.second / 1000);
|
345
|
+
}
|
278
346
|
}
|
279
347
|
}
|
280
348
|
|
@@ -298,73 +366,144 @@ void LocalSearchQuantizer::update_codebooks(
|
|
298
366
|
const float* x,
|
299
367
|
const int32_t* codes,
|
300
368
|
size_t n) {
|
301
|
-
lsq_timer
|
369
|
+
LSQTimerScope scope(&lsq_timer, "update_codebooks");
|
370
|
+
|
371
|
+
if (!update_codebooks_with_double) {
|
372
|
+
// allocate memory
|
373
|
+
// bb = B'B, bx = BX
|
374
|
+
std::vector<float> bb(M * K * M * K, 0.0f); // [M * K, M * K]
|
375
|
+
std::vector<float> bx(M * K * d, 0.0f); // [M * K, d]
|
376
|
+
|
377
|
+
// compute B'B
|
378
|
+
for (size_t i = 0; i < n; i++) {
|
379
|
+
for (size_t m = 0; m < M; m++) {
|
380
|
+
int32_t code1 = codes[i * M + m];
|
381
|
+
int32_t idx1 = m * K + code1;
|
382
|
+
bb[idx1 * M * K + idx1] += 1;
|
383
|
+
|
384
|
+
for (size_t m2 = m + 1; m2 < M; m2++) {
|
385
|
+
int32_t code2 = codes[i * M + m2];
|
386
|
+
int32_t idx2 = m2 * K + code2;
|
387
|
+
bb[idx1 * M * K + idx2] += 1;
|
388
|
+
bb[idx2 * M * K + idx1] += 1;
|
389
|
+
}
|
390
|
+
}
|
391
|
+
}
|
302
392
|
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
393
|
+
// add a regularization term to B'B
|
394
|
+
for (int64_t i = 0; i < M * K; i++) {
|
395
|
+
bb[i * (M * K) + i] += lambd;
|
396
|
+
}
|
307
397
|
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
int32_t idx2 = m2 * K + code2;
|
318
|
-
bb[idx1 * M * K + idx2] += 1;
|
319
|
-
bb[idx2 * M * K + idx1] += 1;
|
398
|
+
// compute (B'B)^(-1)
|
399
|
+
fmat_inverse(bb.data(), M * K); // [M*K, M*K]
|
400
|
+
|
401
|
+
// compute BX
|
402
|
+
for (size_t i = 0; i < n; i++) {
|
403
|
+
for (size_t m = 0; m < M; m++) {
|
404
|
+
int32_t code = codes[i * M + m];
|
405
|
+
float* data = bx.data() + (m * K + code) * d;
|
406
|
+
fvec_add(d, data, x + i * d, data);
|
320
407
|
}
|
321
408
|
}
|
322
|
-
}
|
323
409
|
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
410
|
+
// compute C = (B'B)^(-1) @ BX
|
411
|
+
//
|
412
|
+
// NOTE: LAPACK use column major order
|
413
|
+
// out = alpha * op(A) * op(B) + beta * C
|
414
|
+
FINTEGER nrows_A = d;
|
415
|
+
FINTEGER ncols_A = M * K;
|
416
|
+
|
417
|
+
FINTEGER nrows_B = M * K;
|
418
|
+
FINTEGER ncols_B = M * K;
|
419
|
+
|
420
|
+
float alpha = 1.0f;
|
421
|
+
float beta = 0.0f;
|
422
|
+
sgemm_("Not Transposed",
|
423
|
+
"Not Transposed",
|
424
|
+
&nrows_A, // nrows of op(A)
|
425
|
+
&ncols_B, // ncols of op(B)
|
426
|
+
&ncols_A, // ncols of op(A)
|
427
|
+
&alpha,
|
428
|
+
bx.data(),
|
429
|
+
&nrows_A, // nrows of A
|
430
|
+
bb.data(),
|
431
|
+
&nrows_B, // nrows of B
|
432
|
+
&beta,
|
433
|
+
codebooks.data(),
|
434
|
+
&nrows_A); // nrows of output
|
435
|
+
|
436
|
+
} else {
|
437
|
+
// allocate memory
|
438
|
+
// bb = B'B, bx = BX
|
439
|
+
std::vector<double> bb(M * K * M * K, 0.0f); // [M * K, M * K]
|
440
|
+
std::vector<double> bx(M * K * d, 0.0f); // [M * K, d]
|
441
|
+
|
442
|
+
// compute B'B
|
443
|
+
for (size_t i = 0; i < n; i++) {
|
444
|
+
for (size_t m = 0; m < M; m++) {
|
445
|
+
int32_t code1 = codes[i * M + m];
|
446
|
+
int32_t idx1 = m * K + code1;
|
447
|
+
bb[idx1 * M * K + idx1] += 1;
|
448
|
+
|
449
|
+
for (size_t m2 = m + 1; m2 < M; m2++) {
|
450
|
+
int32_t code2 = codes[i * M + m2];
|
451
|
+
int32_t idx2 = m2 * K + code2;
|
452
|
+
bb[idx1 * M * K + idx2] += 1;
|
453
|
+
bb[idx2 * M * K + idx1] += 1;
|
454
|
+
}
|
455
|
+
}
|
456
|
+
}
|
328
457
|
|
329
|
-
|
330
|
-
|
458
|
+
// add a regularization term to B'B
|
459
|
+
for (int64_t i = 0; i < M * K; i++) {
|
460
|
+
bb[i * (M * K) + i] += lambd;
|
461
|
+
}
|
331
462
|
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
463
|
+
// compute (B'B)^(-1)
|
464
|
+
dmat_inverse(bb.data(), M * K); // [M*K, M*K]
|
465
|
+
|
466
|
+
// compute BX
|
467
|
+
for (size_t i = 0; i < n; i++) {
|
468
|
+
for (size_t m = 0; m < M; m++) {
|
469
|
+
int32_t code = codes[i * M + m];
|
470
|
+
double* data = bx.data() + (m * K + code) * d;
|
471
|
+
dfvec_add(d, data, x + i * d, data);
|
472
|
+
}
|
338
473
|
}
|
339
|
-
}
|
340
474
|
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
475
|
+
// compute C = (B'B)^(-1) @ BX
|
476
|
+
//
|
477
|
+
// NOTE: LAPACK use column major order
|
478
|
+
// out = alpha * op(A) * op(B) + beta * C
|
479
|
+
FINTEGER nrows_A = d;
|
480
|
+
FINTEGER ncols_A = M * K;
|
481
|
+
|
482
|
+
FINTEGER nrows_B = M * K;
|
483
|
+
FINTEGER ncols_B = M * K;
|
484
|
+
|
485
|
+
std::vector<double> d_codebooks(M * K * d);
|
486
|
+
|
487
|
+
double alpha = 1.0f;
|
488
|
+
double beta = 0.0f;
|
489
|
+
dgemm_("Not Transposed",
|
490
|
+
"Not Transposed",
|
491
|
+
&nrows_A, // nrows of op(A)
|
492
|
+
&ncols_B, // ncols of op(B)
|
493
|
+
&ncols_A, // ncols of op(A)
|
494
|
+
&alpha,
|
495
|
+
bx.data(),
|
496
|
+
&nrows_A, // nrows of A
|
497
|
+
bb.data(),
|
498
|
+
&nrows_B, // nrows of B
|
499
|
+
&beta,
|
500
|
+
d_codebooks.data(),
|
501
|
+
&nrows_A); // nrows of output
|
502
|
+
|
503
|
+
for (size_t i = 0; i < M * K * d; i++) {
|
504
|
+
codebooks[i] = (float)d_codebooks[i];
|
505
|
+
}
|
506
|
+
}
|
368
507
|
}
|
369
508
|
|
370
509
|
/** encode using iterative conditional mode
|
@@ -386,15 +525,23 @@ void LocalSearchQuantizer::update_codebooks(
|
|
386
525
|
* These two terms can be precomputed and store in a look up table.
|
387
526
|
*/
|
388
527
|
void LocalSearchQuantizer::icm_encode(
|
389
|
-
const float* x,
|
390
528
|
int32_t* codes,
|
529
|
+
const float* x,
|
391
530
|
size_t n,
|
392
531
|
size_t ils_iters,
|
393
532
|
std::mt19937& gen) const {
|
394
|
-
lsq_timer
|
533
|
+
LSQTimerScope scope(&lsq_timer, "icm_encode");
|
534
|
+
|
535
|
+
auto factory = icm_encoder_factory;
|
536
|
+
std::unique_ptr<lsq::IcmEncoder> icm_encoder;
|
537
|
+
if (factory == nullptr) {
|
538
|
+
icm_encoder.reset(lsq::IcmEncoderFactory().get(this));
|
539
|
+
} else {
|
540
|
+
icm_encoder.reset(factory->get(this));
|
541
|
+
}
|
395
542
|
|
396
|
-
|
397
|
-
|
543
|
+
// precompute binary terms for all chunks
|
544
|
+
icm_encoder->set_binary_term();
|
398
545
|
|
399
546
|
const size_t n_chunks = (n + chunk_size - 1) / chunk_size;
|
400
547
|
for (size_t i = 0; i < n_chunks; i++) {
|
@@ -410,21 +557,20 @@ void LocalSearchQuantizer::icm_encode(
|
|
410
557
|
|
411
558
|
const float* xi = x + i * chunk_size * d;
|
412
559
|
int32_t* codesi = codes + i * chunk_size * M;
|
413
|
-
|
560
|
+
icm_encoder->verbose = (verbose && i == 0);
|
561
|
+
icm_encoder->encode(codesi, xi, gen, ni, ils_iters);
|
414
562
|
}
|
415
|
-
|
416
|
-
lsq_timer.end("icm_encode");
|
417
563
|
}
|
418
564
|
|
419
|
-
void LocalSearchQuantizer::
|
420
|
-
size_t index,
|
421
|
-
const float* x,
|
565
|
+
void LocalSearchQuantizer::icm_encode_impl(
|
422
566
|
int32_t* codes,
|
423
|
-
|
567
|
+
const float* x,
|
424
568
|
const float* binaries,
|
569
|
+
std::mt19937& gen,
|
570
|
+
size_t n,
|
425
571
|
size_t ils_iters,
|
426
|
-
|
427
|
-
std::vector<float> unaries(n * M * K); // [
|
572
|
+
bool verbose) const {
|
573
|
+
std::vector<float> unaries(n * M * K); // [M, n, K]
|
428
574
|
compute_unary_terms(x, unaries.data(), n);
|
429
575
|
|
430
576
|
std::vector<int32_t> best_codes;
|
@@ -438,9 +584,7 @@ void LocalSearchQuantizer::icm_encode_partial(
|
|
438
584
|
// add perturbation to codes
|
439
585
|
perturb_codes(codes, n, gen);
|
440
586
|
|
441
|
-
|
442
|
-
icm_encode_step(unaries.data(), binaries, codes, n);
|
443
|
-
}
|
587
|
+
icm_encode_step(codes, unaries.data(), binaries, n, icm_iters);
|
444
588
|
|
445
589
|
std::vector<float> icm_objs(n, 0.0f);
|
446
590
|
evaluate(codes, x, n, icm_objs.data());
|
@@ -463,7 +607,7 @@ void LocalSearchQuantizer::icm_encode_partial(
|
|
463
607
|
|
464
608
|
memcpy(codes, best_codes.data(), sizeof(int32_t) * n * M);
|
465
609
|
|
466
|
-
if (verbose
|
610
|
+
if (verbose) {
|
467
611
|
printf("\tils_iter %zd: obj = %lf, n_betters/n = %zd/%zd\n",
|
468
612
|
iter1,
|
469
613
|
mean_obj,
|
@@ -474,61 +618,67 @@ void LocalSearchQuantizer::icm_encode_partial(
|
|
474
618
|
}
|
475
619
|
|
476
620
|
void LocalSearchQuantizer::icm_encode_step(
|
621
|
+
int32_t* codes,
|
477
622
|
const float* unaries,
|
478
623
|
const float* binaries,
|
479
|
-
|
480
|
-
size_t
|
481
|
-
|
482
|
-
|
483
|
-
std::vector<float> objs(n * K);
|
484
|
-
#pragma omp parallel for
|
485
|
-
for (int64_t i = 0; i < n; i++) {
|
486
|
-
auto u = unaries + i * (M * K) + m * K;
|
487
|
-
memcpy(objs.data() + i * K, u, sizeof(float) * K);
|
488
|
-
}
|
624
|
+
size_t n,
|
625
|
+
size_t n_iters) const {
|
626
|
+
FAISS_THROW_IF_NOT(M != 0 && K != 0);
|
627
|
+
FAISS_THROW_IF_NOT(binaries != nullptr);
|
489
628
|
|
490
|
-
|
491
|
-
//
|
492
|
-
for (size_t
|
493
|
-
|
494
|
-
|
629
|
+
for (size_t iter = 0; iter < n_iters; iter++) {
|
630
|
+
// condition on the m-th subcode
|
631
|
+
for (size_t m = 0; m < M; m++) {
|
632
|
+
std::vector<float> objs(n * K);
|
633
|
+
#pragma omp parallel for
|
634
|
+
for (int64_t i = 0; i < n; i++) {
|
635
|
+
auto u = unaries + m * n * K + i * K;
|
636
|
+
memcpy(objs.data() + i * K, u, sizeof(float) * K);
|
495
637
|
}
|
496
638
|
|
639
|
+
// compute objective function by adding unary
|
640
|
+
// and binary terms together
|
641
|
+
for (size_t other_m = 0; other_m < M; other_m++) {
|
642
|
+
if (other_m == m) {
|
643
|
+
continue;
|
644
|
+
}
|
645
|
+
|
497
646
|
#pragma omp parallel for
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
647
|
+
for (int64_t i = 0; i < n; i++) {
|
648
|
+
for (int32_t code = 0; code < K; code++) {
|
649
|
+
int32_t code2 = codes[i * M + other_m];
|
650
|
+
size_t binary_idx = m * M * K * K + other_m * K * K +
|
651
|
+
code * K + code2;
|
652
|
+
// binaries[m, other_m, code, code2]
|
653
|
+
objs[i * K + code] += binaries[binary_idx];
|
654
|
+
}
|
505
655
|
}
|
506
656
|
}
|
507
|
-
}
|
508
657
|
|
509
|
-
|
658
|
+
// find the optimal value of the m-th subcode
|
510
659
|
#pragma omp parallel for
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
660
|
+
for (int64_t i = 0; i < n; i++) {
|
661
|
+
float best_obj = HUGE_VALF;
|
662
|
+
int32_t best_code = 0;
|
663
|
+
for (size_t code = 0; code < K; code++) {
|
664
|
+
float obj = objs[i * K + code];
|
665
|
+
if (obj < best_obj) {
|
666
|
+
best_obj = obj;
|
667
|
+
best_code = code;
|
668
|
+
}
|
519
669
|
}
|
670
|
+
codes[i * M + m] = best_code;
|
520
671
|
}
|
521
|
-
codes[i * M + m] = best_code;
|
522
|
-
}
|
523
672
|
|
524
|
-
|
673
|
+
} // loop M
|
674
|
+
}
|
525
675
|
}
|
526
676
|
|
527
677
|
void LocalSearchQuantizer::perturb_codes(
|
528
678
|
int32_t* codes,
|
529
679
|
size_t n,
|
530
680
|
std::mt19937& gen) const {
|
531
|
-
lsq_timer
|
681
|
+
LSQTimerScope scope(&lsq_timer, "perturb_codes");
|
532
682
|
|
533
683
|
std::uniform_int_distribution<size_t> m_distrib(0, M - 1);
|
534
684
|
std::uniform_int_distribution<int32_t> k_distrib(0, K - 1);
|
@@ -539,12 +689,10 @@ void LocalSearchQuantizer::perturb_codes(
|
|
539
689
|
codes[i * M + m] = k_distrib(gen);
|
540
690
|
}
|
541
691
|
}
|
542
|
-
|
543
|
-
lsq_timer.end("perturb_codes");
|
544
692
|
}
|
545
693
|
|
546
694
|
void LocalSearchQuantizer::compute_binary_terms(float* binaries) const {
|
547
|
-
lsq_timer
|
695
|
+
LSQTimerScope scope(&lsq_timer, "compute_binary_terms");
|
548
696
|
|
549
697
|
#pragma omp parallel for
|
550
698
|
for (int64_t m12 = 0; m12 < M * M; m12++) {
|
@@ -562,52 +710,53 @@ void LocalSearchQuantizer::compute_binary_terms(float* binaries) const {
|
|
562
710
|
}
|
563
711
|
}
|
564
712
|
}
|
565
|
-
|
566
|
-
lsq_timer.end("compute_binary_terms");
|
567
713
|
}
|
568
714
|
|
569
715
|
void LocalSearchQuantizer::compute_unary_terms(
|
570
716
|
const float* x,
|
571
|
-
float* unaries,
|
717
|
+
float* unaries, // [M, n, K]
|
572
718
|
size_t n) const {
|
573
|
-
lsq_timer
|
719
|
+
LSQTimerScope scope(&lsq_timer, "compute_unary_terms");
|
574
720
|
|
575
|
-
// compute x *
|
721
|
+
// compute x * codebook^T for each codebook
|
576
722
|
//
|
577
723
|
// NOTE: LAPACK use column major order
|
578
724
|
// out = alpha * op(A) * op(B) + beta * C
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
725
|
+
|
726
|
+
for (size_t m = 0; m < M; m++) {
|
727
|
+
FINTEGER nrows_A = K;
|
728
|
+
FINTEGER ncols_A = d;
|
729
|
+
|
730
|
+
FINTEGER nrows_B = d;
|
731
|
+
FINTEGER ncols_B = n;
|
732
|
+
|
733
|
+
float alpha = -2.0f;
|
734
|
+
float beta = 0.0f;
|
735
|
+
sgemm_("Transposed",
|
736
|
+
"Not Transposed",
|
737
|
+
&nrows_A, // nrows of op(A)
|
738
|
+
&ncols_B, // ncols of op(B)
|
739
|
+
&ncols_A, // ncols of op(A)
|
740
|
+
&alpha,
|
741
|
+
codebooks.data() + m * K * d,
|
742
|
+
&ncols_A, // nrows of A
|
743
|
+
x,
|
744
|
+
&nrows_B, // nrows of B
|
745
|
+
&beta,
|
746
|
+
unaries + m * n * K,
|
747
|
+
&nrows_A); // nrows of output
|
748
|
+
}
|
600
749
|
|
601
750
|
std::vector<float> norms(M * K);
|
602
751
|
fvec_norms_L2sqr(norms.data(), codebooks.data(), d, M * K);
|
603
752
|
|
604
753
|
#pragma omp parallel for
|
605
754
|
for (int64_t i = 0; i < n; i++) {
|
606
|
-
|
607
|
-
|
755
|
+
for (size_t m = 0; m < M; m++) {
|
756
|
+
float* u = unaries + m * n * K + i * K;
|
757
|
+
fvec_add(K, u, norms.data() + m * K, u);
|
758
|
+
}
|
608
759
|
}
|
609
|
-
|
610
|
-
lsq_timer.end("compute_unary_terms");
|
611
760
|
}
|
612
761
|
|
613
762
|
float LocalSearchQuantizer::evaluate(
|
@@ -615,7 +764,7 @@ float LocalSearchQuantizer::evaluate(
|
|
615
764
|
const float* x,
|
616
765
|
size_t n,
|
617
766
|
float* objs) const {
|
618
|
-
lsq_timer
|
767
|
+
LSQTimerScope scope(&lsq_timer, "evaluate");
|
619
768
|
|
620
769
|
// decode
|
621
770
|
std::vector<float> decoded_x(n * d, 0.0f);
|
@@ -631,7 +780,7 @@ float LocalSearchQuantizer::evaluate(
|
|
631
780
|
fvec_add(d, decoded_i, c, decoded_i);
|
632
781
|
}
|
633
782
|
|
634
|
-
float err = fvec_L2sqr(x + i * d, decoded_i, d);
|
783
|
+
float err = faiss::fvec_L2sqr(x + i * d, decoded_i, d);
|
635
784
|
obj += err;
|
636
785
|
|
637
786
|
if (objs) {
|
@@ -639,34 +788,68 @@ float LocalSearchQuantizer::evaluate(
|
|
639
788
|
}
|
640
789
|
}
|
641
790
|
|
642
|
-
lsq_timer.end("evaluate");
|
643
|
-
|
644
791
|
obj = obj / n;
|
645
792
|
return obj;
|
646
793
|
}
|
647
794
|
|
648
|
-
|
649
|
-
|
795
|
+
namespace lsq {
|
796
|
+
|
797
|
+
IcmEncoder::IcmEncoder(const LocalSearchQuantizer* lsq)
|
798
|
+
: verbose(false), lsq(lsq) {}
|
799
|
+
|
800
|
+
void IcmEncoder::set_binary_term() {
|
801
|
+
auto M = lsq->M;
|
802
|
+
auto K = lsq->K;
|
803
|
+
binaries.resize(M * M * K * K);
|
804
|
+
lsq->compute_binary_terms(binaries.data());
|
650
805
|
}
|
651
806
|
|
652
|
-
void
|
653
|
-
|
654
|
-
|
655
|
-
|
807
|
+
void IcmEncoder::encode(
|
808
|
+
int32_t* codes,
|
809
|
+
const float* x,
|
810
|
+
std::mt19937& gen,
|
811
|
+
size_t n,
|
812
|
+
size_t ils_iters) const {
|
813
|
+
lsq->icm_encode_impl(codes, x, binaries.data(), gen, n, ils_iters, verbose);
|
656
814
|
}
|
657
815
|
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
816
|
+
double LSQTimer::get(const std::string& name) {
|
817
|
+
if (t.count(name) == 0) {
|
818
|
+
return 0.0;
|
819
|
+
} else {
|
820
|
+
return t[name];
|
821
|
+
}
|
822
|
+
}
|
823
|
+
|
824
|
+
void LSQTimer::add(const std::string& name, double delta) {
|
825
|
+
if (t.count(name) == 0) {
|
826
|
+
t[name] = delta;
|
827
|
+
} else {
|
828
|
+
t[name] += delta;
|
829
|
+
}
|
664
830
|
}
|
665
831
|
|
666
832
|
void LSQTimer::reset() {
|
667
|
-
|
668
|
-
|
669
|
-
|
833
|
+
t.clear();
|
834
|
+
}
|
835
|
+
|
836
|
+
LSQTimerScope::LSQTimerScope(LSQTimer* timer, std::string name)
|
837
|
+
: timer(timer), name(name), finished(false) {
|
838
|
+
t0 = getmillisecs();
|
670
839
|
}
|
671
840
|
|
841
|
+
void LSQTimerScope::finish() {
|
842
|
+
if (!finished) {
|
843
|
+
auto delta = getmillisecs() - t0;
|
844
|
+
timer->add(name, delta);
|
845
|
+
finished = true;
|
846
|
+
}
|
847
|
+
}
|
848
|
+
|
849
|
+
LSQTimerScope::~LSQTimerScope() {
|
850
|
+
finish();
|
851
|
+
}
|
852
|
+
|
853
|
+
} // namespace lsq
|
854
|
+
|
672
855
|
} // namespace faiss
|