faiss 0.2.3 → 0.2.4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/LICENSE.txt +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/Clustering.cpp +32 -0
- data/vendor/faiss/faiss/Clustering.h +14 -0
- data/vendor/faiss/faiss/Index.h +1 -1
- data/vendor/faiss/faiss/Index2Layer.cpp +19 -92
- data/vendor/faiss/faiss/Index2Layer.h +2 -16
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
- data/vendor/faiss/faiss/{IndexResidual.h → IndexAdditiveQuantizer.h} +101 -58
- data/vendor/faiss/faiss/IndexFlat.cpp +22 -52
- data/vendor/faiss/faiss/IndexFlat.h +9 -15
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
- data/vendor/faiss/faiss/IndexIVF.cpp +79 -7
- data/vendor/faiss/faiss/IndexIVF.h +25 -7
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +9 -12
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +5 -4
- data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +60 -39
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +21 -6
- data/vendor/faiss/faiss/IndexLSH.cpp +4 -30
- data/vendor/faiss/faiss/IndexLSH.h +2 -15
- data/vendor/faiss/faiss/IndexNNDescent.cpp +0 -2
- data/vendor/faiss/faiss/IndexNSG.cpp +0 -2
- data/vendor/faiss/faiss/IndexPQ.cpp +2 -51
- data/vendor/faiss/faiss/IndexPQ.h +2 -17
- data/vendor/faiss/faiss/IndexRefine.cpp +28 -0
- data/vendor/faiss/faiss/IndexRefine.h +10 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -28
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -16
- data/vendor/faiss/faiss/VectorTransform.cpp +2 -1
- data/vendor/faiss/faiss/VectorTransform.h +3 -0
- data/vendor/faiss/faiss/clone_index.cpp +3 -2
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -2
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +257 -24
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +69 -9
- data/vendor/faiss/faiss/impl/HNSW.cpp +10 -5
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +393 -210
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +100 -28
- data/vendor/faiss/faiss/impl/NSG.cpp +0 -3
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +357 -47
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +65 -7
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +12 -19
- data/vendor/faiss/faiss/impl/index_read.cpp +102 -19
- data/vendor/faiss/faiss/impl/index_write.cpp +66 -16
- data/vendor/faiss/faiss/impl/io.cpp +1 -1
- data/vendor/faiss/faiss/impl/io_macros.h +20 -0
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
- data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
- data/vendor/faiss/faiss/index_factory.cpp +585 -414
- data/vendor/faiss/faiss/index_factory.h +3 -0
- data/vendor/faiss/faiss/utils/distances.cpp +4 -2
- data/vendor/faiss/faiss/utils/distances.h +36 -3
- data/vendor/faiss/faiss/utils/distances_simd.cpp +50 -0
- data/vendor/faiss/faiss/utils/utils.h +1 -1
- metadata +12 -5
- data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
@@ -25,9 +25,11 @@
|
|
25
25
|
#include <faiss/invlists/InvertedListsIOHook.h>
|
26
26
|
|
27
27
|
#include <faiss/Index2Layer.h>
|
28
|
+
#include <faiss/IndexAdditiveQuantizer.h>
|
28
29
|
#include <faiss/IndexFlat.h>
|
29
30
|
#include <faiss/IndexHNSW.h>
|
30
31
|
#include <faiss/IndexIVF.h>
|
32
|
+
#include <faiss/IndexIVFAdditiveQuantizer.h>
|
31
33
|
#include <faiss/IndexIVFFlat.h>
|
32
34
|
#include <faiss/IndexIVFPQ.h>
|
33
35
|
#include <faiss/IndexIVFPQFastScan.h>
|
@@ -40,7 +42,6 @@
|
|
40
42
|
#include <faiss/IndexPQFastScan.h>
|
41
43
|
#include <faiss/IndexPreTransform.h>
|
42
44
|
#include <faiss/IndexRefine.h>
|
43
|
-
#include <faiss/IndexResidual.h>
|
44
45
|
#include <faiss/IndexScalarQuantizer.h>
|
45
46
|
#include <faiss/MetaIndexes.h>
|
46
47
|
#include <faiss/VectorTransform.h>
|
@@ -77,16 +78,22 @@ VectorTransform* read_VectorTransform(IOReader* f) {
|
|
77
78
|
VectorTransform* vt = nullptr;
|
78
79
|
|
79
80
|
if (h == fourcc("rrot") || h == fourcc("PCAm") || h == fourcc("LTra") ||
|
80
|
-
h == fourcc("PcAm") || h == fourcc("Viqm")) {
|
81
|
+
h == fourcc("PcAm") || h == fourcc("Viqm") || h == fourcc("Pcam")) {
|
81
82
|
LinearTransform* lt = nullptr;
|
82
83
|
if (h == fourcc("rrot")) {
|
83
84
|
lt = new RandomRotationMatrix();
|
84
|
-
} else if (
|
85
|
+
} else if (
|
86
|
+
h == fourcc("PCAm") || h == fourcc("PcAm") ||
|
87
|
+
h == fourcc("Pcam")) {
|
85
88
|
PCAMatrix* pca = new PCAMatrix();
|
86
89
|
READ1(pca->eigen_power);
|
90
|
+
if (h == fourcc("Pcam")) {
|
91
|
+
READ1(pca->epsilon);
|
92
|
+
}
|
87
93
|
READ1(pca->random_rotation);
|
88
|
-
if (h
|
94
|
+
if (h != fourcc("PCAm")) {
|
89
95
|
READ1(pca->balanced_bins);
|
96
|
+
}
|
90
97
|
READVECTOR(pca->mean);
|
91
98
|
READVECTOR(pca->eigenvalues);
|
92
99
|
READVECTOR(pca->PCAMat);
|
@@ -139,9 +146,10 @@ VectorTransform* read_VectorTransform(IOReader* f) {
|
|
139
146
|
vt = itqt;
|
140
147
|
} else {
|
141
148
|
FAISS_THROW_FMT(
|
142
|
-
"fourcc %ud (\"%s\") not recognized",
|
149
|
+
"fourcc %ud (\"%s\") not recognized in %s",
|
143
150
|
h,
|
144
|
-
fourcc_inv_printable(h).c_str()
|
151
|
+
fourcc_inv_printable(h).c_str(),
|
152
|
+
f->name.c_str());
|
145
153
|
}
|
146
154
|
READ1(vt->d_in);
|
147
155
|
READ1(vt->d_out);
|
@@ -239,15 +247,58 @@ static void read_ProductQuantizer(ProductQuantizer* pq, IOReader* f) {
|
|
239
247
|
READVECTOR(pq->centroids);
|
240
248
|
}
|
241
249
|
|
242
|
-
static void
|
250
|
+
static void read_ResidualQuantizer_old(ResidualQuantizer* rq, IOReader* f) {
|
243
251
|
READ1(rq->d);
|
244
252
|
READ1(rq->M);
|
245
253
|
READVECTOR(rq->nbits);
|
246
|
-
rq->set_derived_values();
|
247
254
|
READ1(rq->is_trained);
|
248
255
|
READ1(rq->train_type);
|
249
256
|
READ1(rq->max_beam_size);
|
250
257
|
READVECTOR(rq->codebooks);
|
258
|
+
READ1(rq->search_type);
|
259
|
+
READ1(rq->norm_min);
|
260
|
+
READ1(rq->norm_max);
|
261
|
+
rq->set_derived_values();
|
262
|
+
}
|
263
|
+
|
264
|
+
static void read_AdditiveQuantizer(AdditiveQuantizer* aq, IOReader* f) {
|
265
|
+
READ1(aq->d);
|
266
|
+
READ1(aq->M);
|
267
|
+
READVECTOR(aq->nbits);
|
268
|
+
READ1(aq->is_trained);
|
269
|
+
READVECTOR(aq->codebooks);
|
270
|
+
READ1(aq->search_type);
|
271
|
+
READ1(aq->norm_min);
|
272
|
+
READ1(aq->norm_max);
|
273
|
+
if (aq->search_type == AdditiveQuantizer::ST_norm_cqint8 ||
|
274
|
+
aq->search_type == AdditiveQuantizer::ST_norm_cqint4) {
|
275
|
+
READXBVECTOR(aq->qnorm.codes);
|
276
|
+
}
|
277
|
+
aq->set_derived_values();
|
278
|
+
}
|
279
|
+
|
280
|
+
static void read_ResidualQuantizer(ResidualQuantizer* rq, IOReader* f) {
|
281
|
+
read_AdditiveQuantizer(rq, f);
|
282
|
+
READ1(rq->train_type);
|
283
|
+
READ1(rq->max_beam_size);
|
284
|
+
if (!(rq->train_type & ResidualQuantizer::Skip_codebook_tables)) {
|
285
|
+
rq->compute_codebook_tables();
|
286
|
+
}
|
287
|
+
}
|
288
|
+
|
289
|
+
static void read_LocalSearchQuantizer(LocalSearchQuantizer* lsq, IOReader* f) {
|
290
|
+
read_AdditiveQuantizer(lsq, f);
|
291
|
+
READ1(lsq->K);
|
292
|
+
READ1(lsq->train_iters);
|
293
|
+
READ1(lsq->encode_ils_iters);
|
294
|
+
READ1(lsq->train_ils_iters);
|
295
|
+
READ1(lsq->icm_iters);
|
296
|
+
READ1(lsq->p);
|
297
|
+
READ1(lsq->lambd);
|
298
|
+
READ1(lsq->chunk_size);
|
299
|
+
READ1(lsq->random_seed);
|
300
|
+
READ1(lsq->nperts);
|
301
|
+
READ1(lsq->update_codebooks_with_double);
|
251
302
|
}
|
252
303
|
|
253
304
|
static void read_ScalarQuantizer(ScalarQuantizer* ivsc, IOReader* f) {
|
@@ -422,8 +473,10 @@ Index* read_index(IOReader* f, int io_flags) {
|
|
422
473
|
idxf = new IndexFlat();
|
423
474
|
}
|
424
475
|
read_index_header(idxf, f);
|
425
|
-
|
426
|
-
|
476
|
+
idxf->code_size = idxf->d * sizeof(float);
|
477
|
+
READXBVECTOR(idxf->codes);
|
478
|
+
FAISS_THROW_IF_NOT(
|
479
|
+
idxf->codes.size() == idxf->ntotal * idxf->code_size);
|
427
480
|
// leak!
|
428
481
|
idx = idxf;
|
429
482
|
} else if (h == fourcc("IxHE") || h == fourcc("IxHe")) {
|
@@ -433,7 +486,9 @@ Index* read_index(IOReader* f, int io_flags) {
|
|
433
486
|
READ1(idxl->rotate_data);
|
434
487
|
READ1(idxl->train_thresholds);
|
435
488
|
READVECTOR(idxl->thresholds);
|
436
|
-
|
489
|
+
int code_size_i;
|
490
|
+
READ1(code_size_i);
|
491
|
+
idxl->code_size = code_size_i;
|
437
492
|
if (h == fourcc("IxHE")) {
|
438
493
|
FAISS_THROW_IF_NOT_FMT(
|
439
494
|
idxl->nbits % 64 == 0,
|
@@ -441,7 +496,7 @@ Index* read_index(IOReader* f, int io_flags) {
|
|
441
496
|
"nbits multiple of 64 (got %d)",
|
442
497
|
(int)idxl->nbits);
|
443
498
|
// leak
|
444
|
-
idxl->
|
499
|
+
idxl->code_size *= 8;
|
445
500
|
}
|
446
501
|
{
|
447
502
|
RandomRotationMatrix* rrot = dynamic_cast<RandomRotationMatrix*>(
|
@@ -454,7 +509,7 @@ Index* read_index(IOReader* f, int io_flags) {
|
|
454
509
|
FAISS_THROW_IF_NOT(
|
455
510
|
idxl->rrot.d_in == idxl->d && idxl->rrot.d_out == idxl->nbits);
|
456
511
|
FAISS_THROW_IF_NOT(
|
457
|
-
idxl->codes.size() == idxl->ntotal * idxl->
|
512
|
+
idxl->codes.size() == idxl->ntotal * idxl->code_size);
|
458
513
|
idx = idxl;
|
459
514
|
} else if (
|
460
515
|
h == fourcc("IxPQ") || h == fourcc("IxPo") || h == fourcc("IxPq")) {
|
@@ -462,6 +517,7 @@ Index* read_index(IOReader* f, int io_flags) {
|
|
462
517
|
IndexPQ* idxp = new IndexPQ();
|
463
518
|
read_index_header(idxp, f);
|
464
519
|
read_ProductQuantizer(&idxp->pq, f);
|
520
|
+
idxp->code_size = idxp->pq.code_size;
|
465
521
|
READVECTOR(idxp->codes);
|
466
522
|
if (h == fourcc("IxPo") || h == fourcc("IxPq")) {
|
467
523
|
READ1(idxp->search_type);
|
@@ -475,13 +531,21 @@ Index* read_index(IOReader* f, int io_flags) {
|
|
475
531
|
idxp->metric_type = METRIC_L2;
|
476
532
|
}
|
477
533
|
idx = idxp;
|
478
|
-
} else if (h == fourcc("IxRQ")) {
|
479
|
-
|
534
|
+
} else if (h == fourcc("IxRQ") || h == fourcc("IxRq")) {
|
535
|
+
IndexResidualQuantizer* idxr = new IndexResidualQuantizer();
|
480
536
|
read_index_header(idxr, f);
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
537
|
+
if (h == fourcc("IxRQ")) {
|
538
|
+
read_ResidualQuantizer_old(&idxr->rq, f);
|
539
|
+
} else {
|
540
|
+
read_ResidualQuantizer(&idxr->rq, f);
|
541
|
+
}
|
542
|
+
READ1(idxr->code_size);
|
543
|
+
READVECTOR(idxr->codes);
|
544
|
+
idx = idxr;
|
545
|
+
} else if (h == fourcc("IxLS")) {
|
546
|
+
auto idxr = new IndexLocalSearchQuantizer();
|
547
|
+
read_index_header(idxr, f);
|
548
|
+
read_LocalSearchQuantizer(&idxr->lsq, f);
|
485
549
|
READ1(idxr->code_size);
|
486
550
|
READVECTOR(idxr->codes);
|
487
551
|
idx = idxr;
|
@@ -571,6 +635,25 @@ Index* read_index(IOReader* f, int io_flags) {
|
|
571
635
|
}
|
572
636
|
read_InvertedLists(ivsc, f, io_flags);
|
573
637
|
idx = ivsc;
|
638
|
+
} else if (h == fourcc("IwLS") || h == fourcc("IwRQ")) {
|
639
|
+
bool is_LSQ = h == fourcc("IwLS");
|
640
|
+
IndexIVFAdditiveQuantizer* iva;
|
641
|
+
if (is_LSQ) {
|
642
|
+
iva = new IndexIVFLocalSearchQuantizer();
|
643
|
+
} else {
|
644
|
+
iva = new IndexIVFResidualQuantizer();
|
645
|
+
}
|
646
|
+
read_ivf_header(iva, f);
|
647
|
+
READ1(iva->code_size);
|
648
|
+
if (is_LSQ) {
|
649
|
+
read_LocalSearchQuantizer((LocalSearchQuantizer*)iva->aq, f);
|
650
|
+
} else {
|
651
|
+
read_ResidualQuantizer((ResidualQuantizer*)iva->aq, f);
|
652
|
+
}
|
653
|
+
READ1(iva->by_residual);
|
654
|
+
READ1(iva->use_precomputed_table);
|
655
|
+
read_InvertedLists(iva, f, io_flags);
|
656
|
+
idx = iva;
|
574
657
|
} else if (h == fourcc("IwSh")) {
|
575
658
|
IndexIVFSpectralHash* ivsp = new IndexIVFSpectralHash();
|
576
659
|
read_ivf_header(ivsp, f);
|
@@ -26,9 +26,11 @@
|
|
26
26
|
#include <faiss/utils/hamming.h>
|
27
27
|
|
28
28
|
#include <faiss/Index2Layer.h>
|
29
|
+
#include <faiss/IndexAdditiveQuantizer.h>
|
29
30
|
#include <faiss/IndexFlat.h>
|
30
31
|
#include <faiss/IndexHNSW.h>
|
31
32
|
#include <faiss/IndexIVF.h>
|
33
|
+
#include <faiss/IndexIVFAdditiveQuantizer.h>
|
32
34
|
#include <faiss/IndexIVFFlat.h>
|
33
35
|
#include <faiss/IndexIVFPQ.h>
|
34
36
|
#include <faiss/IndexIVFPQFastScan.h>
|
@@ -41,7 +43,6 @@
|
|
41
43
|
#include <faiss/IndexPQFastScan.h>
|
42
44
|
#include <faiss/IndexPreTransform.h>
|
43
45
|
#include <faiss/IndexRefine.h>
|
44
|
-
#include <faiss/IndexResidual.h>
|
45
46
|
#include <faiss/IndexScalarQuantizer.h>
|
46
47
|
#include <faiss/MetaIndexes.h>
|
47
48
|
#include <faiss/VectorTransform.h>
|
@@ -95,9 +96,10 @@ void write_VectorTransform(const VectorTransform* vt, IOWriter* f) {
|
|
95
96
|
uint32_t h = fourcc("rrot");
|
96
97
|
WRITE1(h);
|
97
98
|
} else if (const PCAMatrix* pca = dynamic_cast<const PCAMatrix*>(lt)) {
|
98
|
-
uint32_t h = fourcc("
|
99
|
+
uint32_t h = fourcc("Pcam");
|
99
100
|
WRITE1(h);
|
100
101
|
WRITE1(pca->eigen_power);
|
102
|
+
WRITE1(pca->epsilon);
|
101
103
|
WRITE1(pca->random_rotation);
|
102
104
|
WRITE1(pca->balanced_bins);
|
103
105
|
WRITEVECTOR(pca->mean);
|
@@ -158,14 +160,42 @@ void write_ProductQuantizer(const ProductQuantizer* pq, IOWriter* f) {
|
|
158
160
|
WRITEVECTOR(pq->centroids);
|
159
161
|
}
|
160
162
|
|
161
|
-
void
|
162
|
-
WRITE1(
|
163
|
-
WRITE1(
|
164
|
-
WRITEVECTOR(
|
165
|
-
WRITE1(
|
163
|
+
static void write_AdditiveQuantizer(const AdditiveQuantizer* aq, IOWriter* f) {
|
164
|
+
WRITE1(aq->d);
|
165
|
+
WRITE1(aq->M);
|
166
|
+
WRITEVECTOR(aq->nbits);
|
167
|
+
WRITE1(aq->is_trained);
|
168
|
+
WRITEVECTOR(aq->codebooks);
|
169
|
+
WRITE1(aq->search_type);
|
170
|
+
WRITE1(aq->norm_min);
|
171
|
+
WRITE1(aq->norm_max);
|
172
|
+
if (aq->search_type == AdditiveQuantizer::ST_norm_cqint8 ||
|
173
|
+
aq->search_type == AdditiveQuantizer::ST_norm_cqint4) {
|
174
|
+
WRITEXBVECTOR(aq->qnorm.codes);
|
175
|
+
}
|
176
|
+
}
|
177
|
+
|
178
|
+
static void write_ResidualQuantizer(const ResidualQuantizer* rq, IOWriter* f) {
|
179
|
+
write_AdditiveQuantizer(rq, f);
|
166
180
|
WRITE1(rq->train_type);
|
167
181
|
WRITE1(rq->max_beam_size);
|
168
|
-
|
182
|
+
}
|
183
|
+
|
184
|
+
static void write_LocalSearchQuantizer(
|
185
|
+
const LocalSearchQuantizer* lsq,
|
186
|
+
IOWriter* f) {
|
187
|
+
write_AdditiveQuantizer(lsq, f);
|
188
|
+
WRITE1(lsq->K);
|
189
|
+
WRITE1(lsq->train_iters);
|
190
|
+
WRITE1(lsq->encode_ils_iters);
|
191
|
+
WRITE1(lsq->train_ils_iters);
|
192
|
+
WRITE1(lsq->icm_iters);
|
193
|
+
WRITE1(lsq->p);
|
194
|
+
WRITE1(lsq->lambd);
|
195
|
+
WRITE1(lsq->chunk_size);
|
196
|
+
WRITE1(lsq->random_seed);
|
197
|
+
WRITE1(lsq->nperts);
|
198
|
+
WRITE1(lsq->update_codebooks_with_double);
|
169
199
|
}
|
170
200
|
|
171
201
|
static void write_ScalarQuantizer(const ScalarQuantizer* ivsc, IOWriter* f) {
|
@@ -315,7 +345,7 @@ void write_index(const Index* idx, IOWriter* f) {
|
|
315
345
|
: "IxFl");
|
316
346
|
WRITE1(h);
|
317
347
|
write_index_header(idx, f);
|
318
|
-
|
348
|
+
WRITEXBVECTOR(idxf->codes);
|
319
349
|
} else if (const IndexLSH* idxl = dynamic_cast<const IndexLSH*>(idx)) {
|
320
350
|
uint32_t h = fourcc("IxHe");
|
321
351
|
WRITE1(h);
|
@@ -324,7 +354,8 @@ void write_index(const Index* idx, IOWriter* f) {
|
|
324
354
|
WRITE1(idxl->rotate_data);
|
325
355
|
WRITE1(idxl->train_thresholds);
|
326
356
|
WRITEVECTOR(idxl->thresholds);
|
327
|
-
|
357
|
+
int code_size_i = idxl->code_size;
|
358
|
+
WRITE1(code_size_i);
|
328
359
|
write_VectorTransform(&idxl->rrot, f);
|
329
360
|
WRITEVECTOR(idxl->codes);
|
330
361
|
} else if (const IndexPQ* idxp = dynamic_cast<const IndexPQ*>(idx)) {
|
@@ -338,15 +369,20 @@ void write_index(const Index* idx, IOWriter* f) {
|
|
338
369
|
WRITE1(idxp->encode_signs);
|
339
370
|
WRITE1(idxp->polysemous_ht);
|
340
371
|
} else if (
|
341
|
-
const
|
342
|
-
dynamic_cast<const
|
343
|
-
uint32_t h = fourcc("
|
372
|
+
const IndexResidualQuantizer* idxr =
|
373
|
+
dynamic_cast<const IndexResidualQuantizer*>(idx)) {
|
374
|
+
uint32_t h = fourcc("IxRq");
|
344
375
|
WRITE1(h);
|
345
376
|
write_index_header(idx, f);
|
346
377
|
write_ResidualQuantizer(&idxr->rq, f);
|
347
|
-
WRITE1(idxr->
|
348
|
-
|
349
|
-
|
378
|
+
WRITE1(idxr->code_size);
|
379
|
+
WRITEVECTOR(idxr->codes);
|
380
|
+
} else if (
|
381
|
+
auto* idxr = dynamic_cast<const IndexLocalSearchQuantizer*>(idx)) {
|
382
|
+
uint32_t h = fourcc("IxLS");
|
383
|
+
WRITE1(h);
|
384
|
+
write_index_header(idx, f);
|
385
|
+
write_LocalSearchQuantizer(&idxr->lsq, f);
|
350
386
|
WRITE1(idxr->code_size);
|
351
387
|
WRITEVECTOR(idxr->codes);
|
352
388
|
} else if (
|
@@ -421,6 +457,20 @@ void write_index(const Index* idx, IOWriter* f) {
|
|
421
457
|
WRITE1(ivsc->code_size);
|
422
458
|
WRITE1(ivsc->by_residual);
|
423
459
|
write_InvertedLists(ivsc->invlists, f);
|
460
|
+
} else if (auto iva = dynamic_cast<const IndexIVFAdditiveQuantizer*>(idx)) {
|
461
|
+
bool is_LSQ = dynamic_cast<const IndexIVFLocalSearchQuantizer*>(iva);
|
462
|
+
uint32_t h = fourcc(is_LSQ ? "IwLS" : "IwRQ");
|
463
|
+
WRITE1(h);
|
464
|
+
write_ivf_header(iva, f);
|
465
|
+
WRITE1(iva->code_size);
|
466
|
+
if (is_LSQ) {
|
467
|
+
write_LocalSearchQuantizer((LocalSearchQuantizer*)iva->aq, f);
|
468
|
+
} else {
|
469
|
+
write_ResidualQuantizer((ResidualQuantizer*)iva->aq, f);
|
470
|
+
}
|
471
|
+
WRITE1(iva->by_residual);
|
472
|
+
WRITE1(iva->use_precomputed_table);
|
473
|
+
write_InvertedLists(iva->invlists, f);
|
424
474
|
} else if (
|
425
475
|
const IndexIVFSpectralHash* ivsp =
|
426
476
|
dynamic_cast<const IndexIVFSpectralHash*>(idx)) {
|
@@ -66,3 +66,23 @@
|
|
66
66
|
WRITEANDCHECK(&size, 1); \
|
67
67
|
WRITEANDCHECK((vec).data(), size); \
|
68
68
|
}
|
69
|
+
|
70
|
+
// read/write xb vector for backwards compatibility of IndexFlat
|
71
|
+
|
72
|
+
#define WRITEXBVECTOR(vec) \
|
73
|
+
{ \
|
74
|
+
FAISS_THROW_IF_NOT((vec).size() % 4 == 0); \
|
75
|
+
size_t size = (vec).size() / 4; \
|
76
|
+
WRITEANDCHECK(&size, 1); \
|
77
|
+
WRITEANDCHECK((vec).data(), size * 4); \
|
78
|
+
}
|
79
|
+
|
80
|
+
#define READXBVECTOR(vec) \
|
81
|
+
{ \
|
82
|
+
size_t size; \
|
83
|
+
READANDCHECK(&size, 1); \
|
84
|
+
FAISS_THROW_IF_NOT(size >= 0 && size < (uint64_t{1} << 40)); \
|
85
|
+
size *= 4; \
|
86
|
+
(vec).resize(size); \
|
87
|
+
READANDCHECK((vec).data(), size); \
|
88
|
+
}
|
@@ -0,0 +1,301 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
#include <algorithm>
|
9
|
+
#include <cstdint>
|
10
|
+
#include <cstring>
|
11
|
+
#include <functional>
|
12
|
+
#include <numeric>
|
13
|
+
#include <string>
|
14
|
+
#include <unordered_map>
|
15
|
+
#include <vector>
|
16
|
+
|
17
|
+
#include <faiss/Index.h>
|
18
|
+
#include <faiss/impl/FaissAssert.h>
|
19
|
+
#include <faiss/impl/kmeans1d.h>
|
20
|
+
|
21
|
+
namespace faiss {
|
22
|
+
|
23
|
+
using idx_t = Index::idx_t;
|
24
|
+
using LookUpFunc = std::function<float(idx_t, idx_t)>;
|
25
|
+
|
26
|
+
void reduce(
|
27
|
+
const std::vector<idx_t>& rows,
|
28
|
+
const std::vector<idx_t>& input_cols,
|
29
|
+
const LookUpFunc& lookup,
|
30
|
+
std::vector<idx_t>& output_cols) {
|
31
|
+
for (idx_t col : input_cols) {
|
32
|
+
while (!output_cols.empty()) {
|
33
|
+
idx_t row = rows[output_cols.size() - 1];
|
34
|
+
float a = lookup(row, col);
|
35
|
+
float b = lookup(row, output_cols.back());
|
36
|
+
if (a >= b) { // defeated
|
37
|
+
break;
|
38
|
+
}
|
39
|
+
output_cols.pop_back();
|
40
|
+
}
|
41
|
+
if (output_cols.size() < rows.size()) {
|
42
|
+
output_cols.push_back(col);
|
43
|
+
}
|
44
|
+
}
|
45
|
+
}
|
46
|
+
|
47
|
+
void interpolate(
|
48
|
+
const std::vector<idx_t>& rows,
|
49
|
+
const std::vector<idx_t>& cols,
|
50
|
+
const LookUpFunc& lookup,
|
51
|
+
idx_t* argmins) {
|
52
|
+
std::unordered_map<idx_t, idx_t> idx_to_col;
|
53
|
+
for (idx_t idx = 0; idx < cols.size(); ++idx) {
|
54
|
+
idx_to_col[cols[idx]] = idx;
|
55
|
+
}
|
56
|
+
|
57
|
+
idx_t start = 0;
|
58
|
+
for (idx_t r = 0; r < rows.size(); r += 2) {
|
59
|
+
idx_t row = rows[r];
|
60
|
+
idx_t end = cols.size() - 1;
|
61
|
+
if (r < rows.size() - 1) {
|
62
|
+
idx_t idx = argmins[rows[r + 1]];
|
63
|
+
end = idx_to_col[idx];
|
64
|
+
}
|
65
|
+
idx_t argmin = cols[start];
|
66
|
+
float min = lookup(row, argmin);
|
67
|
+
for (idx_t c = start + 1; c <= end; c++) {
|
68
|
+
float value = lookup(row, cols[c]);
|
69
|
+
if (value < min) {
|
70
|
+
argmin = cols[c];
|
71
|
+
min = value;
|
72
|
+
}
|
73
|
+
}
|
74
|
+
argmins[row] = argmin;
|
75
|
+
start = end;
|
76
|
+
}
|
77
|
+
}
|
78
|
+
|
79
|
+
/** SMAWK algo. Find the row minima of a monotone matrix.
|
80
|
+
*
|
81
|
+
* References:
|
82
|
+
* 1. http://web.cs.unlv.edu/larmore/Courses/CSC477/monge.pdf
|
83
|
+
* 2. https://gist.github.com/dstein64/8e94a6a25efc1335657e910ff525f405
|
84
|
+
* 3. https://github.com/dstein64/kmeans1d
|
85
|
+
*/
|
86
|
+
void smawk_impl(
|
87
|
+
const std::vector<idx_t>& rows,
|
88
|
+
const std::vector<idx_t>& input_cols,
|
89
|
+
const LookUpFunc& lookup,
|
90
|
+
idx_t* argmins) {
|
91
|
+
if (rows.size() == 0) {
|
92
|
+
return;
|
93
|
+
}
|
94
|
+
|
95
|
+
/**********************************
|
96
|
+
* REDUCE
|
97
|
+
**********************************/
|
98
|
+
auto ptr = &input_cols;
|
99
|
+
std::vector<idx_t> survived_cols; // survived columns
|
100
|
+
if (rows.size() < input_cols.size()) {
|
101
|
+
reduce(rows, input_cols, lookup, survived_cols);
|
102
|
+
ptr = &survived_cols;
|
103
|
+
}
|
104
|
+
auto& cols = *ptr; // avoid memory copy
|
105
|
+
|
106
|
+
/**********************************
|
107
|
+
* INTERPOLATE
|
108
|
+
**********************************/
|
109
|
+
|
110
|
+
// call recursively on odd-indexed rows
|
111
|
+
std::vector<idx_t> odd_rows;
|
112
|
+
for (idx_t i = 1; i < rows.size(); i += 2) {
|
113
|
+
odd_rows.push_back(rows[i]);
|
114
|
+
}
|
115
|
+
smawk_impl(odd_rows, cols, lookup, argmins);
|
116
|
+
|
117
|
+
// interpolate the even-indexed rows
|
118
|
+
interpolate(rows, cols, lookup, argmins);
|
119
|
+
}
|
120
|
+
|
121
|
+
void smawk(
|
122
|
+
const idx_t nrows,
|
123
|
+
const idx_t ncols,
|
124
|
+
const LookUpFunc& lookup,
|
125
|
+
idx_t* argmins) {
|
126
|
+
std::vector<idx_t> rows(nrows);
|
127
|
+
std::vector<idx_t> cols(ncols);
|
128
|
+
std::iota(std::begin(rows), std::end(rows), 0);
|
129
|
+
std::iota(std::begin(cols), std::end(cols), 0);
|
130
|
+
|
131
|
+
smawk_impl(rows, cols, lookup, argmins);
|
132
|
+
}
|
133
|
+
|
134
|
+
void smawk(
|
135
|
+
const idx_t nrows,
|
136
|
+
const idx_t ncols,
|
137
|
+
const float* x,
|
138
|
+
idx_t* argmins) {
|
139
|
+
auto lookup = [&x, &ncols](idx_t i, idx_t j) { return x[i * ncols + j]; };
|
140
|
+
smawk(nrows, ncols, lookup, argmins);
|
141
|
+
}
|
142
|
+
|
143
|
+
namespace {
|
144
|
+
|
145
|
+
class CostCalculator {
|
146
|
+
// The reuslt would be inaccurate if we use float
|
147
|
+
std::vector<double> cumsum;
|
148
|
+
std::vector<double> cumsum2;
|
149
|
+
|
150
|
+
public:
|
151
|
+
CostCalculator(const std::vector<float>& vec, idx_t n) {
|
152
|
+
cumsum.push_back(0.0);
|
153
|
+
cumsum2.push_back(0.0);
|
154
|
+
for (idx_t i = 0; i < n; ++i) {
|
155
|
+
float x = vec[i];
|
156
|
+
cumsum.push_back(x + cumsum[i]);
|
157
|
+
cumsum2.push_back(x * x + cumsum2[i]);
|
158
|
+
}
|
159
|
+
}
|
160
|
+
|
161
|
+
float operator()(idx_t i, idx_t j) {
|
162
|
+
if (j < i) {
|
163
|
+
return 0.0f;
|
164
|
+
}
|
165
|
+
auto mu = (cumsum[j + 1] - cumsum[i]) / (j - i + 1);
|
166
|
+
auto result = cumsum2[j + 1] - cumsum2[i];
|
167
|
+
result += (j - i + 1) * (mu * mu);
|
168
|
+
result -= (2 * mu) * (cumsum[j + 1] - cumsum[i]);
|
169
|
+
return float(result);
|
170
|
+
}
|
171
|
+
};
|
172
|
+
|
173
|
+
template <class T>
|
174
|
+
class Matrix {
|
175
|
+
std::vector<T> data;
|
176
|
+
idx_t nrows;
|
177
|
+
idx_t ncols;
|
178
|
+
|
179
|
+
public:
|
180
|
+
Matrix(idx_t nrows, idx_t ncols) {
|
181
|
+
this->nrows = nrows;
|
182
|
+
this->ncols = ncols;
|
183
|
+
data.resize(nrows * ncols);
|
184
|
+
}
|
185
|
+
|
186
|
+
inline T& at(idx_t i, idx_t j) {
|
187
|
+
return data[i * ncols + j];
|
188
|
+
}
|
189
|
+
};
|
190
|
+
|
191
|
+
} // anonymous namespace
|
192
|
+
|
193
|
+
double kmeans1d(const float* x, size_t n, size_t nclusters, float* centroids) {
|
194
|
+
FAISS_THROW_IF_NOT(n >= nclusters);
|
195
|
+
|
196
|
+
// corner case
|
197
|
+
if (n == nclusters) {
|
198
|
+
memcpy(centroids, x, n * sizeof(*x));
|
199
|
+
return 0.0f;
|
200
|
+
}
|
201
|
+
|
202
|
+
/***************************************************
|
203
|
+
* sort in ascending order, O(NlogN) in time
|
204
|
+
***************************************************/
|
205
|
+
std::vector<float> arr(x, x + n);
|
206
|
+
std::sort(arr.begin(), arr.end());
|
207
|
+
|
208
|
+
/***************************************************
|
209
|
+
dynamic programming algorithm
|
210
|
+
|
211
|
+
Reference: https://arxiv.org/abs/1701.07204
|
212
|
+
-------------------------------
|
213
|
+
|
214
|
+
Assume x is already sorted in ascending order.
|
215
|
+
|
216
|
+
N: number of points
|
217
|
+
K: number of clusters
|
218
|
+
|
219
|
+
CC(i, j): the cost of grouping xi,...,xj into one cluster
|
220
|
+
D[k][m]: the cost of optimally clustering x1,...,xm into k clusters
|
221
|
+
T[k][m]: the start index of the k-th cluster
|
222
|
+
|
223
|
+
The DP process is as follow:
|
224
|
+
D[k][m] = min_i D[k − 1][i − 1] + CC(i, m)
|
225
|
+
T[k][m] = argmin_i D[k − 1][i − 1] + CC(i, m)
|
226
|
+
|
227
|
+
This could be solved in O(KN^2) time and O(KN) space.
|
228
|
+
|
229
|
+
To further reduce the time complexity, we use SMAWK algo to
|
230
|
+
solve the argmin problem as follow:
|
231
|
+
|
232
|
+
For each k:
|
233
|
+
C[m][i] = D[k − 1][i − 1] + CC(i, m)
|
234
|
+
|
235
|
+
Here C is a n x n totally monotone matrix.
|
236
|
+
We could find the row minima by SMAWK in O(N) time.
|
237
|
+
|
238
|
+
Now the time complexity is reduced from O(kN^2) to O(KN).
|
239
|
+
****************************************************/
|
240
|
+
|
241
|
+
CostCalculator CC(arr, n);
|
242
|
+
Matrix<float> D(nclusters, n);
|
243
|
+
Matrix<idx_t> T(nclusters, n);
|
244
|
+
|
245
|
+
for (idx_t m = 0; m < n; m++) {
|
246
|
+
D.at(0, m) = CC(0, m);
|
247
|
+
T.at(0, m) = 0;
|
248
|
+
}
|
249
|
+
|
250
|
+
std::vector<idx_t> indices(nclusters, 0);
|
251
|
+
|
252
|
+
for (idx_t k = 1; k < nclusters; ++k) {
|
253
|
+
// we define C here
|
254
|
+
auto C = [&D, &CC, &k](idx_t m, idx_t i) {
|
255
|
+
if (i == 0) {
|
256
|
+
return CC(i, m);
|
257
|
+
}
|
258
|
+
idx_t col = std::min(m, i - 1);
|
259
|
+
return D.at(k - 1, col) + CC(i, m);
|
260
|
+
};
|
261
|
+
|
262
|
+
std::vector<idx_t> argmins(n); // argmin of each row
|
263
|
+
smawk(n, n, C, argmins.data());
|
264
|
+
for (idx_t m = 0; m < argmins.size(); m++) {
|
265
|
+
idx_t idx = argmins[m];
|
266
|
+
D.at(k, m) = C(m, idx);
|
267
|
+
T.at(k, m) = idx;
|
268
|
+
}
|
269
|
+
}
|
270
|
+
|
271
|
+
/***************************************************
|
272
|
+
compute centroids by backtracking
|
273
|
+
|
274
|
+
T[K - 1][T[K][N] - 1] T[K][N] N
|
275
|
+
--------------|------------------------|-----------|
|
276
|
+
| cluster K - 1 | cluster K |
|
277
|
+
|
278
|
+
****************************************************/
|
279
|
+
|
280
|
+
// for imbalance factor
|
281
|
+
double tot = 0.0, uf = 0.0;
|
282
|
+
|
283
|
+
idx_t end = n;
|
284
|
+
for (idx_t k = nclusters - 1; k >= 0; k--) {
|
285
|
+
idx_t start = T.at(k, end - 1);
|
286
|
+
float sum = std::accumulate(&arr[start], &arr[end], 0.0f);
|
287
|
+
idx_t size = end - start;
|
288
|
+
FAISS_THROW_IF_NOT_FMT(
|
289
|
+
size > 0, "Cluster %d: size %d", int(k), int(size));
|
290
|
+
centroids[k] = sum / size;
|
291
|
+
end = start;
|
292
|
+
|
293
|
+
tot += size;
|
294
|
+
uf += size * double(size);
|
295
|
+
}
|
296
|
+
|
297
|
+
uf = uf * nclusters / (tot * tot);
|
298
|
+
return uf;
|
299
|
+
}
|
300
|
+
|
301
|
+
} // namespace faiss
|