faiss 0.3.1 → 0.3.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +8 -0
- data/LICENSE.txt +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +2 -2
- data/vendor/faiss/faiss/AutoTune.h +3 -3
- data/vendor/faiss/faiss/Clustering.cpp +37 -6
- data/vendor/faiss/faiss/Clustering.h +12 -3
- data/vendor/faiss/faiss/IVFlib.cpp +6 -3
- data/vendor/faiss/faiss/IVFlib.h +2 -2
- data/vendor/faiss/faiss/Index.cpp +6 -2
- data/vendor/faiss/faiss/Index.h +30 -8
- data/vendor/faiss/faiss/Index2Layer.cpp +2 -2
- data/vendor/faiss/faiss/Index2Layer.h +2 -2
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +7 -7
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +2 -2
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +14 -16
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +2 -2
- data/vendor/faiss/faiss/IndexBinary.cpp +13 -2
- data/vendor/faiss/faiss/IndexBinary.h +8 -2
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -3
- data/vendor/faiss/faiss/IndexBinaryFlat.h +2 -2
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -2
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -2
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +2 -7
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +3 -3
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -3
- data/vendor/faiss/faiss/IndexBinaryHash.h +2 -2
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +3 -3
- data/vendor/faiss/faiss/IndexBinaryIVF.h +2 -2
- data/vendor/faiss/faiss/IndexFastScan.cpp +32 -18
- data/vendor/faiss/faiss/IndexFastScan.h +11 -2
- data/vendor/faiss/faiss/IndexFlat.cpp +13 -10
- data/vendor/faiss/faiss/IndexFlat.h +2 -2
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -7
- data/vendor/faiss/faiss/IndexFlatCodes.h +25 -5
- data/vendor/faiss/faiss/IndexHNSW.cpp +156 -96
- data/vendor/faiss/faiss/IndexHNSW.h +54 -5
- data/vendor/faiss/faiss/IndexIDMap.cpp +19 -3
- data/vendor/faiss/faiss/IndexIDMap.h +5 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +5 -6
- data/vendor/faiss/faiss/IndexIVF.h +13 -4
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +21 -7
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +5 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +3 -14
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -4
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +201 -91
- data/vendor/faiss/faiss/IndexIVFFastScan.h +33 -9
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +2 -2
- data/vendor/faiss/faiss/IndexIVFFlat.h +2 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +2 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +2 -2
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -6
- data/vendor/faiss/faiss/IndexIVFPQ.h +2 -2
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +7 -14
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +2 -4
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -2
- data/vendor/faiss/faiss/IndexIVFPQR.h +2 -2
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +2 -3
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -2
- data/vendor/faiss/faiss/IndexLSH.cpp +2 -3
- data/vendor/faiss/faiss/IndexLSH.h +2 -2
- data/vendor/faiss/faiss/IndexLattice.cpp +3 -21
- data/vendor/faiss/faiss/IndexLattice.h +5 -24
- data/vendor/faiss/faiss/IndexNNDescent.cpp +2 -31
- data/vendor/faiss/faiss/IndexNNDescent.h +3 -3
- data/vendor/faiss/faiss/IndexNSG.cpp +2 -5
- data/vendor/faiss/faiss/IndexNSG.h +3 -3
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
- data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +26 -26
- data/vendor/faiss/faiss/IndexPQ.h +2 -2
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +2 -5
- data/vendor/faiss/faiss/IndexPQFastScan.h +2 -11
- data/vendor/faiss/faiss/IndexPreTransform.cpp +2 -2
- data/vendor/faiss/faiss/IndexPreTransform.h +3 -3
- data/vendor/faiss/faiss/IndexRefine.cpp +46 -9
- data/vendor/faiss/faiss/IndexRefine.h +9 -2
- data/vendor/faiss/faiss/IndexReplicas.cpp +2 -2
- data/vendor/faiss/faiss/IndexReplicas.h +2 -2
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +2 -2
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +2 -2
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +5 -4
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -2
- data/vendor/faiss/faiss/IndexShards.cpp +2 -2
- data/vendor/faiss/faiss/IndexShards.h +2 -2
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +2 -2
- data/vendor/faiss/faiss/IndexShardsIVF.h +2 -2
- data/vendor/faiss/faiss/MatrixStats.cpp +2 -2
- data/vendor/faiss/faiss/MatrixStats.h +2 -2
- data/vendor/faiss/faiss/MetaIndexes.cpp +2 -3
- data/vendor/faiss/faiss/MetaIndexes.h +2 -2
- data/vendor/faiss/faiss/MetricType.h +9 -4
- data/vendor/faiss/faiss/VectorTransform.cpp +2 -2
- data/vendor/faiss/faiss/VectorTransform.h +2 -2
- data/vendor/faiss/faiss/clone_index.cpp +2 -2
- data/vendor/faiss/faiss/clone_index.h +2 -2
- data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +2 -2
- data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +2 -2
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +97 -19
- data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +192 -0
- data/vendor/faiss/faiss/cppcontrib/factory_tools.h +29 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2 -2
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +85 -32
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2 -2
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +2 -2
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +2 -2
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +2 -2
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +2 -2
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +2 -5
- data/vendor/faiss/faiss/gpu/GpuAutoTune.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +45 -13
- data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +12 -6
- data/vendor/faiss/faiss/gpu/GpuDistance.h +11 -7
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +3 -3
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuIndex.h +10 -15
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +285 -0
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -2
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +4 -2
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +3 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +7 -2
- data/vendor/faiss/faiss/gpu/GpuResources.h +11 -4
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +66 -11
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +15 -5
- data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -2
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +28 -23
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +2 -2
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +2 -2
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +2 -2
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +2 -2
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +8 -2
- data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +2 -3
- data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +2 -2
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +10 -7
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +2 -2
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +54 -54
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +144 -77
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +51 -51
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +2 -2
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +3 -3
- data/vendor/faiss/faiss/gpu/test/TestGpuResidualQuantizer.cpp +70 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +74 -4
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +2 -2
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +3 -3
- data/vendor/faiss/faiss/gpu/utils/{RaftUtils.h → CuvsUtils.h} +12 -11
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +8 -2
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +2 -2
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +2 -2
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +2 -2
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +6 -3
- data/vendor/faiss/faiss/gpu/utils/Timer.h +3 -3
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +79 -11
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +17 -5
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +27 -2
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +11 -3
- data/vendor/faiss/faiss/impl/CodePacker.cpp +2 -2
- data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
- data/vendor/faiss/faiss/impl/DistanceComputer.h +48 -2
- data/vendor/faiss/faiss/impl/FaissAssert.h +6 -4
- data/vendor/faiss/faiss/impl/FaissException.cpp +2 -2
- data/vendor/faiss/faiss/impl/FaissException.h +2 -3
- data/vendor/faiss/faiss/impl/HNSW.cpp +378 -205
- data/vendor/faiss/faiss/impl/HNSW.h +55 -24
- data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
- data/vendor/faiss/faiss/impl/IDSelector.h +2 -2
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +10 -10
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +2 -2
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +36 -2
- data/vendor/faiss/faiss/impl/NNDescent.cpp +15 -10
- data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
- data/vendor/faiss/faiss/impl/NSG.cpp +26 -49
- data/vendor/faiss/faiss/impl/NSG.h +20 -8
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +2 -2
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +2 -2
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +2 -4
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +2 -2
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -2
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +3 -2
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +7 -3
- data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +2 -36
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +3 -13
- data/vendor/faiss/faiss/impl/ResultHandler.h +153 -34
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +721 -104
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +5 -2
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +2 -2
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +2 -2
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +7 -2
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +2 -2
- data/vendor/faiss/faiss/impl/code_distance/code_distance-sve.h +440 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +55 -2
- data/vendor/faiss/faiss/impl/index_read.cpp +31 -20
- data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
- data/vendor/faiss/faiss/impl/index_write.cpp +30 -16
- data/vendor/faiss/faiss/impl/io.cpp +15 -7
- data/vendor/faiss/faiss/impl/io.h +6 -6
- data/vendor/faiss/faiss/impl/io_macros.h +8 -9
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +2 -3
- data/vendor/faiss/faiss/impl/kmeans1d.h +2 -2
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +2 -3
- data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
- data/vendor/faiss/faiss/impl/platform_macros.h +34 -2
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +13 -2
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +20 -2
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +3 -3
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +450 -3
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +8 -8
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +3 -3
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +151 -67
- data/vendor/faiss/faiss/index_factory.cpp +51 -34
- data/vendor/faiss/faiss/index_factory.h +2 -2
- data/vendor/faiss/faiss/index_io.h +14 -7
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +30 -10
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +5 -2
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +11 -3
- data/vendor/faiss/faiss/invlists/DirectMap.h +2 -2
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +57 -19
- data/vendor/faiss/faiss/invlists/InvertedLists.h +20 -11
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +2 -2
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +2 -2
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +23 -9
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +4 -3
- data/vendor/faiss/faiss/python/python_callbacks.cpp +5 -5
- data/vendor/faiss/faiss/python/python_callbacks.h +2 -2
- data/vendor/faiss/faiss/utils/AlignedTable.h +5 -3
- data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
- data/vendor/faiss/faiss/utils/Heap.h +107 -2
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +346 -0
- data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
- data/vendor/faiss/faiss/utils/WorkerThread.cpp +2 -2
- data/vendor/faiss/faiss/utils/WorkerThread.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk/generic.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk/mode.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +2 -2
- data/vendor/faiss/faiss/utils/bf16.h +36 -0
- data/vendor/faiss/faiss/utils/distances.cpp +249 -90
- data/vendor/faiss/faiss/utils/distances.h +8 -8
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +2 -2
- data/vendor/faiss/faiss/utils/distances_simd.cpp +1543 -56
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +72 -2
- data/vendor/faiss/faiss/utils/extra_distances.cpp +87 -140
- data/vendor/faiss/faiss/utils/extra_distances.h +5 -4
- data/vendor/faiss/faiss/utils/fp16-arm.h +2 -2
- data/vendor/faiss/faiss/utils/fp16-fp16c.h +2 -2
- data/vendor/faiss/faiss/utils/fp16-inl.h +2 -2
- data/vendor/faiss/faiss/utils/fp16.h +2 -2
- data/vendor/faiss/faiss/utils/hamming-inl.h +2 -2
- data/vendor/faiss/faiss/utils/hamming.cpp +3 -4
- data/vendor/faiss/faiss/utils/hamming.h +2 -2
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +2 -2
- data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +490 -0
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +2 -2
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +6 -3
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +7 -3
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +5 -5
- data/vendor/faiss/faiss/utils/ordered_key_value.h +2 -2
- data/vendor/faiss/faiss/utils/partitioning.cpp +2 -2
- data/vendor/faiss/faiss/utils/partitioning.h +2 -2
- data/vendor/faiss/faiss/utils/prefetch.h +2 -2
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +2 -2
- data/vendor/faiss/faiss/utils/quantize_lut.h +2 -2
- data/vendor/faiss/faiss/utils/random.cpp +45 -2
- data/vendor/faiss/faiss/utils/random.h +27 -2
- data/vendor/faiss/faiss/utils/simdlib.h +12 -3
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +2 -2
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +2 -2
- data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -4
- data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
- data/vendor/faiss/faiss/utils/sorting.cpp +2 -2
- data/vendor/faiss/faiss/utils/sorting.h +2 -2
- data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +2 -2
- data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
- data/vendor/faiss/faiss/utils/utils.cpp +17 -10
- data/vendor/faiss/faiss/utils/utils.h +7 -3
- metadata +22 -11
- data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -1,5 +1,5 @@
|
|
1
|
-
|
2
|
-
* Copyright (c)
|
1
|
+
/*
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
3
3
|
*
|
4
4
|
* This source code is licensed under the MIT license found in the
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
@@ -7,6 +7,7 @@
|
|
7
7
|
|
8
8
|
#include <faiss/impl/HNSW.h>
|
9
9
|
|
10
|
+
#include <cstddef>
|
10
11
|
#include <string>
|
11
12
|
|
12
13
|
#include <faiss/impl/AuxIndexStructures.h>
|
@@ -110,8 +111,8 @@ void HNSW::print_neighbor_stats(int level) const {
|
|
110
111
|
level,
|
111
112
|
nb_neighbors(level));
|
112
113
|
size_t tot_neigh = 0, tot_common = 0, tot_reciprocal = 0, n_node = 0;
|
113
|
-
#pragma omp parallel for reduction(
|
114
|
-
|
114
|
+
#pragma omp parallel for reduction(+ : tot_neigh) reduction(+ : tot_common) \
|
115
|
+
reduction(+ : tot_reciprocal) reduction(+ : n_node)
|
115
116
|
for (int i = 0; i < levels.size(); i++) {
|
116
117
|
if (levels[i] > level) {
|
117
118
|
n_node++;
|
@@ -165,10 +166,10 @@ void HNSW::print_neighbor_stats(int level) const {
|
|
165
166
|
}
|
166
167
|
|
167
168
|
void HNSW::fill_with_random_links(size_t n) {
|
168
|
-
int
|
169
|
+
int max_level_2 = prepare_level_tab(n);
|
169
170
|
RandomGenerator rng2(456);
|
170
171
|
|
171
|
-
for (int level =
|
172
|
+
for (int level = max_level_2 - 1; level >= 0; --level) {
|
172
173
|
std::vector<int> elts;
|
173
174
|
for (int i = 0; i < n; i++) {
|
174
175
|
if (levels[i] > level) {
|
@@ -209,16 +210,16 @@ int HNSW::prepare_level_tab(size_t n, bool preset_levels) {
|
|
209
210
|
}
|
210
211
|
}
|
211
212
|
|
212
|
-
int
|
213
|
+
int max_level_2 = 0;
|
213
214
|
for (int i = 0; i < n; i++) {
|
214
215
|
int pt_level = levels[i + n0] - 1;
|
215
|
-
if (pt_level >
|
216
|
-
|
216
|
+
if (pt_level > max_level_2)
|
217
|
+
max_level_2 = pt_level;
|
217
218
|
offsets.push_back(offsets.back() + cum_nb_neighbors(pt_level + 1));
|
218
|
-
neighbors.resize(offsets.back(), -1);
|
219
219
|
}
|
220
|
+
neighbors.resize(offsets.back(), -1);
|
220
221
|
|
221
|
-
return
|
222
|
+
return max_level_2;
|
222
223
|
}
|
223
224
|
|
224
225
|
/** Enumerate vertices from nearest to farthest from query, keep a
|
@@ -229,7 +230,14 @@ void HNSW::shrink_neighbor_list(
|
|
229
230
|
DistanceComputer& qdis,
|
230
231
|
std::priority_queue<NodeDistFarther>& input,
|
231
232
|
std::vector<NodeDistFarther>& output,
|
232
|
-
int max_size
|
233
|
+
int max_size,
|
234
|
+
bool keep_max_size_level0) {
|
235
|
+
// This prevents number of neighbors at
|
236
|
+
// level 0 from being shrunk to less than 2 * M.
|
237
|
+
// This is essential in making sure
|
238
|
+
// `faiss::gpu::GpuIndexCagra::copyFrom(IndexHNSWCagra*)` is functional
|
239
|
+
std::vector<NodeDistFarther> outsiders;
|
240
|
+
|
233
241
|
while (input.size() > 0) {
|
234
242
|
NodeDistFarther v1 = input.top();
|
235
243
|
input.pop();
|
@@ -250,8 +258,15 @@ void HNSW::shrink_neighbor_list(
|
|
250
258
|
if (output.size() >= max_size) {
|
251
259
|
return;
|
252
260
|
}
|
261
|
+
} else if (keep_max_size_level0) {
|
262
|
+
outsiders.push_back(v1);
|
253
263
|
}
|
254
264
|
}
|
265
|
+
size_t idx = 0;
|
266
|
+
while (keep_max_size_level0 && (output.size() < max_size) &&
|
267
|
+
(idx < outsiders.size())) {
|
268
|
+
output.push_back(outsiders[idx++]);
|
269
|
+
}
|
255
270
|
}
|
256
271
|
|
257
272
|
namespace {
|
@@ -268,7 +283,8 @@ using NodeDistFarther = HNSW::NodeDistFarther;
|
|
268
283
|
void shrink_neighbor_list(
|
269
284
|
DistanceComputer& qdis,
|
270
285
|
std::priority_queue<NodeDistCloser>& resultSet1,
|
271
|
-
int max_size
|
286
|
+
int max_size,
|
287
|
+
bool keep_max_size_level0 = false) {
|
272
288
|
if (resultSet1.size() < max_size) {
|
273
289
|
return;
|
274
290
|
}
|
@@ -280,7 +296,8 @@ void shrink_neighbor_list(
|
|
280
296
|
resultSet1.pop();
|
281
297
|
}
|
282
298
|
|
283
|
-
HNSW::shrink_neighbor_list(
|
299
|
+
HNSW::shrink_neighbor_list(
|
300
|
+
qdis, resultSet, returnlist, max_size, keep_max_size_level0);
|
284
301
|
|
285
302
|
for (NodeDistFarther curen2 : returnlist) {
|
286
303
|
resultSet1.emplace(curen2.d, curen2.id);
|
@@ -294,7 +311,8 @@ void add_link(
|
|
294
311
|
DistanceComputer& qdis,
|
295
312
|
storage_idx_t src,
|
296
313
|
storage_idx_t dest,
|
297
|
-
int level
|
314
|
+
int level,
|
315
|
+
bool keep_max_size_level0 = false) {
|
298
316
|
size_t begin, end;
|
299
317
|
hnsw.neighbor_range(src, level, &begin, &end);
|
300
318
|
if (hnsw.neighbors[end - 1] == -1) {
|
@@ -319,7 +337,7 @@ void add_link(
|
|
319
337
|
resultSet.emplace(qdis.symmetric_dis(src, neigh), neigh);
|
320
338
|
}
|
321
339
|
|
322
|
-
shrink_neighbor_list(qdis, resultSet, end - begin);
|
340
|
+
shrink_neighbor_list(qdis, resultSet, end - begin, keep_max_size_level0);
|
323
341
|
|
324
342
|
// ...and back
|
325
343
|
size_t i = begin;
|
@@ -333,6 +351,8 @@ void add_link(
|
|
333
351
|
}
|
334
352
|
}
|
335
353
|
|
354
|
+
} // namespace
|
355
|
+
|
336
356
|
/// search neighbors on a single level, starting from an entry point
|
337
357
|
void search_neighbors_to_add(
|
338
358
|
HNSW& hnsw,
|
@@ -341,7 +361,8 @@ void search_neighbors_to_add(
|
|
341
361
|
int entry_point,
|
342
362
|
float d_entry_point,
|
343
363
|
int level,
|
344
|
-
VisitedTable& vt
|
364
|
+
VisitedTable& vt,
|
365
|
+
bool reference_version) {
|
345
366
|
// top is nearest candidate
|
346
367
|
std::priority_queue<NodeDistFarther> candidates;
|
347
368
|
|
@@ -363,62 +384,98 @@ void search_neighbors_to_add(
|
|
363
384
|
// loop over neighbors
|
364
385
|
size_t begin, end;
|
365
386
|
hnsw.neighbor_range(currNode, level, &begin, &end);
|
366
|
-
for (size_t i = begin; i < end; i++) {
|
367
|
-
storage_idx_t nodeId = hnsw.neighbors[i];
|
368
|
-
if (nodeId < 0)
|
369
|
-
break;
|
370
|
-
if (vt.get(nodeId))
|
371
|
-
continue;
|
372
|
-
vt.set(nodeId);
|
373
387
|
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
388
|
+
// The reference version is not used, but kept here because:
|
389
|
+
// 1. It is easier to switch back if the optimized version has a problem
|
390
|
+
// 2. It serves as a starting point for new optimizations
|
391
|
+
// 3. It helps understand the code
|
392
|
+
// 4. It ensures the reference version is still compilable if the
|
393
|
+
// optimized version changes
|
394
|
+
// The reference and the optimized versions' results are compared in
|
395
|
+
// test_hnsw.cpp
|
396
|
+
if (reference_version) {
|
397
|
+
// a reference version
|
398
|
+
for (size_t i = begin; i < end; i++) {
|
399
|
+
storage_idx_t nodeId = hnsw.neighbors[i];
|
400
|
+
if (nodeId < 0)
|
401
|
+
break;
|
402
|
+
if (vt.get(nodeId))
|
403
|
+
continue;
|
404
|
+
vt.set(nodeId);
|
405
|
+
|
406
|
+
float dis = qdis(nodeId);
|
407
|
+
NodeDistFarther evE1(dis, nodeId);
|
408
|
+
|
409
|
+
if (results.size() < hnsw.efConstruction ||
|
410
|
+
results.top().d > dis) {
|
411
|
+
results.emplace(dis, nodeId);
|
412
|
+
candidates.emplace(dis, nodeId);
|
413
|
+
if (results.size() > hnsw.efConstruction) {
|
414
|
+
results.pop();
|
415
|
+
}
|
382
416
|
}
|
383
417
|
}
|
384
|
-
}
|
385
|
-
|
386
|
-
|
387
|
-
|
418
|
+
} else {
|
419
|
+
// a faster version
|
420
|
+
|
421
|
+
// the following version processes 4 neighbors at a time
|
422
|
+
auto update_with_candidate = [&](const storage_idx_t idx,
|
423
|
+
const float dis) {
|
424
|
+
if (results.size() < hnsw.efConstruction ||
|
425
|
+
results.top().d > dis) {
|
426
|
+
results.emplace(dis, idx);
|
427
|
+
candidates.emplace(dis, idx);
|
428
|
+
if (results.size() > hnsw.efConstruction) {
|
429
|
+
results.pop();
|
430
|
+
}
|
431
|
+
}
|
432
|
+
};
|
388
433
|
|
389
|
-
|
390
|
-
|
391
|
-
**************************************************************/
|
434
|
+
int n_buffered = 0;
|
435
|
+
storage_idx_t buffered_ids[4];
|
392
436
|
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
437
|
+
for (size_t j = begin; j < end; j++) {
|
438
|
+
storage_idx_t nodeId = hnsw.neighbors[j];
|
439
|
+
if (nodeId < 0)
|
440
|
+
break;
|
441
|
+
if (vt.get(nodeId)) {
|
442
|
+
continue;
|
443
|
+
}
|
444
|
+
vt.set(nodeId);
|
445
|
+
|
446
|
+
buffered_ids[n_buffered] = nodeId;
|
447
|
+
n_buffered += 1;
|
448
|
+
|
449
|
+
if (n_buffered == 4) {
|
450
|
+
float dis[4];
|
451
|
+
qdis.distances_batch_4(
|
452
|
+
buffered_ids[0],
|
453
|
+
buffered_ids[1],
|
454
|
+
buffered_ids[2],
|
455
|
+
buffered_ids[3],
|
456
|
+
dis[0],
|
457
|
+
dis[1],
|
458
|
+
dis[2],
|
459
|
+
dis[3]);
|
460
|
+
|
461
|
+
for (size_t id4 = 0; id4 < 4; id4++) {
|
462
|
+
update_with_candidate(buffered_ids[id4], dis[id4]);
|
463
|
+
}
|
402
464
|
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
nearest = v;
|
412
|
-
d_nearest = dis;
|
465
|
+
n_buffered = 0;
|
466
|
+
}
|
467
|
+
}
|
468
|
+
|
469
|
+
// process leftovers
|
470
|
+
for (size_t icnt = 0; icnt < n_buffered; icnt++) {
|
471
|
+
float dis = qdis(buffered_ids[icnt]);
|
472
|
+
update_with_candidate(buffered_ids[icnt], dis);
|
413
473
|
}
|
414
|
-
}
|
415
|
-
if (nearest == prev_nearest) {
|
416
|
-
return;
|
417
474
|
}
|
418
475
|
}
|
419
|
-
}
|
420
476
|
|
421
|
-
|
477
|
+
vt.advance();
|
478
|
+
}
|
422
479
|
|
423
480
|
/// Finds neighbors and builds links with them, starting from an entry
|
424
481
|
/// point. The own neighbor list is assumed to be locked.
|
@@ -429,7 +486,8 @@ void HNSW::add_links_starting_from(
|
|
429
486
|
float d_nearest,
|
430
487
|
int level,
|
431
488
|
omp_lock_t* locks,
|
432
|
-
VisitedTable& vt
|
489
|
+
VisitedTable& vt,
|
490
|
+
bool keep_max_size_level0) {
|
433
491
|
std::priority_queue<NodeDistCloser> link_targets;
|
434
492
|
|
435
493
|
search_neighbors_to_add(
|
@@ -438,21 +496,21 @@ void HNSW::add_links_starting_from(
|
|
438
496
|
// but we can afford only this many neighbors
|
439
497
|
int M = nb_neighbors(level);
|
440
498
|
|
441
|
-
::faiss::shrink_neighbor_list(ptdis, link_targets, M);
|
499
|
+
::faiss::shrink_neighbor_list(ptdis, link_targets, M, keep_max_size_level0);
|
442
500
|
|
443
|
-
std::vector<storage_idx_t>
|
444
|
-
|
501
|
+
std::vector<storage_idx_t> neighbors_to_add;
|
502
|
+
neighbors_to_add.reserve(link_targets.size());
|
445
503
|
while (!link_targets.empty()) {
|
446
504
|
storage_idx_t other_id = link_targets.top().id;
|
447
|
-
add_link(*this, ptdis, pt_id, other_id, level);
|
448
|
-
|
505
|
+
add_link(*this, ptdis, pt_id, other_id, level, keep_max_size_level0);
|
506
|
+
neighbors_to_add.push_back(other_id);
|
449
507
|
link_targets.pop();
|
450
508
|
}
|
451
509
|
|
452
510
|
omp_unset_lock(&locks[pt_id]);
|
453
|
-
for (storage_idx_t other_id :
|
511
|
+
for (storage_idx_t other_id : neighbors_to_add) {
|
454
512
|
omp_set_lock(&locks[other_id]);
|
455
|
-
add_link(*this, ptdis, other_id, pt_id, level);
|
513
|
+
add_link(*this, ptdis, other_id, pt_id, level, keep_max_size_level0);
|
456
514
|
omp_unset_lock(&locks[other_id]);
|
457
515
|
}
|
458
516
|
omp_set_lock(&locks[pt_id]);
|
@@ -467,7 +525,8 @@ void HNSW::add_with_locks(
|
|
467
525
|
int pt_level,
|
468
526
|
int pt_id,
|
469
527
|
std::vector<omp_lock_t>& locks,
|
470
|
-
VisitedTable& vt
|
528
|
+
VisitedTable& vt,
|
529
|
+
bool keep_max_size_level0) {
|
471
530
|
// greedy search on upper levels
|
472
531
|
|
473
532
|
storage_idx_t nearest;
|
@@ -496,7 +555,14 @@ void HNSW::add_with_locks(
|
|
496
555
|
|
497
556
|
for (; level >= 0; level--) {
|
498
557
|
add_links_starting_from(
|
499
|
-
ptdis,
|
558
|
+
ptdis,
|
559
|
+
pt_id,
|
560
|
+
nearest,
|
561
|
+
d_nearest,
|
562
|
+
level,
|
563
|
+
locks.data(),
|
564
|
+
vt,
|
565
|
+
keep_max_size_level0);
|
500
566
|
}
|
501
567
|
|
502
568
|
omp_unset_lock(&locks[pt_id]);
|
@@ -511,12 +577,10 @@ void HNSW::add_with_locks(
|
|
511
577
|
* Searching
|
512
578
|
**************************************************************/
|
513
579
|
|
514
|
-
namespace {
|
515
580
|
using MinimaxHeap = HNSW::MinimaxHeap;
|
516
581
|
using Node = HNSW::Node;
|
517
582
|
using C = HNSW::C;
|
518
583
|
/** Do a BFS on the candidates list */
|
519
|
-
|
520
584
|
int search_from_candidates(
|
521
585
|
const HNSW& hnsw,
|
522
586
|
DistanceComputer& qdis,
|
@@ -525,8 +589,8 @@ int search_from_candidates(
|
|
525
589
|
VisitedTable& vt,
|
526
590
|
HNSWStats& stats,
|
527
591
|
int level,
|
528
|
-
int nres_in
|
529
|
-
const SearchParametersHNSW* params
|
592
|
+
int nres_in,
|
593
|
+
const SearchParametersHNSW* params) {
|
530
594
|
int nres = nres_in;
|
531
595
|
int ndis = 0;
|
532
596
|
|
@@ -571,27 +635,7 @@ int search_from_candidates(
|
|
571
635
|
size_t begin, end;
|
572
636
|
hnsw.neighbor_range(v0, level, &begin, &end);
|
573
637
|
|
574
|
-
//
|
575
|
-
// for (size_t j = begin; j < end; j++) {
|
576
|
-
// int v1 = hnsw.neighbors[j];
|
577
|
-
// if (v1 < 0)
|
578
|
-
// break;
|
579
|
-
// if (vt.get(v1)) {
|
580
|
-
// continue;
|
581
|
-
// }
|
582
|
-
// vt.set(v1);
|
583
|
-
// ndis++;
|
584
|
-
// float d = qdis(v1);
|
585
|
-
// if (!sel || sel->is_member(v1)) {
|
586
|
-
// if (nres < k) {
|
587
|
-
// faiss::maxheap_push(++nres, D, I, d, v1);
|
588
|
-
// } else if (d < D[0]) {
|
589
|
-
// faiss::maxheap_replace_top(nres, D, I, d, v1);
|
590
|
-
// }
|
591
|
-
// }
|
592
|
-
// candidates.push(v1, d);
|
593
|
-
// }
|
594
|
-
|
638
|
+
// a faster version: reference version in unit test test_hnsw.cpp
|
595
639
|
// the following version processes 4 neighbors at a time
|
596
640
|
size_t jmax = begin;
|
597
641
|
for (size_t j = begin; j < end; j++) {
|
@@ -606,7 +650,6 @@ int search_from_candidates(
|
|
606
650
|
int counter = 0;
|
607
651
|
size_t saved_j[4];
|
608
652
|
|
609
|
-
ndis += jmax - begin;
|
610
653
|
threshold = res.threshold;
|
611
654
|
|
612
655
|
auto add_to_heap = [&](const size_t idx, const float dis) {
|
@@ -614,6 +657,7 @@ int search_from_candidates(
|
|
614
657
|
if (dis < threshold) {
|
615
658
|
if (res.add_result(dis, idx)) {
|
616
659
|
threshold = res.threshold;
|
660
|
+
nres += 1;
|
617
661
|
}
|
618
662
|
}
|
619
663
|
}
|
@@ -644,6 +688,8 @@ int search_from_candidates(
|
|
644
688
|
add_to_heap(saved_j[id4], dis[id4]);
|
645
689
|
}
|
646
690
|
|
691
|
+
ndis += 4;
|
692
|
+
|
647
693
|
counter = 0;
|
648
694
|
}
|
649
695
|
}
|
@@ -651,6 +697,8 @@ int search_from_candidates(
|
|
651
697
|
for (size_t icnt = 0; icnt < counter; icnt++) {
|
652
698
|
float dis = qdis(saved_j[icnt]);
|
653
699
|
add_to_heap(saved_j[icnt], dis);
|
700
|
+
|
701
|
+
ndis += 1;
|
654
702
|
}
|
655
703
|
|
656
704
|
nstep++;
|
@@ -664,7 +712,8 @@ int search_from_candidates(
|
|
664
712
|
if (candidates.size() == 0) {
|
665
713
|
stats.n2++;
|
666
714
|
}
|
667
|
-
stats.
|
715
|
+
stats.ndis += ndis;
|
716
|
+
stats.nhops += nstep;
|
668
717
|
}
|
669
718
|
|
670
719
|
return nres;
|
@@ -700,33 +749,7 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
|
|
700
749
|
size_t begin, end;
|
701
750
|
hnsw.neighbor_range(v0, 0, &begin, &end);
|
702
751
|
|
703
|
-
//
|
704
|
-
// for (size_t j = begin; j < end; ++j) {
|
705
|
-
// int v1 = hnsw.neighbors[j];
|
706
|
-
//
|
707
|
-
// if (v1 < 0) {
|
708
|
-
// break;
|
709
|
-
// }
|
710
|
-
// if (vt->get(v1)) {
|
711
|
-
// continue;
|
712
|
-
// }
|
713
|
-
//
|
714
|
-
// vt->set(v1);
|
715
|
-
//
|
716
|
-
// float d1 = qdis(v1);
|
717
|
-
// ++ndis;
|
718
|
-
//
|
719
|
-
// if (top_candidates.top().first > d1 ||
|
720
|
-
// top_candidates.size() < ef) {
|
721
|
-
// candidates.emplace(d1, v1);
|
722
|
-
// top_candidates.emplace(d1, v1);
|
723
|
-
//
|
724
|
-
// if (top_candidates.size() > ef) {
|
725
|
-
// top_candidates.pop();
|
726
|
-
// }
|
727
|
-
// }
|
728
|
-
// }
|
729
|
-
|
752
|
+
// a faster version: reference version in unit test test_hnsw.cpp
|
730
753
|
// the following version processes 4 neighbors at a time
|
731
754
|
size_t jmax = begin;
|
732
755
|
for (size_t j = begin; j < end; j++) {
|
@@ -741,8 +764,6 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
|
|
741
764
|
int counter = 0;
|
742
765
|
size_t saved_j[4];
|
743
766
|
|
744
|
-
ndis += jmax - begin;
|
745
|
-
|
746
767
|
auto add_to_heap = [&](const size_t idx, const float dis) {
|
747
768
|
if (top_candidates.top().first > dis ||
|
748
769
|
top_candidates.size() < ef) {
|
@@ -779,6 +800,8 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
|
|
779
800
|
add_to_heap(saved_j[id4], dis[id4]);
|
780
801
|
}
|
781
802
|
|
803
|
+
ndis += 4;
|
804
|
+
|
782
805
|
counter = 0;
|
783
806
|
}
|
784
807
|
}
|
@@ -786,18 +809,102 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
|
|
786
809
|
for (size_t icnt = 0; icnt < counter; icnt++) {
|
787
810
|
float dis = qdis(saved_j[icnt]);
|
788
811
|
add_to_heap(saved_j[icnt], dis);
|
812
|
+
|
813
|
+
ndis += 1;
|
789
814
|
}
|
815
|
+
|
816
|
+
stats.nhops += 1;
|
790
817
|
}
|
791
818
|
|
792
819
|
++stats.n1;
|
793
820
|
if (candidates.size() == 0) {
|
794
821
|
++stats.n2;
|
795
822
|
}
|
796
|
-
stats.
|
823
|
+
stats.ndis += ndis;
|
797
824
|
|
798
825
|
return top_candidates;
|
799
826
|
}
|
800
827
|
|
828
|
+
/// greedily update a nearest vector at a given level
|
829
|
+
HNSWStats greedy_update_nearest(
|
830
|
+
const HNSW& hnsw,
|
831
|
+
DistanceComputer& qdis,
|
832
|
+
int level,
|
833
|
+
storage_idx_t& nearest,
|
834
|
+
float& d_nearest) {
|
835
|
+
HNSWStats stats;
|
836
|
+
|
837
|
+
for (;;) {
|
838
|
+
storage_idx_t prev_nearest = nearest;
|
839
|
+
|
840
|
+
size_t begin, end;
|
841
|
+
hnsw.neighbor_range(nearest, level, &begin, &end);
|
842
|
+
|
843
|
+
size_t ndis = 0;
|
844
|
+
|
845
|
+
// a faster version: reference version in unit test test_hnsw.cpp
|
846
|
+
// the following version processes 4 neighbors at a time
|
847
|
+
auto update_with_candidate = [&](const storage_idx_t idx,
|
848
|
+
const float dis) {
|
849
|
+
if (dis < d_nearest) {
|
850
|
+
nearest = idx;
|
851
|
+
d_nearest = dis;
|
852
|
+
}
|
853
|
+
};
|
854
|
+
|
855
|
+
int n_buffered = 0;
|
856
|
+
storage_idx_t buffered_ids[4];
|
857
|
+
|
858
|
+
for (size_t j = begin; j < end; j++) {
|
859
|
+
storage_idx_t v = hnsw.neighbors[j];
|
860
|
+
if (v < 0)
|
861
|
+
break;
|
862
|
+
ndis += 1;
|
863
|
+
|
864
|
+
buffered_ids[n_buffered] = v;
|
865
|
+
n_buffered += 1;
|
866
|
+
|
867
|
+
if (n_buffered == 4) {
|
868
|
+
float dis[4];
|
869
|
+
qdis.distances_batch_4(
|
870
|
+
buffered_ids[0],
|
871
|
+
buffered_ids[1],
|
872
|
+
buffered_ids[2],
|
873
|
+
buffered_ids[3],
|
874
|
+
dis[0],
|
875
|
+
dis[1],
|
876
|
+
dis[2],
|
877
|
+
dis[3]);
|
878
|
+
|
879
|
+
for (size_t id4 = 0; id4 < 4; id4++) {
|
880
|
+
update_with_candidate(buffered_ids[id4], dis[id4]);
|
881
|
+
}
|
882
|
+
|
883
|
+
n_buffered = 0;
|
884
|
+
}
|
885
|
+
}
|
886
|
+
|
887
|
+
// process leftovers
|
888
|
+
for (size_t icnt = 0; icnt < n_buffered; icnt++) {
|
889
|
+
float dis = qdis(buffered_ids[icnt]);
|
890
|
+
update_with_candidate(buffered_ids[icnt], dis);
|
891
|
+
}
|
892
|
+
|
893
|
+
// update stats
|
894
|
+
stats.ndis += ndis;
|
895
|
+
stats.nhops += 1;
|
896
|
+
|
897
|
+
if (nearest == prev_nearest) {
|
898
|
+
return stats;
|
899
|
+
}
|
900
|
+
}
|
901
|
+
}
|
902
|
+
|
903
|
+
namespace {
|
904
|
+
using MinimaxHeap = HNSW::MinimaxHeap;
|
905
|
+
using Node = HNSW::Node;
|
906
|
+
using C = HNSW::C;
|
907
|
+
|
801
908
|
// just used as a lower bound for the minmaxheap, but it is set for heap search
|
802
909
|
int extract_k_from_ResultHandler(ResultHandler<C>& res) {
|
803
910
|
using RH = HeapBlockResultHandler<C>;
|
@@ -807,7 +914,7 @@ int extract_k_from_ResultHandler(ResultHandler<C>& res) {
|
|
807
914
|
return 1;
|
808
915
|
}
|
809
916
|
|
810
|
-
} //
|
917
|
+
} // namespace
|
811
918
|
|
812
919
|
HNSWStats HNSW::search(
|
813
920
|
DistanceComputer& qdis,
|
@@ -820,85 +927,47 @@ HNSWStats HNSW::search(
|
|
820
927
|
}
|
821
928
|
int k = extract_k_from_ResultHandler(res);
|
822
929
|
|
823
|
-
|
824
|
-
|
825
|
-
storage_idx_t nearest = entry_point;
|
826
|
-
float d_nearest = qdis(nearest);
|
827
|
-
|
828
|
-
for (int level = max_level; level >= 1; level--) {
|
829
|
-
greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
|
830
|
-
}
|
831
|
-
|
832
|
-
int ef = std::max(params ? params->efSearch : efSearch, k);
|
833
|
-
if (search_bounded_queue) { // this is the most common branch
|
834
|
-
MinimaxHeap candidates(ef);
|
930
|
+
bool bounded_queue =
|
931
|
+
params ? params->bounded_queue : this->search_bounded_queue;
|
835
932
|
|
836
|
-
|
933
|
+
// greedy search on upper levels
|
934
|
+
storage_idx_t nearest = entry_point;
|
935
|
+
float d_nearest = qdis(nearest);
|
837
936
|
|
838
|
-
|
839
|
-
|
840
|
-
|
841
|
-
|
842
|
-
|
843
|
-
*this,
|
844
|
-
Node(d_nearest, nearest),
|
845
|
-
qdis,
|
846
|
-
ef,
|
847
|
-
&vt,
|
848
|
-
stats);
|
849
|
-
|
850
|
-
while (top_candidates.size() > k) {
|
851
|
-
top_candidates.pop();
|
852
|
-
}
|
937
|
+
for (int level = max_level; level >= 1; level--) {
|
938
|
+
HNSWStats local_stats =
|
939
|
+
greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
|
940
|
+
stats.combine(local_stats);
|
941
|
+
}
|
853
942
|
|
854
|
-
|
855
|
-
|
856
|
-
|
857
|
-
std::tie(d, label) = top_candidates.top();
|
858
|
-
res.add_result(d, label);
|
859
|
-
top_candidates.pop();
|
860
|
-
}
|
861
|
-
}
|
943
|
+
int ef = std::max(params ? params->efSearch : efSearch, k);
|
944
|
+
if (bounded_queue) { // this is the most common branch
|
945
|
+
MinimaxHeap candidates(ef);
|
862
946
|
|
863
|
-
|
947
|
+
candidates.push(nearest, d_nearest);
|
864
948
|
|
949
|
+
search_from_candidates(
|
950
|
+
*this, qdis, res, candidates, vt, stats, 0, 0, params);
|
865
951
|
} else {
|
866
|
-
|
867
|
-
|
868
|
-
|
869
|
-
std::vector<idx_t> I_to_next(candidates_size);
|
870
|
-
std::vector<float> D_to_next(candidates_size);
|
871
|
-
|
872
|
-
HeapBlockResultHandler<C> block_resh(
|
873
|
-
1, D_to_next.data(), I_to_next.data(), candidates_size);
|
874
|
-
HeapBlockResultHandler<C>::SingleResultHandler resh(block_resh);
|
875
|
-
|
876
|
-
int nres = 1;
|
877
|
-
I_to_next[0] = entry_point;
|
878
|
-
D_to_next[0] = qdis(entry_point);
|
879
|
-
|
880
|
-
for (int level = max_level; level >= 0; level--) {
|
881
|
-
// copy I, D -> candidates
|
952
|
+
std::priority_queue<Node> top_candidates =
|
953
|
+
search_from_candidate_unbounded(
|
954
|
+
*this, Node(d_nearest, nearest), qdis, ef, &vt, stats);
|
882
955
|
|
883
|
-
|
884
|
-
|
885
|
-
|
886
|
-
candidates.push(I_to_next[i], D_to_next[i]);
|
887
|
-
}
|
956
|
+
while (top_candidates.size() > k) {
|
957
|
+
top_candidates.pop();
|
958
|
+
}
|
888
959
|
|
889
|
-
|
890
|
-
|
891
|
-
|
892
|
-
|
893
|
-
|
894
|
-
|
895
|
-
*this, qdis, resh, candidates, vt, stats, level);
|
896
|
-
resh.end();
|
897
|
-
}
|
898
|
-
vt.advance();
|
960
|
+
while (!top_candidates.empty()) {
|
961
|
+
float d;
|
962
|
+
storage_idx_t label;
|
963
|
+
std::tie(d, label) = top_candidates.top();
|
964
|
+
res.add_result(d, label);
|
965
|
+
top_candidates.pop();
|
899
966
|
}
|
900
967
|
}
|
901
968
|
|
969
|
+
vt.advance();
|
970
|
+
|
902
971
|
return stats;
|
903
972
|
}
|
904
973
|
|
@@ -910,9 +979,12 @@ void HNSW::search_level_0(
|
|
910
979
|
const float* nearest_d,
|
911
980
|
int search_type,
|
912
981
|
HNSWStats& search_stats,
|
913
|
-
VisitedTable& vt
|
982
|
+
VisitedTable& vt,
|
983
|
+
const SearchParametersHNSW* params) const {
|
914
984
|
const HNSW& hnsw = *this;
|
985
|
+
auto efSearch = params ? params->efSearch : hnsw.efSearch;
|
915
986
|
int k = extract_k_from_ResultHandler(res);
|
987
|
+
|
916
988
|
if (search_type == 1) {
|
917
989
|
int nres = 0;
|
918
990
|
|
@@ -925,16 +997,25 @@ void HNSW::search_level_0(
|
|
925
997
|
if (vt.get(cj))
|
926
998
|
continue;
|
927
999
|
|
928
|
-
int candidates_size = std::max(
|
1000
|
+
int candidates_size = std::max(efSearch, k);
|
929
1001
|
MinimaxHeap candidates(candidates_size);
|
930
1002
|
|
931
1003
|
candidates.push(cj, nearest_d[j]);
|
932
1004
|
|
933
1005
|
nres = search_from_candidates(
|
934
|
-
hnsw,
|
1006
|
+
hnsw,
|
1007
|
+
qdis,
|
1008
|
+
res,
|
1009
|
+
candidates,
|
1010
|
+
vt,
|
1011
|
+
search_stats,
|
1012
|
+
0,
|
1013
|
+
nres,
|
1014
|
+
params);
|
1015
|
+
nres = std::min(nres, candidates_size);
|
935
1016
|
}
|
936
1017
|
} else if (search_type == 2) {
|
937
|
-
int candidates_size = std::max(
|
1018
|
+
int candidates_size = std::max(efSearch, int(k));
|
938
1019
|
candidates_size = std::max(candidates_size, int(nprobe));
|
939
1020
|
|
940
1021
|
MinimaxHeap candidates(candidates_size);
|
@@ -947,7 +1028,7 @@ void HNSW::search_level_0(
|
|
947
1028
|
}
|
948
1029
|
|
949
1030
|
search_from_candidates(
|
950
|
-
hnsw, qdis, res, candidates, vt, search_stats, 0);
|
1031
|
+
hnsw, qdis, res, candidates, vt, search_stats, 0, 0, params);
|
951
1032
|
}
|
952
1033
|
}
|
953
1034
|
|
@@ -1013,7 +1094,99 @@ void HNSW::MinimaxHeap::clear() {
|
|
1013
1094
|
nvalid = k = 0;
|
1014
1095
|
}
|
1015
1096
|
|
1016
|
-
#ifdef
|
1097
|
+
#ifdef __AVX512F__
|
1098
|
+
|
1099
|
+
int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
|
1100
|
+
assert(k > 0);
|
1101
|
+
static_assert(
|
1102
|
+
std::is_same<storage_idx_t, int32_t>::value,
|
1103
|
+
"This code expects storage_idx_t to be int32_t");
|
1104
|
+
|
1105
|
+
int32_t min_idx = -1;
|
1106
|
+
float min_dis = std::numeric_limits<float>::infinity();
|
1107
|
+
|
1108
|
+
__m512i min_indices = _mm512_set1_epi32(-1);
|
1109
|
+
__m512 min_distances =
|
1110
|
+
_mm512_set1_ps(std::numeric_limits<float>::infinity());
|
1111
|
+
__m512i current_indices = _mm512_setr_epi32(
|
1112
|
+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
|
1113
|
+
__m512i offset = _mm512_set1_epi32(16);
|
1114
|
+
|
1115
|
+
// The following loop tracks the rightmost index with the min distance.
|
1116
|
+
// -1 index values are ignored.
|
1117
|
+
const int k16 = (k / 16) * 16;
|
1118
|
+
for (size_t iii = 0; iii < k16; iii += 16) {
|
1119
|
+
__m512i indices =
|
1120
|
+
_mm512_loadu_si512((const __m512i*)(ids.data() + iii));
|
1121
|
+
__m512 distances = _mm512_loadu_ps(dis.data() + iii);
|
1122
|
+
|
1123
|
+
// This mask filters out -1 values among indices.
|
1124
|
+
__mmask16 m1mask =
|
1125
|
+
_mm512_cmpgt_epi32_mask(_mm512_setzero_si512(), indices);
|
1126
|
+
|
1127
|
+
__mmask16 dmask =
|
1128
|
+
_mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS);
|
1129
|
+
__mmask16 finalmask = m1mask | dmask;
|
1130
|
+
|
1131
|
+
const __m512i min_indices_new = _mm512_mask_blend_epi32(
|
1132
|
+
finalmask, current_indices, min_indices);
|
1133
|
+
const __m512 min_distances_new =
|
1134
|
+
_mm512_mask_blend_ps(finalmask, distances, min_distances);
|
1135
|
+
|
1136
|
+
min_indices = min_indices_new;
|
1137
|
+
min_distances = min_distances_new;
|
1138
|
+
|
1139
|
+
current_indices = _mm512_add_epi32(current_indices, offset);
|
1140
|
+
}
|
1141
|
+
|
1142
|
+
// leftovers
|
1143
|
+
if (k16 != k) {
|
1144
|
+
const __mmask16 kmask = (1 << (k - k16)) - 1;
|
1145
|
+
|
1146
|
+
__m512i indices = _mm512_mask_loadu_epi32(
|
1147
|
+
_mm512_set1_epi32(-1), kmask, ids.data() + k16);
|
1148
|
+
__m512 distances = _mm512_maskz_loadu_ps(kmask, dis.data() + k16);
|
1149
|
+
|
1150
|
+
// This mask filters out -1 values among indices.
|
1151
|
+
__mmask16 m1mask =
|
1152
|
+
_mm512_cmpgt_epi32_mask(_mm512_setzero_si512(), indices);
|
1153
|
+
|
1154
|
+
__mmask16 dmask =
|
1155
|
+
_mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS);
|
1156
|
+
__mmask16 finalmask = m1mask | dmask;
|
1157
|
+
|
1158
|
+
const __m512i min_indices_new = _mm512_mask_blend_epi32(
|
1159
|
+
finalmask, current_indices, min_indices);
|
1160
|
+
const __m512 min_distances_new =
|
1161
|
+
_mm512_mask_blend_ps(finalmask, distances, min_distances);
|
1162
|
+
|
1163
|
+
min_indices = min_indices_new;
|
1164
|
+
min_distances = min_distances_new;
|
1165
|
+
}
|
1166
|
+
|
1167
|
+
// grab min distance
|
1168
|
+
min_dis = _mm512_reduce_min_ps(min_distances);
|
1169
|
+
// blend
|
1170
|
+
__mmask16 mindmask =
|
1171
|
+
_mm512_cmpeq_ps_mask(min_distances, _mm512_set1_ps(min_dis));
|
1172
|
+
// pick the max one
|
1173
|
+
min_idx = _mm512_mask_reduce_max_epi32(mindmask, min_indices);
|
1174
|
+
|
1175
|
+
if (min_idx == -1) {
|
1176
|
+
return -1;
|
1177
|
+
}
|
1178
|
+
|
1179
|
+
if (vmin_out) {
|
1180
|
+
*vmin_out = min_dis;
|
1181
|
+
}
|
1182
|
+
int ret = ids[min_idx];
|
1183
|
+
ids[min_idx] = -1;
|
1184
|
+
--nvalid;
|
1185
|
+
return ret;
|
1186
|
+
}
|
1187
|
+
|
1188
|
+
#elif __AVX2__
|
1189
|
+
|
1017
1190
|
int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
|
1018
1191
|
assert(k > 0);
|
1019
1192
|
static_assert(
|