faiss 0.6.0 → 0.6.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/ext/faiss/extconf.rb +2 -1
- data/ext/faiss/{index_rb.cpp → index.cpp} +1 -1
- data/ext/faiss/index_binary.cpp +1 -1
- data/ext/faiss/kmeans.cpp +1 -1
- data/ext/faiss/pca_matrix.cpp +1 -1
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/ext/faiss/{utils_rb.cpp → utils.cpp} +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +93 -80
- data/vendor/faiss/faiss/Clustering.cpp +39 -240
- data/vendor/faiss/faiss/Clustering.h +6 -0
- data/vendor/faiss/faiss/IVFlib.cpp +41 -21
- data/vendor/faiss/faiss/Index.cpp +6 -5
- data/vendor/faiss/faiss/Index.h +5 -5
- data/vendor/faiss/faiss/Index2Layer.cpp +37 -53
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +49 -37
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +36 -34
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +4 -1
- data/vendor/faiss/faiss/IndexBinary.cpp +5 -3
- data/vendor/faiss/faiss/IndexBinary.h +4 -4
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +1 -1
- data/vendor/faiss/faiss/IndexBinaryFlat.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -4
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +84 -92
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +9 -3
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +45 -236
- data/vendor/faiss/faiss/IndexBinaryHash.h +6 -6
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +87 -415
- data/vendor/faiss/faiss/IndexFastScan.cpp +72 -109
- data/vendor/faiss/faiss/IndexFastScan.h +25 -23
- data/vendor/faiss/faiss/IndexFlat.cpp +27 -20
- data/vendor/faiss/faiss/IndexFlat.h +21 -18
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +42 -19
- data/vendor/faiss/faiss/IndexHNSW.cpp +283 -145
- data/vendor/faiss/faiss/IndexHNSW.h +16 -2
- data/vendor/faiss/faiss/IndexIDMap.cpp +25 -21
- data/vendor/faiss/faiss/IndexIDMap.h +9 -7
- data/vendor/faiss/faiss/IndexIVF.cpp +465 -362
- data/vendor/faiss/faiss/IndexIVF.h +33 -12
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +77 -74
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +96 -93
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -1
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +357 -238
- data/vendor/faiss/faiss/IndexIVFFastScan.h +42 -41
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +36 -68
- data/vendor/faiss/faiss/IndexIVFFlat.h +32 -0
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +53 -30
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +3 -1
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +18 -15
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +71 -843
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +151 -121
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +3 -0
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +21 -17
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +26 -39
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +475 -476
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +248 -93
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +41 -127
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
- data/vendor/faiss/faiss/IndexLSH.cpp +36 -19
- data/vendor/faiss/faiss/IndexLattice.cpp +13 -13
- data/vendor/faiss/faiss/IndexNNDescent.cpp +36 -21
- data/vendor/faiss/faiss/IndexNNDescent.h +2 -2
- data/vendor/faiss/faiss/IndexNSG.cpp +39 -23
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +31 -11
- data/vendor/faiss/faiss/IndexPQ.cpp +128 -221
- data/vendor/faiss/faiss/IndexPQ.h +3 -2
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +20 -14
- data/vendor/faiss/faiss/IndexPQFastScan.h +3 -0
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -18
- data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +11 -36
- data/vendor/faiss/faiss/IndexRaBitQ.h +2 -1
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +41 -277
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +183 -27
- data/vendor/faiss/faiss/IndexRefine.cpp +30 -25
- data/vendor/faiss/faiss/IndexRefine.h +4 -4
- data/vendor/faiss/faiss/IndexReplicas.cpp +6 -6
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +15 -14
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +1 -1
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +82 -14
- data/vendor/faiss/faiss/IndexShards.cpp +10 -9
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +21 -15
- data/vendor/faiss/faiss/MatrixStats.cpp +5 -4
- data/vendor/faiss/faiss/MetaIndexes.cpp +19 -17
- data/vendor/faiss/faiss/MetaIndexes.h +1 -1
- data/vendor/faiss/faiss/MetricType.h +14 -7
- data/vendor/faiss/faiss/SuperKMeans.cpp +656 -0
- data/vendor/faiss/faiss/SuperKMeans.h +97 -0
- data/vendor/faiss/faiss/VectorTransform.cpp +237 -149
- data/vendor/faiss/faiss/VectorTransform.h +16 -16
- data/vendor/faiss/faiss/build.cpp +23 -0
- data/vendor/faiss/faiss/build.h +15 -0
- data/vendor/faiss/faiss/clone_index.cpp +48 -47
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +47 -47
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +11 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +38 -38
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +11 -0
- data/vendor/faiss/faiss/factory_tools.cpp +5 -0
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +6 -5
- data/vendor/faiss/faiss/gpu/GpuResources.h +1 -1
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +9 -9
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +4 -3
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +46 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +56 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +78 -1
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +72 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +23 -0
- data/vendor/faiss/faiss/gpu/utils/CuvsFilterConvert.h +1 -1
- data/vendor/faiss/faiss/gpu/utils/CuvsUtils.h +21 -10
- data/vendor/faiss/faiss/gpu_metal/GpuIndexFlat.h +22 -0
- data/vendor/faiss/faiss/gpu_metal/MetalCloner.h +35 -0
- data/vendor/faiss/faiss/gpu_metal/MetalFlatKernels.h +40 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +51 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndexFlat.h +65 -0
- data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +66 -0
- data/vendor/faiss/faiss/gpu_metal/MetalResources.h +79 -0
- data/vendor/faiss/faiss/gpu_metal/StandardMetalResources.h +35 -0
- data/vendor/faiss/faiss/impl/AdSampling.cpp +103 -0
- data/vendor/faiss/faiss/impl/AdSampling.h +35 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +29 -25
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +10 -9
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +3 -0
- data/vendor/faiss/faiss/impl/ClusteringHelpers.cpp +244 -0
- data/vendor/faiss/faiss/impl/ClusteringHelpers.h +94 -0
- data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +16 -16
- data/vendor/faiss/faiss/impl/CodePacker.cpp +3 -3
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +1 -1
- data/vendor/faiss/faiss/impl/DistanceComputer.h +8 -8
- data/vendor/faiss/faiss/impl/FaissAssert.h +6 -3
- data/vendor/faiss/faiss/impl/FaissException.h +50 -3
- data/vendor/faiss/faiss/impl/HNSW.cpp +92 -317
- data/vendor/faiss/faiss/impl/HNSW.h +13 -34
- data/vendor/faiss/faiss/impl/IDSelector.cpp +15 -11
- data/vendor/faiss/faiss/impl/IDSelector.h +8 -8
- data/vendor/faiss/faiss/impl/InvertedListScannerStats.h +26 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +82 -77
- data/vendor/faiss/faiss/impl/NNDescent.cpp +62 -25
- data/vendor/faiss/faiss/impl/NNDescent.h +6 -2
- data/vendor/faiss/faiss/impl/NSG.cpp +38 -21
- data/vendor/faiss/faiss/impl/NSG.h +4 -4
- data/vendor/faiss/faiss/impl/Panorama.cpp +23 -6
- data/vendor/faiss/faiss/impl/Panorama.h +258 -87
- data/vendor/faiss/faiss/impl/PdxLayout.cpp +93 -0
- data/vendor/faiss/faiss/impl/PdxLayout.h +41 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +46 -32
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +3 -3
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +35 -35
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +21 -16
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +30 -23
- data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +55 -49
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +65 -0
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +296 -283
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +26 -23
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResultHandler.h +99 -75
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +52 -4
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +27 -1
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +14 -11
- data/vendor/faiss/faiss/impl/VisitedTable.h +7 -0
- data/vendor/faiss/faiss/impl/approx_topk/approx_topk.h +276 -0
- data/vendor/faiss/faiss/impl/approx_topk/avx2.cpp +68 -0
- data/vendor/faiss/faiss/{utils → impl}/approx_topk/generic.h +15 -8
- data/vendor/faiss/faiss/impl/approx_topk/neon.cpp +68 -0
- data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab-inl.h +169 -0
- data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab.h +117 -0
- data/vendor/faiss/faiss/impl/approx_topk/simdlib256-inl.h +146 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHNSW_impl.h +73 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHash_impl.h +270 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryIVF_impl.h +460 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexIVFSpectralHash_impl.h +159 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexPQ_impl.h +92 -0
- data/vendor/faiss/faiss/impl/binary_hamming/avx2.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/avx512.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/dispatch.h +143 -0
- data/vendor/faiss/faiss/impl/binary_hamming/neon.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/rvv.cpp +26 -0
- data/vendor/faiss/faiss/impl/expanded_scanners.h +8 -3
- data/vendor/faiss/faiss/impl/{FastScanDistancePostProcessing.h → fast_scan/FastScanDistancePostProcessing.h} +13 -6
- data/vendor/faiss/faiss/impl/{LookupTableScaler.h → fast_scan/LookupTableScaler.h} +16 -5
- data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops.h +237 -0
- data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops_512.h +185 -0
- data/vendor/faiss/faiss/impl/fast_scan/decompose_qbs.h +229 -0
- data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +268 -0
- data/vendor/faiss/faiss/impl/{pq4_fast_scan.cpp → fast_scan/fast_scan.cpp} +169 -2
- data/vendor/faiss/faiss/impl/fast_scan/fast_scan.h +341 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-avx2.cpp +36 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-avx512.cpp +40 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-neon.cpp +120 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-riscv.cpp +104 -0
- data/vendor/faiss/faiss/impl/fast_scan/kernels_simd256.h +213 -0
- data/vendor/faiss/faiss/impl/{pq4_fast_scan_search_qbs.cpp → fast_scan/kernels_simd512.h} +26 -356
- data/vendor/faiss/faiss/impl/fast_scan/rabitq_dispatching.h +90 -0
- data/vendor/faiss/faiss/impl/fast_scan/rabitq_result_handler.h +108 -0
- data/vendor/faiss/faiss/impl/{simd_result_handlers.h → fast_scan/simd_result_handlers.h} +282 -134
- data/vendor/faiss/faiss/impl/hnsw/LockVector.cpp +54 -0
- data/vendor/faiss/faiss/impl/hnsw/LockVector.h +64 -0
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +91 -0
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -0
- data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +104 -0
- data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +111 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +1132 -45
- data/vendor/faiss/faiss/impl/index_read_utils.h +1 -1
- data/vendor/faiss/faiss/impl/index_write.cpp +95 -13
- data/vendor/faiss/faiss/impl/io.cpp +6 -6
- data/vendor/faiss/faiss/impl/io_macros.h +33 -16
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +10 -10
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +37 -23
- data/vendor/faiss/faiss/impl/lattice_Zn.h +6 -6
- data/vendor/faiss/faiss/impl/mapped_io.cpp +6 -6
- data/vendor/faiss/faiss/impl/platform_macros.h +11 -4
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQScanner_impl.h +549 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.cpp +245 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.h +105 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/PQDistanceComputer_impl.h +106 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/{pq_code_distance-avx2.cpp → pq_code_distance-avx2.h} +9 -13
- data/vendor/faiss/faiss/impl/pq_code_distance/{pq_code_distance-avx512.cpp → pq_code_distance-avx512.h} +9 -57
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +29 -111
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.h +96 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +238 -5
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-sve.cpp +5 -7
- data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +68 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +311 -477
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
- data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +1 -1
- data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +3 -2
- data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +102 -11
- data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +27 -1
- data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +3 -3
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +148 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +167 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +59 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +163 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +311 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +192 -8
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +12 -0
- data/vendor/faiss/faiss/impl/simd_dispatch.h +100 -66
- data/vendor/faiss/faiss/impl/simdlib/simdlib.h +57 -0
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_avx2.h +264 -172
- data/vendor/faiss/faiss/impl/simdlib/simdlib_avx512.h +414 -0
- data/vendor/faiss/faiss/impl/simdlib/simdlib_dispatch.h +44 -0
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_emulated.h +231 -166
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_neon.h +270 -218
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_ppc64.h +201 -160
- data/vendor/faiss/faiss/impl/svs_io.cpp +12 -3
- data/vendor/faiss/faiss/impl/svs_io.h +8 -2
- data/vendor/faiss/faiss/index_factory.cpp +86 -18
- data/vendor/faiss/faiss/index_io.h +24 -0
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +66 -16
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +24 -14
- data/vendor/faiss/faiss/invlists/DirectMap.h +4 -3
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +157 -73
- data/vendor/faiss/faiss/invlists/InvertedLists.h +86 -23
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +4 -4
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +13 -13
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +2 -2
- data/vendor/faiss/faiss/svs/IndexSVSIVF.cpp +350 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVF.h +128 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.cpp +40 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.h +43 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.cpp +225 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.h +71 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +25 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +18 -2
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +12 -3
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +7 -2
- data/vendor/faiss/faiss/utils/Heap.cpp +10 -10
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +47 -36
- data/vendor/faiss/faiss/utils/NeuralNet.h +1 -1
- data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +10 -4
- data/vendor/faiss/faiss/utils/distances.cpp +390 -560
- data/vendor/faiss/faiss/utils/distances.h +20 -1
- data/vendor/faiss/faiss/utils/distances_dispatch.h +117 -37
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +8 -7
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +33 -14
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +12 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +16 -293
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based_neon.cpp +57 -0
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_kernel-inl.h +290 -0
- data/vendor/faiss/faiss/utils/distances_simd.cpp +5 -177
- data/vendor/faiss/faiss/utils/extra_distances.cpp +9 -8
- data/vendor/faiss/faiss/utils/extra_distances.h +32 -6
- data/vendor/faiss/faiss/utils/hamming-inl.h +13 -11
- data/vendor/faiss/faiss/utils/hamming.cpp +66 -517
- data/vendor/faiss/faiss/utils/hamming.h +92 -2
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +287 -10
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx2.h +142 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +234 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-generic.h +368 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-neon.h +322 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-rvv.h +39 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer.h +146 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_impl.h +481 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_neon.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_rvv.cpp +15 -0
- data/vendor/faiss/faiss/utils/partitioning.cpp +66 -987
- data/vendor/faiss/faiss/utils/partitioning.h +31 -0
- data/vendor/faiss/faiss/utils/popcount.h +29 -0
- data/vendor/faiss/faiss/utils/pq_code_distance.h +2 -2
- data/vendor/faiss/faiss/utils/prefetch.h +2 -2
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +30 -30
- data/vendor/faiss/faiss/utils/quantize_lut.h +1 -1
- data/vendor/faiss/faiss/utils/rabitq_simd.h +57 -536
- data/vendor/faiss/faiss/utils/random.cpp +6 -6
- data/vendor/faiss/faiss/utils/simd_impl/IVFFlatScanner-inl.h +51 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +5 -1
- data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +213 -4
- data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +163 -10
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +250 -4
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +7 -4
- data/vendor/faiss/faiss/utils/simd_impl/distances_rvv.cpp +189 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_simdlib256.h +195 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +2 -1
- data/vendor/faiss/faiss/utils/{distances_fused/simdlib_based.h → simd_impl/exhaustive_L2sqr_blas_cmax.h} +5 -10
- data/vendor/faiss/faiss/utils/simd_impl/hamming_impl.h +481 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_avx2.cpp +14 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_neon.cpp +14 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +1085 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx2.cpp +355 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512.cpp +477 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_neon.cpp +55 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_rvv.cpp +55 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_dispatch.h +32 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels.h +43 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx2.cpp +57 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx512.cpp +45 -0
- data/vendor/faiss/faiss/utils/simd_levels.cpp +17 -5
- data/vendor/faiss/faiss/utils/simd_levels.h +93 -1
- data/vendor/faiss/faiss/utils/sorting.cpp +48 -36
- data/vendor/faiss/faiss/utils/utils.cpp +5 -5
- data/vendor/faiss/faiss/utils/utils.h +3 -3
- metadata +119 -34
- data/vendor/faiss/faiss/impl/RaBitQStats.cpp +0 -29
- data/vendor/faiss/faiss/impl/RaBitQStats.h +0 -56
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +0 -224
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +0 -230
- data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +0 -84
- data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +0 -196
- data/vendor/faiss/faiss/utils/approx_topk/mode.h +0 -34
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +0 -36
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +0 -235
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +0 -462
- data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +0 -490
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -449
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +0 -87
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +0 -524
- data/vendor/faiss/faiss/utils/simdlib.h +0 -42
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +0 -365
- /data/ext/faiss/{utils_rb.h → utils.h} +0 -0
|
@@ -8,9 +8,9 @@
|
|
|
8
8
|
#include <faiss/IndexIVF.h>
|
|
9
9
|
|
|
10
10
|
#include <omp.h>
|
|
11
|
+
#include <atomic>
|
|
11
12
|
#include <cstdint>
|
|
12
13
|
#include <memory>
|
|
13
|
-
#include <mutex>
|
|
14
14
|
|
|
15
15
|
#include <algorithm>
|
|
16
16
|
#include <cinttypes>
|
|
@@ -37,8 +37,8 @@ using ScopedCodes = InvertedLists::ScopedCodes;
|
|
|
37
37
|
* Level1Quantizer implementation
|
|
38
38
|
******************************************/
|
|
39
39
|
|
|
40
|
-
Level1Quantizer::Level1Quantizer(Index*
|
|
41
|
-
: quantizer(
|
|
40
|
+
Level1Quantizer::Level1Quantizer(Index* quantizer_in, size_t nlist_in)
|
|
41
|
+
: quantizer(quantizer_in), nlist(nlist_in) {
|
|
42
42
|
// here we set a low # iterations because this is typically used
|
|
43
43
|
// for large clusterings (nb this is not used for the MultiIndex,
|
|
44
44
|
// for which quantizer_trains_alone = true)
|
|
@@ -58,8 +58,10 @@ void Level1Quantizer::train_q1(
|
|
|
58
58
|
const float* x,
|
|
59
59
|
bool verbose,
|
|
60
60
|
MetricType metric_type) {
|
|
61
|
+
FAISS_THROW_IF_NOT_MSG(quantizer, "IVF quantizer must not be null");
|
|
61
62
|
size_t d = quantizer->d;
|
|
62
|
-
if (quantizer->is_trained &&
|
|
63
|
+
if (quantizer->is_trained &&
|
|
64
|
+
(static_cast<size_t>(quantizer->ntotal) == nlist)) {
|
|
63
65
|
if (verbose) {
|
|
64
66
|
printf("IVF quantizer does not need training.\n");
|
|
65
67
|
}
|
|
@@ -70,14 +72,14 @@ void Level1Quantizer::train_q1(
|
|
|
70
72
|
quantizer->verbose = verbose;
|
|
71
73
|
quantizer->train(n, x);
|
|
72
74
|
FAISS_THROW_IF_NOT_MSG(
|
|
73
|
-
quantizer->ntotal == nlist,
|
|
75
|
+
static_cast<size_t>(quantizer->ntotal) == nlist,
|
|
74
76
|
"nlist not consistent with quantizer size");
|
|
75
77
|
} else if (quantizer_trains_alone == 0) {
|
|
76
78
|
if (verbose) {
|
|
77
79
|
printf("Training level-1 quantizer on %zd vectors in %zdD\n", n, d);
|
|
78
80
|
}
|
|
79
81
|
|
|
80
|
-
Clustering clus(d, nlist, cp);
|
|
82
|
+
Clustering clus(static_cast<int>(d), static_cast<int>(nlist), cp);
|
|
81
83
|
quantizer->reset();
|
|
82
84
|
if (clustering_index) {
|
|
83
85
|
clus.train(n, x, *clustering_index);
|
|
@@ -99,7 +101,7 @@ void Level1Quantizer::train_q1(
|
|
|
99
101
|
metric_type == METRIC_L2 ||
|
|
100
102
|
(metric_type == METRIC_INNER_PRODUCT && cp.spherical));
|
|
101
103
|
|
|
102
|
-
Clustering clus(d, nlist, cp);
|
|
104
|
+
Clustering clus(static_cast<int>(d), static_cast<int>(nlist), cp);
|
|
103
105
|
if (!clustering_index) {
|
|
104
106
|
IndexFlatL2 assigner(d);
|
|
105
107
|
clus.train(n, x, assigner);
|
|
@@ -148,7 +150,7 @@ idx_t Level1Quantizer::decode_listno(const uint8_t* code) const {
|
|
|
148
150
|
nbit += 8;
|
|
149
151
|
nl >>= 8;
|
|
150
152
|
}
|
|
151
|
-
FAISS_THROW_IF_NOT(list_no >= 0 && list_no < nlist);
|
|
153
|
+
FAISS_THROW_IF_NOT(list_no >= 0 && static_cast<size_t>(list_no) < nlist);
|
|
152
154
|
return list_no;
|
|
153
155
|
}
|
|
154
156
|
|
|
@@ -157,21 +159,23 @@ idx_t Level1Quantizer::decode_listno(const uint8_t* code) const {
|
|
|
157
159
|
******************************************/
|
|
158
160
|
|
|
159
161
|
IndexIVF::IndexIVF(
|
|
160
|
-
Index*
|
|
161
|
-
size_t
|
|
162
|
-
size_t
|
|
163
|
-
size_t
|
|
162
|
+
Index* quantizer_in,
|
|
163
|
+
size_t d_in,
|
|
164
|
+
size_t nlist_in,
|
|
165
|
+
size_t code_size_in,
|
|
164
166
|
MetricType metric,
|
|
165
|
-
bool
|
|
166
|
-
: Index(
|
|
167
|
-
IndexIVFInterface(
|
|
167
|
+
bool own_invlists_in)
|
|
168
|
+
: Index(d_in, metric),
|
|
169
|
+
IndexIVFInterface(quantizer_in, nlist_in),
|
|
168
170
|
invlists(
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
171
|
+
own_invlists_in
|
|
172
|
+
? new ArrayInvertedLists(nlist_in, code_size_in)
|
|
173
|
+
: nullptr),
|
|
174
|
+
own_invlists(own_invlists_in),
|
|
175
|
+
code_size(code_size_in) {
|
|
176
|
+
FAISS_THROW_IF_NOT(static_cast<int>(d_in) == quantizer_in->d);
|
|
177
|
+
is_trained = quantizer_in->is_trained &&
|
|
178
|
+
(static_cast<size_t>(quantizer_in->ntotal) == nlist_in);
|
|
175
179
|
// Spherical by default if the metric is inner_product
|
|
176
180
|
if (metric_type == METRIC_INNER_PRODUCT) {
|
|
177
181
|
cp.spherical = true;
|
|
@@ -185,6 +189,8 @@ void IndexIVF::add(idx_t n, const float* x) {
|
|
|
185
189
|
}
|
|
186
190
|
|
|
187
191
|
void IndexIVF::add_with_ids(idx_t n, const float* x, const idx_t* xids) {
|
|
192
|
+
FAISS_THROW_IF_NOT_MSG(quantizer, "IVF quantizer must not be null");
|
|
193
|
+
FAISS_THROW_IF_NOT_MSG(invlists, "IVF index has no inverted lists");
|
|
188
194
|
std::unique_ptr<idx_t[]> coarse_idx(new idx_t[n]);
|
|
189
195
|
quantizer->assign(n, x, coarse_idx.get());
|
|
190
196
|
add_core(n, x, xids, coarse_idx.get());
|
|
@@ -235,7 +241,7 @@ void IndexIVF::add_core(
|
|
|
235
241
|
|
|
236
242
|
size_t nadd = 0, nminus1 = 0;
|
|
237
243
|
|
|
238
|
-
for (
|
|
244
|
+
for (idx_t i = 0; i < n; i++) {
|
|
239
245
|
if (coarse_idx[i] < 0) {
|
|
240
246
|
nminus1++;
|
|
241
247
|
}
|
|
@@ -252,7 +258,7 @@ void IndexIVF::add_core(
|
|
|
252
258
|
int rank = omp_get_thread_num();
|
|
253
259
|
|
|
254
260
|
// each thread takes care of a subset of lists
|
|
255
|
-
for (
|
|
261
|
+
for (idx_t i = 0; i < n; i++) {
|
|
256
262
|
idx_t list_no = coarse_idx[i];
|
|
257
263
|
if (list_no >= 0 && list_no % nt == rank) {
|
|
258
264
|
idx_t id = xids ? xids[i] : ntotal + i;
|
|
@@ -305,45 +311,49 @@ void IndexIVF::search(
|
|
|
305
311
|
idx_t* labels,
|
|
306
312
|
const SearchParameters* params_in) const {
|
|
307
313
|
FAISS_THROW_IF_NOT(k > 0);
|
|
314
|
+
FAISS_THROW_IF_NOT_MSG(quantizer, "IVF quantizer must not be null");
|
|
315
|
+
FAISS_THROW_IF_NOT_MSG(is_trained, "IVF index is not trained");
|
|
316
|
+
FAISS_THROW_IF_NOT_MSG(invlists, "IVF index has no inverted lists");
|
|
308
317
|
const IVFSearchParameters* params = nullptr;
|
|
309
318
|
if (params_in) {
|
|
310
319
|
params = dynamic_cast<const IVFSearchParameters*>(params_in);
|
|
311
320
|
FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
|
|
312
321
|
}
|
|
313
|
-
const size_t
|
|
322
|
+
const size_t cur_nprobe =
|
|
314
323
|
std::min(nlist, params ? params->nprobe : this->nprobe);
|
|
315
|
-
FAISS_THROW_IF_NOT(
|
|
324
|
+
FAISS_THROW_IF_NOT(cur_nprobe > 0);
|
|
316
325
|
|
|
317
326
|
// search function for a subset of queries
|
|
318
|
-
auto sub_search_func = [this, k,
|
|
319
|
-
idx_t
|
|
320
|
-
const float*
|
|
321
|
-
float*
|
|
322
|
-
idx_t*
|
|
327
|
+
auto sub_search_func = [this, k, cur_nprobe, params](
|
|
328
|
+
idx_t sub_n,
|
|
329
|
+
const float* sub_x,
|
|
330
|
+
float* sub_distances,
|
|
331
|
+
idx_t* sub_labels,
|
|
323
332
|
IndexIVFStats* ivf_stats) {
|
|
324
|
-
std::unique_ptr<idx_t[]> idx(new idx_t[
|
|
325
|
-
std::unique_ptr<float[]> coarse_dis(new float[
|
|
333
|
+
std::unique_ptr<idx_t[]> idx(new idx_t[sub_n * cur_nprobe]);
|
|
334
|
+
std::unique_ptr<float[]> coarse_dis(new float[sub_n * cur_nprobe]);
|
|
326
335
|
|
|
327
336
|
double t0 = getmillisecs();
|
|
328
337
|
quantizer->search(
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
338
|
+
sub_n,
|
|
339
|
+
sub_x,
|
|
340
|
+
cur_nprobe,
|
|
332
341
|
coarse_dis.get(),
|
|
333
342
|
idx.get(),
|
|
334
343
|
params ? params->quantizer_params : nullptr);
|
|
335
344
|
|
|
336
345
|
double t1 = getmillisecs();
|
|
337
|
-
invlists->prefetch_lists(
|
|
346
|
+
invlists->prefetch_lists(
|
|
347
|
+
idx.get(), static_cast<int>(sub_n * cur_nprobe));
|
|
338
348
|
|
|
339
349
|
search_preassigned(
|
|
340
|
-
|
|
341
|
-
|
|
350
|
+
sub_n,
|
|
351
|
+
sub_x,
|
|
342
352
|
k,
|
|
343
353
|
idx.get(),
|
|
344
354
|
coarse_dis.get(),
|
|
345
|
-
|
|
346
|
-
|
|
355
|
+
sub_distances,
|
|
356
|
+
sub_labels,
|
|
347
357
|
false,
|
|
348
358
|
params,
|
|
349
359
|
ivf_stats);
|
|
@@ -355,32 +365,28 @@ void IndexIVF::search(
|
|
|
355
365
|
if ((parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT) == 0) {
|
|
356
366
|
int nt = std::min(omp_get_max_threads(), int(n));
|
|
357
367
|
std::vector<IndexIVFStats> stats(nt);
|
|
358
|
-
std::
|
|
359
|
-
std::string exception_string;
|
|
368
|
+
std::exception_ptr ex;
|
|
360
369
|
|
|
361
370
|
#pragma omp parallel for if (nt > 1)
|
|
362
371
|
for (idx_t slice = 0; slice < nt; slice++) {
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
372
|
+
try {
|
|
373
|
+
IndexIVFStats local_stats;
|
|
374
|
+
idx_t i0 = n * slice / nt;
|
|
375
|
+
idx_t i1 = n * (slice + 1) / nt;
|
|
376
|
+
if (i1 > i0) {
|
|
368
377
|
sub_search_func(
|
|
369
378
|
i1 - i0,
|
|
370
379
|
x + i0 * d,
|
|
371
380
|
distances + i0 * k,
|
|
372
381
|
labels + i0 * k,
|
|
373
382
|
&stats[slice]);
|
|
374
|
-
} catch (const std::exception& e) {
|
|
375
|
-
std::lock_guard<std::mutex> lock(exception_mutex);
|
|
376
|
-
exception_string = e.what();
|
|
377
383
|
}
|
|
384
|
+
} catch (...) {
|
|
385
|
+
omp_capture_exception(ex);
|
|
378
386
|
}
|
|
379
387
|
}
|
|
380
388
|
|
|
381
|
-
|
|
382
|
-
FAISS_THROW_MSG(exception_string.c_str());
|
|
383
|
-
}
|
|
389
|
+
omp_rethrow_if_exception(ex);
|
|
384
390
|
|
|
385
391
|
// collect stats
|
|
386
392
|
for (idx_t slice = 0; slice < nt; slice++) {
|
|
@@ -405,13 +411,17 @@ void IndexIVF::search_preassigned(
|
|
|
405
411
|
const IVFSearchParameters* params,
|
|
406
412
|
IndexIVFStats* ivf_stats) const {
|
|
407
413
|
FAISS_THROW_IF_NOT(k > 0);
|
|
414
|
+
FAISS_THROW_IF_NOT_MSG(is_trained, "IVF index is not trained");
|
|
415
|
+
FAISS_THROW_IF_NOT_MSG(invlists, "IVF index has no inverted lists");
|
|
408
416
|
|
|
409
|
-
idx_t
|
|
410
|
-
|
|
411
|
-
FAISS_THROW_IF_NOT(
|
|
417
|
+
idx_t cur_nprobe = params ? params->nprobe : this->nprobe;
|
|
418
|
+
cur_nprobe = std::min((idx_t)nlist, cur_nprobe);
|
|
419
|
+
FAISS_THROW_IF_NOT(cur_nprobe > 0);
|
|
412
420
|
|
|
413
421
|
const idx_t unlimited_list_size = std::numeric_limits<idx_t>::max();
|
|
414
|
-
idx_t
|
|
422
|
+
idx_t cur_max_codes = params ? params->max_codes : this->max_codes;
|
|
423
|
+
const bool ensure_topk_full = params ? params->ensure_topk_full : false;
|
|
424
|
+
|
|
415
425
|
IDSelector* sel = params ? params->sel : nullptr;
|
|
416
426
|
const IDSelectorRange* selr = dynamic_cast<const IDSelectorRange*>(sel);
|
|
417
427
|
if (selr) {
|
|
@@ -427,7 +437,8 @@ void IndexIVF::search_preassigned(
|
|
|
427
437
|
"selector and store_pairs cannot be combined");
|
|
428
438
|
|
|
429
439
|
FAISS_THROW_IF_NOT_MSG(
|
|
430
|
-
!invlists->use_iterator ||
|
|
440
|
+
!invlists->use_iterator ||
|
|
441
|
+
(cur_max_codes == 0 && store_pairs == false),
|
|
431
442
|
"iterable inverted lists don't support max_codes and store_pairs");
|
|
432
443
|
|
|
433
444
|
size_t nlistv = 0, ndis = 0, nheap = 0;
|
|
@@ -435,106 +446,119 @@ void IndexIVF::search_preassigned(
|
|
|
435
446
|
using HeapForIP = CMin<float, idx_t>;
|
|
436
447
|
using HeapForL2 = CMax<float, idx_t>;
|
|
437
448
|
|
|
438
|
-
|
|
439
|
-
std::
|
|
440
|
-
std::string exception_string;
|
|
449
|
+
std::exception_ptr ex;
|
|
450
|
+
std::atomic<bool> interrupt{false};
|
|
441
451
|
|
|
442
452
|
int pmode = this->parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT;
|
|
443
453
|
bool do_heap_init = !(this->parallel_mode & PARALLEL_MODE_NO_HEAP_INIT);
|
|
444
454
|
|
|
445
455
|
FAISS_THROW_IF_NOT_MSG(
|
|
446
|
-
|
|
456
|
+
cur_max_codes == 0 || pmode == 0 || pmode == 3,
|
|
447
457
|
"max_codes supported only for parallel_mode = 0 or 3");
|
|
448
458
|
|
|
449
|
-
|
|
450
|
-
|
|
459
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
460
|
+
!ensure_topk_full || pmode == 0 || pmode == 3,
|
|
461
|
+
"ensure_topk_full supported only for parallel_mode = 0 or 3");
|
|
462
|
+
|
|
463
|
+
if (cur_max_codes == 0) {
|
|
464
|
+
cur_max_codes = unlimited_list_size;
|
|
451
465
|
}
|
|
466
|
+
// Budget used by the probe loop below. ensure_topk_full makes a small
|
|
467
|
+
// max_codes budget large enough to give k post-filter candidates a chance.
|
|
468
|
+
idx_t effective_max_codes =
|
|
469
|
+
ensure_topk_full ? std::max(cur_max_codes, k) : cur_max_codes;
|
|
452
470
|
|
|
453
471
|
[[maybe_unused]] bool do_parallel = omp_get_max_threads() >= 2 &&
|
|
454
472
|
(pmode == 0 ? false
|
|
455
473
|
: pmode == 3 ? n > 1
|
|
456
|
-
: pmode == 1 ?
|
|
457
|
-
:
|
|
474
|
+
: pmode == 1 ? cur_nprobe > 1
|
|
475
|
+
: cur_nprobe * n > 1);
|
|
458
476
|
|
|
459
477
|
void* inverted_list_context =
|
|
460
478
|
params ? params->inverted_list_context : nullptr;
|
|
461
479
|
|
|
462
480
|
#pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis, nheap)
|
|
463
481
|
{
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
482
|
+
// C++ exceptions that escape an OpenMP parallel region without being
|
|
483
|
+
// caught inside it call std::terminate — they cannot propagate across
|
|
484
|
+
// thread boundaries. The outer try/catch covers per-thread setup
|
|
485
|
+
// (scanner creation, set_query); the inner try/catch in scan_one_list
|
|
486
|
+
// covers per-list operations. Both set interrupt=true to stop further
|
|
487
|
+
// work and re-throw after the parallel region exits.
|
|
488
|
+
try {
|
|
489
|
+
std::unique_ptr<InvertedListScanner> scanner(
|
|
490
|
+
get_InvertedListScanner(store_pairs, sel, params));
|
|
491
|
+
|
|
492
|
+
/*****************************************************
|
|
493
|
+
* Depending on parallel_mode, there are two possible ways
|
|
494
|
+
* to organize the search. Here we define local functions
|
|
495
|
+
* that are in common between the two
|
|
496
|
+
******************************************************/
|
|
497
|
+
|
|
498
|
+
// initialize + reorder a result heap
|
|
499
|
+
|
|
500
|
+
auto init_result = [&](float* simi, idx_t* idxi) {
|
|
501
|
+
if (!do_heap_init) {
|
|
502
|
+
return;
|
|
503
|
+
}
|
|
504
|
+
if (metric_type == METRIC_INNER_PRODUCT) {
|
|
505
|
+
heap_heapify<HeapForIP>(k, simi, idxi);
|
|
506
|
+
} else {
|
|
507
|
+
heap_heapify<HeapForL2>(k, simi, idxi);
|
|
508
|
+
}
|
|
509
|
+
};
|
|
510
|
+
|
|
511
|
+
auto add_local_results = [&](const float* local_dis,
|
|
512
|
+
const idx_t* local_idx,
|
|
513
|
+
float* simi,
|
|
514
|
+
idx_t* idxi) {
|
|
515
|
+
if (metric_type == METRIC_INNER_PRODUCT) {
|
|
516
|
+
heap_addn<HeapForIP>(
|
|
517
|
+
k, simi, idxi, local_dis, local_idx, k);
|
|
518
|
+
} else {
|
|
519
|
+
heap_addn<HeapForL2>(
|
|
520
|
+
k, simi, idxi, local_dis, local_idx, k);
|
|
521
|
+
}
|
|
522
|
+
};
|
|
474
523
|
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
524
|
+
auto reorder_result = [&](float* simi, idx_t* idxi) {
|
|
525
|
+
if (!do_heap_init) {
|
|
526
|
+
return;
|
|
527
|
+
}
|
|
528
|
+
if (metric_type == METRIC_INNER_PRODUCT) {
|
|
529
|
+
heap_reorder<HeapForIP>(k, simi, idxi);
|
|
530
|
+
} else {
|
|
531
|
+
heap_reorder<HeapForL2>(k, simi, idxi);
|
|
532
|
+
}
|
|
533
|
+
};
|
|
485
534
|
|
|
486
|
-
|
|
487
|
-
|
|
535
|
+
// single list scan using the current scanner (with query
|
|
536
|
+
// set properly) and storing results in simi and idxi
|
|
537
|
+
auto scan_one_list = [&](idx_t key,
|
|
538
|
+
float coarse_dis_i,
|
|
488
539
|
float* simi,
|
|
489
|
-
idx_t* idxi
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
}
|
|
506
|
-
};
|
|
507
|
-
|
|
508
|
-
// single list scan using the current scanner (with query
|
|
509
|
-
// set properly) and storing results in simi and idxi
|
|
510
|
-
auto scan_one_list = [&](idx_t key,
|
|
511
|
-
float coarse_dis_i,
|
|
512
|
-
float* simi,
|
|
513
|
-
idx_t* idxi,
|
|
514
|
-
idx_t list_size_max) {
|
|
515
|
-
if (key < 0) {
|
|
516
|
-
// not enough centroids for multiprobe
|
|
517
|
-
return (size_t)0;
|
|
518
|
-
}
|
|
519
|
-
FAISS_THROW_IF_NOT_FMT(
|
|
520
|
-
key < (idx_t)nlist,
|
|
521
|
-
"Invalid key=%" PRId64 " nlist=%zd\n",
|
|
522
|
-
key,
|
|
523
|
-
nlist);
|
|
524
|
-
|
|
525
|
-
// don't waste time on empty lists
|
|
526
|
-
if (invlists->is_empty(key, inverted_list_context)) {
|
|
527
|
-
return (size_t)0;
|
|
528
|
-
}
|
|
529
|
-
|
|
530
|
-
scanner->set_list(key, coarse_dis_i);
|
|
540
|
+
idx_t* idxi,
|
|
541
|
+
idx_t list_size_max) {
|
|
542
|
+
if (key < 0) {
|
|
543
|
+
// not enough centroids for multiprobe
|
|
544
|
+
return (size_t)0;
|
|
545
|
+
}
|
|
546
|
+
FAISS_THROW_IF_NOT_FMT(
|
|
547
|
+
key < (idx_t)nlist,
|
|
548
|
+
"Invalid key=%" PRId64 " nlist=%zd\n",
|
|
549
|
+
key,
|
|
550
|
+
nlist);
|
|
551
|
+
|
|
552
|
+
// don't waste time on empty lists
|
|
553
|
+
if (invlists->is_empty(key, inverted_list_context)) {
|
|
554
|
+
return (size_t)0;
|
|
555
|
+
}
|
|
531
556
|
|
|
532
|
-
|
|
557
|
+
scanner->set_list(key, coarse_dis_i);
|
|
533
558
|
|
|
534
|
-
|
|
559
|
+
nlistv++;
|
|
535
560
|
if (invlists->use_iterator) {
|
|
536
561
|
size_t list_size = 0;
|
|
537
|
-
|
|
538
562
|
std::unique_ptr<InvertedListsIterator> it(
|
|
539
563
|
invlists->get_iterator(key, inverted_list_context));
|
|
540
564
|
|
|
@@ -544,8 +568,8 @@ void IndexIVF::search_preassigned(
|
|
|
544
568
|
return list_size;
|
|
545
569
|
} else {
|
|
546
570
|
size_t list_size = invlists->list_size(key);
|
|
547
|
-
if (list_size > list_size_max) {
|
|
548
|
-
list_size = list_size_max;
|
|
571
|
+
if (list_size > static_cast<size_t>(list_size_max)) {
|
|
572
|
+
list_size = static_cast<size_t>(list_size_max);
|
|
549
573
|
}
|
|
550
574
|
|
|
551
575
|
InvertedLists::ScopedCodes scodes(invlists, key);
|
|
@@ -573,144 +597,167 @@ void IndexIVF::search_preassigned(
|
|
|
573
597
|
ids += jmin;
|
|
574
598
|
}
|
|
575
599
|
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
600
|
+
size_t old_scan_cnt = 0;
|
|
601
|
+
size_t old_heap_updates = 0;
|
|
602
|
+
if (metric_type == METRIC_INNER_PRODUCT) {
|
|
603
|
+
HeapResultHandler<HeapForIP, false> handler(
|
|
604
|
+
k, simi, idxi);
|
|
605
|
+
old_scan_cnt = handler.stats.scan_cnt;
|
|
606
|
+
old_heap_updates = handler.stats.nheap_updates;
|
|
607
|
+
scanner->scan_codes(list_size, codes, ids, handler);
|
|
608
|
+
nheap += handler.stats.nheap_updates - old_heap_updates;
|
|
609
|
+
return handler.stats.scan_cnt - old_scan_cnt;
|
|
610
|
+
} else {
|
|
611
|
+
HeapResultHandler<HeapForL2, false> handler(
|
|
612
|
+
k, simi, idxi);
|
|
613
|
+
old_scan_cnt = handler.stats.scan_cnt;
|
|
614
|
+
old_heap_updates = handler.stats.nheap_updates;
|
|
615
|
+
scanner->scan_codes(list_size, codes, ids, handler);
|
|
616
|
+
nheap += handler.stats.nheap_updates - old_heap_updates;
|
|
617
|
+
return handler.stats.scan_cnt - old_scan_cnt;
|
|
618
|
+
}
|
|
580
619
|
}
|
|
581
|
-
}
|
|
582
|
-
std::lock_guard<std::mutex> lock(exception_mutex);
|
|
583
|
-
exception_string =
|
|
584
|
-
demangle_cpp_symbol(typeid(e).name()) + " " + e.what();
|
|
585
|
-
interrupt = true;
|
|
586
|
-
return size_t(0);
|
|
587
|
-
}
|
|
588
|
-
};
|
|
620
|
+
};
|
|
589
621
|
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
622
|
+
/****************************************************
|
|
623
|
+
* Actual loops, depending on parallel_mode
|
|
624
|
+
****************************************************/
|
|
593
625
|
|
|
594
|
-
|
|
626
|
+
if (pmode == 0 || pmode == 3) {
|
|
595
627
|
#pragma omp for
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
}
|
|
600
|
-
|
|
601
|
-
// loop over queries
|
|
602
|
-
scanner->set_query(x + i * d);
|
|
603
|
-
float* simi = distances + i * k;
|
|
604
|
-
idx_t* idxi = labels + i * k;
|
|
605
|
-
|
|
606
|
-
init_result(simi, idxi);
|
|
607
|
-
|
|
608
|
-
idx_t nscan = 0;
|
|
609
|
-
|
|
610
|
-
// loop over probes
|
|
611
|
-
for (size_t ik = 0; ik < nprobe; ik++) {
|
|
612
|
-
nscan += scan_one_list(
|
|
613
|
-
keys[i * nprobe + ik],
|
|
614
|
-
coarse_dis[i * nprobe + ik],
|
|
615
|
-
simi,
|
|
616
|
-
idxi,
|
|
617
|
-
max_codes - nscan);
|
|
618
|
-
if (nscan >= max_codes) {
|
|
619
|
-
break;
|
|
628
|
+
for (idx_t i = 0; i < n; i++) {
|
|
629
|
+
if (interrupt.load(std::memory_order_relaxed)) {
|
|
630
|
+
continue;
|
|
620
631
|
}
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
632
|
+
try {
|
|
633
|
+
// loop over queries
|
|
634
|
+
scanner->set_query(x + i * d);
|
|
635
|
+
float* simi = distances + i * k;
|
|
636
|
+
idx_t* idxi = labels + i * k;
|
|
637
|
+
|
|
638
|
+
init_result(simi, idxi);
|
|
639
|
+
|
|
640
|
+
idx_t nscan = 0;
|
|
641
|
+
|
|
642
|
+
// loop over probes
|
|
643
|
+
for (idx_t ik = 0; ik < cur_nprobe; ik++) {
|
|
644
|
+
// For soft budgets, scan whole lists so
|
|
645
|
+
// IDSelector-filtered rows do not consume the
|
|
646
|
+
// remaining code budget.
|
|
647
|
+
const idx_t list_size_max = ensure_topk_full
|
|
648
|
+
? unlimited_list_size
|
|
649
|
+
: effective_max_codes - nscan;
|
|
650
|
+
nscan += scan_one_list(
|
|
651
|
+
keys[i * cur_nprobe + ik],
|
|
652
|
+
coarse_dis[i * cur_nprobe + ik],
|
|
653
|
+
simi,
|
|
654
|
+
idxi,
|
|
655
|
+
list_size_max);
|
|
656
|
+
|
|
657
|
+
// Early-stop check: apply max_codes after each
|
|
658
|
+
// list. nscan is the number of distances
|
|
659
|
+
// actually computed.
|
|
660
|
+
if (nscan >= effective_max_codes) {
|
|
661
|
+
break;
|
|
662
|
+
}
|
|
663
|
+
}
|
|
625
664
|
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
}
|
|
665
|
+
ndis += nscan;
|
|
666
|
+
reorder_result(simi, idxi);
|
|
629
667
|
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
668
|
+
InterruptCallback::check();
|
|
669
|
+
} catch (...) {
|
|
670
|
+
omp_capture_exception(ex, [&] { interrupt = true; });
|
|
671
|
+
}
|
|
672
|
+
} // parallel for
|
|
673
|
+
} else if (pmode == 1) {
|
|
674
|
+
std::vector<idx_t> local_idx(k);
|
|
675
|
+
std::vector<float> local_dis(k);
|
|
634
676
|
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
677
|
+
for (idx_t i = 0; i < n; i++) {
|
|
678
|
+
scanner->set_query(x + i * d);
|
|
679
|
+
init_result(local_dis.data(), local_idx.data());
|
|
638
680
|
|
|
639
681
|
#pragma omp for schedule(dynamic)
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
682
|
+
for (idx_t ik = 0; ik < cur_nprobe; ik++) {
|
|
683
|
+
try {
|
|
684
|
+
ndis += scan_one_list(
|
|
685
|
+
keys[i * cur_nprobe + ik],
|
|
686
|
+
coarse_dis[i * cur_nprobe + ik],
|
|
687
|
+
local_dis.data(),
|
|
688
|
+
local_idx.data(),
|
|
689
|
+
unlimited_list_size);
|
|
690
|
+
|
|
691
|
+
// can't do the test on max_codes
|
|
692
|
+
} catch (...) {
|
|
693
|
+
omp_capture_exception(
|
|
694
|
+
ex, [&] { interrupt = true; });
|
|
695
|
+
}
|
|
696
|
+
}
|
|
697
|
+
// merge thread-local results
|
|
651
698
|
|
|
652
|
-
|
|
653
|
-
|
|
699
|
+
float* simi = distances + i * k;
|
|
700
|
+
idx_t* idxi = labels + i * k;
|
|
654
701
|
#pragma omp single
|
|
655
|
-
|
|
702
|
+
init_result(simi, idxi);
|
|
656
703
|
|
|
657
704
|
#pragma omp barrier
|
|
658
705
|
#pragma omp critical
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
706
|
+
{
|
|
707
|
+
add_local_results(
|
|
708
|
+
local_dis.data(), local_idx.data(), simi, idxi);
|
|
709
|
+
}
|
|
663
710
|
#pragma omp barrier
|
|
664
711
|
#pragma omp single
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
712
|
+
reorder_result(simi, idxi);
|
|
713
|
+
}
|
|
714
|
+
} else if (pmode == 2) {
|
|
715
|
+
std::vector<idx_t> local_idx(k);
|
|
716
|
+
std::vector<float> local_dis(k);
|
|
670
717
|
|
|
671
718
|
#pragma omp single
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
719
|
+
for (int64_t i = 0; i < n; i++) {
|
|
720
|
+
init_result(distances + i * k, labels + i * k);
|
|
721
|
+
}
|
|
675
722
|
|
|
676
723
|
#pragma omp for schedule(dynamic)
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
724
|
+
for (int64_t ij = 0; ij < n * cur_nprobe; ij++) {
|
|
725
|
+
try {
|
|
726
|
+
size_t i = ij / cur_nprobe;
|
|
727
|
+
|
|
728
|
+
scanner->set_query(x + i * d);
|
|
729
|
+
init_result(local_dis.data(), local_idx.data());
|
|
730
|
+
ndis += scan_one_list(
|
|
731
|
+
keys[ij],
|
|
732
|
+
coarse_dis[ij],
|
|
733
|
+
local_dis.data(),
|
|
734
|
+
local_idx.data(),
|
|
735
|
+
unlimited_list_size);
|
|
688
736
|
#pragma omp critical
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
737
|
+
{
|
|
738
|
+
add_local_results(
|
|
739
|
+
local_dis.data(),
|
|
740
|
+
local_idx.data(),
|
|
741
|
+
distances + i * k,
|
|
742
|
+
labels + i * k);
|
|
743
|
+
}
|
|
744
|
+
} catch (...) {
|
|
745
|
+
omp_capture_exception(ex, [&] { interrupt = true; });
|
|
746
|
+
}
|
|
695
747
|
}
|
|
696
|
-
}
|
|
697
748
|
#pragma omp single
|
|
698
|
-
|
|
699
|
-
|
|
749
|
+
for (int64_t i = 0; i < n; i++) {
|
|
750
|
+
reorder_result(distances + i * k, labels + i * k);
|
|
751
|
+
}
|
|
752
|
+
} else {
|
|
753
|
+
FAISS_THROW_FMT("parallel_mode %d not supported\n", pmode);
|
|
700
754
|
}
|
|
701
|
-
}
|
|
702
|
-
|
|
755
|
+
} catch (...) {
|
|
756
|
+
omp_capture_exception(ex, [&] { interrupt = true; });
|
|
703
757
|
}
|
|
704
758
|
} // parallel section
|
|
705
759
|
|
|
706
|
-
|
|
707
|
-
if (!exception_string.empty()) {
|
|
708
|
-
FAISS_THROW_FMT(
|
|
709
|
-
"search interrupted with: %s", exception_string.c_str());
|
|
710
|
-
} else {
|
|
711
|
-
FAISS_THROW_MSG("computation interrupted");
|
|
712
|
-
}
|
|
713
|
-
}
|
|
760
|
+
omp_rethrow_if_exception(ex);
|
|
714
761
|
|
|
715
762
|
if (ivf_stats == nullptr) {
|
|
716
763
|
ivf_stats = &indexIVF_stats;
|
|
@@ -727,6 +774,8 @@ void IndexIVF::range_search(
|
|
|
727
774
|
float radius,
|
|
728
775
|
RangeSearchResult* result,
|
|
729
776
|
const SearchParameters* params_in) const {
|
|
777
|
+
FAISS_THROW_IF_NOT_MSG(quantizer, "IVF quantizer must not be null");
|
|
778
|
+
FAISS_THROW_IF_NOT_MSG(is_trained, "IVF index is not trained");
|
|
730
779
|
const IVFSearchParameters* params = nullptr;
|
|
731
780
|
const SearchParameters* quantizer_params = nullptr;
|
|
732
781
|
if (params_in) {
|
|
@@ -734,18 +783,18 @@ void IndexIVF::range_search(
|
|
|
734
783
|
FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
|
|
735
784
|
quantizer_params = params->quantizer_params;
|
|
736
785
|
}
|
|
737
|
-
const size_t
|
|
786
|
+
const size_t cur_nprobe =
|
|
738
787
|
std::min(nlist, params ? params->nprobe : this->nprobe);
|
|
739
|
-
std::unique_ptr<idx_t[]> keys(new idx_t[nx *
|
|
740
|
-
std::unique_ptr<float[]> coarse_dis(new float[nx *
|
|
788
|
+
std::unique_ptr<idx_t[]> keys(new idx_t[nx * cur_nprobe]);
|
|
789
|
+
std::unique_ptr<float[]> coarse_dis(new float[nx * cur_nprobe]);
|
|
741
790
|
|
|
742
791
|
double t0 = getmillisecs();
|
|
743
792
|
quantizer->search(
|
|
744
|
-
nx, x,
|
|
793
|
+
nx, x, cur_nprobe, coarse_dis.get(), keys.get(), quantizer_params);
|
|
745
794
|
indexIVF_stats.quantization_time += getmillisecs() - t0;
|
|
746
795
|
|
|
747
796
|
t0 = getmillisecs();
|
|
748
|
-
invlists->prefetch_lists(keys.get(), nx *
|
|
797
|
+
invlists->prefetch_lists(keys.get(), static_cast<int>(nx * cur_nprobe));
|
|
749
798
|
|
|
750
799
|
range_search_preassigned(
|
|
751
800
|
nx,
|
|
@@ -771,22 +820,29 @@ void IndexIVF::range_search_preassigned(
|
|
|
771
820
|
bool store_pairs,
|
|
772
821
|
const IVFSearchParameters* params,
|
|
773
822
|
IndexIVFStats* stats) const {
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
823
|
+
FAISS_THROW_IF_NOT_MSG(is_trained, "IVF index is not trained");
|
|
824
|
+
idx_t cur_nprobe = params ? params->nprobe : this->nprobe;
|
|
825
|
+
cur_nprobe = std::min((idx_t)nlist, cur_nprobe);
|
|
826
|
+
FAISS_THROW_IF_NOT(cur_nprobe > 0);
|
|
827
|
+
|
|
828
|
+
idx_t cur_max_codes = params ? params->max_codes : this->max_codes;
|
|
829
|
+
// Range-search early-stop budget. 0 disables the empty-bucket stop.
|
|
830
|
+
const size_t max_empty_result_buckets =
|
|
831
|
+
params ? params->max_empty_result_buckets : 0;
|
|
779
832
|
IDSelector* sel = params ? params->sel : nullptr;
|
|
780
833
|
|
|
781
834
|
FAISS_THROW_IF_NOT_MSG(
|
|
782
|
-
!invlists->use_iterator ||
|
|
835
|
+
!invlists->use_iterator ||
|
|
836
|
+
(cur_max_codes == 0 && store_pairs == false),
|
|
783
837
|
"iterable inverted lists don't support max_codes and store_pairs");
|
|
784
838
|
|
|
839
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
840
|
+
max_empty_result_buckets == 0 || parallel_mode == 0,
|
|
841
|
+
"max_empty_result_buckets supported only for parallel_mode = 0");
|
|
842
|
+
|
|
785
843
|
size_t nlistv = 0, ndis = 0;
|
|
786
844
|
|
|
787
|
-
|
|
788
|
-
std::mutex exception_mutex;
|
|
789
|
-
std::string exception_string;
|
|
845
|
+
std::exception_ptr ex;
|
|
790
846
|
|
|
791
847
|
std::vector<RangeSearchPartialResult*> all_pres(omp_get_max_threads());
|
|
792
848
|
|
|
@@ -795,122 +851,142 @@ void IndexIVF::range_search_preassigned(
|
|
|
795
851
|
[[maybe_unused]] bool do_parallel = omp_get_max_threads() >= 2 &&
|
|
796
852
|
(pmode == 3 ? false
|
|
797
853
|
: pmode == 0 ? nx > 1
|
|
798
|
-
: pmode == 1 ?
|
|
799
|
-
:
|
|
854
|
+
: pmode == 1 ? cur_nprobe > 1
|
|
855
|
+
: cur_nprobe * nx > 1);
|
|
800
856
|
|
|
801
857
|
void* inverted_list_context =
|
|
802
858
|
params ? params->inverted_list_context : nullptr;
|
|
803
859
|
|
|
804
860
|
#pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis)
|
|
805
861
|
{
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
key,
|
|
823
|
-
ik,
|
|
824
|
-
nlist);
|
|
825
|
-
|
|
826
|
-
if (invlists->is_empty(key, inverted_list_context)) {
|
|
827
|
-
return;
|
|
828
|
-
}
|
|
862
|
+
try {
|
|
863
|
+
RangeSearchPartialResult pres(result);
|
|
864
|
+
std::unique_ptr<InvertedListScanner> scanner(
|
|
865
|
+
get_InvertedListScanner(store_pairs, sel, params));
|
|
866
|
+
FAISS_THROW_IF_NOT(scanner.get());
|
|
867
|
+
all_pres[omp_get_thread_num()] = &pres;
|
|
868
|
+
|
|
869
|
+
// prepare the list scanning function
|
|
870
|
+
|
|
871
|
+
auto scan_list_func = [&](size_t i,
|
|
872
|
+
size_t ik,
|
|
873
|
+
RangeQueryResult& qres) {
|
|
874
|
+
idx_t key = keys[i * cur_nprobe + ik]; /* select the list */
|
|
875
|
+
if (key < 0) {
|
|
876
|
+
return;
|
|
877
|
+
}
|
|
829
878
|
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
879
|
+
FAISS_THROW_IF_NOT_FMT(
|
|
880
|
+
key < (idx_t)nlist,
|
|
881
|
+
"Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
|
|
882
|
+
key,
|
|
883
|
+
ik,
|
|
884
|
+
nlist);
|
|
885
|
+
|
|
886
|
+
if (invlists->is_empty(key, inverted_list_context)) {
|
|
887
|
+
return;
|
|
888
|
+
}
|
|
889
|
+
|
|
890
|
+
scanner->set_list(key, coarse_dis[i * cur_nprobe + ik]);
|
|
891
|
+
const size_t scan_cnt0 = qres.stats.scan_cnt;
|
|
833
892
|
if (invlists->use_iterator) {
|
|
893
|
+
size_t list_size = 0;
|
|
834
894
|
std::unique_ptr<InvertedListsIterator> it(
|
|
835
895
|
invlists->get_iterator(key, inverted_list_context));
|
|
836
896
|
|
|
837
897
|
scanner->iterate_codes_range(
|
|
838
898
|
it.get(), radius, qres, list_size);
|
|
899
|
+
qres.stats.scan_cnt += list_size;
|
|
839
900
|
} else {
|
|
840
901
|
InvertedLists::ScopedCodes scodes(invlists, key);
|
|
841
902
|
InvertedLists::ScopedIds ids(invlists, key);
|
|
842
|
-
list_size = invlists->list_size(key);
|
|
903
|
+
size_t list_size = invlists->list_size(key);
|
|
843
904
|
|
|
844
905
|
scanner->scan_codes_range(
|
|
845
906
|
list_size, scodes.get(), ids.get(), radius, qres);
|
|
846
907
|
}
|
|
847
908
|
nlistv++;
|
|
848
|
-
ndis +=
|
|
849
|
-
}
|
|
850
|
-
std::lock_guard<std::mutex> lock(exception_mutex);
|
|
851
|
-
exception_string =
|
|
852
|
-
demangle_cpp_symbol(typeid(e).name()) + " " + e.what();
|
|
853
|
-
interrupt = true;
|
|
854
|
-
}
|
|
855
|
-
};
|
|
909
|
+
ndis += qres.stats.scan_cnt - scan_cnt0;
|
|
910
|
+
};
|
|
856
911
|
|
|
857
|
-
|
|
912
|
+
if (parallel_mode == 0) {
|
|
858
913
|
#pragma omp for
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
914
|
+
for (idx_t i = 0; i < nx; i++) {
|
|
915
|
+
try {
|
|
916
|
+
scanner->set_query(x + i * d);
|
|
917
|
+
RangeQueryResult& qres = pres.new_result(i);
|
|
918
|
+
|
|
919
|
+
// Stop after enough consecutive probes add no range
|
|
920
|
+
// results. A hit resets the counter.
|
|
921
|
+
size_t prev_nres = qres.nres;
|
|
922
|
+
size_t ndup = 0;
|
|
923
|
+
for (idx_t ik = 0; ik < cur_nprobe; ik++) {
|
|
924
|
+
scan_list_func(i, ik, qres);
|
|
925
|
+
if (max_empty_result_buckets > 0) {
|
|
926
|
+
// Early-stop check: stop range search after
|
|
927
|
+
// enough consecutive empty probes.
|
|
928
|
+
ndup = (qres.nres == prev_nres) ? ndup + 1 : 0;
|
|
929
|
+
if (ndup >= max_empty_result_buckets) {
|
|
930
|
+
break;
|
|
931
|
+
}
|
|
932
|
+
prev_nres = qres.nres;
|
|
933
|
+
}
|
|
934
|
+
}
|
|
935
|
+
} catch (...) {
|
|
936
|
+
omp_capture_exception(ex);
|
|
937
|
+
}
|
|
866
938
|
}
|
|
867
|
-
}
|
|
868
939
|
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
940
|
+
} else if (parallel_mode == 1) {
|
|
941
|
+
for (idx_t i = 0; i < nx; i++) {
|
|
942
|
+
scanner->set_query(x + i * d);
|
|
872
943
|
|
|
873
|
-
|
|
944
|
+
RangeQueryResult& qres = pres.new_result(i);
|
|
874
945
|
|
|
875
946
|
#pragma omp for schedule(dynamic)
|
|
876
|
-
|
|
877
|
-
|
|
947
|
+
for (int64_t ik = 0; ik < cur_nprobe; ik++) {
|
|
948
|
+
try {
|
|
949
|
+
scan_list_func(i, ik, qres);
|
|
950
|
+
} catch (...) {
|
|
951
|
+
omp_capture_exception(ex);
|
|
952
|
+
}
|
|
953
|
+
}
|
|
878
954
|
}
|
|
879
|
-
}
|
|
880
|
-
|
|
881
|
-
RangeQueryResult* qres = nullptr;
|
|
955
|
+
} else if (parallel_mode == 2) {
|
|
956
|
+
RangeQueryResult* qres = nullptr;
|
|
882
957
|
|
|
883
958
|
#pragma omp for schedule(dynamic)
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
959
|
+
for (idx_t iik = 0; iik < nx * (idx_t)cur_nprobe; iik++) {
|
|
960
|
+
try {
|
|
961
|
+
idx_t i = iik / (idx_t)cur_nprobe;
|
|
962
|
+
idx_t ik = iik % (idx_t)cur_nprobe;
|
|
963
|
+
if (qres == nullptr || qres->qno != i) {
|
|
964
|
+
qres = &pres.new_result(i);
|
|
965
|
+
scanner->set_query(x + i * d);
|
|
966
|
+
}
|
|
967
|
+
scan_list_func(i, ik, *qres);
|
|
968
|
+
} catch (...) {
|
|
969
|
+
omp_capture_exception(ex);
|
|
970
|
+
}
|
|
890
971
|
}
|
|
891
|
-
|
|
972
|
+
} else {
|
|
973
|
+
FAISS_THROW_FMT(
|
|
974
|
+
"parallel_mode %d not supported\n", parallel_mode);
|
|
892
975
|
}
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
if (parallel_mode == 0) {
|
|
897
|
-
pres.finalize();
|
|
898
|
-
} else {
|
|
976
|
+
if (parallel_mode == 0) {
|
|
977
|
+
pres.finalize();
|
|
978
|
+
} else {
|
|
899
979
|
#pragma omp barrier
|
|
900
980
|
#pragma omp single
|
|
901
|
-
|
|
981
|
+
RangeSearchPartialResult::merge(all_pres, false);
|
|
902
982
|
#pragma omp barrier
|
|
983
|
+
}
|
|
984
|
+
} catch (...) {
|
|
985
|
+
omp_capture_exception(ex);
|
|
903
986
|
}
|
|
904
987
|
}
|
|
905
988
|
|
|
906
|
-
|
|
907
|
-
if (!exception_string.empty()) {
|
|
908
|
-
FAISS_THROW_FMT(
|
|
909
|
-
"search interrupted with: %s", exception_string.c_str());
|
|
910
|
-
} else {
|
|
911
|
-
FAISS_THROW_MSG("computation interrupted");
|
|
912
|
-
}
|
|
913
|
-
}
|
|
989
|
+
omp_rethrow_if_exception(ex);
|
|
914
990
|
|
|
915
991
|
if (stats == nullptr) {
|
|
916
992
|
stats = &indexIVF_stats;
|
|
@@ -931,26 +1007,31 @@ void IndexIVF::search1(
|
|
|
931
1007
|
FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
|
|
932
1008
|
quantizer_params = params->quantizer_params;
|
|
933
1009
|
}
|
|
934
|
-
const size_t
|
|
1010
|
+
const size_t cur_nprobe =
|
|
935
1011
|
std::min(nlist, params ? params->nprobe : this->nprobe);
|
|
936
1012
|
size_t nx = 1;
|
|
937
|
-
std::unique_ptr<idx_t[]> keys(new idx_t[nx *
|
|
938
|
-
std::unique_ptr<float[]> coarse_dis(new float[nx *
|
|
1013
|
+
std::unique_ptr<idx_t[]> keys(new idx_t[nx * cur_nprobe]);
|
|
1014
|
+
std::unique_ptr<float[]> coarse_dis(new float[nx * cur_nprobe]);
|
|
939
1015
|
|
|
940
1016
|
double t0 = getmillisecs();
|
|
941
1017
|
quantizer->search(
|
|
942
|
-
nx, x,
|
|
1018
|
+
nx, x, cur_nprobe, coarse_dis.get(), keys.get(), quantizer_params);
|
|
943
1019
|
indexIVF_stats.quantization_time += getmillisecs() - t0;
|
|
944
1020
|
|
|
945
1021
|
t0 = getmillisecs();
|
|
946
|
-
invlists->prefetch_lists(keys.get(), nx *
|
|
1022
|
+
invlists->prefetch_lists(keys.get(), static_cast<int>(nx * cur_nprobe));
|
|
947
1023
|
|
|
948
1024
|
std::unique_ptr<InvertedListScanner> scanner(
|
|
949
1025
|
get_InvertedListScanner(false, nullptr, params));
|
|
950
1026
|
scanner->set_query(x);
|
|
951
1027
|
|
|
952
|
-
for (
|
|
1028
|
+
for (size_t i = 0; i < cur_nprobe; i++) {
|
|
953
1029
|
idx_t key = keys[i];
|
|
1030
|
+
FAISS_THROW_IF_NOT_FMT(
|
|
1031
|
+
key < (idx_t)nlist,
|
|
1032
|
+
"Invalid key=%" PRId64 " nlist=%zd\n",
|
|
1033
|
+
key,
|
|
1034
|
+
nlist);
|
|
954
1035
|
if (key < 0 || invlists->is_empty(key)) {
|
|
955
1036
|
continue;
|
|
956
1037
|
}
|
|
@@ -981,11 +1062,11 @@ void IndexIVF::reconstruct(idx_t key, float* recons) const {
|
|
|
981
1062
|
void IndexIVF::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
|
|
982
1063
|
FAISS_THROW_IF_NOT(ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
|
|
983
1064
|
|
|
984
|
-
for (
|
|
1065
|
+
for (size_t list_no = 0; list_no < nlist; list_no++) {
|
|
985
1066
|
size_t list_size = invlists->list_size(list_no);
|
|
986
1067
|
ScopedIds idlist(invlists, list_no);
|
|
987
1068
|
|
|
988
|
-
for (
|
|
1069
|
+
for (size_t offset = 0; offset < list_size; offset++) {
|
|
989
1070
|
idx_t id = idlist[offset];
|
|
990
1071
|
if (!(id >= i0 && id < i0 + ni)) {
|
|
991
1072
|
continue;
|
|
@@ -1046,16 +1127,16 @@ void IndexIVF::search_and_reconstruct(
|
|
|
1046
1127
|
params = dynamic_cast<const IVFSearchParameters*>(params_in);
|
|
1047
1128
|
FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
|
|
1048
1129
|
}
|
|
1049
|
-
const size_t
|
|
1130
|
+
const size_t cur_nprobe =
|
|
1050
1131
|
std::min(nlist, params ? params->nprobe : this->nprobe);
|
|
1051
|
-
FAISS_THROW_IF_NOT(
|
|
1132
|
+
FAISS_THROW_IF_NOT(cur_nprobe > 0);
|
|
1052
1133
|
|
|
1053
|
-
std::unique_ptr<idx_t[]> idx(new idx_t[n *
|
|
1054
|
-
std::unique_ptr<float[]> coarse_dis(new float[n *
|
|
1134
|
+
std::unique_ptr<idx_t[]> idx(new idx_t[n * cur_nprobe]);
|
|
1135
|
+
std::unique_ptr<float[]> coarse_dis(new float[n * cur_nprobe]);
|
|
1055
1136
|
|
|
1056
|
-
quantizer->search(n, x,
|
|
1137
|
+
quantizer->search(n, x, cur_nprobe, coarse_dis.get(), idx.get());
|
|
1057
1138
|
|
|
1058
|
-
invlists->prefetch_lists(idx.get(), n *
|
|
1139
|
+
invlists->prefetch_lists(idx.get(), static_cast<int>(n * cur_nprobe));
|
|
1059
1140
|
|
|
1060
1141
|
// search_preassigned() with `store_pairs` enabled to obtain the list_no
|
|
1061
1142
|
// and offset into `codes` for reconstruction
|
|
@@ -1077,8 +1158,8 @@ void IndexIVF::search_and_reconstruct(
|
|
|
1077
1158
|
// Fill with NaNs
|
|
1078
1159
|
memset(reconstructed, -1, sizeof(*reconstructed) * d);
|
|
1079
1160
|
} else {
|
|
1080
|
-
|
|
1081
|
-
|
|
1161
|
+
size_t list_no = lo_listno(key);
|
|
1162
|
+
size_t offset = lo_offset(key);
|
|
1082
1163
|
|
|
1083
1164
|
// Update label to the actual id
|
|
1084
1165
|
labels[ij] = invlists->get_single_id(list_no, offset);
|
|
@@ -1102,16 +1183,16 @@ void IndexIVF::search_and_return_codes(
|
|
|
1102
1183
|
params = dynamic_cast<const IVFSearchParameters*>(params_in);
|
|
1103
1184
|
FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
|
|
1104
1185
|
}
|
|
1105
|
-
const size_t
|
|
1186
|
+
const size_t cur_nprobe =
|
|
1106
1187
|
std::min(nlist, params ? params->nprobe : this->nprobe);
|
|
1107
|
-
FAISS_THROW_IF_NOT(
|
|
1188
|
+
FAISS_THROW_IF_NOT(cur_nprobe > 0);
|
|
1108
1189
|
|
|
1109
|
-
std::unique_ptr<idx_t[]> idx(new idx_t[n *
|
|
1110
|
-
std::unique_ptr<float[]> coarse_dis(new float[n *
|
|
1190
|
+
std::unique_ptr<idx_t[]> idx(new idx_t[n * cur_nprobe]);
|
|
1191
|
+
std::unique_ptr<float[]> coarse_dis(new float[n * cur_nprobe]);
|
|
1111
1192
|
|
|
1112
|
-
quantizer->search(n, x,
|
|
1193
|
+
quantizer->search(n, x, cur_nprobe, coarse_dis.get(), idx.get());
|
|
1113
1194
|
|
|
1114
|
-
invlists->prefetch_lists(idx.get(), n *
|
|
1195
|
+
invlists->prefetch_lists(idx.get(), static_cast<int>(n * cur_nprobe));
|
|
1115
1196
|
|
|
1116
1197
|
// search_preassigned() with `store_pairs` enabled to obtain the list_no
|
|
1117
1198
|
// and offset into `codes` for reconstruction
|
|
@@ -1140,8 +1221,8 @@ void IndexIVF::search_and_return_codes(
|
|
|
1140
1221
|
// Fill with 0xff
|
|
1141
1222
|
memset(code1, -1, code_size_1);
|
|
1142
1223
|
} else {
|
|
1143
|
-
|
|
1144
|
-
|
|
1224
|
+
size_t list_no = lo_listno(key);
|
|
1225
|
+
size_t offset = lo_offset(key);
|
|
1145
1226
|
const uint8_t* cc = invlists->get_single_code(list_no, offset);
|
|
1146
1227
|
|
|
1147
1228
|
labels[ij] = invlists->get_single_id(list_no, offset);
|
|
@@ -1180,7 +1261,8 @@ void IndexIVF::update_vectors(int n, const idx_t* new_ids, const float* x) {
|
|
|
1180
1261
|
IDSelectorArray sel(n, new_ids);
|
|
1181
1262
|
size_t nremove = remove_ids(sel);
|
|
1182
1263
|
FAISS_THROW_IF_NOT_MSG(
|
|
1183
|
-
nremove == n,
|
|
1264
|
+
nremove == static_cast<size_t>(n),
|
|
1265
|
+
"did not find all entries to remove");
|
|
1184
1266
|
add_with_ids(n, x, new_ids);
|
|
1185
1267
|
return;
|
|
1186
1268
|
}
|
|
@@ -1242,7 +1324,7 @@ idx_t IndexIVF::train_encoder_num_vectors() const {
|
|
|
1242
1324
|
void IndexIVF::train_encoder(
|
|
1243
1325
|
idx_t /*n*/,
|
|
1244
1326
|
const float* /*x*/,
|
|
1245
|
-
const idx_t* assign) {
|
|
1327
|
+
const idx_t* /*assign*/) {
|
|
1246
1328
|
// does nothing by default
|
|
1247
1329
|
if (verbose) {
|
|
1248
1330
|
printf("IndexIVF: no residual training\n");
|
|
@@ -1385,11 +1467,19 @@ size_t InvertedListScanner::iterate_codes(
|
|
|
1385
1467
|
size_t nup = 0;
|
|
1386
1468
|
list_size = 0;
|
|
1387
1469
|
|
|
1470
|
+
const bool has_cb = it->has_search_callbacks_;
|
|
1471
|
+
|
|
1388
1472
|
if (!keep_max) {
|
|
1389
1473
|
for (; it->is_available(); it->next()) {
|
|
1390
1474
|
auto id_and_codes = it->get_id_and_codes();
|
|
1391
1475
|
float dis = distance_to_code(id_and_codes.second);
|
|
1476
|
+
if (has_cb) {
|
|
1477
|
+
it->on_distance_computed(id_and_codes.first, dis);
|
|
1478
|
+
}
|
|
1392
1479
|
if (dis < simi[0]) {
|
|
1480
|
+
if (has_cb) {
|
|
1481
|
+
it->on_heap_changed(id_and_codes.first, idxi[0]);
|
|
1482
|
+
}
|
|
1393
1483
|
maxheap_replace_top(k, simi, idxi, dis, id_and_codes.first);
|
|
1394
1484
|
nup++;
|
|
1395
1485
|
}
|
|
@@ -1399,7 +1489,13 @@ size_t InvertedListScanner::iterate_codes(
|
|
|
1399
1489
|
for (; it->is_available(); it->next()) {
|
|
1400
1490
|
auto id_and_codes = it->get_id_and_codes();
|
|
1401
1491
|
float dis = distance_to_code(id_and_codes.second);
|
|
1492
|
+
if (has_cb) {
|
|
1493
|
+
it->on_distance_computed(id_and_codes.first, dis);
|
|
1494
|
+
}
|
|
1402
1495
|
if (dis > simi[0]) {
|
|
1496
|
+
if (has_cb) {
|
|
1497
|
+
it->on_heap_changed(id_and_codes.first, idxi[0]);
|
|
1498
|
+
}
|
|
1403
1499
|
minheap_replace_top(k, simi, idxi, dis, id_and_codes.first);
|
|
1404
1500
|
nup++;
|
|
1405
1501
|
}
|
|
@@ -1419,10 +1515,14 @@ void InvertedListScanner::scan_codes_range(
|
|
|
1419
1515
|
using C = CMax<float, idx_t>;
|
|
1420
1516
|
RangeResultHandler<C, false> handler(&res, radius);
|
|
1421
1517
|
scan_codes(list_size, codes, ids, handler);
|
|
1518
|
+
res.stats.scan_cnt += handler.stats.scan_cnt;
|
|
1519
|
+
res.stats.nheap_updates += handler.stats.nheap_updates;
|
|
1422
1520
|
} else {
|
|
1423
1521
|
using C = CMin<float, idx_t>;
|
|
1424
1522
|
RangeResultHandler<C, false> handler(&res, radius);
|
|
1425
1523
|
scan_codes(list_size, codes, ids, handler);
|
|
1524
|
+
res.stats.scan_cnt += handler.stats.scan_cnt;
|
|
1525
|
+
res.stats.nheap_updates += handler.stats.nheap_updates;
|
|
1426
1526
|
}
|
|
1427
1527
|
}
|
|
1428
1528
|
|
|
@@ -1431,6 +1531,7 @@ void InvertedListScanner::iterate_codes_range(
|
|
|
1431
1531
|
float radius,
|
|
1432
1532
|
RangeQueryResult& res,
|
|
1433
1533
|
size_t& list_size) const {
|
|
1534
|
+
size_t nup = 0;
|
|
1434
1535
|
list_size = 0;
|
|
1435
1536
|
for (; it->is_available(); it->next()) {
|
|
1436
1537
|
auto id_and_codes = it->get_id_and_codes();
|
|
@@ -1440,9 +1541,11 @@ void InvertedListScanner::iterate_codes_range(
|
|
|
1440
1541
|
: dis > radius; // TODO templatize to remove this test
|
|
1441
1542
|
if (keep) {
|
|
1442
1543
|
res.add(dis, id_and_codes.first);
|
|
1544
|
+
nup++;
|
|
1443
1545
|
}
|
|
1444
1546
|
list_size++;
|
|
1445
1547
|
}
|
|
1548
|
+
res.stats.nheap_updates += nup;
|
|
1446
1549
|
}
|
|
1447
1550
|
|
|
1448
1551
|
} // namespace faiss
|