faiss 0.3.0 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -1
- data/ext/faiss/extconf.rb +9 -2
- data/ext/faiss/index.cpp +1 -1
- data/ext/faiss/index_binary.cpp +2 -2
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +7 -7
- data/vendor/faiss/faiss/AutoTune.h +0 -1
- data/vendor/faiss/faiss/Clustering.cpp +4 -18
- data/vendor/faiss/faiss/Clustering.h +31 -21
- data/vendor/faiss/faiss/IVFlib.cpp +22 -11
- data/vendor/faiss/faiss/Index.cpp +1 -1
- data/vendor/faiss/faiss/Index.h +20 -5
- data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
- data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinary.h +8 -19
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
- data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
- data/vendor/faiss/faiss/IndexFastScan.h +9 -8
- data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
- data/vendor/faiss/faiss/IndexFlat.h +20 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
- data/vendor/faiss/faiss/IndexHNSW.h +12 -48
- data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
- data/vendor/faiss/faiss/IndexIDMap.h +24 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
- data/vendor/faiss/faiss/IndexIVF.h +37 -5
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
- data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
- data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
- data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
- data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
- data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
- data/vendor/faiss/faiss/IndexNSG.h +10 -10
- data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
- data/vendor/faiss/faiss/IndexPQ.h +1 -4
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
- data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
- data/vendor/faiss/faiss/IndexRefine.h +7 -0
- data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
- data/vendor/faiss/faiss/IndexShards.cpp +21 -29
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
- data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
- data/vendor/faiss/faiss/MatrixStats.h +21 -9
- data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
- data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
- data/vendor/faiss/faiss/VectorTransform.h +7 -7
- data/vendor/faiss/faiss/clone_index.cpp +15 -10
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
- data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
- data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
- data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
- data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
- data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
- data/vendor/faiss/faiss/impl/FaissException.h +13 -34
- data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
- data/vendor/faiss/faiss/impl/HNSW.h +9 -8
- data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
- data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
- data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
- data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
- data/vendor/faiss/faiss/impl/io.cpp +10 -10
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
- data/vendor/faiss/faiss/index_factory.cpp +10 -7
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
- data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
- data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
- data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
- data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
- data/vendor/faiss/faiss/utils/distances.cpp +128 -74
- data/vendor/faiss/faiss/utils/distances.h +81 -4
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
- data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
- data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
- data/vendor/faiss/faiss/utils/fp16.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
- data/vendor/faiss/faiss/utils/hamming.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
- data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
- data/vendor/faiss/faiss/utils/prefetch.h +77 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
- data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
- data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
- data/vendor/faiss/faiss/utils/sorting.h +27 -0
- data/vendor/faiss/faiss/utils/utils.cpp +112 -6
- data/vendor/faiss/faiss/utils/utils.h +57 -20
- metadata +10 -3
|
@@ -5,8 +5,6 @@
|
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
|
6
6
|
*/
|
|
7
7
|
|
|
8
|
-
// -*- c++ -*-
|
|
9
|
-
|
|
10
8
|
#include <faiss/impl/HNSW.h>
|
|
11
9
|
|
|
12
10
|
#include <string>
|
|
@@ -14,6 +12,17 @@
|
|
|
14
12
|
#include <faiss/impl/AuxIndexStructures.h>
|
|
15
13
|
#include <faiss/impl/DistanceComputer.h>
|
|
16
14
|
#include <faiss/impl/IDSelector.h>
|
|
15
|
+
#include <faiss/impl/ResultHandler.h>
|
|
16
|
+
#include <faiss/utils/prefetch.h>
|
|
17
|
+
|
|
18
|
+
#include <faiss/impl/platform_macros.h>
|
|
19
|
+
|
|
20
|
+
#ifdef __AVX2__
|
|
21
|
+
#include <immintrin.h>
|
|
22
|
+
|
|
23
|
+
#include <limits>
|
|
24
|
+
#include <type_traits>
|
|
25
|
+
#endif
|
|
17
26
|
|
|
18
27
|
namespace faiss {
|
|
19
28
|
|
|
@@ -212,7 +221,7 @@ int HNSW::prepare_level_tab(size_t n, bool preset_levels) {
|
|
|
212
221
|
return max_level;
|
|
213
222
|
}
|
|
214
223
|
|
|
215
|
-
/** Enumerate vertices from
|
|
224
|
+
/** Enumerate vertices from nearest to farthest from query, keep a
|
|
216
225
|
* neighbor only if there is no previous neighbor that is closer to
|
|
217
226
|
* that vertex than the query.
|
|
218
227
|
*/
|
|
@@ -503,17 +512,15 @@ void HNSW::add_with_locks(
|
|
|
503
512
|
**************************************************************/
|
|
504
513
|
|
|
505
514
|
namespace {
|
|
506
|
-
|
|
507
515
|
using MinimaxHeap = HNSW::MinimaxHeap;
|
|
508
516
|
using Node = HNSW::Node;
|
|
517
|
+
using C = HNSW::C;
|
|
509
518
|
/** Do a BFS on the candidates list */
|
|
510
519
|
|
|
511
520
|
int search_from_candidates(
|
|
512
521
|
const HNSW& hnsw,
|
|
513
522
|
DistanceComputer& qdis,
|
|
514
|
-
|
|
515
|
-
idx_t* I,
|
|
516
|
-
float* D,
|
|
523
|
+
ResultHandler<C>& res,
|
|
517
524
|
MinimaxHeap& candidates,
|
|
518
525
|
VisitedTable& vt,
|
|
519
526
|
HNSWStats& stats,
|
|
@@ -529,15 +536,16 @@ int search_from_candidates(
|
|
|
529
536
|
int efSearch = params ? params->efSearch : hnsw.efSearch;
|
|
530
537
|
const IDSelector* sel = params ? params->sel : nullptr;
|
|
531
538
|
|
|
539
|
+
C::T threshold = res.threshold;
|
|
532
540
|
for (int i = 0; i < candidates.size(); i++) {
|
|
533
541
|
idx_t v1 = candidates.ids[i];
|
|
534
542
|
float d = candidates.dis[i];
|
|
535
543
|
FAISS_ASSERT(v1 >= 0);
|
|
536
544
|
if (!sel || sel->is_member(v1)) {
|
|
537
|
-
if (
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
545
|
+
if (d < threshold) {
|
|
546
|
+
if (res.add_result(d, v1)) {
|
|
547
|
+
threshold = res.threshold;
|
|
548
|
+
}
|
|
541
549
|
}
|
|
542
550
|
}
|
|
543
551
|
vt.set(v1);
|
|
@@ -563,24 +571,86 @@ int search_from_candidates(
|
|
|
563
571
|
size_t begin, end;
|
|
564
572
|
hnsw.neighbor_range(v0, level, &begin, &end);
|
|
565
573
|
|
|
574
|
+
// // baseline version
|
|
575
|
+
// for (size_t j = begin; j < end; j++) {
|
|
576
|
+
// int v1 = hnsw.neighbors[j];
|
|
577
|
+
// if (v1 < 0)
|
|
578
|
+
// break;
|
|
579
|
+
// if (vt.get(v1)) {
|
|
580
|
+
// continue;
|
|
581
|
+
// }
|
|
582
|
+
// vt.set(v1);
|
|
583
|
+
// ndis++;
|
|
584
|
+
// float d = qdis(v1);
|
|
585
|
+
// if (!sel || sel->is_member(v1)) {
|
|
586
|
+
// if (nres < k) {
|
|
587
|
+
// faiss::maxheap_push(++nres, D, I, d, v1);
|
|
588
|
+
// } else if (d < D[0]) {
|
|
589
|
+
// faiss::maxheap_replace_top(nres, D, I, d, v1);
|
|
590
|
+
// }
|
|
591
|
+
// }
|
|
592
|
+
// candidates.push(v1, d);
|
|
593
|
+
// }
|
|
594
|
+
|
|
595
|
+
// the following version processes 4 neighbors at a time
|
|
596
|
+
size_t jmax = begin;
|
|
566
597
|
for (size_t j = begin; j < end; j++) {
|
|
567
598
|
int v1 = hnsw.neighbors[j];
|
|
568
599
|
if (v1 < 0)
|
|
569
600
|
break;
|
|
570
|
-
|
|
571
|
-
|
|
601
|
+
|
|
602
|
+
prefetch_L2(vt.visited.data() + v1);
|
|
603
|
+
jmax += 1;
|
|
604
|
+
}
|
|
605
|
+
|
|
606
|
+
int counter = 0;
|
|
607
|
+
size_t saved_j[4];
|
|
608
|
+
|
|
609
|
+
ndis += jmax - begin;
|
|
610
|
+
threshold = res.threshold;
|
|
611
|
+
|
|
612
|
+
auto add_to_heap = [&](const size_t idx, const float dis) {
|
|
613
|
+
if (!sel || sel->is_member(idx)) {
|
|
614
|
+
if (dis < threshold) {
|
|
615
|
+
if (res.add_result(dis, idx)) {
|
|
616
|
+
threshold = res.threshold;
|
|
617
|
+
}
|
|
618
|
+
}
|
|
572
619
|
}
|
|
620
|
+
candidates.push(idx, dis);
|
|
621
|
+
};
|
|
622
|
+
|
|
623
|
+
for (size_t j = begin; j < jmax; j++) {
|
|
624
|
+
int v1 = hnsw.neighbors[j];
|
|
625
|
+
|
|
626
|
+
bool vget = vt.get(v1);
|
|
573
627
|
vt.set(v1);
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
628
|
+
saved_j[counter] = v1;
|
|
629
|
+
counter += vget ? 0 : 1;
|
|
630
|
+
|
|
631
|
+
if (counter == 4) {
|
|
632
|
+
float dis[4];
|
|
633
|
+
qdis.distances_batch_4(
|
|
634
|
+
saved_j[0],
|
|
635
|
+
saved_j[1],
|
|
636
|
+
saved_j[2],
|
|
637
|
+
saved_j[3],
|
|
638
|
+
dis[0],
|
|
639
|
+
dis[1],
|
|
640
|
+
dis[2],
|
|
641
|
+
dis[3]);
|
|
642
|
+
|
|
643
|
+
for (size_t id4 = 0; id4 < 4; id4++) {
|
|
644
|
+
add_to_heap(saved_j[id4], dis[id4]);
|
|
581
645
|
}
|
|
646
|
+
|
|
647
|
+
counter = 0;
|
|
582
648
|
}
|
|
583
|
-
|
|
649
|
+
}
|
|
650
|
+
|
|
651
|
+
for (size_t icnt = 0; icnt < counter; icnt++) {
|
|
652
|
+
float dis = qdis(saved_j[icnt]);
|
|
653
|
+
add_to_heap(saved_j[icnt], dis);
|
|
584
654
|
}
|
|
585
655
|
|
|
586
656
|
nstep++;
|
|
@@ -630,29 +700,92 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
|
|
|
630
700
|
size_t begin, end;
|
|
631
701
|
hnsw.neighbor_range(v0, 0, &begin, &end);
|
|
632
702
|
|
|
633
|
-
|
|
703
|
+
// // baseline version
|
|
704
|
+
// for (size_t j = begin; j < end; ++j) {
|
|
705
|
+
// int v1 = hnsw.neighbors[j];
|
|
706
|
+
//
|
|
707
|
+
// if (v1 < 0) {
|
|
708
|
+
// break;
|
|
709
|
+
// }
|
|
710
|
+
// if (vt->get(v1)) {
|
|
711
|
+
// continue;
|
|
712
|
+
// }
|
|
713
|
+
//
|
|
714
|
+
// vt->set(v1);
|
|
715
|
+
//
|
|
716
|
+
// float d1 = qdis(v1);
|
|
717
|
+
// ++ndis;
|
|
718
|
+
//
|
|
719
|
+
// if (top_candidates.top().first > d1 ||
|
|
720
|
+
// top_candidates.size() < ef) {
|
|
721
|
+
// candidates.emplace(d1, v1);
|
|
722
|
+
// top_candidates.emplace(d1, v1);
|
|
723
|
+
//
|
|
724
|
+
// if (top_candidates.size() > ef) {
|
|
725
|
+
// top_candidates.pop();
|
|
726
|
+
// }
|
|
727
|
+
// }
|
|
728
|
+
// }
|
|
729
|
+
|
|
730
|
+
// the following version processes 4 neighbors at a time
|
|
731
|
+
size_t jmax = begin;
|
|
732
|
+
for (size_t j = begin; j < end; j++) {
|
|
634
733
|
int v1 = hnsw.neighbors[j];
|
|
635
|
-
|
|
636
|
-
if (v1 < 0) {
|
|
734
|
+
if (v1 < 0)
|
|
637
735
|
break;
|
|
638
|
-
}
|
|
639
|
-
if (vt->get(v1)) {
|
|
640
|
-
continue;
|
|
641
|
-
}
|
|
642
736
|
|
|
643
|
-
vt->
|
|
737
|
+
prefetch_L2(vt->visited.data() + v1);
|
|
738
|
+
jmax += 1;
|
|
739
|
+
}
|
|
740
|
+
|
|
741
|
+
int counter = 0;
|
|
742
|
+
size_t saved_j[4];
|
|
644
743
|
|
|
645
|
-
|
|
646
|
-
++ndis;
|
|
744
|
+
ndis += jmax - begin;
|
|
647
745
|
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
top_candidates.
|
|
746
|
+
auto add_to_heap = [&](const size_t idx, const float dis) {
|
|
747
|
+
if (top_candidates.top().first > dis ||
|
|
748
|
+
top_candidates.size() < ef) {
|
|
749
|
+
candidates.emplace(dis, idx);
|
|
750
|
+
top_candidates.emplace(dis, idx);
|
|
651
751
|
|
|
652
752
|
if (top_candidates.size() > ef) {
|
|
653
753
|
top_candidates.pop();
|
|
654
754
|
}
|
|
655
755
|
}
|
|
756
|
+
};
|
|
757
|
+
|
|
758
|
+
for (size_t j = begin; j < jmax; j++) {
|
|
759
|
+
int v1 = hnsw.neighbors[j];
|
|
760
|
+
|
|
761
|
+
bool vget = vt->get(v1);
|
|
762
|
+
vt->set(v1);
|
|
763
|
+
saved_j[counter] = v1;
|
|
764
|
+
counter += vget ? 0 : 1;
|
|
765
|
+
|
|
766
|
+
if (counter == 4) {
|
|
767
|
+
float dis[4];
|
|
768
|
+
qdis.distances_batch_4(
|
|
769
|
+
saved_j[0],
|
|
770
|
+
saved_j[1],
|
|
771
|
+
saved_j[2],
|
|
772
|
+
saved_j[3],
|
|
773
|
+
dis[0],
|
|
774
|
+
dis[1],
|
|
775
|
+
dis[2],
|
|
776
|
+
dis[3]);
|
|
777
|
+
|
|
778
|
+
for (size_t id4 = 0; id4 < 4; id4++) {
|
|
779
|
+
add_to_heap(saved_j[id4], dis[id4]);
|
|
780
|
+
}
|
|
781
|
+
|
|
782
|
+
counter = 0;
|
|
783
|
+
}
|
|
784
|
+
}
|
|
785
|
+
|
|
786
|
+
for (size_t icnt = 0; icnt < counter; icnt++) {
|
|
787
|
+
float dis = qdis(saved_j[icnt]);
|
|
788
|
+
add_to_heap(saved_j[icnt], dis);
|
|
656
789
|
}
|
|
657
790
|
}
|
|
658
791
|
|
|
@@ -665,19 +798,28 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
|
|
|
665
798
|
return top_candidates;
|
|
666
799
|
}
|
|
667
800
|
|
|
801
|
+
// just used as a lower bound for the minmaxheap, but it is set for heap search
|
|
802
|
+
int extract_k_from_ResultHandler(ResultHandler<C>& res) {
|
|
803
|
+
using RH = HeapBlockResultHandler<C>;
|
|
804
|
+
if (auto hres = dynamic_cast<RH::SingleResultHandler*>(&res)) {
|
|
805
|
+
return hres->k;
|
|
806
|
+
}
|
|
807
|
+
return 1;
|
|
808
|
+
}
|
|
809
|
+
|
|
668
810
|
} // anonymous namespace
|
|
669
811
|
|
|
670
812
|
HNSWStats HNSW::search(
|
|
671
813
|
DistanceComputer& qdis,
|
|
672
|
-
|
|
673
|
-
idx_t* I,
|
|
674
|
-
float* D,
|
|
814
|
+
ResultHandler<C>& res,
|
|
675
815
|
VisitedTable& vt,
|
|
676
816
|
const SearchParametersHNSW* params) const {
|
|
677
817
|
HNSWStats stats;
|
|
678
818
|
if (entry_point == -1) {
|
|
679
819
|
return stats;
|
|
680
820
|
}
|
|
821
|
+
int k = extract_k_from_ResultHandler(res);
|
|
822
|
+
|
|
681
823
|
if (upper_beam == 1) {
|
|
682
824
|
// greedy search on upper levels
|
|
683
825
|
storage_idx_t nearest = entry_point;
|
|
@@ -687,14 +829,14 @@ HNSWStats HNSW::search(
|
|
|
687
829
|
greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
|
|
688
830
|
}
|
|
689
831
|
|
|
690
|
-
int ef = std::max(efSearch, k);
|
|
832
|
+
int ef = std::max(params ? params->efSearch : efSearch, k);
|
|
691
833
|
if (search_bounded_queue) { // this is the most common branch
|
|
692
834
|
MinimaxHeap candidates(ef);
|
|
693
835
|
|
|
694
836
|
candidates.push(nearest, d_nearest);
|
|
695
837
|
|
|
696
838
|
search_from_candidates(
|
|
697
|
-
*this, qdis,
|
|
839
|
+
*this, qdis, res, candidates, vt, stats, 0, 0, params);
|
|
698
840
|
} else {
|
|
699
841
|
std::priority_queue<Node> top_candidates =
|
|
700
842
|
search_from_candidate_unbounded(
|
|
@@ -709,12 +851,11 @@ HNSWStats HNSW::search(
|
|
|
709
851
|
top_candidates.pop();
|
|
710
852
|
}
|
|
711
853
|
|
|
712
|
-
int nres = 0;
|
|
713
854
|
while (!top_candidates.empty()) {
|
|
714
855
|
float d;
|
|
715
856
|
storage_idx_t label;
|
|
716
857
|
std::tie(d, label) = top_candidates.top();
|
|
717
|
-
|
|
858
|
+
res.add_result(d, label);
|
|
718
859
|
top_candidates.pop();
|
|
719
860
|
}
|
|
720
861
|
}
|
|
@@ -728,6 +869,10 @@ HNSWStats HNSW::search(
|
|
|
728
869
|
std::vector<idx_t> I_to_next(candidates_size);
|
|
729
870
|
std::vector<float> D_to_next(candidates_size);
|
|
730
871
|
|
|
872
|
+
HeapBlockResultHandler<C> block_resh(
|
|
873
|
+
1, D_to_next.data(), I_to_next.data(), candidates_size);
|
|
874
|
+
HeapBlockResultHandler<C>::SingleResultHandler resh(block_resh);
|
|
875
|
+
|
|
731
876
|
int nres = 1;
|
|
732
877
|
I_to_next[0] = entry_point;
|
|
733
878
|
D_to_next[0] = qdis(entry_point);
|
|
@@ -743,18 +888,12 @@ HNSWStats HNSW::search(
|
|
|
743
888
|
|
|
744
889
|
if (level == 0) {
|
|
745
890
|
nres = search_from_candidates(
|
|
746
|
-
*this, qdis,
|
|
891
|
+
*this, qdis, res, candidates, vt, stats, 0);
|
|
747
892
|
} else {
|
|
893
|
+
resh.begin(0);
|
|
748
894
|
nres = search_from_candidates(
|
|
749
|
-
*this,
|
|
750
|
-
|
|
751
|
-
candidates_size,
|
|
752
|
-
I_to_next.data(),
|
|
753
|
-
D_to_next.data(),
|
|
754
|
-
candidates,
|
|
755
|
-
vt,
|
|
756
|
-
stats,
|
|
757
|
-
level);
|
|
895
|
+
*this, qdis, resh, candidates, vt, stats, level);
|
|
896
|
+
resh.end();
|
|
758
897
|
}
|
|
759
898
|
vt.advance();
|
|
760
899
|
}
|
|
@@ -765,9 +904,7 @@ HNSWStats HNSW::search(
|
|
|
765
904
|
|
|
766
905
|
void HNSW::search_level_0(
|
|
767
906
|
DistanceComputer& qdis,
|
|
768
|
-
|
|
769
|
-
idx_t* idxi,
|
|
770
|
-
float* simi,
|
|
907
|
+
ResultHandler<C>& res,
|
|
771
908
|
idx_t nprobe,
|
|
772
909
|
const storage_idx_t* nearest_i,
|
|
773
910
|
const float* nearest_d,
|
|
@@ -775,7 +912,7 @@ void HNSW::search_level_0(
|
|
|
775
912
|
HNSWStats& search_stats,
|
|
776
913
|
VisitedTable& vt) const {
|
|
777
914
|
const HNSW& hnsw = *this;
|
|
778
|
-
|
|
915
|
+
int k = extract_k_from_ResultHandler(res);
|
|
779
916
|
if (search_type == 1) {
|
|
780
917
|
int nres = 0;
|
|
781
918
|
|
|
@@ -788,22 +925,13 @@ void HNSW::search_level_0(
|
|
|
788
925
|
if (vt.get(cj))
|
|
789
926
|
continue;
|
|
790
927
|
|
|
791
|
-
int candidates_size = std::max(hnsw.efSearch,
|
|
928
|
+
int candidates_size = std::max(hnsw.efSearch, k);
|
|
792
929
|
MinimaxHeap candidates(candidates_size);
|
|
793
930
|
|
|
794
931
|
candidates.push(cj, nearest_d[j]);
|
|
795
932
|
|
|
796
933
|
nres = search_from_candidates(
|
|
797
|
-
hnsw,
|
|
798
|
-
qdis,
|
|
799
|
-
k,
|
|
800
|
-
idxi,
|
|
801
|
-
simi,
|
|
802
|
-
candidates,
|
|
803
|
-
vt,
|
|
804
|
-
search_stats,
|
|
805
|
-
0,
|
|
806
|
-
nres);
|
|
934
|
+
hnsw, qdis, res, candidates, vt, search_stats, 0, nres);
|
|
807
935
|
}
|
|
808
936
|
} else if (search_type == 2) {
|
|
809
937
|
int candidates_size = std::max(hnsw.efSearch, int(k));
|
|
@@ -819,10 +947,43 @@ void HNSW::search_level_0(
|
|
|
819
947
|
}
|
|
820
948
|
|
|
821
949
|
search_from_candidates(
|
|
822
|
-
hnsw, qdis,
|
|
950
|
+
hnsw, qdis, res, candidates, vt, search_stats, 0);
|
|
823
951
|
}
|
|
824
952
|
}
|
|
825
953
|
|
|
954
|
+
void HNSW::permute_entries(const idx_t* map) {
|
|
955
|
+
// remap levels
|
|
956
|
+
storage_idx_t ntotal = levels.size();
|
|
957
|
+
std::vector<storage_idx_t> imap(ntotal); // inverse mapping
|
|
958
|
+
// map: new index -> old index
|
|
959
|
+
// imap: old index -> new index
|
|
960
|
+
for (int i = 0; i < ntotal; i++) {
|
|
961
|
+
assert(map[i] >= 0 && map[i] < ntotal);
|
|
962
|
+
imap[map[i]] = i;
|
|
963
|
+
}
|
|
964
|
+
if (entry_point != -1) {
|
|
965
|
+
entry_point = imap[entry_point];
|
|
966
|
+
}
|
|
967
|
+
std::vector<int> new_levels(ntotal);
|
|
968
|
+
std::vector<size_t> new_offsets(ntotal + 1);
|
|
969
|
+
std::vector<storage_idx_t> new_neighbors(neighbors.size());
|
|
970
|
+
size_t no = 0;
|
|
971
|
+
for (int i = 0; i < ntotal; i++) {
|
|
972
|
+
storage_idx_t o = map[i]; // corresponding "old" index
|
|
973
|
+
new_levels[i] = levels[o];
|
|
974
|
+
for (size_t j = offsets[o]; j < offsets[o + 1]; j++) {
|
|
975
|
+
storage_idx_t neigh = neighbors[j];
|
|
976
|
+
new_neighbors[no++] = neigh >= 0 ? imap[neigh] : neigh;
|
|
977
|
+
}
|
|
978
|
+
new_offsets[i + 1] = no;
|
|
979
|
+
}
|
|
980
|
+
assert(new_offsets[ntotal] == offsets[ntotal]);
|
|
981
|
+
// swap everyone
|
|
982
|
+
std::swap(levels, new_levels);
|
|
983
|
+
std::swap(offsets, new_offsets);
|
|
984
|
+
std::swap(neighbors, new_neighbors);
|
|
985
|
+
}
|
|
986
|
+
|
|
826
987
|
/**************************************************************
|
|
827
988
|
* MinimaxHeap
|
|
828
989
|
**************************************************************/
|
|
@@ -852,17 +1013,105 @@ void HNSW::MinimaxHeap::clear() {
|
|
|
852
1013
|
nvalid = k = 0;
|
|
853
1014
|
}
|
|
854
1015
|
|
|
1016
|
+
#ifdef __AVX2__
|
|
1017
|
+
int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
|
|
1018
|
+
assert(k > 0);
|
|
1019
|
+
static_assert(
|
|
1020
|
+
std::is_same<storage_idx_t, int32_t>::value,
|
|
1021
|
+
"This code expects storage_idx_t to be int32_t");
|
|
1022
|
+
|
|
1023
|
+
int32_t min_idx = -1;
|
|
1024
|
+
float min_dis = std::numeric_limits<float>::infinity();
|
|
1025
|
+
|
|
1026
|
+
size_t iii = 0;
|
|
1027
|
+
|
|
1028
|
+
__m256i min_indices = _mm256_setr_epi32(-1, -1, -1, -1, -1, -1, -1, -1);
|
|
1029
|
+
__m256 min_distances =
|
|
1030
|
+
_mm256_set1_ps(std::numeric_limits<float>::infinity());
|
|
1031
|
+
__m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
|
|
1032
|
+
__m256i offset = _mm256_set1_epi32(8);
|
|
1033
|
+
|
|
1034
|
+
// The baseline version is available in non-AVX2 branch.
|
|
1035
|
+
|
|
1036
|
+
// The following loop tracks the rightmost index with the min distance.
|
|
1037
|
+
// -1 index values are ignored.
|
|
1038
|
+
const int k8 = (k / 8) * 8;
|
|
1039
|
+
for (; iii < k8; iii += 8) {
|
|
1040
|
+
__m256i indices =
|
|
1041
|
+
_mm256_loadu_si256((const __m256i*)(ids.data() + iii));
|
|
1042
|
+
__m256 distances = _mm256_loadu_ps(dis.data() + iii);
|
|
1043
|
+
|
|
1044
|
+
// This mask filters out -1 values among indices.
|
|
1045
|
+
__m256i m1mask = _mm256_cmpgt_epi32(_mm256_setzero_si256(), indices);
|
|
1046
|
+
|
|
1047
|
+
__m256i dmask = _mm256_castps_si256(
|
|
1048
|
+
_mm256_cmp_ps(min_distances, distances, _CMP_LT_OS));
|
|
1049
|
+
__m256 finalmask = _mm256_castsi256_ps(_mm256_or_si256(m1mask, dmask));
|
|
1050
|
+
|
|
1051
|
+
const __m256i min_indices_new = _mm256_castps_si256(_mm256_blendv_ps(
|
|
1052
|
+
_mm256_castsi256_ps(current_indices),
|
|
1053
|
+
_mm256_castsi256_ps(min_indices),
|
|
1054
|
+
finalmask));
|
|
1055
|
+
|
|
1056
|
+
const __m256 min_distances_new =
|
|
1057
|
+
_mm256_blendv_ps(distances, min_distances, finalmask);
|
|
1058
|
+
|
|
1059
|
+
min_indices = min_indices_new;
|
|
1060
|
+
min_distances = min_distances_new;
|
|
1061
|
+
|
|
1062
|
+
current_indices = _mm256_add_epi32(current_indices, offset);
|
|
1063
|
+
}
|
|
1064
|
+
|
|
1065
|
+
// Vectorizing is doable, but is not practical
|
|
1066
|
+
int32_t vidx8[8];
|
|
1067
|
+
float vdis8[8];
|
|
1068
|
+
_mm256_storeu_ps(vdis8, min_distances);
|
|
1069
|
+
_mm256_storeu_si256((__m256i*)vidx8, min_indices);
|
|
1070
|
+
|
|
1071
|
+
for (size_t j = 0; j < 8; j++) {
|
|
1072
|
+
if (min_dis > vdis8[j] || (min_dis == vdis8[j] && min_idx < vidx8[j])) {
|
|
1073
|
+
min_idx = vidx8[j];
|
|
1074
|
+
min_dis = vdis8[j];
|
|
1075
|
+
}
|
|
1076
|
+
}
|
|
1077
|
+
|
|
1078
|
+
// process last values. Vectorizing is doable, but is not practical
|
|
1079
|
+
for (; iii < k; iii++) {
|
|
1080
|
+
if (ids[iii] != -1 && dis[iii] <= min_dis) {
|
|
1081
|
+
min_dis = dis[iii];
|
|
1082
|
+
min_idx = iii;
|
|
1083
|
+
}
|
|
1084
|
+
}
|
|
1085
|
+
|
|
1086
|
+
if (min_idx == -1) {
|
|
1087
|
+
return -1;
|
|
1088
|
+
}
|
|
1089
|
+
|
|
1090
|
+
if (vmin_out) {
|
|
1091
|
+
*vmin_out = min_dis;
|
|
1092
|
+
}
|
|
1093
|
+
int ret = ids[min_idx];
|
|
1094
|
+
ids[min_idx] = -1;
|
|
1095
|
+
--nvalid;
|
|
1096
|
+
return ret;
|
|
1097
|
+
}
|
|
1098
|
+
|
|
1099
|
+
#else
|
|
1100
|
+
|
|
1101
|
+
// baseline non-vectorized version
|
|
855
1102
|
int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
|
|
856
1103
|
assert(k > 0);
|
|
857
1104
|
// returns min. This is an O(n) operation
|
|
858
1105
|
int i = k - 1;
|
|
859
1106
|
while (i >= 0) {
|
|
860
|
-
if (ids[i] != -1)
|
|
1107
|
+
if (ids[i] != -1) {
|
|
861
1108
|
break;
|
|
1109
|
+
}
|
|
862
1110
|
i--;
|
|
863
1111
|
}
|
|
864
|
-
if (i == -1)
|
|
1112
|
+
if (i == -1) {
|
|
865
1113
|
return -1;
|
|
1114
|
+
}
|
|
866
1115
|
int imin = i;
|
|
867
1116
|
float vmin = dis[i];
|
|
868
1117
|
i--;
|
|
@@ -873,14 +1122,16 @@ int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
|
|
|
873
1122
|
}
|
|
874
1123
|
i--;
|
|
875
1124
|
}
|
|
876
|
-
if (vmin_out)
|
|
1125
|
+
if (vmin_out) {
|
|
877
1126
|
*vmin_out = vmin;
|
|
1127
|
+
}
|
|
878
1128
|
int ret = ids[imin];
|
|
879
1129
|
ids[imin] = -1;
|
|
880
1130
|
--nvalid;
|
|
881
1131
|
|
|
882
1132
|
return ret;
|
|
883
1133
|
}
|
|
1134
|
+
#endif
|
|
884
1135
|
|
|
885
1136
|
int HNSW::MinimaxHeap::count_below(float thresh) {
|
|
886
1137
|
int n_below = 0;
|
|
@@ -5,8 +5,6 @@
|
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
|
6
6
|
*/
|
|
7
7
|
|
|
8
|
-
// -*- c++ -*-
|
|
9
|
-
|
|
10
8
|
#pragma once
|
|
11
9
|
|
|
12
10
|
#include <queue>
|
|
@@ -42,6 +40,8 @@ namespace faiss {
|
|
|
42
40
|
struct VisitedTable;
|
|
43
41
|
struct DistanceComputer; // from AuxIndexStructures
|
|
44
42
|
struct HNSWStats;
|
|
43
|
+
template <class C>
|
|
44
|
+
struct ResultHandler;
|
|
45
45
|
|
|
46
46
|
struct SearchParametersHNSW : SearchParameters {
|
|
47
47
|
int efSearch = 16;
|
|
@@ -54,6 +54,9 @@ struct HNSW {
|
|
|
54
54
|
/// internal storage of vectors (32 bits: this is expensive)
|
|
55
55
|
using storage_idx_t = int32_t;
|
|
56
56
|
|
|
57
|
+
// for now we do only these distances
|
|
58
|
+
using C = CMax<float, int64_t>;
|
|
59
|
+
|
|
57
60
|
typedef std::pair<float, storage_idx_t> Node;
|
|
58
61
|
|
|
59
62
|
/** Heap structure that allows fast
|
|
@@ -195,18 +198,14 @@ struct HNSW {
|
|
|
195
198
|
/// search interface for 1 point, single thread
|
|
196
199
|
HNSWStats search(
|
|
197
200
|
DistanceComputer& qdis,
|
|
198
|
-
|
|
199
|
-
idx_t* I,
|
|
200
|
-
float* D,
|
|
201
|
+
ResultHandler<C>& res,
|
|
201
202
|
VisitedTable& vt,
|
|
202
203
|
const SearchParametersHNSW* params = nullptr) const;
|
|
203
204
|
|
|
204
205
|
/// search only in level 0 from a given vertex
|
|
205
206
|
void search_level_0(
|
|
206
207
|
DistanceComputer& qdis,
|
|
207
|
-
|
|
208
|
-
idx_t* idxi,
|
|
209
|
-
float* simi,
|
|
208
|
+
ResultHandler<C>& res,
|
|
210
209
|
idx_t nprobe,
|
|
211
210
|
const storage_idx_t* nearest_i,
|
|
212
211
|
const float* nearest_d,
|
|
@@ -226,6 +225,8 @@ struct HNSW {
|
|
|
226
225
|
std::priority_queue<NodeDistFarther>& input,
|
|
227
226
|
std::vector<NodeDistFarther>& output,
|
|
228
227
|
int max_size);
|
|
228
|
+
|
|
229
|
+
void permute_entries(const idx_t* map);
|
|
229
230
|
};
|
|
230
231
|
|
|
231
232
|
struct HNSWStats {
|
|
@@ -10,7 +10,7 @@
|
|
|
10
10
|
#include <unordered_set>
|
|
11
11
|
#include <vector>
|
|
12
12
|
|
|
13
|
-
#include <faiss/
|
|
13
|
+
#include <faiss/MetricType.h>
|
|
14
14
|
|
|
15
15
|
/** IDSelector is intended to define a subset of vectors to handle (for removal
|
|
16
16
|
* or as subset to search) */
|
|
@@ -140,7 +140,7 @@ struct IDSelectorAnd : IDSelector {
|
|
|
140
140
|
: lhs(lhs), rhs(rhs) {}
|
|
141
141
|
bool is_member(idx_t id) const final {
|
|
142
142
|
return lhs->is_member(id) && rhs->is_member(id);
|
|
143
|
-
}
|
|
143
|
+
}
|
|
144
144
|
virtual ~IDSelectorAnd() {}
|
|
145
145
|
};
|
|
146
146
|
|
|
@@ -153,7 +153,7 @@ struct IDSelectorOr : IDSelector {
|
|
|
153
153
|
: lhs(lhs), rhs(rhs) {}
|
|
154
154
|
bool is_member(idx_t id) const final {
|
|
155
155
|
return lhs->is_member(id) || rhs->is_member(id);
|
|
156
|
-
}
|
|
156
|
+
}
|
|
157
157
|
virtual ~IDSelectorOr() {}
|
|
158
158
|
};
|
|
159
159
|
|
|
@@ -166,7 +166,7 @@ struct IDSelectorXOr : IDSelector {
|
|
|
166
166
|
: lhs(lhs), rhs(rhs) {}
|
|
167
167
|
bool is_member(idx_t id) const final {
|
|
168
168
|
return lhs->is_member(id) ^ rhs->is_member(id);
|
|
169
|
-
}
|
|
169
|
+
}
|
|
170
170
|
virtual ~IDSelectorXOr() {}
|
|
171
171
|
};
|
|
172
172
|
|
|
@@ -628,7 +628,9 @@ void LocalSearchQuantizer::icm_encode_step(
|
|
|
628
628
|
{
|
|
629
629
|
size_t binary_idx = (other_m + 1) * M * K * K +
|
|
630
630
|
m * K * K + code2 * K + code;
|
|
631
|
-
_mm_prefetch(
|
|
631
|
+
_mm_prefetch(
|
|
632
|
+
(const char*)(binaries + binary_idx),
|
|
633
|
+
_MM_HINT_T0);
|
|
632
634
|
}
|
|
633
635
|
}
|
|
634
636
|
#endif
|