faiss 0.3.1 → 0.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +8 -0
- data/LICENSE.txt +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +2 -2
- data/vendor/faiss/faiss/AutoTune.h +3 -3
- data/vendor/faiss/faiss/Clustering.cpp +37 -6
- data/vendor/faiss/faiss/Clustering.h +12 -3
- data/vendor/faiss/faiss/IVFlib.cpp +6 -3
- data/vendor/faiss/faiss/IVFlib.h +2 -2
- data/vendor/faiss/faiss/Index.cpp +6 -2
- data/vendor/faiss/faiss/Index.h +30 -8
- data/vendor/faiss/faiss/Index2Layer.cpp +2 -2
- data/vendor/faiss/faiss/Index2Layer.h +2 -2
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +7 -7
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +2 -2
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +14 -16
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +2 -2
- data/vendor/faiss/faiss/IndexBinary.cpp +13 -2
- data/vendor/faiss/faiss/IndexBinary.h +8 -2
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -3
- data/vendor/faiss/faiss/IndexBinaryFlat.h +2 -2
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -2
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -2
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +2 -7
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +3 -3
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -3
- data/vendor/faiss/faiss/IndexBinaryHash.h +2 -2
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +3 -3
- data/vendor/faiss/faiss/IndexBinaryIVF.h +2 -2
- data/vendor/faiss/faiss/IndexFastScan.cpp +32 -18
- data/vendor/faiss/faiss/IndexFastScan.h +11 -2
- data/vendor/faiss/faiss/IndexFlat.cpp +13 -10
- data/vendor/faiss/faiss/IndexFlat.h +2 -2
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -7
- data/vendor/faiss/faiss/IndexFlatCodes.h +25 -5
- data/vendor/faiss/faiss/IndexHNSW.cpp +156 -96
- data/vendor/faiss/faiss/IndexHNSW.h +54 -5
- data/vendor/faiss/faiss/IndexIDMap.cpp +19 -3
- data/vendor/faiss/faiss/IndexIDMap.h +5 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +5 -6
- data/vendor/faiss/faiss/IndexIVF.h +13 -4
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +21 -7
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +5 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +3 -14
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -4
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +201 -91
- data/vendor/faiss/faiss/IndexIVFFastScan.h +33 -9
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +2 -2
- data/vendor/faiss/faiss/IndexIVFFlat.h +2 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +2 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +2 -2
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -6
- data/vendor/faiss/faiss/IndexIVFPQ.h +2 -2
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +7 -14
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +2 -4
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -2
- data/vendor/faiss/faiss/IndexIVFPQR.h +2 -2
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +2 -3
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -2
- data/vendor/faiss/faiss/IndexLSH.cpp +2 -3
- data/vendor/faiss/faiss/IndexLSH.h +2 -2
- data/vendor/faiss/faiss/IndexLattice.cpp +3 -21
- data/vendor/faiss/faiss/IndexLattice.h +5 -24
- data/vendor/faiss/faiss/IndexNNDescent.cpp +2 -31
- data/vendor/faiss/faiss/IndexNNDescent.h +3 -3
- data/vendor/faiss/faiss/IndexNSG.cpp +2 -5
- data/vendor/faiss/faiss/IndexNSG.h +3 -3
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
- data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +26 -26
- data/vendor/faiss/faiss/IndexPQ.h +2 -2
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +2 -5
- data/vendor/faiss/faiss/IndexPQFastScan.h +2 -11
- data/vendor/faiss/faiss/IndexPreTransform.cpp +2 -2
- data/vendor/faiss/faiss/IndexPreTransform.h +3 -3
- data/vendor/faiss/faiss/IndexRefine.cpp +46 -9
- data/vendor/faiss/faiss/IndexRefine.h +9 -2
- data/vendor/faiss/faiss/IndexReplicas.cpp +2 -2
- data/vendor/faiss/faiss/IndexReplicas.h +2 -2
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +2 -2
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +2 -2
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +5 -4
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -2
- data/vendor/faiss/faiss/IndexShards.cpp +2 -2
- data/vendor/faiss/faiss/IndexShards.h +2 -2
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +2 -2
- data/vendor/faiss/faiss/IndexShardsIVF.h +2 -2
- data/vendor/faiss/faiss/MatrixStats.cpp +2 -2
- data/vendor/faiss/faiss/MatrixStats.h +2 -2
- data/vendor/faiss/faiss/MetaIndexes.cpp +2 -3
- data/vendor/faiss/faiss/MetaIndexes.h +2 -2
- data/vendor/faiss/faiss/MetricType.h +9 -4
- data/vendor/faiss/faiss/VectorTransform.cpp +2 -2
- data/vendor/faiss/faiss/VectorTransform.h +2 -2
- data/vendor/faiss/faiss/clone_index.cpp +2 -2
- data/vendor/faiss/faiss/clone_index.h +2 -2
- data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +2 -2
- data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +2 -2
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +97 -19
- data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +192 -0
- data/vendor/faiss/faiss/cppcontrib/factory_tools.h +29 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2 -2
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +85 -32
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2 -2
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +2 -2
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +2 -2
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +2 -2
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +2 -2
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +2 -5
- data/vendor/faiss/faiss/gpu/GpuAutoTune.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +45 -13
- data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +12 -6
- data/vendor/faiss/faiss/gpu/GpuDistance.h +11 -7
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +3 -3
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuIndex.h +10 -15
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +285 -0
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -2
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +4 -2
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +3 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +7 -2
- data/vendor/faiss/faiss/gpu/GpuResources.h +11 -4
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +66 -11
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +15 -5
- data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -2
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +28 -23
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +2 -2
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +2 -2
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +2 -2
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +2 -2
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +8 -2
- data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +2 -3
- data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +2 -2
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +10 -7
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +2 -2
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +54 -54
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +144 -77
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +51 -51
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +2 -2
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +3 -3
- data/vendor/faiss/faiss/gpu/test/TestGpuResidualQuantizer.cpp +70 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +74 -4
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +2 -2
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +3 -3
- data/vendor/faiss/faiss/gpu/utils/{RaftUtils.h → CuvsUtils.h} +12 -11
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +8 -2
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +2 -2
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +2 -2
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +2 -2
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +6 -3
- data/vendor/faiss/faiss/gpu/utils/Timer.h +3 -3
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +79 -11
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +17 -5
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +27 -2
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +11 -3
- data/vendor/faiss/faiss/impl/CodePacker.cpp +2 -2
- data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
- data/vendor/faiss/faiss/impl/DistanceComputer.h +48 -2
- data/vendor/faiss/faiss/impl/FaissAssert.h +6 -4
- data/vendor/faiss/faiss/impl/FaissException.cpp +2 -2
- data/vendor/faiss/faiss/impl/FaissException.h +2 -3
- data/vendor/faiss/faiss/impl/HNSW.cpp +378 -205
- data/vendor/faiss/faiss/impl/HNSW.h +55 -24
- data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
- data/vendor/faiss/faiss/impl/IDSelector.h +2 -2
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +10 -10
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +2 -2
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +36 -2
- data/vendor/faiss/faiss/impl/NNDescent.cpp +15 -10
- data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
- data/vendor/faiss/faiss/impl/NSG.cpp +26 -49
- data/vendor/faiss/faiss/impl/NSG.h +20 -8
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +2 -2
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +2 -2
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +2 -4
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +2 -2
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -2
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +3 -2
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +7 -3
- data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +2 -36
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +3 -13
- data/vendor/faiss/faiss/impl/ResultHandler.h +153 -34
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +721 -104
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +5 -2
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +2 -2
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +2 -2
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +7 -2
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +2 -2
- data/vendor/faiss/faiss/impl/code_distance/code_distance-sve.h +440 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +55 -2
- data/vendor/faiss/faiss/impl/index_read.cpp +31 -20
- data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
- data/vendor/faiss/faiss/impl/index_write.cpp +30 -16
- data/vendor/faiss/faiss/impl/io.cpp +15 -7
- data/vendor/faiss/faiss/impl/io.h +6 -6
- data/vendor/faiss/faiss/impl/io_macros.h +8 -9
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +2 -3
- data/vendor/faiss/faiss/impl/kmeans1d.h +2 -2
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +2 -3
- data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
- data/vendor/faiss/faiss/impl/platform_macros.h +34 -2
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +13 -2
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +20 -2
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +3 -3
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +450 -3
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +8 -8
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +3 -3
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +151 -67
- data/vendor/faiss/faiss/index_factory.cpp +51 -34
- data/vendor/faiss/faiss/index_factory.h +2 -2
- data/vendor/faiss/faiss/index_io.h +14 -7
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +30 -10
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +5 -2
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +11 -3
- data/vendor/faiss/faiss/invlists/DirectMap.h +2 -2
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +57 -19
- data/vendor/faiss/faiss/invlists/InvertedLists.h +20 -11
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +2 -2
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +2 -2
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +23 -9
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +4 -3
- data/vendor/faiss/faiss/python/python_callbacks.cpp +5 -5
- data/vendor/faiss/faiss/python/python_callbacks.h +2 -2
- data/vendor/faiss/faiss/utils/AlignedTable.h +5 -3
- data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
- data/vendor/faiss/faiss/utils/Heap.h +107 -2
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +346 -0
- data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
- data/vendor/faiss/faiss/utils/WorkerThread.cpp +2 -2
- data/vendor/faiss/faiss/utils/WorkerThread.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk/generic.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk/mode.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +2 -2
- data/vendor/faiss/faiss/utils/bf16.h +36 -0
- data/vendor/faiss/faiss/utils/distances.cpp +249 -90
- data/vendor/faiss/faiss/utils/distances.h +8 -8
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +2 -2
- data/vendor/faiss/faiss/utils/distances_simd.cpp +1543 -56
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +72 -2
- data/vendor/faiss/faiss/utils/extra_distances.cpp +87 -140
- data/vendor/faiss/faiss/utils/extra_distances.h +5 -4
- data/vendor/faiss/faiss/utils/fp16-arm.h +2 -2
- data/vendor/faiss/faiss/utils/fp16-fp16c.h +2 -2
- data/vendor/faiss/faiss/utils/fp16-inl.h +2 -2
- data/vendor/faiss/faiss/utils/fp16.h +2 -2
- data/vendor/faiss/faiss/utils/hamming-inl.h +2 -2
- data/vendor/faiss/faiss/utils/hamming.cpp +3 -4
- data/vendor/faiss/faiss/utils/hamming.h +2 -2
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +2 -2
- data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +490 -0
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +2 -2
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +6 -3
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +7 -3
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +5 -5
- data/vendor/faiss/faiss/utils/ordered_key_value.h +2 -2
- data/vendor/faiss/faiss/utils/partitioning.cpp +2 -2
- data/vendor/faiss/faiss/utils/partitioning.h +2 -2
- data/vendor/faiss/faiss/utils/prefetch.h +2 -2
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +2 -2
- data/vendor/faiss/faiss/utils/quantize_lut.h +2 -2
- data/vendor/faiss/faiss/utils/random.cpp +45 -2
- data/vendor/faiss/faiss/utils/random.h +27 -2
- data/vendor/faiss/faiss/utils/simdlib.h +12 -3
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +2 -2
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +2 -2
- data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -4
- data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
- data/vendor/faiss/faiss/utils/sorting.cpp +2 -2
- data/vendor/faiss/faiss/utils/sorting.h +2 -2
- data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +2 -2
- data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
- data/vendor/faiss/faiss/utils/utils.cpp +17 -10
- data/vendor/faiss/faiss/utils/utils.h +7 -3
- metadata +22 -11
- data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -1,5 +1,5 @@
|
|
1
|
-
|
2
|
-
* Copyright (c)
|
1
|
+
/*
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
3
3
|
*
|
4
4
|
* This source code is licensed under the MIT license found in the
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
@@ -7,6 +7,7 @@
|
|
7
7
|
|
8
8
|
#include <faiss/impl/HNSW.h>
|
9
9
|
|
10
|
+
#include <cstddef>
|
10
11
|
#include <string>
|
11
12
|
|
12
13
|
#include <faiss/impl/AuxIndexStructures.h>
|
@@ -110,8 +111,8 @@ void HNSW::print_neighbor_stats(int level) const {
|
|
110
111
|
level,
|
111
112
|
nb_neighbors(level));
|
112
113
|
size_t tot_neigh = 0, tot_common = 0, tot_reciprocal = 0, n_node = 0;
|
113
|
-
#pragma omp parallel for reduction(
|
114
|
-
|
114
|
+
#pragma omp parallel for reduction(+ : tot_neigh) reduction(+ : tot_common) \
|
115
|
+
reduction(+ : tot_reciprocal) reduction(+ : n_node)
|
115
116
|
for (int i = 0; i < levels.size(); i++) {
|
116
117
|
if (levels[i] > level) {
|
117
118
|
n_node++;
|
@@ -165,10 +166,10 @@ void HNSW::print_neighbor_stats(int level) const {
|
|
165
166
|
}
|
166
167
|
|
167
168
|
void HNSW::fill_with_random_links(size_t n) {
|
168
|
-
int
|
169
|
+
int max_level_2 = prepare_level_tab(n);
|
169
170
|
RandomGenerator rng2(456);
|
170
171
|
|
171
|
-
for (int level =
|
172
|
+
for (int level = max_level_2 - 1; level >= 0; --level) {
|
172
173
|
std::vector<int> elts;
|
173
174
|
for (int i = 0; i < n; i++) {
|
174
175
|
if (levels[i] > level) {
|
@@ -209,16 +210,16 @@ int HNSW::prepare_level_tab(size_t n, bool preset_levels) {
|
|
209
210
|
}
|
210
211
|
}
|
211
212
|
|
212
|
-
int
|
213
|
+
int max_level_2 = 0;
|
213
214
|
for (int i = 0; i < n; i++) {
|
214
215
|
int pt_level = levels[i + n0] - 1;
|
215
|
-
if (pt_level >
|
216
|
-
|
216
|
+
if (pt_level > max_level_2)
|
217
|
+
max_level_2 = pt_level;
|
217
218
|
offsets.push_back(offsets.back() + cum_nb_neighbors(pt_level + 1));
|
218
|
-
neighbors.resize(offsets.back(), -1);
|
219
219
|
}
|
220
|
+
neighbors.resize(offsets.back(), -1);
|
220
221
|
|
221
|
-
return
|
222
|
+
return max_level_2;
|
222
223
|
}
|
223
224
|
|
224
225
|
/** Enumerate vertices from nearest to farthest from query, keep a
|
@@ -229,7 +230,14 @@ void HNSW::shrink_neighbor_list(
|
|
229
230
|
DistanceComputer& qdis,
|
230
231
|
std::priority_queue<NodeDistFarther>& input,
|
231
232
|
std::vector<NodeDistFarther>& output,
|
232
|
-
int max_size
|
233
|
+
int max_size,
|
234
|
+
bool keep_max_size_level0) {
|
235
|
+
// This prevents number of neighbors at
|
236
|
+
// level 0 from being shrunk to less than 2 * M.
|
237
|
+
// This is essential in making sure
|
238
|
+
// `faiss::gpu::GpuIndexCagra::copyFrom(IndexHNSWCagra*)` is functional
|
239
|
+
std::vector<NodeDistFarther> outsiders;
|
240
|
+
|
233
241
|
while (input.size() > 0) {
|
234
242
|
NodeDistFarther v1 = input.top();
|
235
243
|
input.pop();
|
@@ -250,8 +258,15 @@ void HNSW::shrink_neighbor_list(
|
|
250
258
|
if (output.size() >= max_size) {
|
251
259
|
return;
|
252
260
|
}
|
261
|
+
} else if (keep_max_size_level0) {
|
262
|
+
outsiders.push_back(v1);
|
253
263
|
}
|
254
264
|
}
|
265
|
+
size_t idx = 0;
|
266
|
+
while (keep_max_size_level0 && (output.size() < max_size) &&
|
267
|
+
(idx < outsiders.size())) {
|
268
|
+
output.push_back(outsiders[idx++]);
|
269
|
+
}
|
255
270
|
}
|
256
271
|
|
257
272
|
namespace {
|
@@ -268,7 +283,8 @@ using NodeDistFarther = HNSW::NodeDistFarther;
|
|
268
283
|
void shrink_neighbor_list(
|
269
284
|
DistanceComputer& qdis,
|
270
285
|
std::priority_queue<NodeDistCloser>& resultSet1,
|
271
|
-
int max_size
|
286
|
+
int max_size,
|
287
|
+
bool keep_max_size_level0 = false) {
|
272
288
|
if (resultSet1.size() < max_size) {
|
273
289
|
return;
|
274
290
|
}
|
@@ -280,7 +296,8 @@ void shrink_neighbor_list(
|
|
280
296
|
resultSet1.pop();
|
281
297
|
}
|
282
298
|
|
283
|
-
HNSW::shrink_neighbor_list(
|
299
|
+
HNSW::shrink_neighbor_list(
|
300
|
+
qdis, resultSet, returnlist, max_size, keep_max_size_level0);
|
284
301
|
|
285
302
|
for (NodeDistFarther curen2 : returnlist) {
|
286
303
|
resultSet1.emplace(curen2.d, curen2.id);
|
@@ -294,7 +311,8 @@ void add_link(
|
|
294
311
|
DistanceComputer& qdis,
|
295
312
|
storage_idx_t src,
|
296
313
|
storage_idx_t dest,
|
297
|
-
int level
|
314
|
+
int level,
|
315
|
+
bool keep_max_size_level0 = false) {
|
298
316
|
size_t begin, end;
|
299
317
|
hnsw.neighbor_range(src, level, &begin, &end);
|
300
318
|
if (hnsw.neighbors[end - 1] == -1) {
|
@@ -319,7 +337,7 @@ void add_link(
|
|
319
337
|
resultSet.emplace(qdis.symmetric_dis(src, neigh), neigh);
|
320
338
|
}
|
321
339
|
|
322
|
-
shrink_neighbor_list(qdis, resultSet, end - begin);
|
340
|
+
shrink_neighbor_list(qdis, resultSet, end - begin, keep_max_size_level0);
|
323
341
|
|
324
342
|
// ...and back
|
325
343
|
size_t i = begin;
|
@@ -333,6 +351,8 @@ void add_link(
|
|
333
351
|
}
|
334
352
|
}
|
335
353
|
|
354
|
+
} // namespace
|
355
|
+
|
336
356
|
/// search neighbors on a single level, starting from an entry point
|
337
357
|
void search_neighbors_to_add(
|
338
358
|
HNSW& hnsw,
|
@@ -341,7 +361,8 @@ void search_neighbors_to_add(
|
|
341
361
|
int entry_point,
|
342
362
|
float d_entry_point,
|
343
363
|
int level,
|
344
|
-
VisitedTable& vt
|
364
|
+
VisitedTable& vt,
|
365
|
+
bool reference_version) {
|
345
366
|
// top is nearest candidate
|
346
367
|
std::priority_queue<NodeDistFarther> candidates;
|
347
368
|
|
@@ -363,62 +384,98 @@ void search_neighbors_to_add(
|
|
363
384
|
// loop over neighbors
|
364
385
|
size_t begin, end;
|
365
386
|
hnsw.neighbor_range(currNode, level, &begin, &end);
|
366
|
-
for (size_t i = begin; i < end; i++) {
|
367
|
-
storage_idx_t nodeId = hnsw.neighbors[i];
|
368
|
-
if (nodeId < 0)
|
369
|
-
break;
|
370
|
-
if (vt.get(nodeId))
|
371
|
-
continue;
|
372
|
-
vt.set(nodeId);
|
373
387
|
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
388
|
+
// The reference version is not used, but kept here because:
|
389
|
+
// 1. It is easier to switch back if the optimized version has a problem
|
390
|
+
// 2. It serves as a starting point for new optimizations
|
391
|
+
// 3. It helps understand the code
|
392
|
+
// 4. It ensures the reference version is still compilable if the
|
393
|
+
// optimized version changes
|
394
|
+
// The reference and the optimized versions' results are compared in
|
395
|
+
// test_hnsw.cpp
|
396
|
+
if (reference_version) {
|
397
|
+
// a reference version
|
398
|
+
for (size_t i = begin; i < end; i++) {
|
399
|
+
storage_idx_t nodeId = hnsw.neighbors[i];
|
400
|
+
if (nodeId < 0)
|
401
|
+
break;
|
402
|
+
if (vt.get(nodeId))
|
403
|
+
continue;
|
404
|
+
vt.set(nodeId);
|
405
|
+
|
406
|
+
float dis = qdis(nodeId);
|
407
|
+
NodeDistFarther evE1(dis, nodeId);
|
408
|
+
|
409
|
+
if (results.size() < hnsw.efConstruction ||
|
410
|
+
results.top().d > dis) {
|
411
|
+
results.emplace(dis, nodeId);
|
412
|
+
candidates.emplace(dis, nodeId);
|
413
|
+
if (results.size() > hnsw.efConstruction) {
|
414
|
+
results.pop();
|
415
|
+
}
|
382
416
|
}
|
383
417
|
}
|
384
|
-
}
|
385
|
-
|
386
|
-
|
387
|
-
|
418
|
+
} else {
|
419
|
+
// a faster version
|
420
|
+
|
421
|
+
// the following version processes 4 neighbors at a time
|
422
|
+
auto update_with_candidate = [&](const storage_idx_t idx,
|
423
|
+
const float dis) {
|
424
|
+
if (results.size() < hnsw.efConstruction ||
|
425
|
+
results.top().d > dis) {
|
426
|
+
results.emplace(dis, idx);
|
427
|
+
candidates.emplace(dis, idx);
|
428
|
+
if (results.size() > hnsw.efConstruction) {
|
429
|
+
results.pop();
|
430
|
+
}
|
431
|
+
}
|
432
|
+
};
|
388
433
|
|
389
|
-
|
390
|
-
|
391
|
-
**************************************************************/
|
434
|
+
int n_buffered = 0;
|
435
|
+
storage_idx_t buffered_ids[4];
|
392
436
|
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
437
|
+
for (size_t j = begin; j < end; j++) {
|
438
|
+
storage_idx_t nodeId = hnsw.neighbors[j];
|
439
|
+
if (nodeId < 0)
|
440
|
+
break;
|
441
|
+
if (vt.get(nodeId)) {
|
442
|
+
continue;
|
443
|
+
}
|
444
|
+
vt.set(nodeId);
|
445
|
+
|
446
|
+
buffered_ids[n_buffered] = nodeId;
|
447
|
+
n_buffered += 1;
|
448
|
+
|
449
|
+
if (n_buffered == 4) {
|
450
|
+
float dis[4];
|
451
|
+
qdis.distances_batch_4(
|
452
|
+
buffered_ids[0],
|
453
|
+
buffered_ids[1],
|
454
|
+
buffered_ids[2],
|
455
|
+
buffered_ids[3],
|
456
|
+
dis[0],
|
457
|
+
dis[1],
|
458
|
+
dis[2],
|
459
|
+
dis[3]);
|
460
|
+
|
461
|
+
for (size_t id4 = 0; id4 < 4; id4++) {
|
462
|
+
update_with_candidate(buffered_ids[id4], dis[id4]);
|
463
|
+
}
|
402
464
|
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
nearest = v;
|
412
|
-
d_nearest = dis;
|
465
|
+
n_buffered = 0;
|
466
|
+
}
|
467
|
+
}
|
468
|
+
|
469
|
+
// process leftovers
|
470
|
+
for (size_t icnt = 0; icnt < n_buffered; icnt++) {
|
471
|
+
float dis = qdis(buffered_ids[icnt]);
|
472
|
+
update_with_candidate(buffered_ids[icnt], dis);
|
413
473
|
}
|
414
|
-
}
|
415
|
-
if (nearest == prev_nearest) {
|
416
|
-
return;
|
417
474
|
}
|
418
475
|
}
|
419
|
-
}
|
420
476
|
|
421
|
-
|
477
|
+
vt.advance();
|
478
|
+
}
|
422
479
|
|
423
480
|
/// Finds neighbors and builds links with them, starting from an entry
|
424
481
|
/// point. The own neighbor list is assumed to be locked.
|
@@ -429,7 +486,8 @@ void HNSW::add_links_starting_from(
|
|
429
486
|
float d_nearest,
|
430
487
|
int level,
|
431
488
|
omp_lock_t* locks,
|
432
|
-
VisitedTable& vt
|
489
|
+
VisitedTable& vt,
|
490
|
+
bool keep_max_size_level0) {
|
433
491
|
std::priority_queue<NodeDistCloser> link_targets;
|
434
492
|
|
435
493
|
search_neighbors_to_add(
|
@@ -438,21 +496,21 @@ void HNSW::add_links_starting_from(
|
|
438
496
|
// but we can afford only this many neighbors
|
439
497
|
int M = nb_neighbors(level);
|
440
498
|
|
441
|
-
::faiss::shrink_neighbor_list(ptdis, link_targets, M);
|
499
|
+
::faiss::shrink_neighbor_list(ptdis, link_targets, M, keep_max_size_level0);
|
442
500
|
|
443
|
-
std::vector<storage_idx_t>
|
444
|
-
|
501
|
+
std::vector<storage_idx_t> neighbors_to_add;
|
502
|
+
neighbors_to_add.reserve(link_targets.size());
|
445
503
|
while (!link_targets.empty()) {
|
446
504
|
storage_idx_t other_id = link_targets.top().id;
|
447
|
-
add_link(*this, ptdis, pt_id, other_id, level);
|
448
|
-
|
505
|
+
add_link(*this, ptdis, pt_id, other_id, level, keep_max_size_level0);
|
506
|
+
neighbors_to_add.push_back(other_id);
|
449
507
|
link_targets.pop();
|
450
508
|
}
|
451
509
|
|
452
510
|
omp_unset_lock(&locks[pt_id]);
|
453
|
-
for (storage_idx_t other_id :
|
511
|
+
for (storage_idx_t other_id : neighbors_to_add) {
|
454
512
|
omp_set_lock(&locks[other_id]);
|
455
|
-
add_link(*this, ptdis, other_id, pt_id, level);
|
513
|
+
add_link(*this, ptdis, other_id, pt_id, level, keep_max_size_level0);
|
456
514
|
omp_unset_lock(&locks[other_id]);
|
457
515
|
}
|
458
516
|
omp_set_lock(&locks[pt_id]);
|
@@ -467,7 +525,8 @@ void HNSW::add_with_locks(
|
|
467
525
|
int pt_level,
|
468
526
|
int pt_id,
|
469
527
|
std::vector<omp_lock_t>& locks,
|
470
|
-
VisitedTable& vt
|
528
|
+
VisitedTable& vt,
|
529
|
+
bool keep_max_size_level0) {
|
471
530
|
// greedy search on upper levels
|
472
531
|
|
473
532
|
storage_idx_t nearest;
|
@@ -496,7 +555,14 @@ void HNSW::add_with_locks(
|
|
496
555
|
|
497
556
|
for (; level >= 0; level--) {
|
498
557
|
add_links_starting_from(
|
499
|
-
ptdis,
|
558
|
+
ptdis,
|
559
|
+
pt_id,
|
560
|
+
nearest,
|
561
|
+
d_nearest,
|
562
|
+
level,
|
563
|
+
locks.data(),
|
564
|
+
vt,
|
565
|
+
keep_max_size_level0);
|
500
566
|
}
|
501
567
|
|
502
568
|
omp_unset_lock(&locks[pt_id]);
|
@@ -511,12 +577,10 @@ void HNSW::add_with_locks(
|
|
511
577
|
* Searching
|
512
578
|
**************************************************************/
|
513
579
|
|
514
|
-
namespace {
|
515
580
|
using MinimaxHeap = HNSW::MinimaxHeap;
|
516
581
|
using Node = HNSW::Node;
|
517
582
|
using C = HNSW::C;
|
518
583
|
/** Do a BFS on the candidates list */
|
519
|
-
|
520
584
|
int search_from_candidates(
|
521
585
|
const HNSW& hnsw,
|
522
586
|
DistanceComputer& qdis,
|
@@ -525,8 +589,8 @@ int search_from_candidates(
|
|
525
589
|
VisitedTable& vt,
|
526
590
|
HNSWStats& stats,
|
527
591
|
int level,
|
528
|
-
int nres_in
|
529
|
-
const SearchParametersHNSW* params
|
592
|
+
int nres_in,
|
593
|
+
const SearchParametersHNSW* params) {
|
530
594
|
int nres = nres_in;
|
531
595
|
int ndis = 0;
|
532
596
|
|
@@ -571,27 +635,7 @@ int search_from_candidates(
|
|
571
635
|
size_t begin, end;
|
572
636
|
hnsw.neighbor_range(v0, level, &begin, &end);
|
573
637
|
|
574
|
-
//
|
575
|
-
// for (size_t j = begin; j < end; j++) {
|
576
|
-
// int v1 = hnsw.neighbors[j];
|
577
|
-
// if (v1 < 0)
|
578
|
-
// break;
|
579
|
-
// if (vt.get(v1)) {
|
580
|
-
// continue;
|
581
|
-
// }
|
582
|
-
// vt.set(v1);
|
583
|
-
// ndis++;
|
584
|
-
// float d = qdis(v1);
|
585
|
-
// if (!sel || sel->is_member(v1)) {
|
586
|
-
// if (nres < k) {
|
587
|
-
// faiss::maxheap_push(++nres, D, I, d, v1);
|
588
|
-
// } else if (d < D[0]) {
|
589
|
-
// faiss::maxheap_replace_top(nres, D, I, d, v1);
|
590
|
-
// }
|
591
|
-
// }
|
592
|
-
// candidates.push(v1, d);
|
593
|
-
// }
|
594
|
-
|
638
|
+
// a faster version: reference version in unit test test_hnsw.cpp
|
595
639
|
// the following version processes 4 neighbors at a time
|
596
640
|
size_t jmax = begin;
|
597
641
|
for (size_t j = begin; j < end; j++) {
|
@@ -606,7 +650,6 @@ int search_from_candidates(
|
|
606
650
|
int counter = 0;
|
607
651
|
size_t saved_j[4];
|
608
652
|
|
609
|
-
ndis += jmax - begin;
|
610
653
|
threshold = res.threshold;
|
611
654
|
|
612
655
|
auto add_to_heap = [&](const size_t idx, const float dis) {
|
@@ -614,6 +657,7 @@ int search_from_candidates(
|
|
614
657
|
if (dis < threshold) {
|
615
658
|
if (res.add_result(dis, idx)) {
|
616
659
|
threshold = res.threshold;
|
660
|
+
nres += 1;
|
617
661
|
}
|
618
662
|
}
|
619
663
|
}
|
@@ -644,6 +688,8 @@ int search_from_candidates(
|
|
644
688
|
add_to_heap(saved_j[id4], dis[id4]);
|
645
689
|
}
|
646
690
|
|
691
|
+
ndis += 4;
|
692
|
+
|
647
693
|
counter = 0;
|
648
694
|
}
|
649
695
|
}
|
@@ -651,6 +697,8 @@ int search_from_candidates(
|
|
651
697
|
for (size_t icnt = 0; icnt < counter; icnt++) {
|
652
698
|
float dis = qdis(saved_j[icnt]);
|
653
699
|
add_to_heap(saved_j[icnt], dis);
|
700
|
+
|
701
|
+
ndis += 1;
|
654
702
|
}
|
655
703
|
|
656
704
|
nstep++;
|
@@ -664,7 +712,8 @@ int search_from_candidates(
|
|
664
712
|
if (candidates.size() == 0) {
|
665
713
|
stats.n2++;
|
666
714
|
}
|
667
|
-
stats.
|
715
|
+
stats.ndis += ndis;
|
716
|
+
stats.nhops += nstep;
|
668
717
|
}
|
669
718
|
|
670
719
|
return nres;
|
@@ -700,33 +749,7 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
|
|
700
749
|
size_t begin, end;
|
701
750
|
hnsw.neighbor_range(v0, 0, &begin, &end);
|
702
751
|
|
703
|
-
//
|
704
|
-
// for (size_t j = begin; j < end; ++j) {
|
705
|
-
// int v1 = hnsw.neighbors[j];
|
706
|
-
//
|
707
|
-
// if (v1 < 0) {
|
708
|
-
// break;
|
709
|
-
// }
|
710
|
-
// if (vt->get(v1)) {
|
711
|
-
// continue;
|
712
|
-
// }
|
713
|
-
//
|
714
|
-
// vt->set(v1);
|
715
|
-
//
|
716
|
-
// float d1 = qdis(v1);
|
717
|
-
// ++ndis;
|
718
|
-
//
|
719
|
-
// if (top_candidates.top().first > d1 ||
|
720
|
-
// top_candidates.size() < ef) {
|
721
|
-
// candidates.emplace(d1, v1);
|
722
|
-
// top_candidates.emplace(d1, v1);
|
723
|
-
//
|
724
|
-
// if (top_candidates.size() > ef) {
|
725
|
-
// top_candidates.pop();
|
726
|
-
// }
|
727
|
-
// }
|
728
|
-
// }
|
729
|
-
|
752
|
+
// a faster version: reference version in unit test test_hnsw.cpp
|
730
753
|
// the following version processes 4 neighbors at a time
|
731
754
|
size_t jmax = begin;
|
732
755
|
for (size_t j = begin; j < end; j++) {
|
@@ -741,8 +764,6 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
|
|
741
764
|
int counter = 0;
|
742
765
|
size_t saved_j[4];
|
743
766
|
|
744
|
-
ndis += jmax - begin;
|
745
|
-
|
746
767
|
auto add_to_heap = [&](const size_t idx, const float dis) {
|
747
768
|
if (top_candidates.top().first > dis ||
|
748
769
|
top_candidates.size() < ef) {
|
@@ -779,6 +800,8 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
|
|
779
800
|
add_to_heap(saved_j[id4], dis[id4]);
|
780
801
|
}
|
781
802
|
|
803
|
+
ndis += 4;
|
804
|
+
|
782
805
|
counter = 0;
|
783
806
|
}
|
784
807
|
}
|
@@ -786,18 +809,102 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
|
|
786
809
|
for (size_t icnt = 0; icnt < counter; icnt++) {
|
787
810
|
float dis = qdis(saved_j[icnt]);
|
788
811
|
add_to_heap(saved_j[icnt], dis);
|
812
|
+
|
813
|
+
ndis += 1;
|
789
814
|
}
|
815
|
+
|
816
|
+
stats.nhops += 1;
|
790
817
|
}
|
791
818
|
|
792
819
|
++stats.n1;
|
793
820
|
if (candidates.size() == 0) {
|
794
821
|
++stats.n2;
|
795
822
|
}
|
796
|
-
stats.
|
823
|
+
stats.ndis += ndis;
|
797
824
|
|
798
825
|
return top_candidates;
|
799
826
|
}
|
800
827
|
|
828
|
+
/// greedily update a nearest vector at a given level
|
829
|
+
HNSWStats greedy_update_nearest(
|
830
|
+
const HNSW& hnsw,
|
831
|
+
DistanceComputer& qdis,
|
832
|
+
int level,
|
833
|
+
storage_idx_t& nearest,
|
834
|
+
float& d_nearest) {
|
835
|
+
HNSWStats stats;
|
836
|
+
|
837
|
+
for (;;) {
|
838
|
+
storage_idx_t prev_nearest = nearest;
|
839
|
+
|
840
|
+
size_t begin, end;
|
841
|
+
hnsw.neighbor_range(nearest, level, &begin, &end);
|
842
|
+
|
843
|
+
size_t ndis = 0;
|
844
|
+
|
845
|
+
// a faster version: reference version in unit test test_hnsw.cpp
|
846
|
+
// the following version processes 4 neighbors at a time
|
847
|
+
auto update_with_candidate = [&](const storage_idx_t idx,
|
848
|
+
const float dis) {
|
849
|
+
if (dis < d_nearest) {
|
850
|
+
nearest = idx;
|
851
|
+
d_nearest = dis;
|
852
|
+
}
|
853
|
+
};
|
854
|
+
|
855
|
+
int n_buffered = 0;
|
856
|
+
storage_idx_t buffered_ids[4];
|
857
|
+
|
858
|
+
for (size_t j = begin; j < end; j++) {
|
859
|
+
storage_idx_t v = hnsw.neighbors[j];
|
860
|
+
if (v < 0)
|
861
|
+
break;
|
862
|
+
ndis += 1;
|
863
|
+
|
864
|
+
buffered_ids[n_buffered] = v;
|
865
|
+
n_buffered += 1;
|
866
|
+
|
867
|
+
if (n_buffered == 4) {
|
868
|
+
float dis[4];
|
869
|
+
qdis.distances_batch_4(
|
870
|
+
buffered_ids[0],
|
871
|
+
buffered_ids[1],
|
872
|
+
buffered_ids[2],
|
873
|
+
buffered_ids[3],
|
874
|
+
dis[0],
|
875
|
+
dis[1],
|
876
|
+
dis[2],
|
877
|
+
dis[3]);
|
878
|
+
|
879
|
+
for (size_t id4 = 0; id4 < 4; id4++) {
|
880
|
+
update_with_candidate(buffered_ids[id4], dis[id4]);
|
881
|
+
}
|
882
|
+
|
883
|
+
n_buffered = 0;
|
884
|
+
}
|
885
|
+
}
|
886
|
+
|
887
|
+
// process leftovers
|
888
|
+
for (size_t icnt = 0; icnt < n_buffered; icnt++) {
|
889
|
+
float dis = qdis(buffered_ids[icnt]);
|
890
|
+
update_with_candidate(buffered_ids[icnt], dis);
|
891
|
+
}
|
892
|
+
|
893
|
+
// update stats
|
894
|
+
stats.ndis += ndis;
|
895
|
+
stats.nhops += 1;
|
896
|
+
|
897
|
+
if (nearest == prev_nearest) {
|
898
|
+
return stats;
|
899
|
+
}
|
900
|
+
}
|
901
|
+
}
|
902
|
+
|
903
|
+
namespace {
|
904
|
+
using MinimaxHeap = HNSW::MinimaxHeap;
|
905
|
+
using Node = HNSW::Node;
|
906
|
+
using C = HNSW::C;
|
907
|
+
|
801
908
|
// just used as a lower bound for the minmaxheap, but it is set for heap search
|
802
909
|
int extract_k_from_ResultHandler(ResultHandler<C>& res) {
|
803
910
|
using RH = HeapBlockResultHandler<C>;
|
@@ -807,7 +914,7 @@ int extract_k_from_ResultHandler(ResultHandler<C>& res) {
|
|
807
914
|
return 1;
|
808
915
|
}
|
809
916
|
|
810
|
-
} //
|
917
|
+
} // namespace
|
811
918
|
|
812
919
|
HNSWStats HNSW::search(
|
813
920
|
DistanceComputer& qdis,
|
@@ -820,85 +927,47 @@ HNSWStats HNSW::search(
|
|
820
927
|
}
|
821
928
|
int k = extract_k_from_ResultHandler(res);
|
822
929
|
|
823
|
-
|
824
|
-
|
825
|
-
storage_idx_t nearest = entry_point;
|
826
|
-
float d_nearest = qdis(nearest);
|
827
|
-
|
828
|
-
for (int level = max_level; level >= 1; level--) {
|
829
|
-
greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
|
830
|
-
}
|
831
|
-
|
832
|
-
int ef = std::max(params ? params->efSearch : efSearch, k);
|
833
|
-
if (search_bounded_queue) { // this is the most common branch
|
834
|
-
MinimaxHeap candidates(ef);
|
930
|
+
bool bounded_queue =
|
931
|
+
params ? params->bounded_queue : this->search_bounded_queue;
|
835
932
|
|
836
|
-
|
933
|
+
// greedy search on upper levels
|
934
|
+
storage_idx_t nearest = entry_point;
|
935
|
+
float d_nearest = qdis(nearest);
|
837
936
|
|
838
|
-
|
839
|
-
|
840
|
-
|
841
|
-
|
842
|
-
|
843
|
-
*this,
|
844
|
-
Node(d_nearest, nearest),
|
845
|
-
qdis,
|
846
|
-
ef,
|
847
|
-
&vt,
|
848
|
-
stats);
|
849
|
-
|
850
|
-
while (top_candidates.size() > k) {
|
851
|
-
top_candidates.pop();
|
852
|
-
}
|
937
|
+
for (int level = max_level; level >= 1; level--) {
|
938
|
+
HNSWStats local_stats =
|
939
|
+
greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
|
940
|
+
stats.combine(local_stats);
|
941
|
+
}
|
853
942
|
|
854
|
-
|
855
|
-
|
856
|
-
|
857
|
-
std::tie(d, label) = top_candidates.top();
|
858
|
-
res.add_result(d, label);
|
859
|
-
top_candidates.pop();
|
860
|
-
}
|
861
|
-
}
|
943
|
+
int ef = std::max(params ? params->efSearch : efSearch, k);
|
944
|
+
if (bounded_queue) { // this is the most common branch
|
945
|
+
MinimaxHeap candidates(ef);
|
862
946
|
|
863
|
-
|
947
|
+
candidates.push(nearest, d_nearest);
|
864
948
|
|
949
|
+
search_from_candidates(
|
950
|
+
*this, qdis, res, candidates, vt, stats, 0, 0, params);
|
865
951
|
} else {
|
866
|
-
|
867
|
-
|
868
|
-
|
869
|
-
std::vector<idx_t> I_to_next(candidates_size);
|
870
|
-
std::vector<float> D_to_next(candidates_size);
|
871
|
-
|
872
|
-
HeapBlockResultHandler<C> block_resh(
|
873
|
-
1, D_to_next.data(), I_to_next.data(), candidates_size);
|
874
|
-
HeapBlockResultHandler<C>::SingleResultHandler resh(block_resh);
|
875
|
-
|
876
|
-
int nres = 1;
|
877
|
-
I_to_next[0] = entry_point;
|
878
|
-
D_to_next[0] = qdis(entry_point);
|
879
|
-
|
880
|
-
for (int level = max_level; level >= 0; level--) {
|
881
|
-
// copy I, D -> candidates
|
952
|
+
std::priority_queue<Node> top_candidates =
|
953
|
+
search_from_candidate_unbounded(
|
954
|
+
*this, Node(d_nearest, nearest), qdis, ef, &vt, stats);
|
882
955
|
|
883
|
-
|
884
|
-
|
885
|
-
|
886
|
-
candidates.push(I_to_next[i], D_to_next[i]);
|
887
|
-
}
|
956
|
+
while (top_candidates.size() > k) {
|
957
|
+
top_candidates.pop();
|
958
|
+
}
|
888
959
|
|
889
|
-
|
890
|
-
|
891
|
-
|
892
|
-
|
893
|
-
|
894
|
-
|
895
|
-
*this, qdis, resh, candidates, vt, stats, level);
|
896
|
-
resh.end();
|
897
|
-
}
|
898
|
-
vt.advance();
|
960
|
+
while (!top_candidates.empty()) {
|
961
|
+
float d;
|
962
|
+
storage_idx_t label;
|
963
|
+
std::tie(d, label) = top_candidates.top();
|
964
|
+
res.add_result(d, label);
|
965
|
+
top_candidates.pop();
|
899
966
|
}
|
900
967
|
}
|
901
968
|
|
969
|
+
vt.advance();
|
970
|
+
|
902
971
|
return stats;
|
903
972
|
}
|
904
973
|
|
@@ -910,9 +979,12 @@ void HNSW::search_level_0(
|
|
910
979
|
const float* nearest_d,
|
911
980
|
int search_type,
|
912
981
|
HNSWStats& search_stats,
|
913
|
-
VisitedTable& vt
|
982
|
+
VisitedTable& vt,
|
983
|
+
const SearchParametersHNSW* params) const {
|
914
984
|
const HNSW& hnsw = *this;
|
985
|
+
auto efSearch = params ? params->efSearch : hnsw.efSearch;
|
915
986
|
int k = extract_k_from_ResultHandler(res);
|
987
|
+
|
916
988
|
if (search_type == 1) {
|
917
989
|
int nres = 0;
|
918
990
|
|
@@ -925,16 +997,25 @@ void HNSW::search_level_0(
|
|
925
997
|
if (vt.get(cj))
|
926
998
|
continue;
|
927
999
|
|
928
|
-
int candidates_size = std::max(
|
1000
|
+
int candidates_size = std::max(efSearch, k);
|
929
1001
|
MinimaxHeap candidates(candidates_size);
|
930
1002
|
|
931
1003
|
candidates.push(cj, nearest_d[j]);
|
932
1004
|
|
933
1005
|
nres = search_from_candidates(
|
934
|
-
hnsw,
|
1006
|
+
hnsw,
|
1007
|
+
qdis,
|
1008
|
+
res,
|
1009
|
+
candidates,
|
1010
|
+
vt,
|
1011
|
+
search_stats,
|
1012
|
+
0,
|
1013
|
+
nres,
|
1014
|
+
params);
|
1015
|
+
nres = std::min(nres, candidates_size);
|
935
1016
|
}
|
936
1017
|
} else if (search_type == 2) {
|
937
|
-
int candidates_size = std::max(
|
1018
|
+
int candidates_size = std::max(efSearch, int(k));
|
938
1019
|
candidates_size = std::max(candidates_size, int(nprobe));
|
939
1020
|
|
940
1021
|
MinimaxHeap candidates(candidates_size);
|
@@ -947,7 +1028,7 @@ void HNSW::search_level_0(
|
|
947
1028
|
}
|
948
1029
|
|
949
1030
|
search_from_candidates(
|
950
|
-
hnsw, qdis, res, candidates, vt, search_stats, 0);
|
1031
|
+
hnsw, qdis, res, candidates, vt, search_stats, 0, 0, params);
|
951
1032
|
}
|
952
1033
|
}
|
953
1034
|
|
@@ -1013,7 +1094,99 @@ void HNSW::MinimaxHeap::clear() {
|
|
1013
1094
|
nvalid = k = 0;
|
1014
1095
|
}
|
1015
1096
|
|
1016
|
-
#ifdef
|
1097
|
+
#ifdef __AVX512F__
|
1098
|
+
|
1099
|
+
int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
|
1100
|
+
assert(k > 0);
|
1101
|
+
static_assert(
|
1102
|
+
std::is_same<storage_idx_t, int32_t>::value,
|
1103
|
+
"This code expects storage_idx_t to be int32_t");
|
1104
|
+
|
1105
|
+
int32_t min_idx = -1;
|
1106
|
+
float min_dis = std::numeric_limits<float>::infinity();
|
1107
|
+
|
1108
|
+
__m512i min_indices = _mm512_set1_epi32(-1);
|
1109
|
+
__m512 min_distances =
|
1110
|
+
_mm512_set1_ps(std::numeric_limits<float>::infinity());
|
1111
|
+
__m512i current_indices = _mm512_setr_epi32(
|
1112
|
+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
|
1113
|
+
__m512i offset = _mm512_set1_epi32(16);
|
1114
|
+
|
1115
|
+
// The following loop tracks the rightmost index with the min distance.
|
1116
|
+
// -1 index values are ignored.
|
1117
|
+
const int k16 = (k / 16) * 16;
|
1118
|
+
for (size_t iii = 0; iii < k16; iii += 16) {
|
1119
|
+
__m512i indices =
|
1120
|
+
_mm512_loadu_si512((const __m512i*)(ids.data() + iii));
|
1121
|
+
__m512 distances = _mm512_loadu_ps(dis.data() + iii);
|
1122
|
+
|
1123
|
+
// This mask filters out -1 values among indices.
|
1124
|
+
__mmask16 m1mask =
|
1125
|
+
_mm512_cmpgt_epi32_mask(_mm512_setzero_si512(), indices);
|
1126
|
+
|
1127
|
+
__mmask16 dmask =
|
1128
|
+
_mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS);
|
1129
|
+
__mmask16 finalmask = m1mask | dmask;
|
1130
|
+
|
1131
|
+
const __m512i min_indices_new = _mm512_mask_blend_epi32(
|
1132
|
+
finalmask, current_indices, min_indices);
|
1133
|
+
const __m512 min_distances_new =
|
1134
|
+
_mm512_mask_blend_ps(finalmask, distances, min_distances);
|
1135
|
+
|
1136
|
+
min_indices = min_indices_new;
|
1137
|
+
min_distances = min_distances_new;
|
1138
|
+
|
1139
|
+
current_indices = _mm512_add_epi32(current_indices, offset);
|
1140
|
+
}
|
1141
|
+
|
1142
|
+
// leftovers
|
1143
|
+
if (k16 != k) {
|
1144
|
+
const __mmask16 kmask = (1 << (k - k16)) - 1;
|
1145
|
+
|
1146
|
+
__m512i indices = _mm512_mask_loadu_epi32(
|
1147
|
+
_mm512_set1_epi32(-1), kmask, ids.data() + k16);
|
1148
|
+
__m512 distances = _mm512_maskz_loadu_ps(kmask, dis.data() + k16);
|
1149
|
+
|
1150
|
+
// This mask filters out -1 values among indices.
|
1151
|
+
__mmask16 m1mask =
|
1152
|
+
_mm512_cmpgt_epi32_mask(_mm512_setzero_si512(), indices);
|
1153
|
+
|
1154
|
+
__mmask16 dmask =
|
1155
|
+
_mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS);
|
1156
|
+
__mmask16 finalmask = m1mask | dmask;
|
1157
|
+
|
1158
|
+
const __m512i min_indices_new = _mm512_mask_blend_epi32(
|
1159
|
+
finalmask, current_indices, min_indices);
|
1160
|
+
const __m512 min_distances_new =
|
1161
|
+
_mm512_mask_blend_ps(finalmask, distances, min_distances);
|
1162
|
+
|
1163
|
+
min_indices = min_indices_new;
|
1164
|
+
min_distances = min_distances_new;
|
1165
|
+
}
|
1166
|
+
|
1167
|
+
// grab min distance
|
1168
|
+
min_dis = _mm512_reduce_min_ps(min_distances);
|
1169
|
+
// blend
|
1170
|
+
__mmask16 mindmask =
|
1171
|
+
_mm512_cmpeq_ps_mask(min_distances, _mm512_set1_ps(min_dis));
|
1172
|
+
// pick the max one
|
1173
|
+
min_idx = _mm512_mask_reduce_max_epi32(mindmask, min_indices);
|
1174
|
+
|
1175
|
+
if (min_idx == -1) {
|
1176
|
+
return -1;
|
1177
|
+
}
|
1178
|
+
|
1179
|
+
if (vmin_out) {
|
1180
|
+
*vmin_out = min_dis;
|
1181
|
+
}
|
1182
|
+
int ret = ids[min_idx];
|
1183
|
+
ids[min_idx] = -1;
|
1184
|
+
--nvalid;
|
1185
|
+
return ret;
|
1186
|
+
}
|
1187
|
+
|
1188
|
+
#elif __AVX2__
|
1189
|
+
|
1017
1190
|
int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
|
1018
1191
|
assert(k > 0);
|
1019
1192
|
static_assert(
|