faiss 0.3.1 → 0.3.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.h +1 -1
- data/vendor/faiss/faiss/Clustering.cpp +35 -4
- data/vendor/faiss/faiss/Clustering.h +10 -1
- data/vendor/faiss/faiss/IVFlib.cpp +4 -1
- data/vendor/faiss/faiss/Index.h +21 -6
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -1
- data/vendor/faiss/faiss/IndexFastScan.cpp +22 -4
- data/vendor/faiss/faiss/IndexFlat.cpp +11 -7
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +159 -5
- data/vendor/faiss/faiss/IndexFlatCodes.h +20 -3
- data/vendor/faiss/faiss/IndexHNSW.cpp +143 -90
- data/vendor/faiss/faiss/IndexHNSW.h +52 -3
- data/vendor/faiss/faiss/IndexIVF.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVF.h +9 -1
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +15 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +130 -57
- data/vendor/faiss/faiss/IndexIVFFastScan.h +14 -7
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +1 -3
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +21 -2
- data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
- data/vendor/faiss/faiss/IndexLattice.h +3 -22
- data/vendor/faiss/faiss/IndexNNDescent.cpp +0 -29
- data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
- data/vendor/faiss/faiss/IndexNSG.h +1 -1
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
- data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
- data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
- data/vendor/faiss/faiss/IndexRefine.cpp +5 -5
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/MetricType.h +7 -2
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
- data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
- data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +36 -4
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +6 -0
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndex.h +2 -8
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +6 -0
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +2 -0
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +25 -0
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +6 -0
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +65 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
- data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +25 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +9 -1
- data/vendor/faiss/faiss/impl/DistanceComputer.h +46 -0
- data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
- data/vendor/faiss/faiss/impl/HNSW.cpp +358 -190
- data/vendor/faiss/faiss/impl/HNSW.h +43 -22
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +8 -8
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +13 -8
- data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +1 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +5 -1
- data/vendor/faiss/faiss/impl/ResultHandler.h +151 -32
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +719 -102
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +5 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +29 -15
- data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
- data/vendor/faiss/faiss/impl/index_write.cpp +28 -10
- data/vendor/faiss/faiss/impl/io.cpp +13 -5
- data/vendor/faiss/faiss/impl/io.h +4 -4
- data/vendor/faiss/faiss/impl/io_macros.h +6 -0
- data/vendor/faiss/faiss/impl/platform_macros.h +22 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +11 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +1 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +448 -1
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +5 -5
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +143 -59
- data/vendor/faiss/faiss/index_factory.cpp +31 -13
- data/vendor/faiss/faiss/index_io.h +12 -5
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +9 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +55 -17
- data/vendor/faiss/faiss/invlists/InvertedLists.h +18 -9
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +21 -6
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +3 -3
- data/vendor/faiss/faiss/utils/Heap.h +105 -0
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
- data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
- data/vendor/faiss/faiss/utils/bf16.h +36 -0
- data/vendor/faiss/faiss/utils/distances.cpp +58 -88
- data/vendor/faiss/faiss/utils/distances.h +5 -5
- data/vendor/faiss/faiss/utils/distances_simd.cpp +997 -9
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
- data/vendor/faiss/faiss/utils/hamming.cpp +1 -1
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +4 -1
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +2 -1
- data/vendor/faiss/faiss/utils/random.cpp +43 -0
- data/vendor/faiss/faiss/utils/random.h +25 -0
- data/vendor/faiss/faiss/utils/simdlib.h +10 -1
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
- data/vendor/faiss/faiss/utils/simdlib_neon.h +5 -2
- data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
- data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
- data/vendor/faiss/faiss/utils/utils.cpp +10 -3
- data/vendor/faiss/faiss/utils/utils.h +3 -0
- metadata +16 -4
- data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -15,6 +15,7 @@
|
|
15
15
|
#include <faiss/utils/simdlib.h>
|
16
16
|
|
17
17
|
#include <faiss/impl/FaissAssert.h>
|
18
|
+
#include <faiss/impl/IDSelector.h>
|
18
19
|
#include <faiss/impl/ResultHandler.h>
|
19
20
|
#include <faiss/impl/platform_macros.h>
|
20
21
|
#include <faiss/utils/AlignedTable.h>
|
@@ -137,6 +138,7 @@ struct FixedStorageHandler : SIMDResultHandler {
|
|
137
138
|
}
|
138
139
|
}
|
139
140
|
}
|
141
|
+
|
140
142
|
virtual ~FixedStorageHandler() {}
|
141
143
|
};
|
142
144
|
|
@@ -150,8 +152,10 @@ struct ResultHandlerCompare : SIMDResultHandlerToFloat {
|
|
150
152
|
int64_t i0 = 0; // query origin
|
151
153
|
int64_t j0 = 0; // db origin
|
152
154
|
|
153
|
-
|
154
|
-
|
155
|
+
const IDSelector* sel;
|
156
|
+
|
157
|
+
ResultHandlerCompare(size_t nq, size_t ntotal, const IDSelector* sel_in)
|
158
|
+
: SIMDResultHandlerToFloat(nq, ntotal), sel{sel_in} {
|
155
159
|
this->is_CMax = C::is_max;
|
156
160
|
this->sizeof_ids = sizeof(typename C::TI);
|
157
161
|
this->with_fields = with_id_map;
|
@@ -232,9 +236,14 @@ struct SingleResultHandler : ResultHandlerCompare<C, with_id_map> {
|
|
232
236
|
float* dis;
|
233
237
|
int64_t* ids;
|
234
238
|
|
235
|
-
SingleResultHandler(
|
236
|
-
|
237
|
-
|
239
|
+
SingleResultHandler(
|
240
|
+
size_t nq,
|
241
|
+
size_t ntotal,
|
242
|
+
float* dis,
|
243
|
+
int64_t* ids,
|
244
|
+
const IDSelector* sel_in)
|
245
|
+
: RHC(nq, ntotal, sel_in), idis(nq), dis(dis), ids(ids) {
|
246
|
+
for (size_t i = 0; i < nq; i++) {
|
238
247
|
ids[i] = -1;
|
239
248
|
idis[i] = C::neutral();
|
240
249
|
}
|
@@ -256,20 +265,36 @@ struct SingleResultHandler : ResultHandlerCompare<C, with_id_map> {
|
|
256
265
|
d0.store(d32tab);
|
257
266
|
d1.store(d32tab + 16);
|
258
267
|
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
268
|
+
if (this->sel != nullptr) {
|
269
|
+
while (lt_mask) {
|
270
|
+
// find first non-zero
|
271
|
+
int j = __builtin_ctz(lt_mask);
|
272
|
+
auto real_idx = this->adjust_id(b, j);
|
273
|
+
lt_mask -= 1 << j;
|
274
|
+
if (this->sel->is_member(real_idx)) {
|
275
|
+
T d = d32tab[j];
|
276
|
+
if (C::cmp(idis[q], d)) {
|
277
|
+
idis[q] = d;
|
278
|
+
ids[q] = real_idx;
|
279
|
+
}
|
280
|
+
}
|
281
|
+
}
|
282
|
+
} else {
|
283
|
+
while (lt_mask) {
|
284
|
+
// find first non-zero
|
285
|
+
int j = __builtin_ctz(lt_mask);
|
286
|
+
lt_mask -= 1 << j;
|
287
|
+
T d = d32tab[j];
|
288
|
+
if (C::cmp(idis[q], d)) {
|
289
|
+
idis[q] = d;
|
290
|
+
ids[q] = this->adjust_id(b, j);
|
291
|
+
}
|
267
292
|
}
|
268
293
|
}
|
269
294
|
}
|
270
295
|
|
271
296
|
void end() {
|
272
|
-
for (
|
297
|
+
for (size_t q = 0; q < this->nq; q++) {
|
273
298
|
if (!normalizers) {
|
274
299
|
dis[q] = idis[q];
|
275
300
|
} else {
|
@@ -296,8 +321,14 @@ struct HeapHandler : ResultHandlerCompare<C, with_id_map> {
|
|
296
321
|
|
297
322
|
int64_t k; // number of results to keep
|
298
323
|
|
299
|
-
HeapHandler(
|
300
|
-
|
324
|
+
HeapHandler(
|
325
|
+
size_t nq,
|
326
|
+
size_t ntotal,
|
327
|
+
int64_t k,
|
328
|
+
float* dis,
|
329
|
+
int64_t* ids,
|
330
|
+
const IDSelector* sel_in)
|
331
|
+
: RHC(nq, ntotal, sel_in),
|
301
332
|
idis(nq * k),
|
302
333
|
iids(nq * k),
|
303
334
|
dis(dis),
|
@@ -330,21 +361,36 @@ struct HeapHandler : ResultHandlerCompare<C, with_id_map> {
|
|
330
361
|
d0.store(d32tab);
|
331
362
|
d1.store(d32tab + 16);
|
332
363
|
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
364
|
+
if (this->sel != nullptr) {
|
365
|
+
while (lt_mask) {
|
366
|
+
// find first non-zero
|
367
|
+
int j = __builtin_ctz(lt_mask);
|
368
|
+
auto real_idx = this->adjust_id(b, j);
|
369
|
+
lt_mask -= 1 << j;
|
370
|
+
if (this->sel->is_member(real_idx)) {
|
371
|
+
T dis = d32tab[j];
|
372
|
+
if (C::cmp(heap_dis[0], dis)) {
|
373
|
+
heap_replace_top<C>(
|
374
|
+
k, heap_dis, heap_ids, dis, real_idx);
|
375
|
+
}
|
376
|
+
}
|
377
|
+
}
|
378
|
+
} else {
|
379
|
+
while (lt_mask) {
|
380
|
+
// find first non-zero
|
381
|
+
int j = __builtin_ctz(lt_mask);
|
382
|
+
lt_mask -= 1 << j;
|
383
|
+
T dis = d32tab[j];
|
384
|
+
if (C::cmp(heap_dis[0], dis)) {
|
385
|
+
int64_t idx = this->adjust_id(b, j);
|
386
|
+
heap_replace_top<C>(k, heap_dis, heap_ids, dis, idx);
|
387
|
+
}
|
342
388
|
}
|
343
389
|
}
|
344
390
|
}
|
345
391
|
|
346
392
|
void end() override {
|
347
|
-
for (
|
393
|
+
for (size_t q = 0; q < this->nq; q++) {
|
348
394
|
T* heap_dis_in = idis.data() + q * k;
|
349
395
|
TI* heap_ids_in = iids.data() + q * k;
|
350
396
|
heap_reorder<C>(k, heap_dis_in, heap_ids_in);
|
@@ -393,8 +439,12 @@ struct ReservoirHandler : ResultHandlerCompare<C, with_id_map> {
|
|
393
439
|
size_t k,
|
394
440
|
size_t cap,
|
395
441
|
float* dis,
|
396
|
-
int64_t* ids
|
397
|
-
|
442
|
+
int64_t* ids,
|
443
|
+
const IDSelector* sel_in)
|
444
|
+
: RHC(nq, ntotal, sel_in),
|
445
|
+
capacity((cap + 15) & ~15),
|
446
|
+
dis(dis),
|
447
|
+
ids(ids) {
|
398
448
|
assert(capacity % 16 == 0);
|
399
449
|
all_ids.resize(nq * capacity);
|
400
450
|
all_vals.resize(nq * capacity);
|
@@ -423,12 +473,25 @@ struct ReservoirHandler : ResultHandlerCompare<C, with_id_map> {
|
|
423
473
|
d0.store(d32tab);
|
424
474
|
d1.store(d32tab + 16);
|
425
475
|
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
476
|
+
if (this->sel != nullptr) {
|
477
|
+
while (lt_mask) {
|
478
|
+
// find first non-zero
|
479
|
+
int j = __builtin_ctz(lt_mask);
|
480
|
+
auto real_idx = this->adjust_id(b, j);
|
481
|
+
lt_mask -= 1 << j;
|
482
|
+
if (this->sel->is_member(real_idx)) {
|
483
|
+
T dis = d32tab[j];
|
484
|
+
res.add(dis, real_idx);
|
485
|
+
}
|
486
|
+
}
|
487
|
+
} else {
|
488
|
+
while (lt_mask) {
|
489
|
+
// find first non-zero
|
490
|
+
int j = __builtin_ctz(lt_mask);
|
491
|
+
lt_mask -= 1 << j;
|
492
|
+
T dis = d32tab[j];
|
493
|
+
res.add(dis, this->adjust_id(b, j));
|
494
|
+
}
|
432
495
|
}
|
433
496
|
}
|
434
497
|
|
@@ -439,7 +502,7 @@ struct ReservoirHandler : ResultHandlerCompare<C, with_id_map> {
|
|
439
502
|
CMin<float, int64_t>>::type;
|
440
503
|
|
441
504
|
std::vector<int> perm(reservoirs[0].n);
|
442
|
-
for (
|
505
|
+
for (size_t q = 0; q < reservoirs.size(); q++) {
|
443
506
|
ReservoirTopN<C>& res = reservoirs[q];
|
444
507
|
size_t n = res.n;
|
445
508
|
|
@@ -454,14 +517,14 @@ struct ReservoirHandler : ResultHandlerCompare<C, with_id_map> {
|
|
454
517
|
one_a = 1 / normalizers[2 * q];
|
455
518
|
b = normalizers[2 * q + 1];
|
456
519
|
}
|
457
|
-
for (
|
520
|
+
for (size_t i = 0; i < res.i; i++) {
|
458
521
|
perm[i] = i;
|
459
522
|
}
|
460
523
|
// indirect sort of result arrays
|
461
524
|
std::sort(perm.begin(), perm.begin() + res.i, [&res](int i, int j) {
|
462
525
|
return C::cmp(res.vals[j], res.vals[i]);
|
463
526
|
});
|
464
|
-
for (
|
527
|
+
for (size_t i = 0; i < res.i; i++) {
|
465
528
|
heap_dis[i] = res.vals[perm[i]] * one_a + b;
|
466
529
|
heap_ids[i] = res.ids[perm[i]];
|
467
530
|
}
|
@@ -472,7 +535,7 @@ struct ReservoirHandler : ResultHandlerCompare<C, with_id_map> {
|
|
472
535
|
}
|
473
536
|
};
|
474
537
|
|
475
|
-
/** Result
|
538
|
+
/** Result handler for range search. The difficulty is that the range distances
|
476
539
|
* have to be scaled using the scaler.
|
477
540
|
*/
|
478
541
|
|
@@ -499,13 +562,17 @@ struct RangeHandler : ResultHandlerCompare<C, with_id_map> {
|
|
499
562
|
};
|
500
563
|
std::vector<Triplet> triplets;
|
501
564
|
|
502
|
-
RangeHandler(
|
503
|
-
|
565
|
+
RangeHandler(
|
566
|
+
RangeSearchResult& rres,
|
567
|
+
float radius,
|
568
|
+
size_t ntotal,
|
569
|
+
const IDSelector* sel_in)
|
570
|
+
: RHC(rres.nq, ntotal, sel_in), rres(rres), radius(radius) {
|
504
571
|
thresholds.resize(nq);
|
505
572
|
n_per_query.resize(nq + 1);
|
506
573
|
}
|
507
574
|
|
508
|
-
virtual void begin(const float* norms) {
|
575
|
+
virtual void begin(const float* norms) override {
|
509
576
|
normalizers = norms;
|
510
577
|
for (int q = 0; q < nq; ++q) {
|
511
578
|
thresholds[q] =
|
@@ -528,13 +595,28 @@ struct RangeHandler : ResultHandlerCompare<C, with_id_map> {
|
|
528
595
|
d0.store(d32tab);
|
529
596
|
d1.store(d32tab + 16);
|
530
597
|
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
598
|
+
if (this->sel != nullptr) {
|
599
|
+
while (lt_mask) {
|
600
|
+
// find first non-zero
|
601
|
+
int j = __builtin_ctz(lt_mask);
|
602
|
+
lt_mask -= 1 << j;
|
603
|
+
|
604
|
+
auto real_idx = this->adjust_id(b, j);
|
605
|
+
if (this->sel->is_member(real_idx)) {
|
606
|
+
T dis = d32tab[j];
|
607
|
+
n_per_query[q]++;
|
608
|
+
triplets.push_back({idx_t(q + q0), real_idx, dis});
|
609
|
+
}
|
610
|
+
}
|
611
|
+
} else {
|
612
|
+
while (lt_mask) {
|
613
|
+
// find first non-zero
|
614
|
+
int j = __builtin_ctz(lt_mask);
|
615
|
+
lt_mask -= 1 << j;
|
616
|
+
T dis = d32tab[j];
|
617
|
+
n_per_query[q]++;
|
618
|
+
triplets.push_back({idx_t(q + q0), this->adjust_id(b, j), dis});
|
619
|
+
}
|
538
620
|
}
|
539
621
|
}
|
540
622
|
|
@@ -578,8 +660,9 @@ struct PartialRangeHandler : RangeHandler<C, with_id_map> {
|
|
578
660
|
float radius,
|
579
661
|
size_t ntotal,
|
580
662
|
size_t q0,
|
581
|
-
size_t q1
|
582
|
-
|
663
|
+
size_t q1,
|
664
|
+
const IDSelector* sel_in)
|
665
|
+
: RangeHandler<C, with_id_map>(*pres.res, radius, ntotal, sel_in),
|
583
666
|
pres(pres) {
|
584
667
|
nq = q1 - q0;
|
585
668
|
this->q0 = q0;
|
@@ -630,7 +713,7 @@ struct PartialRangeHandler : RangeHandler<C, with_id_map> {
|
|
630
713
|
*/
|
631
714
|
|
632
715
|
template <class C, bool W, class Consumer, class... Types>
|
633
|
-
void
|
716
|
+
void dispatch_SIMDResultHandler_fixedCW(
|
634
717
|
SIMDResultHandler& res,
|
635
718
|
Consumer& consumer,
|
636
719
|
Types... args) {
|
@@ -650,19 +733,19 @@ void dispatch_SIMDResultHanlder_fixedCW(
|
|
650
733
|
}
|
651
734
|
|
652
735
|
template <class C, class Consumer, class... Types>
|
653
|
-
void
|
736
|
+
void dispatch_SIMDResultHandler_fixedC(
|
654
737
|
SIMDResultHandler& res,
|
655
738
|
Consumer& consumer,
|
656
739
|
Types... args) {
|
657
740
|
if (res.with_fields) {
|
658
|
-
|
741
|
+
dispatch_SIMDResultHandler_fixedCW<C, true>(res, consumer, args...);
|
659
742
|
} else {
|
660
|
-
|
743
|
+
dispatch_SIMDResultHandler_fixedCW<C, false>(res, consumer, args...);
|
661
744
|
}
|
662
745
|
}
|
663
746
|
|
664
747
|
template <class Consumer, class... Types>
|
665
|
-
void
|
748
|
+
void dispatch_SIMDResultHandler(
|
666
749
|
SIMDResultHandler& res,
|
667
750
|
Consumer& consumer,
|
668
751
|
Types... args) {
|
@@ -680,24 +763,25 @@ void dispatch_SIMDResultHanlder(
|
|
680
763
|
}
|
681
764
|
} else if (res.sizeof_ids == sizeof(int)) {
|
682
765
|
if (res.is_CMax) {
|
683
|
-
|
766
|
+
dispatch_SIMDResultHandler_fixedC<CMax<uint16_t, int>>(
|
684
767
|
res, consumer, args...);
|
685
768
|
} else {
|
686
|
-
|
769
|
+
dispatch_SIMDResultHandler_fixedC<CMin<uint16_t, int>>(
|
687
770
|
res, consumer, args...);
|
688
771
|
}
|
689
772
|
} else if (res.sizeof_ids == sizeof(int64_t)) {
|
690
773
|
if (res.is_CMax) {
|
691
|
-
|
774
|
+
dispatch_SIMDResultHandler_fixedC<CMax<uint16_t, int64_t>>(
|
692
775
|
res, consumer, args...);
|
693
776
|
} else {
|
694
|
-
|
777
|
+
dispatch_SIMDResultHandler_fixedC<CMin<uint16_t, int64_t>>(
|
695
778
|
res, consumer, args...);
|
696
779
|
}
|
697
780
|
} else {
|
698
781
|
FAISS_THROW_FMT("Unknown id size %d", res.sizeof_ids);
|
699
782
|
}
|
700
783
|
}
|
784
|
+
|
701
785
|
} // namespace simd_result_handlers
|
702
786
|
|
703
787
|
} // namespace faiss
|
@@ -140,8 +140,12 @@ std::map<std::string, ScalarQuantizer::QuantizerType> sq_types = {
|
|
140
140
|
{"SQ4", ScalarQuantizer::QT_4bit},
|
141
141
|
{"SQ6", ScalarQuantizer::QT_6bit},
|
142
142
|
{"SQfp16", ScalarQuantizer::QT_fp16},
|
143
|
+
{"SQbf16", ScalarQuantizer::QT_bf16},
|
144
|
+
{"SQ8_direct_signed", ScalarQuantizer::QT_8bit_direct_signed},
|
145
|
+
{"SQ8_direct", ScalarQuantizer::QT_8bit_direct},
|
143
146
|
};
|
144
|
-
const std::string sq_pattern =
|
147
|
+
const std::string sq_pattern =
|
148
|
+
"(SQ4|SQ8|SQ6|SQfp16|SQbf16|SQ8_direct_signed|SQ8_direct)";
|
145
149
|
|
146
150
|
std::map<std::string, AdditiveQuantizer::Search_type_t> aq_search_type = {
|
147
151
|
{"_Nfloat", AdditiveQuantizer::ST_norm_float},
|
@@ -222,6 +226,19 @@ VectorTransform* parse_VectorTransform(const std::string& description, int d) {
|
|
222
226
|
* Parse IndexIVF
|
223
227
|
*/
|
224
228
|
|
229
|
+
size_t parse_nlist(std::string s) {
|
230
|
+
size_t multiplier = 1;
|
231
|
+
if (s.back() == 'k') {
|
232
|
+
s.pop_back();
|
233
|
+
multiplier = 1024;
|
234
|
+
}
|
235
|
+
if (s.back() == 'M') {
|
236
|
+
s.pop_back();
|
237
|
+
multiplier = 1024 * 1024;
|
238
|
+
}
|
239
|
+
return std::stoi(s) * multiplier;
|
240
|
+
}
|
241
|
+
|
225
242
|
// parsing guard + function
|
226
243
|
Index* parse_coarse_quantizer(
|
227
244
|
const std::string& description,
|
@@ -236,8 +253,8 @@ Index* parse_coarse_quantizer(
|
|
236
253
|
};
|
237
254
|
use_2layer = false;
|
238
255
|
|
239
|
-
if (match("IVF([0-9]+)")) {
|
240
|
-
nlist =
|
256
|
+
if (match("IVF([0-9]+[kM]?)")) {
|
257
|
+
nlist = parse_nlist(sm[1].str());
|
241
258
|
return new IndexFlat(d, mt);
|
242
259
|
}
|
243
260
|
if (match("IMI2x([0-9]+)")) {
|
@@ -248,18 +265,18 @@ Index* parse_coarse_quantizer(
|
|
248
265
|
nlist = (size_t)1 << (2 * nbit);
|
249
266
|
return new MultiIndexQuantizer(d, 2, nbit);
|
250
267
|
}
|
251
|
-
if (match("IVF([0-9]+)_HNSW([0-9]*)")) {
|
252
|
-
nlist =
|
268
|
+
if (match("IVF([0-9]+[kM]?)_HNSW([0-9]*)")) {
|
269
|
+
nlist = parse_nlist(sm[1].str());
|
253
270
|
int hnsw_M = sm[2].length() > 0 ? std::stoi(sm[2]) : 32;
|
254
271
|
return new IndexHNSWFlat(d, hnsw_M, mt);
|
255
272
|
}
|
256
|
-
if (match("IVF([0-9]+)_NSG([0-9]+)")) {
|
257
|
-
nlist =
|
273
|
+
if (match("IVF([0-9]+[kM]?)_NSG([0-9]+)")) {
|
274
|
+
nlist = parse_nlist(sm[1].str());
|
258
275
|
int R = std::stoi(sm[2]);
|
259
276
|
return new IndexNSGFlat(d, R, mt);
|
260
277
|
}
|
261
|
-
if (match("IVF([0-9]+)\\(Index([0-9])\\)")) {
|
262
|
-
nlist =
|
278
|
+
if (match("IVF([0-9]+[kM]?)\\(Index([0-9])\\)")) {
|
279
|
+
nlist = parse_nlist(sm[1].str());
|
263
280
|
int no = std::stoi(sm[2].str());
|
264
281
|
FAISS_ASSERT(no >= 0 && no < parenthesis_indexes.size());
|
265
282
|
return parenthesis_indexes[no].release();
|
@@ -526,11 +543,12 @@ Index* parse_other_indexes(
|
|
526
543
|
}
|
527
544
|
|
528
545
|
// IndexLSH
|
529
|
-
if (match("LSH(r?)(t?)")) {
|
530
|
-
|
531
|
-
bool
|
546
|
+
if (match("LSH([0-9]*)(r?)(t?)")) {
|
547
|
+
int nbits = sm[1].length() > 0 ? std::stoi(sm[1].str()) : d;
|
548
|
+
bool rotate_data = sm[2].length() > 0;
|
549
|
+
bool train_thresholds = sm[3].length() > 0;
|
532
550
|
FAISS_THROW_IF_NOT(metric == METRIC_L2);
|
533
|
-
return new IndexLSH(d,
|
551
|
+
return new IndexLSH(d, nbits, rotate_data, train_thresholds);
|
534
552
|
}
|
535
553
|
|
536
554
|
// IndexLattice
|
@@ -5,8 +5,6 @@
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
6
6
|
*/
|
7
7
|
|
8
|
-
// -*- c++ -*-
|
9
|
-
|
10
8
|
// I/O code for indexes
|
11
9
|
|
12
10
|
#ifndef FAISS_INDEX_IO_H
|
@@ -35,9 +33,12 @@ struct IOReader;
|
|
35
33
|
struct IOWriter;
|
36
34
|
struct InvertedLists;
|
37
35
|
|
38
|
-
|
39
|
-
|
40
|
-
|
36
|
+
/// skip the storage for graph-based indexes
|
37
|
+
const int IO_FLAG_SKIP_STORAGE = 1;
|
38
|
+
|
39
|
+
void write_index(const Index* idx, const char* fname, int io_flags = 0);
|
40
|
+
void write_index(const Index* idx, FILE* f, int io_flags = 0);
|
41
|
+
void write_index(const Index* idx, IOWriter* writer, int io_flags = 0);
|
41
42
|
|
42
43
|
void write_index_binary(const IndexBinary* idx, const char* fname);
|
43
44
|
void write_index_binary(const IndexBinary* idx, FILE* f);
|
@@ -52,6 +53,12 @@ const int IO_FLAG_ONDISK_SAME_DIR = 4;
|
|
52
53
|
const int IO_FLAG_SKIP_IVF_DATA = 8;
|
53
54
|
// don't initialize precomputed table after loading
|
54
55
|
const int IO_FLAG_SKIP_PRECOMPUTE_TABLE = 16;
|
56
|
+
// don't compute the sdc table for PQ-based indices
|
57
|
+
// this will prevent distances from being computed
|
58
|
+
// between elements in the index. For indices like HNSWPQ,
|
59
|
+
// this will prevent graph building because sdc
|
60
|
+
// computations are required to construct the graph
|
61
|
+
const int IO_FLAG_PQ_SKIP_SDC_TABLE = 32;
|
55
62
|
// try to memmap data (useful to load an ArrayInvertedLists as an
|
56
63
|
// OnDiskInvertedLists)
|
57
64
|
const int IO_FLAG_MMAP = IO_FLAG_SKIP_IVF_DATA | 0x646f0000;
|
@@ -9,6 +9,7 @@
|
|
9
9
|
|
10
10
|
#include <faiss/impl/CodePacker.h>
|
11
11
|
#include <faiss/impl/FaissAssert.h>
|
12
|
+
#include <faiss/impl/IDSelector.h>
|
12
13
|
|
13
14
|
#include <faiss/impl/io.h>
|
14
15
|
#include <faiss/impl/io_macros.h>
|
@@ -54,7 +55,9 @@ size_t BlockInvertedLists::add_entries(
|
|
54
55
|
codes[list_no].resize(n_block * block_size);
|
55
56
|
if (o % block_size == 0) {
|
56
57
|
// copy whole blocks
|
57
|
-
memcpy(&codes[list_no][o * code_size],
|
58
|
+
memcpy(&codes[list_no][o * packer->code_size],
|
59
|
+
code,
|
60
|
+
n_block * block_size);
|
58
61
|
} else {
|
59
62
|
FAISS_THROW_IF_NOT_MSG(packer, "missing code packer");
|
60
63
|
std::vector<uint8_t> buffer(packer->code_size);
|
@@ -76,6 +79,29 @@ const uint8_t* BlockInvertedLists::get_codes(size_t list_no) const {
|
|
76
79
|
return codes[list_no].get();
|
77
80
|
}
|
78
81
|
|
82
|
+
size_t BlockInvertedLists::remove_ids(const IDSelector& sel) {
|
83
|
+
idx_t nremove = 0;
|
84
|
+
#pragma omp parallel for
|
85
|
+
for (idx_t i = 0; i < nlist; i++) {
|
86
|
+
std::vector<uint8_t> buffer(packer->code_size);
|
87
|
+
idx_t l = ids[i].size(), j = 0;
|
88
|
+
while (j < l) {
|
89
|
+
if (sel.is_member(ids[i][j])) {
|
90
|
+
l--;
|
91
|
+
ids[i][j] = ids[i][l];
|
92
|
+
packer->unpack_1(codes[i].data(), l, buffer.data());
|
93
|
+
packer->pack_1(buffer.data(), j, codes[i].data());
|
94
|
+
} else {
|
95
|
+
j++;
|
96
|
+
}
|
97
|
+
}
|
98
|
+
resize(i, l);
|
99
|
+
nremove += ids[i].size() - l;
|
100
|
+
}
|
101
|
+
|
102
|
+
return nremove;
|
103
|
+
}
|
104
|
+
|
79
105
|
const idx_t* BlockInvertedLists::get_ids(size_t list_no) const {
|
80
106
|
assert(list_no < nlist);
|
81
107
|
return ids[list_no].data();
|
@@ -101,13 +127,7 @@ void BlockInvertedLists::update_entries(
|
|
101
127
|
size_t,
|
102
128
|
const idx_t*,
|
103
129
|
const uint8_t*) {
|
104
|
-
FAISS_THROW_MSG("not
|
105
|
-
/*
|
106
|
-
assert (list_no < nlist);
|
107
|
-
assert (n_entry + offset <= ids[list_no].size());
|
108
|
-
memcpy (&ids[list_no][offset], ids_in, sizeof(ids_in[0]) * n_entry);
|
109
|
-
memcpy (&codes[list_no][offset * code_size], codes_in, code_size * n_entry);
|
110
|
-
*/
|
130
|
+
FAISS_THROW_MSG("not implemented");
|
111
131
|
}
|
112
132
|
|
113
133
|
BlockInvertedLists::~BlockInvertedLists() {
|
@@ -15,6 +15,7 @@
|
|
15
15
|
namespace faiss {
|
16
16
|
|
17
17
|
struct CodePacker;
|
18
|
+
struct IDSelector;
|
18
19
|
|
19
20
|
/** Inverted Lists that are organized by blocks.
|
20
21
|
*
|
@@ -47,6 +48,8 @@ struct BlockInvertedLists : InvertedLists {
|
|
47
48
|
size_t list_size(size_t list_no) const override;
|
48
49
|
const uint8_t* get_codes(size_t list_no) const override;
|
49
50
|
const idx_t* get_ids(size_t list_no) const override;
|
51
|
+
/// remove ids from the InvertedLists
|
52
|
+
size_t remove_ids(const IDSelector& sel);
|
50
53
|
|
51
54
|
// works only on empty BlockInvertedLists
|
52
55
|
// the codes should be of size ceil(n_entry / n_per_block) * block_size
|
@@ -15,6 +15,7 @@
|
|
15
15
|
#include <faiss/impl/AuxIndexStructures.h>
|
16
16
|
#include <faiss/impl/FaissAssert.h>
|
17
17
|
#include <faiss/impl/IDSelector.h>
|
18
|
+
#include <faiss/invlists/BlockInvertedLists.h>
|
18
19
|
|
19
20
|
namespace faiss {
|
20
21
|
|
@@ -148,8 +149,12 @@ size_t DirectMap::remove_ids(const IDSelector& sel, InvertedLists* invlists) {
|
|
148
149
|
std::vector<idx_t> toremove(nlist);
|
149
150
|
|
150
151
|
size_t nremove = 0;
|
151
|
-
|
152
|
+
BlockInvertedLists* block_invlists =
|
153
|
+
dynamic_cast<BlockInvertedLists*>(invlists);
|
152
154
|
if (type == NoMap) {
|
155
|
+
if (block_invlists != nullptr) {
|
156
|
+
return block_invlists->remove_ids(sel);
|
157
|
+
}
|
153
158
|
// exhaustive scan of IVF
|
154
159
|
#pragma omp parallel for
|
155
160
|
for (idx_t i = 0; i < nlist; i++) {
|
@@ -178,6 +183,9 @@ size_t DirectMap::remove_ids(const IDSelector& sel, InvertedLists* invlists) {
|
|
178
183
|
}
|
179
184
|
}
|
180
185
|
} else if (type == Hashtable) {
|
186
|
+
FAISS_THROW_IF_MSG(
|
187
|
+
block_invlists,
|
188
|
+
"remove with hashtable is not supported with BlockInvertedLists");
|
181
189
|
const IDSelectorArray* sela =
|
182
190
|
dynamic_cast<const IDSelectorArray*>(&sel);
|
183
191
|
FAISS_THROW_IF_NOT_MSG(
|