faiss 0.3.0 → 0.3.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -1
- data/ext/faiss/extconf.rb +9 -2
- data/ext/faiss/index.cpp +1 -1
- data/ext/faiss/index_binary.cpp +2 -2
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +7 -7
- data/vendor/faiss/faiss/AutoTune.h +0 -1
- data/vendor/faiss/faiss/Clustering.cpp +4 -18
- data/vendor/faiss/faiss/Clustering.h +31 -21
- data/vendor/faiss/faiss/IVFlib.cpp +22 -11
- data/vendor/faiss/faiss/Index.cpp +1 -1
- data/vendor/faiss/faiss/Index.h +20 -5
- data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
- data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinary.h +8 -19
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
- data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
- data/vendor/faiss/faiss/IndexFastScan.h +9 -8
- data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
- data/vendor/faiss/faiss/IndexFlat.h +20 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
- data/vendor/faiss/faiss/IndexHNSW.h +12 -48
- data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
- data/vendor/faiss/faiss/IndexIDMap.h +24 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
- data/vendor/faiss/faiss/IndexIVF.h +37 -5
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
- data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
- data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
- data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
- data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
- data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
- data/vendor/faiss/faiss/IndexNSG.h +10 -10
- data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
- data/vendor/faiss/faiss/IndexPQ.h +1 -4
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
- data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
- data/vendor/faiss/faiss/IndexRefine.h +7 -0
- data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
- data/vendor/faiss/faiss/IndexShards.cpp +21 -29
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
- data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
- data/vendor/faiss/faiss/MatrixStats.h +21 -9
- data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
- data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
- data/vendor/faiss/faiss/VectorTransform.h +7 -7
- data/vendor/faiss/faiss/clone_index.cpp +15 -10
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
- data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
- data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
- data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
- data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
- data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
- data/vendor/faiss/faiss/impl/FaissException.h +13 -34
- data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
- data/vendor/faiss/faiss/impl/HNSW.h +9 -8
- data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
- data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
- data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
- data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
- data/vendor/faiss/faiss/impl/io.cpp +10 -10
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
- data/vendor/faiss/faiss/index_factory.cpp +10 -7
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
- data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
- data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
- data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
- data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
- data/vendor/faiss/faiss/utils/distances.cpp +128 -74
- data/vendor/faiss/faiss/utils/distances.h +81 -4
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
- data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
- data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
- data/vendor/faiss/faiss/utils/fp16.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
- data/vendor/faiss/faiss/utils/hamming.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
- data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
- data/vendor/faiss/faiss/utils/prefetch.h +77 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
- data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
- data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
- data/vendor/faiss/faiss/utils/sorting.h +27 -0
- data/vendor/faiss/faiss/utils/utils.cpp +112 -6
- data/vendor/faiss/faiss/utils/utils.h +57 -20
- metadata +10 -3
@@ -5,8 +5,6 @@
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
6
6
|
*/
|
7
7
|
|
8
|
-
// -*- c++ -*-
|
9
|
-
|
10
8
|
#include <faiss/impl/HNSW.h>
|
11
9
|
|
12
10
|
#include <string>
|
@@ -14,6 +12,17 @@
|
|
14
12
|
#include <faiss/impl/AuxIndexStructures.h>
|
15
13
|
#include <faiss/impl/DistanceComputer.h>
|
16
14
|
#include <faiss/impl/IDSelector.h>
|
15
|
+
#include <faiss/impl/ResultHandler.h>
|
16
|
+
#include <faiss/utils/prefetch.h>
|
17
|
+
|
18
|
+
#include <faiss/impl/platform_macros.h>
|
19
|
+
|
20
|
+
#ifdef __AVX2__
|
21
|
+
#include <immintrin.h>
|
22
|
+
|
23
|
+
#include <limits>
|
24
|
+
#include <type_traits>
|
25
|
+
#endif
|
17
26
|
|
18
27
|
namespace faiss {
|
19
28
|
|
@@ -212,7 +221,7 @@ int HNSW::prepare_level_tab(size_t n, bool preset_levels) {
|
|
212
221
|
return max_level;
|
213
222
|
}
|
214
223
|
|
215
|
-
/** Enumerate vertices from
|
224
|
+
/** Enumerate vertices from nearest to farthest from query, keep a
|
216
225
|
* neighbor only if there is no previous neighbor that is closer to
|
217
226
|
* that vertex than the query.
|
218
227
|
*/
|
@@ -503,17 +512,15 @@ void HNSW::add_with_locks(
|
|
503
512
|
**************************************************************/
|
504
513
|
|
505
514
|
namespace {
|
506
|
-
|
507
515
|
using MinimaxHeap = HNSW::MinimaxHeap;
|
508
516
|
using Node = HNSW::Node;
|
517
|
+
using C = HNSW::C;
|
509
518
|
/** Do a BFS on the candidates list */
|
510
519
|
|
511
520
|
int search_from_candidates(
|
512
521
|
const HNSW& hnsw,
|
513
522
|
DistanceComputer& qdis,
|
514
|
-
|
515
|
-
idx_t* I,
|
516
|
-
float* D,
|
523
|
+
ResultHandler<C>& res,
|
517
524
|
MinimaxHeap& candidates,
|
518
525
|
VisitedTable& vt,
|
519
526
|
HNSWStats& stats,
|
@@ -529,15 +536,16 @@ int search_from_candidates(
|
|
529
536
|
int efSearch = params ? params->efSearch : hnsw.efSearch;
|
530
537
|
const IDSelector* sel = params ? params->sel : nullptr;
|
531
538
|
|
539
|
+
C::T threshold = res.threshold;
|
532
540
|
for (int i = 0; i < candidates.size(); i++) {
|
533
541
|
idx_t v1 = candidates.ids[i];
|
534
542
|
float d = candidates.dis[i];
|
535
543
|
FAISS_ASSERT(v1 >= 0);
|
536
544
|
if (!sel || sel->is_member(v1)) {
|
537
|
-
if (
|
538
|
-
|
539
|
-
|
540
|
-
|
545
|
+
if (d < threshold) {
|
546
|
+
if (res.add_result(d, v1)) {
|
547
|
+
threshold = res.threshold;
|
548
|
+
}
|
541
549
|
}
|
542
550
|
}
|
543
551
|
vt.set(v1);
|
@@ -563,24 +571,86 @@ int search_from_candidates(
|
|
563
571
|
size_t begin, end;
|
564
572
|
hnsw.neighbor_range(v0, level, &begin, &end);
|
565
573
|
|
574
|
+
// // baseline version
|
575
|
+
// for (size_t j = begin; j < end; j++) {
|
576
|
+
// int v1 = hnsw.neighbors[j];
|
577
|
+
// if (v1 < 0)
|
578
|
+
// break;
|
579
|
+
// if (vt.get(v1)) {
|
580
|
+
// continue;
|
581
|
+
// }
|
582
|
+
// vt.set(v1);
|
583
|
+
// ndis++;
|
584
|
+
// float d = qdis(v1);
|
585
|
+
// if (!sel || sel->is_member(v1)) {
|
586
|
+
// if (nres < k) {
|
587
|
+
// faiss::maxheap_push(++nres, D, I, d, v1);
|
588
|
+
// } else if (d < D[0]) {
|
589
|
+
// faiss::maxheap_replace_top(nres, D, I, d, v1);
|
590
|
+
// }
|
591
|
+
// }
|
592
|
+
// candidates.push(v1, d);
|
593
|
+
// }
|
594
|
+
|
595
|
+
// the following version processes 4 neighbors at a time
|
596
|
+
size_t jmax = begin;
|
566
597
|
for (size_t j = begin; j < end; j++) {
|
567
598
|
int v1 = hnsw.neighbors[j];
|
568
599
|
if (v1 < 0)
|
569
600
|
break;
|
570
|
-
|
571
|
-
|
601
|
+
|
602
|
+
prefetch_L2(vt.visited.data() + v1);
|
603
|
+
jmax += 1;
|
604
|
+
}
|
605
|
+
|
606
|
+
int counter = 0;
|
607
|
+
size_t saved_j[4];
|
608
|
+
|
609
|
+
ndis += jmax - begin;
|
610
|
+
threshold = res.threshold;
|
611
|
+
|
612
|
+
auto add_to_heap = [&](const size_t idx, const float dis) {
|
613
|
+
if (!sel || sel->is_member(idx)) {
|
614
|
+
if (dis < threshold) {
|
615
|
+
if (res.add_result(dis, idx)) {
|
616
|
+
threshold = res.threshold;
|
617
|
+
}
|
618
|
+
}
|
572
619
|
}
|
620
|
+
candidates.push(idx, dis);
|
621
|
+
};
|
622
|
+
|
623
|
+
for (size_t j = begin; j < jmax; j++) {
|
624
|
+
int v1 = hnsw.neighbors[j];
|
625
|
+
|
626
|
+
bool vget = vt.get(v1);
|
573
627
|
vt.set(v1);
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
628
|
+
saved_j[counter] = v1;
|
629
|
+
counter += vget ? 0 : 1;
|
630
|
+
|
631
|
+
if (counter == 4) {
|
632
|
+
float dis[4];
|
633
|
+
qdis.distances_batch_4(
|
634
|
+
saved_j[0],
|
635
|
+
saved_j[1],
|
636
|
+
saved_j[2],
|
637
|
+
saved_j[3],
|
638
|
+
dis[0],
|
639
|
+
dis[1],
|
640
|
+
dis[2],
|
641
|
+
dis[3]);
|
642
|
+
|
643
|
+
for (size_t id4 = 0; id4 < 4; id4++) {
|
644
|
+
add_to_heap(saved_j[id4], dis[id4]);
|
581
645
|
}
|
646
|
+
|
647
|
+
counter = 0;
|
582
648
|
}
|
583
|
-
|
649
|
+
}
|
650
|
+
|
651
|
+
for (size_t icnt = 0; icnt < counter; icnt++) {
|
652
|
+
float dis = qdis(saved_j[icnt]);
|
653
|
+
add_to_heap(saved_j[icnt], dis);
|
584
654
|
}
|
585
655
|
|
586
656
|
nstep++;
|
@@ -630,29 +700,92 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
|
|
630
700
|
size_t begin, end;
|
631
701
|
hnsw.neighbor_range(v0, 0, &begin, &end);
|
632
702
|
|
633
|
-
|
703
|
+
// // baseline version
|
704
|
+
// for (size_t j = begin; j < end; ++j) {
|
705
|
+
// int v1 = hnsw.neighbors[j];
|
706
|
+
//
|
707
|
+
// if (v1 < 0) {
|
708
|
+
// break;
|
709
|
+
// }
|
710
|
+
// if (vt->get(v1)) {
|
711
|
+
// continue;
|
712
|
+
// }
|
713
|
+
//
|
714
|
+
// vt->set(v1);
|
715
|
+
//
|
716
|
+
// float d1 = qdis(v1);
|
717
|
+
// ++ndis;
|
718
|
+
//
|
719
|
+
// if (top_candidates.top().first > d1 ||
|
720
|
+
// top_candidates.size() < ef) {
|
721
|
+
// candidates.emplace(d1, v1);
|
722
|
+
// top_candidates.emplace(d1, v1);
|
723
|
+
//
|
724
|
+
// if (top_candidates.size() > ef) {
|
725
|
+
// top_candidates.pop();
|
726
|
+
// }
|
727
|
+
// }
|
728
|
+
// }
|
729
|
+
|
730
|
+
// the following version processes 4 neighbors at a time
|
731
|
+
size_t jmax = begin;
|
732
|
+
for (size_t j = begin; j < end; j++) {
|
634
733
|
int v1 = hnsw.neighbors[j];
|
635
|
-
|
636
|
-
if (v1 < 0) {
|
734
|
+
if (v1 < 0)
|
637
735
|
break;
|
638
|
-
}
|
639
|
-
if (vt->get(v1)) {
|
640
|
-
continue;
|
641
|
-
}
|
642
736
|
|
643
|
-
vt->
|
737
|
+
prefetch_L2(vt->visited.data() + v1);
|
738
|
+
jmax += 1;
|
739
|
+
}
|
740
|
+
|
741
|
+
int counter = 0;
|
742
|
+
size_t saved_j[4];
|
644
743
|
|
645
|
-
|
646
|
-
++ndis;
|
744
|
+
ndis += jmax - begin;
|
647
745
|
|
648
|
-
|
649
|
-
|
650
|
-
top_candidates.
|
746
|
+
auto add_to_heap = [&](const size_t idx, const float dis) {
|
747
|
+
if (top_candidates.top().first > dis ||
|
748
|
+
top_candidates.size() < ef) {
|
749
|
+
candidates.emplace(dis, idx);
|
750
|
+
top_candidates.emplace(dis, idx);
|
651
751
|
|
652
752
|
if (top_candidates.size() > ef) {
|
653
753
|
top_candidates.pop();
|
654
754
|
}
|
655
755
|
}
|
756
|
+
};
|
757
|
+
|
758
|
+
for (size_t j = begin; j < jmax; j++) {
|
759
|
+
int v1 = hnsw.neighbors[j];
|
760
|
+
|
761
|
+
bool vget = vt->get(v1);
|
762
|
+
vt->set(v1);
|
763
|
+
saved_j[counter] = v1;
|
764
|
+
counter += vget ? 0 : 1;
|
765
|
+
|
766
|
+
if (counter == 4) {
|
767
|
+
float dis[4];
|
768
|
+
qdis.distances_batch_4(
|
769
|
+
saved_j[0],
|
770
|
+
saved_j[1],
|
771
|
+
saved_j[2],
|
772
|
+
saved_j[3],
|
773
|
+
dis[0],
|
774
|
+
dis[1],
|
775
|
+
dis[2],
|
776
|
+
dis[3]);
|
777
|
+
|
778
|
+
for (size_t id4 = 0; id4 < 4; id4++) {
|
779
|
+
add_to_heap(saved_j[id4], dis[id4]);
|
780
|
+
}
|
781
|
+
|
782
|
+
counter = 0;
|
783
|
+
}
|
784
|
+
}
|
785
|
+
|
786
|
+
for (size_t icnt = 0; icnt < counter; icnt++) {
|
787
|
+
float dis = qdis(saved_j[icnt]);
|
788
|
+
add_to_heap(saved_j[icnt], dis);
|
656
789
|
}
|
657
790
|
}
|
658
791
|
|
@@ -665,19 +798,28 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
|
|
665
798
|
return top_candidates;
|
666
799
|
}
|
667
800
|
|
801
|
+
// just used as a lower bound for the minmaxheap, but it is set for heap search
|
802
|
+
int extract_k_from_ResultHandler(ResultHandler<C>& res) {
|
803
|
+
using RH = HeapBlockResultHandler<C>;
|
804
|
+
if (auto hres = dynamic_cast<RH::SingleResultHandler*>(&res)) {
|
805
|
+
return hres->k;
|
806
|
+
}
|
807
|
+
return 1;
|
808
|
+
}
|
809
|
+
|
668
810
|
} // anonymous namespace
|
669
811
|
|
670
812
|
HNSWStats HNSW::search(
|
671
813
|
DistanceComputer& qdis,
|
672
|
-
|
673
|
-
idx_t* I,
|
674
|
-
float* D,
|
814
|
+
ResultHandler<C>& res,
|
675
815
|
VisitedTable& vt,
|
676
816
|
const SearchParametersHNSW* params) const {
|
677
817
|
HNSWStats stats;
|
678
818
|
if (entry_point == -1) {
|
679
819
|
return stats;
|
680
820
|
}
|
821
|
+
int k = extract_k_from_ResultHandler(res);
|
822
|
+
|
681
823
|
if (upper_beam == 1) {
|
682
824
|
// greedy search on upper levels
|
683
825
|
storage_idx_t nearest = entry_point;
|
@@ -687,14 +829,14 @@ HNSWStats HNSW::search(
|
|
687
829
|
greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
|
688
830
|
}
|
689
831
|
|
690
|
-
int ef = std::max(efSearch, k);
|
832
|
+
int ef = std::max(params ? params->efSearch : efSearch, k);
|
691
833
|
if (search_bounded_queue) { // this is the most common branch
|
692
834
|
MinimaxHeap candidates(ef);
|
693
835
|
|
694
836
|
candidates.push(nearest, d_nearest);
|
695
837
|
|
696
838
|
search_from_candidates(
|
697
|
-
*this, qdis,
|
839
|
+
*this, qdis, res, candidates, vt, stats, 0, 0, params);
|
698
840
|
} else {
|
699
841
|
std::priority_queue<Node> top_candidates =
|
700
842
|
search_from_candidate_unbounded(
|
@@ -709,12 +851,11 @@ HNSWStats HNSW::search(
|
|
709
851
|
top_candidates.pop();
|
710
852
|
}
|
711
853
|
|
712
|
-
int nres = 0;
|
713
854
|
while (!top_candidates.empty()) {
|
714
855
|
float d;
|
715
856
|
storage_idx_t label;
|
716
857
|
std::tie(d, label) = top_candidates.top();
|
717
|
-
|
858
|
+
res.add_result(d, label);
|
718
859
|
top_candidates.pop();
|
719
860
|
}
|
720
861
|
}
|
@@ -728,6 +869,10 @@ HNSWStats HNSW::search(
|
|
728
869
|
std::vector<idx_t> I_to_next(candidates_size);
|
729
870
|
std::vector<float> D_to_next(candidates_size);
|
730
871
|
|
872
|
+
HeapBlockResultHandler<C> block_resh(
|
873
|
+
1, D_to_next.data(), I_to_next.data(), candidates_size);
|
874
|
+
HeapBlockResultHandler<C>::SingleResultHandler resh(block_resh);
|
875
|
+
|
731
876
|
int nres = 1;
|
732
877
|
I_to_next[0] = entry_point;
|
733
878
|
D_to_next[0] = qdis(entry_point);
|
@@ -743,18 +888,12 @@ HNSWStats HNSW::search(
|
|
743
888
|
|
744
889
|
if (level == 0) {
|
745
890
|
nres = search_from_candidates(
|
746
|
-
*this, qdis,
|
891
|
+
*this, qdis, res, candidates, vt, stats, 0);
|
747
892
|
} else {
|
893
|
+
resh.begin(0);
|
748
894
|
nres = search_from_candidates(
|
749
|
-
*this,
|
750
|
-
|
751
|
-
candidates_size,
|
752
|
-
I_to_next.data(),
|
753
|
-
D_to_next.data(),
|
754
|
-
candidates,
|
755
|
-
vt,
|
756
|
-
stats,
|
757
|
-
level);
|
895
|
+
*this, qdis, resh, candidates, vt, stats, level);
|
896
|
+
resh.end();
|
758
897
|
}
|
759
898
|
vt.advance();
|
760
899
|
}
|
@@ -765,9 +904,7 @@ HNSWStats HNSW::search(
|
|
765
904
|
|
766
905
|
void HNSW::search_level_0(
|
767
906
|
DistanceComputer& qdis,
|
768
|
-
|
769
|
-
idx_t* idxi,
|
770
|
-
float* simi,
|
907
|
+
ResultHandler<C>& res,
|
771
908
|
idx_t nprobe,
|
772
909
|
const storage_idx_t* nearest_i,
|
773
910
|
const float* nearest_d,
|
@@ -775,7 +912,7 @@ void HNSW::search_level_0(
|
|
775
912
|
HNSWStats& search_stats,
|
776
913
|
VisitedTable& vt) const {
|
777
914
|
const HNSW& hnsw = *this;
|
778
|
-
|
915
|
+
int k = extract_k_from_ResultHandler(res);
|
779
916
|
if (search_type == 1) {
|
780
917
|
int nres = 0;
|
781
918
|
|
@@ -788,22 +925,13 @@ void HNSW::search_level_0(
|
|
788
925
|
if (vt.get(cj))
|
789
926
|
continue;
|
790
927
|
|
791
|
-
int candidates_size = std::max(hnsw.efSearch,
|
928
|
+
int candidates_size = std::max(hnsw.efSearch, k);
|
792
929
|
MinimaxHeap candidates(candidates_size);
|
793
930
|
|
794
931
|
candidates.push(cj, nearest_d[j]);
|
795
932
|
|
796
933
|
nres = search_from_candidates(
|
797
|
-
hnsw,
|
798
|
-
qdis,
|
799
|
-
k,
|
800
|
-
idxi,
|
801
|
-
simi,
|
802
|
-
candidates,
|
803
|
-
vt,
|
804
|
-
search_stats,
|
805
|
-
0,
|
806
|
-
nres);
|
934
|
+
hnsw, qdis, res, candidates, vt, search_stats, 0, nres);
|
807
935
|
}
|
808
936
|
} else if (search_type == 2) {
|
809
937
|
int candidates_size = std::max(hnsw.efSearch, int(k));
|
@@ -819,10 +947,43 @@ void HNSW::search_level_0(
|
|
819
947
|
}
|
820
948
|
|
821
949
|
search_from_candidates(
|
822
|
-
hnsw, qdis,
|
950
|
+
hnsw, qdis, res, candidates, vt, search_stats, 0);
|
823
951
|
}
|
824
952
|
}
|
825
953
|
|
954
|
+
void HNSW::permute_entries(const idx_t* map) {
|
955
|
+
// remap levels
|
956
|
+
storage_idx_t ntotal = levels.size();
|
957
|
+
std::vector<storage_idx_t> imap(ntotal); // inverse mapping
|
958
|
+
// map: new index -> old index
|
959
|
+
// imap: old index -> new index
|
960
|
+
for (int i = 0; i < ntotal; i++) {
|
961
|
+
assert(map[i] >= 0 && map[i] < ntotal);
|
962
|
+
imap[map[i]] = i;
|
963
|
+
}
|
964
|
+
if (entry_point != -1) {
|
965
|
+
entry_point = imap[entry_point];
|
966
|
+
}
|
967
|
+
std::vector<int> new_levels(ntotal);
|
968
|
+
std::vector<size_t> new_offsets(ntotal + 1);
|
969
|
+
std::vector<storage_idx_t> new_neighbors(neighbors.size());
|
970
|
+
size_t no = 0;
|
971
|
+
for (int i = 0; i < ntotal; i++) {
|
972
|
+
storage_idx_t o = map[i]; // corresponding "old" index
|
973
|
+
new_levels[i] = levels[o];
|
974
|
+
for (size_t j = offsets[o]; j < offsets[o + 1]; j++) {
|
975
|
+
storage_idx_t neigh = neighbors[j];
|
976
|
+
new_neighbors[no++] = neigh >= 0 ? imap[neigh] : neigh;
|
977
|
+
}
|
978
|
+
new_offsets[i + 1] = no;
|
979
|
+
}
|
980
|
+
assert(new_offsets[ntotal] == offsets[ntotal]);
|
981
|
+
// swap everyone
|
982
|
+
std::swap(levels, new_levels);
|
983
|
+
std::swap(offsets, new_offsets);
|
984
|
+
std::swap(neighbors, new_neighbors);
|
985
|
+
}
|
986
|
+
|
826
987
|
/**************************************************************
|
827
988
|
* MinimaxHeap
|
828
989
|
**************************************************************/
|
@@ -852,17 +1013,105 @@ void HNSW::MinimaxHeap::clear() {
|
|
852
1013
|
nvalid = k = 0;
|
853
1014
|
}
|
854
1015
|
|
1016
|
+
#ifdef __AVX2__
|
1017
|
+
int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
|
1018
|
+
assert(k > 0);
|
1019
|
+
static_assert(
|
1020
|
+
std::is_same<storage_idx_t, int32_t>::value,
|
1021
|
+
"This code expects storage_idx_t to be int32_t");
|
1022
|
+
|
1023
|
+
int32_t min_idx = -1;
|
1024
|
+
float min_dis = std::numeric_limits<float>::infinity();
|
1025
|
+
|
1026
|
+
size_t iii = 0;
|
1027
|
+
|
1028
|
+
__m256i min_indices = _mm256_setr_epi32(-1, -1, -1, -1, -1, -1, -1, -1);
|
1029
|
+
__m256 min_distances =
|
1030
|
+
_mm256_set1_ps(std::numeric_limits<float>::infinity());
|
1031
|
+
__m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
|
1032
|
+
__m256i offset = _mm256_set1_epi32(8);
|
1033
|
+
|
1034
|
+
// The baseline version is available in non-AVX2 branch.
|
1035
|
+
|
1036
|
+
// The following loop tracks the rightmost index with the min distance.
|
1037
|
+
// -1 index values are ignored.
|
1038
|
+
const int k8 = (k / 8) * 8;
|
1039
|
+
for (; iii < k8; iii += 8) {
|
1040
|
+
__m256i indices =
|
1041
|
+
_mm256_loadu_si256((const __m256i*)(ids.data() + iii));
|
1042
|
+
__m256 distances = _mm256_loadu_ps(dis.data() + iii);
|
1043
|
+
|
1044
|
+
// This mask filters out -1 values among indices.
|
1045
|
+
__m256i m1mask = _mm256_cmpgt_epi32(_mm256_setzero_si256(), indices);
|
1046
|
+
|
1047
|
+
__m256i dmask = _mm256_castps_si256(
|
1048
|
+
_mm256_cmp_ps(min_distances, distances, _CMP_LT_OS));
|
1049
|
+
__m256 finalmask = _mm256_castsi256_ps(_mm256_or_si256(m1mask, dmask));
|
1050
|
+
|
1051
|
+
const __m256i min_indices_new = _mm256_castps_si256(_mm256_blendv_ps(
|
1052
|
+
_mm256_castsi256_ps(current_indices),
|
1053
|
+
_mm256_castsi256_ps(min_indices),
|
1054
|
+
finalmask));
|
1055
|
+
|
1056
|
+
const __m256 min_distances_new =
|
1057
|
+
_mm256_blendv_ps(distances, min_distances, finalmask);
|
1058
|
+
|
1059
|
+
min_indices = min_indices_new;
|
1060
|
+
min_distances = min_distances_new;
|
1061
|
+
|
1062
|
+
current_indices = _mm256_add_epi32(current_indices, offset);
|
1063
|
+
}
|
1064
|
+
|
1065
|
+
// Vectorizing is doable, but is not practical
|
1066
|
+
int32_t vidx8[8];
|
1067
|
+
float vdis8[8];
|
1068
|
+
_mm256_storeu_ps(vdis8, min_distances);
|
1069
|
+
_mm256_storeu_si256((__m256i*)vidx8, min_indices);
|
1070
|
+
|
1071
|
+
for (size_t j = 0; j < 8; j++) {
|
1072
|
+
if (min_dis > vdis8[j] || (min_dis == vdis8[j] && min_idx < vidx8[j])) {
|
1073
|
+
min_idx = vidx8[j];
|
1074
|
+
min_dis = vdis8[j];
|
1075
|
+
}
|
1076
|
+
}
|
1077
|
+
|
1078
|
+
// process last values. Vectorizing is doable, but is not practical
|
1079
|
+
for (; iii < k; iii++) {
|
1080
|
+
if (ids[iii] != -1 && dis[iii] <= min_dis) {
|
1081
|
+
min_dis = dis[iii];
|
1082
|
+
min_idx = iii;
|
1083
|
+
}
|
1084
|
+
}
|
1085
|
+
|
1086
|
+
if (min_idx == -1) {
|
1087
|
+
return -1;
|
1088
|
+
}
|
1089
|
+
|
1090
|
+
if (vmin_out) {
|
1091
|
+
*vmin_out = min_dis;
|
1092
|
+
}
|
1093
|
+
int ret = ids[min_idx];
|
1094
|
+
ids[min_idx] = -1;
|
1095
|
+
--nvalid;
|
1096
|
+
return ret;
|
1097
|
+
}
|
1098
|
+
|
1099
|
+
#else
|
1100
|
+
|
1101
|
+
// baseline non-vectorized version
|
855
1102
|
int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
|
856
1103
|
assert(k > 0);
|
857
1104
|
// returns min. This is an O(n) operation
|
858
1105
|
int i = k - 1;
|
859
1106
|
while (i >= 0) {
|
860
|
-
if (ids[i] != -1)
|
1107
|
+
if (ids[i] != -1) {
|
861
1108
|
break;
|
1109
|
+
}
|
862
1110
|
i--;
|
863
1111
|
}
|
864
|
-
if (i == -1)
|
1112
|
+
if (i == -1) {
|
865
1113
|
return -1;
|
1114
|
+
}
|
866
1115
|
int imin = i;
|
867
1116
|
float vmin = dis[i];
|
868
1117
|
i--;
|
@@ -873,14 +1122,16 @@ int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
|
|
873
1122
|
}
|
874
1123
|
i--;
|
875
1124
|
}
|
876
|
-
if (vmin_out)
|
1125
|
+
if (vmin_out) {
|
877
1126
|
*vmin_out = vmin;
|
1127
|
+
}
|
878
1128
|
int ret = ids[imin];
|
879
1129
|
ids[imin] = -1;
|
880
1130
|
--nvalid;
|
881
1131
|
|
882
1132
|
return ret;
|
883
1133
|
}
|
1134
|
+
#endif
|
884
1135
|
|
885
1136
|
int HNSW::MinimaxHeap::count_below(float thresh) {
|
886
1137
|
int n_below = 0;
|
@@ -5,8 +5,6 @@
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
6
6
|
*/
|
7
7
|
|
8
|
-
// -*- c++ -*-
|
9
|
-
|
10
8
|
#pragma once
|
11
9
|
|
12
10
|
#include <queue>
|
@@ -42,6 +40,8 @@ namespace faiss {
|
|
42
40
|
struct VisitedTable;
|
43
41
|
struct DistanceComputer; // from AuxIndexStructures
|
44
42
|
struct HNSWStats;
|
43
|
+
template <class C>
|
44
|
+
struct ResultHandler;
|
45
45
|
|
46
46
|
struct SearchParametersHNSW : SearchParameters {
|
47
47
|
int efSearch = 16;
|
@@ -54,6 +54,9 @@ struct HNSW {
|
|
54
54
|
/// internal storage of vectors (32 bits: this is expensive)
|
55
55
|
using storage_idx_t = int32_t;
|
56
56
|
|
57
|
+
// for now we do only these distances
|
58
|
+
using C = CMax<float, int64_t>;
|
59
|
+
|
57
60
|
typedef std::pair<float, storage_idx_t> Node;
|
58
61
|
|
59
62
|
/** Heap structure that allows fast
|
@@ -195,18 +198,14 @@ struct HNSW {
|
|
195
198
|
/// search interface for 1 point, single thread
|
196
199
|
HNSWStats search(
|
197
200
|
DistanceComputer& qdis,
|
198
|
-
|
199
|
-
idx_t* I,
|
200
|
-
float* D,
|
201
|
+
ResultHandler<C>& res,
|
201
202
|
VisitedTable& vt,
|
202
203
|
const SearchParametersHNSW* params = nullptr) const;
|
203
204
|
|
204
205
|
/// search only in level 0 from a given vertex
|
205
206
|
void search_level_0(
|
206
207
|
DistanceComputer& qdis,
|
207
|
-
|
208
|
-
idx_t* idxi,
|
209
|
-
float* simi,
|
208
|
+
ResultHandler<C>& res,
|
210
209
|
idx_t nprobe,
|
211
210
|
const storage_idx_t* nearest_i,
|
212
211
|
const float* nearest_d,
|
@@ -226,6 +225,8 @@ struct HNSW {
|
|
226
225
|
std::priority_queue<NodeDistFarther>& input,
|
227
226
|
std::vector<NodeDistFarther>& output,
|
228
227
|
int max_size);
|
228
|
+
|
229
|
+
void permute_entries(const idx_t* map);
|
229
230
|
};
|
230
231
|
|
231
232
|
struct HNSWStats {
|
@@ -10,7 +10,7 @@
|
|
10
10
|
#include <unordered_set>
|
11
11
|
#include <vector>
|
12
12
|
|
13
|
-
#include <faiss/
|
13
|
+
#include <faiss/MetricType.h>
|
14
14
|
|
15
15
|
/** IDSelector is intended to define a subset of vectors to handle (for removal
|
16
16
|
* or as subset to search) */
|
@@ -140,7 +140,7 @@ struct IDSelectorAnd : IDSelector {
|
|
140
140
|
: lhs(lhs), rhs(rhs) {}
|
141
141
|
bool is_member(idx_t id) const final {
|
142
142
|
return lhs->is_member(id) && rhs->is_member(id);
|
143
|
-
}
|
143
|
+
}
|
144
144
|
virtual ~IDSelectorAnd() {}
|
145
145
|
};
|
146
146
|
|
@@ -153,7 +153,7 @@ struct IDSelectorOr : IDSelector {
|
|
153
153
|
: lhs(lhs), rhs(rhs) {}
|
154
154
|
bool is_member(idx_t id) const final {
|
155
155
|
return lhs->is_member(id) || rhs->is_member(id);
|
156
|
-
}
|
156
|
+
}
|
157
157
|
virtual ~IDSelectorOr() {}
|
158
158
|
};
|
159
159
|
|
@@ -166,7 +166,7 @@ struct IDSelectorXOr : IDSelector {
|
|
166
166
|
: lhs(lhs), rhs(rhs) {}
|
167
167
|
bool is_member(idx_t id) const final {
|
168
168
|
return lhs->is_member(id) ^ rhs->is_member(id);
|
169
|
-
}
|
169
|
+
}
|
170
170
|
virtual ~IDSelectorXOr() {}
|
171
171
|
};
|
172
172
|
|
@@ -628,7 +628,9 @@ void LocalSearchQuantizer::icm_encode_step(
|
|
628
628
|
{
|
629
629
|
size_t binary_idx = (other_m + 1) * M * K * K +
|
630
630
|
m * K * K + code2 * K + code;
|
631
|
-
_mm_prefetch(
|
631
|
+
_mm_prefetch(
|
632
|
+
(const char*)(binaries + binary_idx),
|
633
|
+
_MM_HINT_T0);
|
632
634
|
}
|
633
635
|
}
|
634
636
|
#endif
|