faiss 0.5.3 → 0.6.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -0
- data/ext/faiss/ext.cpp +1 -1
- data/ext/faiss/extconf.rb +4 -4
- data/ext/faiss/index.cpp +63 -45
- data/ext/faiss/index_binary.cpp +37 -27
- data/ext/faiss/kmeans.cpp +9 -8
- data/ext/faiss/pca_matrix.cpp +9 -7
- data/ext/faiss/product_quantizer.cpp +13 -11
- data/ext/faiss/utils.cpp +4 -2
- data/ext/faiss/utils.h +4 -0
- data/lib/faiss/version.rb +1 -1
- data/lib/faiss.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +214 -82
- data/vendor/faiss/faiss/AutoTune.h +14 -1
- data/vendor/faiss/faiss/Clustering.cpp +97 -249
- data/vendor/faiss/faiss/Clustering.h +18 -0
- data/vendor/faiss/faiss/IVFlib.cpp +67 -44
- data/vendor/faiss/faiss/Index.cpp +25 -12
- data/vendor/faiss/faiss/Index.h +26 -4
- data/vendor/faiss/faiss/Index2Layer.cpp +37 -53
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +68 -61
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +36 -34
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +4 -1
- data/vendor/faiss/faiss/IndexBinary.cpp +6 -3
- data/vendor/faiss/faiss/IndexBinary.h +4 -4
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +1 -1
- data/vendor/faiss/faiss/IndexBinaryFlat.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -4
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +92 -95
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +9 -3
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +45 -236
- data/vendor/faiss/faiss/IndexBinaryHash.h +6 -6
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +120 -414
- data/vendor/faiss/faiss/IndexFastScan.cpp +105 -129
- data/vendor/faiss/faiss/IndexFastScan.h +35 -24
- data/vendor/faiss/faiss/IndexFlat.cpp +216 -152
- data/vendor/faiss/faiss/IndexFlat.h +32 -14
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +88 -41
- data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +299 -187
- data/vendor/faiss/faiss/IndexHNSW.h +30 -14
- data/vendor/faiss/faiss/IndexIDMap.cpp +26 -22
- data/vendor/faiss/faiss/IndexIDMap.h +9 -7
- data/vendor/faiss/faiss/IndexIVF.cpp +535 -405
- data/vendor/faiss/faiss/IndexIVF.h +47 -16
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +77 -74
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +105 -99
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +6 -3
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +379 -249
- data/vendor/faiss/faiss/IndexIVFFastScan.h +65 -60
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +41 -124
- data/vendor/faiss/faiss/IndexIVFFlat.h +32 -0
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +89 -138
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +3 -1
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +18 -15
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +77 -907
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +184 -122
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +3 -0
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +23 -18
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +59 -60
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +4 -3
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +564 -416
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +269 -111
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +41 -127
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
- data/vendor/faiss/faiss/IndexLSH.cpp +44 -25
- data/vendor/faiss/faiss/IndexLattice.cpp +41 -36
- data/vendor/faiss/faiss/IndexNNDescent.cpp +37 -21
- data/vendor/faiss/faiss/IndexNNDescent.h +2 -2
- data/vendor/faiss/faiss/IndexNSG.cpp +40 -23
- data/vendor/faiss/faiss/IndexNSG.h +0 -2
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +32 -12
- data/vendor/faiss/faiss/IndexPQ.cpp +129 -213
- data/vendor/faiss/faiss/IndexPQ.h +3 -2
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +20 -14
- data/vendor/faiss/faiss/IndexPQFastScan.h +3 -0
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -18
- data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +31 -43
- data/vendor/faiss/faiss/IndexRaBitQ.h +4 -3
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +135 -317
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +192 -34
- data/vendor/faiss/faiss/IndexRefine.cpp +30 -55
- data/vendor/faiss/faiss/IndexRefine.h +4 -4
- data/vendor/faiss/faiss/IndexReplicas.cpp +6 -6
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +15 -14
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +1 -1
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +82 -14
- data/vendor/faiss/faiss/IndexShards.cpp +13 -13
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +21 -15
- data/vendor/faiss/faiss/MatrixStats.cpp +5 -4
- data/vendor/faiss/faiss/MetaIndexes.cpp +19 -17
- data/vendor/faiss/faiss/MetaIndexes.h +1 -1
- data/vendor/faiss/faiss/MetricType.h +29 -6
- data/vendor/faiss/faiss/SuperKMeans.cpp +656 -0
- data/vendor/faiss/faiss/SuperKMeans.h +97 -0
- data/vendor/faiss/faiss/VectorTransform.cpp +349 -141
- data/vendor/faiss/faiss/VectorTransform.h +39 -16
- data/vendor/faiss/faiss/build.cpp +23 -0
- data/vendor/faiss/faiss/build.h +15 -0
- data/vendor/faiss/faiss/clone_index.cpp +55 -51
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +47 -47
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +11 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +38 -38
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +11 -0
- data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +6 -1
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +6 -5
- data/vendor/faiss/faiss/gpu/GpuResources.h +1 -1
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +9 -9
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +4 -3
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +46 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +56 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +78 -1
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +72 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +23 -0
- data/vendor/faiss/faiss/gpu/utils/CuvsFilterConvert.h +1 -1
- data/vendor/faiss/faiss/gpu/utils/CuvsUtils.h +21 -10
- data/vendor/faiss/faiss/gpu_metal/GpuIndexFlat.h +22 -0
- data/vendor/faiss/faiss/gpu_metal/MetalCloner.h +35 -0
- data/vendor/faiss/faiss/gpu_metal/MetalFlatKernels.h +40 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +51 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndexFlat.h +65 -0
- data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +66 -0
- data/vendor/faiss/faiss/gpu_metal/MetalResources.h +79 -0
- data/vendor/faiss/faiss/gpu_metal/StandardMetalResources.h +35 -0
- data/vendor/faiss/faiss/impl/AdSampling.cpp +103 -0
- data/vendor/faiss/faiss/impl/AdSampling.h +35 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +64 -34
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +10 -9
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +3 -28
- data/vendor/faiss/faiss/impl/ClusteringHelpers.cpp +244 -0
- data/vendor/faiss/faiss/impl/ClusteringHelpers.h +94 -0
- data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
- data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
- data/vendor/faiss/faiss/impl/CodePacker.cpp +7 -3
- data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
- data/vendor/faiss/faiss/impl/DistanceComputer.h +8 -8
- data/vendor/faiss/faiss/impl/FaissAssert.h +64 -3
- data/vendor/faiss/faiss/impl/FaissException.h +50 -3
- data/vendor/faiss/faiss/impl/HNSW.cpp +117 -351
- data/vendor/faiss/faiss/impl/HNSW.h +21 -40
- data/vendor/faiss/faiss/impl/IDSelector.cpp +15 -11
- data/vendor/faiss/faiss/impl/IDSelector.h +8 -8
- data/vendor/faiss/faiss/impl/InvertedListScannerStats.h +26 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +114 -102
- data/vendor/faiss/faiss/impl/NNDescent.cpp +63 -26
- data/vendor/faiss/faiss/impl/NNDescent.h +6 -2
- data/vendor/faiss/faiss/impl/NSG.cpp +44 -26
- data/vendor/faiss/faiss/impl/NSG.h +20 -10
- data/vendor/faiss/faiss/impl/Panorama.cpp +76 -52
- data/vendor/faiss/faiss/impl/Panorama.h +265 -78
- data/vendor/faiss/faiss/impl/PdxLayout.cpp +93 -0
- data/vendor/faiss/faiss/impl/PdxLayout.h +41 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +62 -37
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +3 -3
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +35 -35
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +21 -16
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +99 -80
- data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +135 -37
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +148 -21
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +298 -301
- data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +40 -32
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResultHandler.h +218 -113
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +119 -2362
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +27 -3
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +14 -11
- data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
- data/vendor/faiss/faiss/impl/VisitedTable.h +76 -0
- data/vendor/faiss/faiss/impl/approx_topk/approx_topk.h +276 -0
- data/vendor/faiss/faiss/impl/approx_topk/avx2.cpp +68 -0
- data/vendor/faiss/faiss/{utils → impl}/approx_topk/generic.h +15 -8
- data/vendor/faiss/faiss/impl/approx_topk/neon.cpp +68 -0
- data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab-inl.h +169 -0
- data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab.h +117 -0
- data/vendor/faiss/faiss/impl/approx_topk/simdlib256-inl.h +146 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHNSW_impl.h +73 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHash_impl.h +270 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryIVF_impl.h +460 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexIVFSpectralHash_impl.h +159 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexPQ_impl.h +92 -0
- data/vendor/faiss/faiss/impl/binary_hamming/avx2.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/avx512.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/dispatch.h +143 -0
- data/vendor/faiss/faiss/impl/binary_hamming/neon.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/rvv.cpp +26 -0
- data/vendor/faiss/faiss/impl/expanded_scanners.h +163 -0
- data/vendor/faiss/faiss/impl/{FastScanDistancePostProcessing.h → fast_scan/FastScanDistancePostProcessing.h} +13 -6
- data/vendor/faiss/faiss/impl/{LookupTableScaler.h → fast_scan/LookupTableScaler.h} +16 -5
- data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops.h +237 -0
- data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops_512.h +185 -0
- data/vendor/faiss/faiss/impl/fast_scan/decompose_qbs.h +229 -0
- data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +268 -0
- data/vendor/faiss/faiss/impl/{pq4_fast_scan.cpp → fast_scan/fast_scan.cpp} +176 -4
- data/vendor/faiss/faiss/impl/fast_scan/fast_scan.h +341 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-avx2.cpp +36 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-avx512.cpp +40 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-neon.cpp +120 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-riscv.cpp +104 -0
- data/vendor/faiss/faiss/impl/fast_scan/kernels_simd256.h +213 -0
- data/vendor/faiss/faiss/impl/{pq4_fast_scan_search_qbs.cpp → fast_scan/kernels_simd512.h} +26 -348
- data/vendor/faiss/faiss/impl/fast_scan/rabitq_dispatching.h +90 -0
- data/vendor/faiss/faiss/impl/fast_scan/rabitq_result_handler.h +108 -0
- data/vendor/faiss/faiss/impl/{simd_result_handlers.h → fast_scan/simd_result_handlers.h} +290 -142
- data/vendor/faiss/faiss/impl/hnsw/LockVector.cpp +54 -0
- data/vendor/faiss/faiss/impl/hnsw/LockVector.h +64 -0
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +91 -0
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -0
- data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +104 -0
- data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +111 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +1950 -505
- data/vendor/faiss/faiss/impl/index_read_utils.h +1 -2
- data/vendor/faiss/faiss/impl/index_write.cpp +112 -21
- data/vendor/faiss/faiss/impl/io.cpp +6 -6
- data/vendor/faiss/faiss/impl/io_macros.h +33 -16
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +10 -10
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +81 -40
- data/vendor/faiss/faiss/impl/lattice_Zn.h +6 -6
- data/vendor/faiss/faiss/impl/mapped_io.cpp +15 -8
- data/vendor/faiss/faiss/impl/platform_macros.h +11 -4
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQScanner_impl.h +549 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.cpp +245 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.h +105 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/PQDistanceComputer_impl.h +106 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +21 -0
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.h} +43 -220
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.h} +25 -112
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +59 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.h +96 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +256 -0
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -146
- data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +68 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +320 -483
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
- data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +137 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +371 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +190 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +603 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +597 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +388 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +630 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +311 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +387 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +54 -0
- data/vendor/faiss/faiss/impl/simd_dispatch.h +173 -0
- data/vendor/faiss/faiss/impl/simdlib/simdlib.h +57 -0
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_avx2.h +274 -171
- data/vendor/faiss/faiss/impl/simdlib/simdlib_avx512.h +414 -0
- data/vendor/faiss/faiss/impl/simdlib/simdlib_dispatch.h +44 -0
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_emulated.h +231 -166
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_neon.h +275 -217
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_ppc64.h +201 -160
- data/vendor/faiss/faiss/impl/svs_io.cpp +12 -3
- data/vendor/faiss/faiss/impl/svs_io.h +8 -2
- data/vendor/faiss/faiss/index_factory.cpp +115 -28
- data/vendor/faiss/faiss/index_io.h +53 -3
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +73 -20
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +24 -14
- data/vendor/faiss/faiss/invlists/DirectMap.h +4 -3
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +157 -73
- data/vendor/faiss/faiss/invlists/InvertedLists.h +86 -23
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +4 -4
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +14 -14
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
- data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +2 -2
- data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVF.cpp +350 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVF.h +128 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.cpp +40 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.h +43 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.cpp +225 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.h +71 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +25 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +19 -2
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +19 -2
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +14 -0
- data/vendor/faiss/faiss/utils/Heap.cpp +56 -10
- data/vendor/faiss/faiss/utils/Heap.h +21 -0
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +54 -40
- data/vendor/faiss/faiss/utils/NeuralNet.h +1 -1
- data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +10 -4
- data/vendor/faiss/faiss/utils/distances.cpp +507 -559
- data/vendor/faiss/faiss/utils/distances.h +118 -1
- data/vendor/faiss/faiss/utils/distances_dispatch.h +250 -0
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +8 -7
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +33 -14
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +12 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +16 -293
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based_neon.cpp +57 -0
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_kernel-inl.h +290 -0
- data/vendor/faiss/faiss/utils/distances_simd.cpp +72 -3681
- data/vendor/faiss/faiss/utils/extra_distances.cpp +60 -102
- data/vendor/faiss/faiss/utils/extra_distances.h +79 -7
- data/vendor/faiss/faiss/utils/hamming-inl.h +13 -11
- data/vendor/faiss/faiss/utils/hamming.cpp +66 -517
- data/vendor/faiss/faiss/utils/hamming.h +92 -2
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +287 -10
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx2.h +142 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +234 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-generic.h +368 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-neon.h +322 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-rvv.h +39 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer.h +146 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_impl.h +481 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_neon.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_rvv.cpp +15 -0
- data/vendor/faiss/faiss/utils/partitioning.cpp +66 -987
- data/vendor/faiss/faiss/utils/partitioning.h +31 -0
- data/vendor/faiss/faiss/utils/popcount.h +29 -0
- data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
- data/vendor/faiss/faiss/utils/prefetch.h +2 -2
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +30 -30
- data/vendor/faiss/faiss/utils/quantize_lut.h +1 -1
- data/vendor/faiss/faiss/utils/rabitq_simd.h +124 -343
- data/vendor/faiss/faiss/utils/random.cpp +6 -6
- data/vendor/faiss/faiss/utils/simd_impl/IVFFlatScanner-inl.h +51 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +154 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +777 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +306 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1431 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1095 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_rvv.cpp +189 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_simdlib256.h +195 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +392 -0
- data/vendor/faiss/faiss/utils/{distances_fused/simdlib_based.h → simd_impl/exhaustive_L2sqr_blas_cmax.h} +5 -10
- data/vendor/faiss/faiss/utils/simd_impl/hamming_impl.h +481 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_avx2.cpp +14 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_neon.cpp +14 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +1085 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx2.cpp +355 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512.cpp +477 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_neon.cpp +55 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_rvv.cpp +55 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_dispatch.h +32 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels.h +43 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx2.cpp +57 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx512.cpp +45 -0
- data/vendor/faiss/faiss/utils/simd_levels.cpp +334 -0
- data/vendor/faiss/faiss/utils/simd_levels.h +183 -0
- data/vendor/faiss/faiss/utils/sorting.cpp +48 -36
- data/vendor/faiss/faiss/utils/utils.cpp +21 -14
- data/vendor/faiss/faiss/utils/utils.h +3 -3
- metadata +156 -42
- data/vendor/faiss/faiss/impl/RaBitQStats.cpp +0 -29
- data/vendor/faiss/faiss/impl/RaBitQStats.h +0 -56
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +0 -216
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +0 -224
- data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +0 -84
- data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +0 -196
- data/vendor/faiss/faiss/utils/approx_topk/mode.h +0 -34
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +0 -36
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +0 -228
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +0 -462
- data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +0 -490
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -450
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +0 -87
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +0 -524
- data/vendor/faiss/faiss/utils/simdlib.h +0 -42
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +0 -296
- /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
|
@@ -5,14 +5,12 @@
|
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
|
6
6
|
*/
|
|
7
7
|
|
|
8
|
-
// -*- c++ -*-
|
|
9
|
-
|
|
10
8
|
#include <faiss/IndexIVF.h>
|
|
11
9
|
|
|
12
10
|
#include <omp.h>
|
|
11
|
+
#include <atomic>
|
|
13
12
|
#include <cstdint>
|
|
14
13
|
#include <memory>
|
|
15
|
-
#include <mutex>
|
|
16
14
|
|
|
17
15
|
#include <algorithm>
|
|
18
16
|
#include <cinttypes>
|
|
@@ -27,6 +25,8 @@
|
|
|
27
25
|
#include <faiss/impl/CodePacker.h>
|
|
28
26
|
#include <faiss/impl/FaissAssert.h>
|
|
29
27
|
#include <faiss/impl/IDSelector.h>
|
|
28
|
+
#include <faiss/impl/ResultHandler.h>
|
|
29
|
+
#include <faiss/impl/expanded_scanners.h>
|
|
30
30
|
|
|
31
31
|
namespace faiss {
|
|
32
32
|
|
|
@@ -37,8 +37,8 @@ using ScopedCodes = InvertedLists::ScopedCodes;
|
|
|
37
37
|
* Level1Quantizer implementation
|
|
38
38
|
******************************************/
|
|
39
39
|
|
|
40
|
-
Level1Quantizer::Level1Quantizer(Index*
|
|
41
|
-
: quantizer(
|
|
40
|
+
Level1Quantizer::Level1Quantizer(Index* quantizer_in, size_t nlist_in)
|
|
41
|
+
: quantizer(quantizer_in), nlist(nlist_in) {
|
|
42
42
|
// here we set a low # iterations because this is typically used
|
|
43
43
|
// for large clusterings (nb this is not used for the MultiIndex,
|
|
44
44
|
// for which quantizer_trains_alone = true)
|
|
@@ -58,8 +58,10 @@ void Level1Quantizer::train_q1(
|
|
|
58
58
|
const float* x,
|
|
59
59
|
bool verbose,
|
|
60
60
|
MetricType metric_type) {
|
|
61
|
+
FAISS_THROW_IF_NOT_MSG(quantizer, "IVF quantizer must not be null");
|
|
61
62
|
size_t d = quantizer->d;
|
|
62
|
-
if (quantizer->is_trained &&
|
|
63
|
+
if (quantizer->is_trained &&
|
|
64
|
+
(static_cast<size_t>(quantizer->ntotal) == nlist)) {
|
|
63
65
|
if (verbose) {
|
|
64
66
|
printf("IVF quantizer does not need training.\n");
|
|
65
67
|
}
|
|
@@ -70,14 +72,14 @@ void Level1Quantizer::train_q1(
|
|
|
70
72
|
quantizer->verbose = verbose;
|
|
71
73
|
quantizer->train(n, x);
|
|
72
74
|
FAISS_THROW_IF_NOT_MSG(
|
|
73
|
-
quantizer->ntotal == nlist,
|
|
75
|
+
static_cast<size_t>(quantizer->ntotal) == nlist,
|
|
74
76
|
"nlist not consistent with quantizer size");
|
|
75
77
|
} else if (quantizer_trains_alone == 0) {
|
|
76
78
|
if (verbose) {
|
|
77
79
|
printf("Training level-1 quantizer on %zd vectors in %zdD\n", n, d);
|
|
78
80
|
}
|
|
79
81
|
|
|
80
|
-
Clustering clus(d, nlist, cp);
|
|
82
|
+
Clustering clus(static_cast<int>(d), static_cast<int>(nlist), cp);
|
|
81
83
|
quantizer->reset();
|
|
82
84
|
if (clustering_index) {
|
|
83
85
|
clus.train(n, x, *clustering_index);
|
|
@@ -99,7 +101,7 @@ void Level1Quantizer::train_q1(
|
|
|
99
101
|
metric_type == METRIC_L2 ||
|
|
100
102
|
(metric_type == METRIC_INNER_PRODUCT && cp.spherical));
|
|
101
103
|
|
|
102
|
-
Clustering clus(d, nlist, cp);
|
|
104
|
+
Clustering clus(static_cast<int>(d), static_cast<int>(nlist), cp);
|
|
103
105
|
if (!clustering_index) {
|
|
104
106
|
IndexFlatL2 assigner(d);
|
|
105
107
|
clus.train(n, x, assigner);
|
|
@@ -148,7 +150,7 @@ idx_t Level1Quantizer::decode_listno(const uint8_t* code) const {
|
|
|
148
150
|
nbit += 8;
|
|
149
151
|
nl >>= 8;
|
|
150
152
|
}
|
|
151
|
-
FAISS_THROW_IF_NOT(list_no >= 0 && list_no < nlist);
|
|
153
|
+
FAISS_THROW_IF_NOT(list_no >= 0 && static_cast<size_t>(list_no) < nlist);
|
|
152
154
|
return list_no;
|
|
153
155
|
}
|
|
154
156
|
|
|
@@ -157,21 +159,23 @@ idx_t Level1Quantizer::decode_listno(const uint8_t* code) const {
|
|
|
157
159
|
******************************************/
|
|
158
160
|
|
|
159
161
|
IndexIVF::IndexIVF(
|
|
160
|
-
Index*
|
|
161
|
-
size_t
|
|
162
|
-
size_t
|
|
163
|
-
size_t
|
|
162
|
+
Index* quantizer_in,
|
|
163
|
+
size_t d_in,
|
|
164
|
+
size_t nlist_in,
|
|
165
|
+
size_t code_size_in,
|
|
164
166
|
MetricType metric,
|
|
165
|
-
bool
|
|
166
|
-
: Index(
|
|
167
|
-
IndexIVFInterface(
|
|
167
|
+
bool own_invlists_in)
|
|
168
|
+
: Index(d_in, metric),
|
|
169
|
+
IndexIVFInterface(quantizer_in, nlist_in),
|
|
168
170
|
invlists(
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
171
|
+
own_invlists_in
|
|
172
|
+
? new ArrayInvertedLists(nlist_in, code_size_in)
|
|
173
|
+
: nullptr),
|
|
174
|
+
own_invlists(own_invlists_in),
|
|
175
|
+
code_size(code_size_in) {
|
|
176
|
+
FAISS_THROW_IF_NOT(static_cast<int>(d_in) == quantizer_in->d);
|
|
177
|
+
is_trained = quantizer_in->is_trained &&
|
|
178
|
+
(static_cast<size_t>(quantizer_in->ntotal) == nlist_in);
|
|
175
179
|
// Spherical by default if the metric is inner_product
|
|
176
180
|
if (metric_type == METRIC_INNER_PRODUCT) {
|
|
177
181
|
cp.spherical = true;
|
|
@@ -185,6 +189,8 @@ void IndexIVF::add(idx_t n, const float* x) {
|
|
|
185
189
|
}
|
|
186
190
|
|
|
187
191
|
void IndexIVF::add_with_ids(idx_t n, const float* x, const idx_t* xids) {
|
|
192
|
+
FAISS_THROW_IF_NOT_MSG(quantizer, "IVF quantizer must not be null");
|
|
193
|
+
FAISS_THROW_IF_NOT_MSG(invlists, "IVF index has no inverted lists");
|
|
188
194
|
std::unique_ptr<idx_t[]> coarse_idx(new idx_t[n]);
|
|
189
195
|
quantizer->assign(n, x, coarse_idx.get());
|
|
190
196
|
add_core(n, x, xids, coarse_idx.get());
|
|
@@ -235,7 +241,7 @@ void IndexIVF::add_core(
|
|
|
235
241
|
|
|
236
242
|
size_t nadd = 0, nminus1 = 0;
|
|
237
243
|
|
|
238
|
-
for (
|
|
244
|
+
for (idx_t i = 0; i < n; i++) {
|
|
239
245
|
if (coarse_idx[i] < 0) {
|
|
240
246
|
nminus1++;
|
|
241
247
|
}
|
|
@@ -252,7 +258,7 @@ void IndexIVF::add_core(
|
|
|
252
258
|
int rank = omp_get_thread_num();
|
|
253
259
|
|
|
254
260
|
// each thread takes care of a subset of lists
|
|
255
|
-
for (
|
|
261
|
+
for (idx_t i = 0; i < n; i++) {
|
|
256
262
|
idx_t list_no = coarse_idx[i];
|
|
257
263
|
if (list_no >= 0 && list_no % nt == rank) {
|
|
258
264
|
idx_t id = xids ? xids[i] : ntotal + i;
|
|
@@ -305,45 +311,49 @@ void IndexIVF::search(
|
|
|
305
311
|
idx_t* labels,
|
|
306
312
|
const SearchParameters* params_in) const {
|
|
307
313
|
FAISS_THROW_IF_NOT(k > 0);
|
|
314
|
+
FAISS_THROW_IF_NOT_MSG(quantizer, "IVF quantizer must not be null");
|
|
315
|
+
FAISS_THROW_IF_NOT_MSG(is_trained, "IVF index is not trained");
|
|
316
|
+
FAISS_THROW_IF_NOT_MSG(invlists, "IVF index has no inverted lists");
|
|
308
317
|
const IVFSearchParameters* params = nullptr;
|
|
309
318
|
if (params_in) {
|
|
310
319
|
params = dynamic_cast<const IVFSearchParameters*>(params_in);
|
|
311
320
|
FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
|
|
312
321
|
}
|
|
313
|
-
const size_t
|
|
322
|
+
const size_t cur_nprobe =
|
|
314
323
|
std::min(nlist, params ? params->nprobe : this->nprobe);
|
|
315
|
-
FAISS_THROW_IF_NOT(
|
|
324
|
+
FAISS_THROW_IF_NOT(cur_nprobe > 0);
|
|
316
325
|
|
|
317
326
|
// search function for a subset of queries
|
|
318
|
-
auto sub_search_func = [this, k,
|
|
319
|
-
idx_t
|
|
320
|
-
const float*
|
|
321
|
-
float*
|
|
322
|
-
idx_t*
|
|
327
|
+
auto sub_search_func = [this, k, cur_nprobe, params](
|
|
328
|
+
idx_t sub_n,
|
|
329
|
+
const float* sub_x,
|
|
330
|
+
float* sub_distances,
|
|
331
|
+
idx_t* sub_labels,
|
|
323
332
|
IndexIVFStats* ivf_stats) {
|
|
324
|
-
std::unique_ptr<idx_t[]> idx(new idx_t[
|
|
325
|
-
std::unique_ptr<float[]> coarse_dis(new float[
|
|
333
|
+
std::unique_ptr<idx_t[]> idx(new idx_t[sub_n * cur_nprobe]);
|
|
334
|
+
std::unique_ptr<float[]> coarse_dis(new float[sub_n * cur_nprobe]);
|
|
326
335
|
|
|
327
336
|
double t0 = getmillisecs();
|
|
328
337
|
quantizer->search(
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
338
|
+
sub_n,
|
|
339
|
+
sub_x,
|
|
340
|
+
cur_nprobe,
|
|
332
341
|
coarse_dis.get(),
|
|
333
342
|
idx.get(),
|
|
334
343
|
params ? params->quantizer_params : nullptr);
|
|
335
344
|
|
|
336
345
|
double t1 = getmillisecs();
|
|
337
|
-
invlists->prefetch_lists(
|
|
346
|
+
invlists->prefetch_lists(
|
|
347
|
+
idx.get(), static_cast<int>(sub_n * cur_nprobe));
|
|
338
348
|
|
|
339
349
|
search_preassigned(
|
|
340
|
-
|
|
341
|
-
|
|
350
|
+
sub_n,
|
|
351
|
+
sub_x,
|
|
342
352
|
k,
|
|
343
353
|
idx.get(),
|
|
344
354
|
coarse_dis.get(),
|
|
345
|
-
|
|
346
|
-
|
|
355
|
+
sub_distances,
|
|
356
|
+
sub_labels,
|
|
347
357
|
false,
|
|
348
358
|
params,
|
|
349
359
|
ivf_stats);
|
|
@@ -355,32 +365,28 @@ void IndexIVF::search(
|
|
|
355
365
|
if ((parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT) == 0) {
|
|
356
366
|
int nt = std::min(omp_get_max_threads(), int(n));
|
|
357
367
|
std::vector<IndexIVFStats> stats(nt);
|
|
358
|
-
std::
|
|
359
|
-
std::string exception_string;
|
|
368
|
+
std::exception_ptr ex;
|
|
360
369
|
|
|
361
370
|
#pragma omp parallel for if (nt > 1)
|
|
362
371
|
for (idx_t slice = 0; slice < nt; slice++) {
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
372
|
+
try {
|
|
373
|
+
IndexIVFStats local_stats;
|
|
374
|
+
idx_t i0 = n * slice / nt;
|
|
375
|
+
idx_t i1 = n * (slice + 1) / nt;
|
|
376
|
+
if (i1 > i0) {
|
|
368
377
|
sub_search_func(
|
|
369
378
|
i1 - i0,
|
|
370
379
|
x + i0 * d,
|
|
371
380
|
distances + i0 * k,
|
|
372
381
|
labels + i0 * k,
|
|
373
382
|
&stats[slice]);
|
|
374
|
-
} catch (const std::exception& e) {
|
|
375
|
-
std::lock_guard<std::mutex> lock(exception_mutex);
|
|
376
|
-
exception_string = e.what();
|
|
377
383
|
}
|
|
384
|
+
} catch (...) {
|
|
385
|
+
omp_capture_exception(ex);
|
|
378
386
|
}
|
|
379
387
|
}
|
|
380
388
|
|
|
381
|
-
|
|
382
|
-
FAISS_THROW_MSG(exception_string.c_str());
|
|
383
|
-
}
|
|
389
|
+
omp_rethrow_if_exception(ex);
|
|
384
390
|
|
|
385
391
|
// collect stats
|
|
386
392
|
for (idx_t slice = 0; slice < nt; slice++) {
|
|
@@ -405,13 +411,17 @@ void IndexIVF::search_preassigned(
|
|
|
405
411
|
const IVFSearchParameters* params,
|
|
406
412
|
IndexIVFStats* ivf_stats) const {
|
|
407
413
|
FAISS_THROW_IF_NOT(k > 0);
|
|
414
|
+
FAISS_THROW_IF_NOT_MSG(is_trained, "IVF index is not trained");
|
|
415
|
+
FAISS_THROW_IF_NOT_MSG(invlists, "IVF index has no inverted lists");
|
|
408
416
|
|
|
409
|
-
idx_t
|
|
410
|
-
|
|
411
|
-
FAISS_THROW_IF_NOT(
|
|
417
|
+
idx_t cur_nprobe = params ? params->nprobe : this->nprobe;
|
|
418
|
+
cur_nprobe = std::min((idx_t)nlist, cur_nprobe);
|
|
419
|
+
FAISS_THROW_IF_NOT(cur_nprobe > 0);
|
|
412
420
|
|
|
413
421
|
const idx_t unlimited_list_size = std::numeric_limits<idx_t>::max();
|
|
414
|
-
idx_t
|
|
422
|
+
idx_t cur_max_codes = params ? params->max_codes : this->max_codes;
|
|
423
|
+
const bool ensure_topk_full = params ? params->ensure_topk_full : false;
|
|
424
|
+
|
|
415
425
|
IDSelector* sel = params ? params->sel : nullptr;
|
|
416
426
|
const IDSelectorRange* selr = dynamic_cast<const IDSelectorRange*>(sel);
|
|
417
427
|
if (selr) {
|
|
@@ -427,7 +437,8 @@ void IndexIVF::search_preassigned(
|
|
|
427
437
|
"selector and store_pairs cannot be combined");
|
|
428
438
|
|
|
429
439
|
FAISS_THROW_IF_NOT_MSG(
|
|
430
|
-
!invlists->use_iterator ||
|
|
440
|
+
!invlists->use_iterator ||
|
|
441
|
+
(cur_max_codes == 0 && store_pairs == false),
|
|
431
442
|
"iterable inverted lists don't support max_codes and store_pairs");
|
|
432
443
|
|
|
433
444
|
size_t nlistv = 0, ndis = 0, nheap = 0;
|
|
@@ -435,106 +446,119 @@ void IndexIVF::search_preassigned(
|
|
|
435
446
|
using HeapForIP = CMin<float, idx_t>;
|
|
436
447
|
using HeapForL2 = CMax<float, idx_t>;
|
|
437
448
|
|
|
438
|
-
|
|
439
|
-
std::
|
|
440
|
-
std::string exception_string;
|
|
449
|
+
std::exception_ptr ex;
|
|
450
|
+
std::atomic<bool> interrupt{false};
|
|
441
451
|
|
|
442
452
|
int pmode = this->parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT;
|
|
443
453
|
bool do_heap_init = !(this->parallel_mode & PARALLEL_MODE_NO_HEAP_INIT);
|
|
444
454
|
|
|
445
455
|
FAISS_THROW_IF_NOT_MSG(
|
|
446
|
-
|
|
456
|
+
cur_max_codes == 0 || pmode == 0 || pmode == 3,
|
|
447
457
|
"max_codes supported only for parallel_mode = 0 or 3");
|
|
448
458
|
|
|
449
|
-
|
|
450
|
-
|
|
459
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
460
|
+
!ensure_topk_full || pmode == 0 || pmode == 3,
|
|
461
|
+
"ensure_topk_full supported only for parallel_mode = 0 or 3");
|
|
462
|
+
|
|
463
|
+
if (cur_max_codes == 0) {
|
|
464
|
+
cur_max_codes = unlimited_list_size;
|
|
451
465
|
}
|
|
466
|
+
// Budget used by the probe loop below. ensure_topk_full makes a small
|
|
467
|
+
// max_codes budget large enough to give k post-filter candidates a chance.
|
|
468
|
+
idx_t effective_max_codes =
|
|
469
|
+
ensure_topk_full ? std::max(cur_max_codes, k) : cur_max_codes;
|
|
452
470
|
|
|
453
471
|
[[maybe_unused]] bool do_parallel = omp_get_max_threads() >= 2 &&
|
|
454
472
|
(pmode == 0 ? false
|
|
455
473
|
: pmode == 3 ? n > 1
|
|
456
|
-
: pmode == 1 ?
|
|
457
|
-
:
|
|
474
|
+
: pmode == 1 ? cur_nprobe > 1
|
|
475
|
+
: cur_nprobe * n > 1);
|
|
458
476
|
|
|
459
477
|
void* inverted_list_context =
|
|
460
478
|
params ? params->inverted_list_context : nullptr;
|
|
461
479
|
|
|
462
480
|
#pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis, nheap)
|
|
463
481
|
{
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
482
|
+
// C++ exceptions that escape an OpenMP parallel region without being
|
|
483
|
+
// caught inside it call std::terminate — they cannot propagate across
|
|
484
|
+
// thread boundaries. The outer try/catch covers per-thread setup
|
|
485
|
+
// (scanner creation, set_query); the inner try/catch in scan_one_list
|
|
486
|
+
// covers per-list operations. Both set interrupt=true to stop further
|
|
487
|
+
// work and re-throw after the parallel region exits.
|
|
488
|
+
try {
|
|
489
|
+
std::unique_ptr<InvertedListScanner> scanner(
|
|
490
|
+
get_InvertedListScanner(store_pairs, sel, params));
|
|
491
|
+
|
|
492
|
+
/*****************************************************
|
|
493
|
+
* Depending on parallel_mode, there are two possible ways
|
|
494
|
+
* to organize the search. Here we define local functions
|
|
495
|
+
* that are in common between the two
|
|
496
|
+
******************************************************/
|
|
497
|
+
|
|
498
|
+
// initialize + reorder a result heap
|
|
499
|
+
|
|
500
|
+
auto init_result = [&](float* simi, idx_t* idxi) {
|
|
501
|
+
if (!do_heap_init) {
|
|
502
|
+
return;
|
|
503
|
+
}
|
|
504
|
+
if (metric_type == METRIC_INNER_PRODUCT) {
|
|
505
|
+
heap_heapify<HeapForIP>(k, simi, idxi);
|
|
506
|
+
} else {
|
|
507
|
+
heap_heapify<HeapForL2>(k, simi, idxi);
|
|
508
|
+
}
|
|
509
|
+
};
|
|
510
|
+
|
|
511
|
+
auto add_local_results = [&](const float* local_dis,
|
|
512
|
+
const idx_t* local_idx,
|
|
513
|
+
float* simi,
|
|
514
|
+
idx_t* idxi) {
|
|
515
|
+
if (metric_type == METRIC_INNER_PRODUCT) {
|
|
516
|
+
heap_addn<HeapForIP>(
|
|
517
|
+
k, simi, idxi, local_dis, local_idx, k);
|
|
518
|
+
} else {
|
|
519
|
+
heap_addn<HeapForL2>(
|
|
520
|
+
k, simi, idxi, local_dis, local_idx, k);
|
|
521
|
+
}
|
|
522
|
+
};
|
|
474
523
|
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
524
|
+
auto reorder_result = [&](float* simi, idx_t* idxi) {
|
|
525
|
+
if (!do_heap_init) {
|
|
526
|
+
return;
|
|
527
|
+
}
|
|
528
|
+
if (metric_type == METRIC_INNER_PRODUCT) {
|
|
529
|
+
heap_reorder<HeapForIP>(k, simi, idxi);
|
|
530
|
+
} else {
|
|
531
|
+
heap_reorder<HeapForL2>(k, simi, idxi);
|
|
532
|
+
}
|
|
533
|
+
};
|
|
485
534
|
|
|
486
|
-
|
|
487
|
-
|
|
535
|
+
// single list scan using the current scanner (with query
|
|
536
|
+
// set properly) and storing results in simi and idxi
|
|
537
|
+
auto scan_one_list = [&](idx_t key,
|
|
538
|
+
float coarse_dis_i,
|
|
488
539
|
float* simi,
|
|
489
|
-
idx_t* idxi
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
}
|
|
506
|
-
};
|
|
507
|
-
|
|
508
|
-
// single list scan using the current scanner (with query
|
|
509
|
-
// set properly) and storing results in simi and idxi
|
|
510
|
-
auto scan_one_list = [&](idx_t key,
|
|
511
|
-
float coarse_dis_i,
|
|
512
|
-
float* simi,
|
|
513
|
-
idx_t* idxi,
|
|
514
|
-
idx_t list_size_max) {
|
|
515
|
-
if (key < 0) {
|
|
516
|
-
// not enough centroids for multiprobe
|
|
517
|
-
return (size_t)0;
|
|
518
|
-
}
|
|
519
|
-
FAISS_THROW_IF_NOT_FMT(
|
|
520
|
-
key < (idx_t)nlist,
|
|
521
|
-
"Invalid key=%" PRId64 " nlist=%zd\n",
|
|
522
|
-
key,
|
|
523
|
-
nlist);
|
|
524
|
-
|
|
525
|
-
// don't waste time on empty lists
|
|
526
|
-
if (invlists->is_empty(key, inverted_list_context)) {
|
|
527
|
-
return (size_t)0;
|
|
528
|
-
}
|
|
529
|
-
|
|
530
|
-
scanner->set_list(key, coarse_dis_i);
|
|
540
|
+
idx_t* idxi,
|
|
541
|
+
idx_t list_size_max) {
|
|
542
|
+
if (key < 0) {
|
|
543
|
+
// not enough centroids for multiprobe
|
|
544
|
+
return (size_t)0;
|
|
545
|
+
}
|
|
546
|
+
FAISS_THROW_IF_NOT_FMT(
|
|
547
|
+
key < (idx_t)nlist,
|
|
548
|
+
"Invalid key=%" PRId64 " nlist=%zd\n",
|
|
549
|
+
key,
|
|
550
|
+
nlist);
|
|
551
|
+
|
|
552
|
+
// don't waste time on empty lists
|
|
553
|
+
if (invlists->is_empty(key, inverted_list_context)) {
|
|
554
|
+
return (size_t)0;
|
|
555
|
+
}
|
|
531
556
|
|
|
532
|
-
|
|
557
|
+
scanner->set_list(key, coarse_dis_i);
|
|
533
558
|
|
|
534
|
-
|
|
559
|
+
nlistv++;
|
|
535
560
|
if (invlists->use_iterator) {
|
|
536
561
|
size_t list_size = 0;
|
|
537
|
-
|
|
538
562
|
std::unique_ptr<InvertedListsIterator> it(
|
|
539
563
|
invlists->get_iterator(key, inverted_list_context));
|
|
540
564
|
|
|
@@ -544,8 +568,8 @@ void IndexIVF::search_preassigned(
|
|
|
544
568
|
return list_size;
|
|
545
569
|
} else {
|
|
546
570
|
size_t list_size = invlists->list_size(key);
|
|
547
|
-
if (list_size > list_size_max) {
|
|
548
|
-
list_size = list_size_max;
|
|
571
|
+
if (list_size > static_cast<size_t>(list_size_max)) {
|
|
572
|
+
list_size = static_cast<size_t>(list_size_max);
|
|
549
573
|
}
|
|
550
574
|
|
|
551
575
|
InvertedLists::ScopedCodes scodes(invlists, key);
|
|
@@ -573,144 +597,167 @@ void IndexIVF::search_preassigned(
|
|
|
573
597
|
ids += jmin;
|
|
574
598
|
}
|
|
575
599
|
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
600
|
+
size_t old_scan_cnt = 0;
|
|
601
|
+
size_t old_heap_updates = 0;
|
|
602
|
+
if (metric_type == METRIC_INNER_PRODUCT) {
|
|
603
|
+
HeapResultHandler<HeapForIP, false> handler(
|
|
604
|
+
k, simi, idxi);
|
|
605
|
+
old_scan_cnt = handler.stats.scan_cnt;
|
|
606
|
+
old_heap_updates = handler.stats.nheap_updates;
|
|
607
|
+
scanner->scan_codes(list_size, codes, ids, handler);
|
|
608
|
+
nheap += handler.stats.nheap_updates - old_heap_updates;
|
|
609
|
+
return handler.stats.scan_cnt - old_scan_cnt;
|
|
610
|
+
} else {
|
|
611
|
+
HeapResultHandler<HeapForL2, false> handler(
|
|
612
|
+
k, simi, idxi);
|
|
613
|
+
old_scan_cnt = handler.stats.scan_cnt;
|
|
614
|
+
old_heap_updates = handler.stats.nheap_updates;
|
|
615
|
+
scanner->scan_codes(list_size, codes, ids, handler);
|
|
616
|
+
nheap += handler.stats.nheap_updates - old_heap_updates;
|
|
617
|
+
return handler.stats.scan_cnt - old_scan_cnt;
|
|
618
|
+
}
|
|
580
619
|
}
|
|
581
|
-
}
|
|
582
|
-
std::lock_guard<std::mutex> lock(exception_mutex);
|
|
583
|
-
exception_string =
|
|
584
|
-
demangle_cpp_symbol(typeid(e).name()) + " " + e.what();
|
|
585
|
-
interrupt = true;
|
|
586
|
-
return size_t(0);
|
|
587
|
-
}
|
|
588
|
-
};
|
|
620
|
+
};
|
|
589
621
|
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
622
|
+
/****************************************************
|
|
623
|
+
* Actual loops, depending on parallel_mode
|
|
624
|
+
****************************************************/
|
|
593
625
|
|
|
594
|
-
|
|
626
|
+
if (pmode == 0 || pmode == 3) {
|
|
595
627
|
#pragma omp for
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
}
|
|
600
|
-
|
|
601
|
-
// loop over queries
|
|
602
|
-
scanner->set_query(x + i * d);
|
|
603
|
-
float* simi = distances + i * k;
|
|
604
|
-
idx_t* idxi = labels + i * k;
|
|
605
|
-
|
|
606
|
-
init_result(simi, idxi);
|
|
607
|
-
|
|
608
|
-
idx_t nscan = 0;
|
|
609
|
-
|
|
610
|
-
// loop over probes
|
|
611
|
-
for (size_t ik = 0; ik < nprobe; ik++) {
|
|
612
|
-
nscan += scan_one_list(
|
|
613
|
-
keys[i * nprobe + ik],
|
|
614
|
-
coarse_dis[i * nprobe + ik],
|
|
615
|
-
simi,
|
|
616
|
-
idxi,
|
|
617
|
-
max_codes - nscan);
|
|
618
|
-
if (nscan >= max_codes) {
|
|
619
|
-
break;
|
|
628
|
+
for (idx_t i = 0; i < n; i++) {
|
|
629
|
+
if (interrupt.load(std::memory_order_relaxed)) {
|
|
630
|
+
continue;
|
|
620
631
|
}
|
|
621
|
-
|
|
632
|
+
try {
|
|
633
|
+
// loop over queries
|
|
634
|
+
scanner->set_query(x + i * d);
|
|
635
|
+
float* simi = distances + i * k;
|
|
636
|
+
idx_t* idxi = labels + i * k;
|
|
637
|
+
|
|
638
|
+
init_result(simi, idxi);
|
|
639
|
+
|
|
640
|
+
idx_t nscan = 0;
|
|
641
|
+
|
|
642
|
+
// loop over probes
|
|
643
|
+
for (idx_t ik = 0; ik < cur_nprobe; ik++) {
|
|
644
|
+
// For soft budgets, scan whole lists so
|
|
645
|
+
// IDSelector-filtered rows do not consume the
|
|
646
|
+
// remaining code budget.
|
|
647
|
+
const idx_t list_size_max = ensure_topk_full
|
|
648
|
+
? unlimited_list_size
|
|
649
|
+
: effective_max_codes - nscan;
|
|
650
|
+
nscan += scan_one_list(
|
|
651
|
+
keys[i * cur_nprobe + ik],
|
|
652
|
+
coarse_dis[i * cur_nprobe + ik],
|
|
653
|
+
simi,
|
|
654
|
+
idxi,
|
|
655
|
+
list_size_max);
|
|
656
|
+
|
|
657
|
+
// Early-stop check: apply max_codes after each
|
|
658
|
+
// list. nscan is the number of distances
|
|
659
|
+
// actually computed.
|
|
660
|
+
if (nscan >= effective_max_codes) {
|
|
661
|
+
break;
|
|
662
|
+
}
|
|
663
|
+
}
|
|
622
664
|
|
|
623
|
-
|
|
624
|
-
|
|
665
|
+
ndis += nscan;
|
|
666
|
+
reorder_result(simi, idxi);
|
|
625
667
|
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
668
|
+
InterruptCallback::check();
|
|
669
|
+
} catch (...) {
|
|
670
|
+
omp_capture_exception(ex, [&] { interrupt = true; });
|
|
671
|
+
}
|
|
672
|
+
} // parallel for
|
|
673
|
+
} else if (pmode == 1) {
|
|
674
|
+
std::vector<idx_t> local_idx(k);
|
|
675
|
+
std::vector<float> local_dis(k);
|
|
634
676
|
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
677
|
+
for (idx_t i = 0; i < n; i++) {
|
|
678
|
+
scanner->set_query(x + i * d);
|
|
679
|
+
init_result(local_dis.data(), local_idx.data());
|
|
638
680
|
|
|
639
681
|
#pragma omp for schedule(dynamic)
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
682
|
+
for (idx_t ik = 0; ik < cur_nprobe; ik++) {
|
|
683
|
+
try {
|
|
684
|
+
ndis += scan_one_list(
|
|
685
|
+
keys[i * cur_nprobe + ik],
|
|
686
|
+
coarse_dis[i * cur_nprobe + ik],
|
|
687
|
+
local_dis.data(),
|
|
688
|
+
local_idx.data(),
|
|
689
|
+
unlimited_list_size);
|
|
690
|
+
|
|
691
|
+
// can't do the test on max_codes
|
|
692
|
+
} catch (...) {
|
|
693
|
+
omp_capture_exception(
|
|
694
|
+
ex, [&] { interrupt = true; });
|
|
695
|
+
}
|
|
696
|
+
}
|
|
697
|
+
// merge thread-local results
|
|
651
698
|
|
|
652
|
-
|
|
653
|
-
|
|
699
|
+
float* simi = distances + i * k;
|
|
700
|
+
idx_t* idxi = labels + i * k;
|
|
654
701
|
#pragma omp single
|
|
655
|
-
|
|
702
|
+
init_result(simi, idxi);
|
|
656
703
|
|
|
657
704
|
#pragma omp barrier
|
|
658
705
|
#pragma omp critical
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
706
|
+
{
|
|
707
|
+
add_local_results(
|
|
708
|
+
local_dis.data(), local_idx.data(), simi, idxi);
|
|
709
|
+
}
|
|
663
710
|
#pragma omp barrier
|
|
664
711
|
#pragma omp single
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
712
|
+
reorder_result(simi, idxi);
|
|
713
|
+
}
|
|
714
|
+
} else if (pmode == 2) {
|
|
715
|
+
std::vector<idx_t> local_idx(k);
|
|
716
|
+
std::vector<float> local_dis(k);
|
|
670
717
|
|
|
671
718
|
#pragma omp single
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
719
|
+
for (int64_t i = 0; i < n; i++) {
|
|
720
|
+
init_result(distances + i * k, labels + i * k);
|
|
721
|
+
}
|
|
675
722
|
|
|
676
723
|
#pragma omp for schedule(dynamic)
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
724
|
+
for (int64_t ij = 0; ij < n * cur_nprobe; ij++) {
|
|
725
|
+
try {
|
|
726
|
+
size_t i = ij / cur_nprobe;
|
|
727
|
+
|
|
728
|
+
scanner->set_query(x + i * d);
|
|
729
|
+
init_result(local_dis.data(), local_idx.data());
|
|
730
|
+
ndis += scan_one_list(
|
|
731
|
+
keys[ij],
|
|
732
|
+
coarse_dis[ij],
|
|
733
|
+
local_dis.data(),
|
|
734
|
+
local_idx.data(),
|
|
735
|
+
unlimited_list_size);
|
|
688
736
|
#pragma omp critical
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
737
|
+
{
|
|
738
|
+
add_local_results(
|
|
739
|
+
local_dis.data(),
|
|
740
|
+
local_idx.data(),
|
|
741
|
+
distances + i * k,
|
|
742
|
+
labels + i * k);
|
|
743
|
+
}
|
|
744
|
+
} catch (...) {
|
|
745
|
+
omp_capture_exception(ex, [&] { interrupt = true; });
|
|
746
|
+
}
|
|
695
747
|
}
|
|
696
|
-
}
|
|
697
748
|
#pragma omp single
|
|
698
|
-
|
|
699
|
-
|
|
749
|
+
for (int64_t i = 0; i < n; i++) {
|
|
750
|
+
reorder_result(distances + i * k, labels + i * k);
|
|
751
|
+
}
|
|
752
|
+
} else {
|
|
753
|
+
FAISS_THROW_FMT("parallel_mode %d not supported\n", pmode);
|
|
700
754
|
}
|
|
701
|
-
}
|
|
702
|
-
|
|
755
|
+
} catch (...) {
|
|
756
|
+
omp_capture_exception(ex, [&] { interrupt = true; });
|
|
703
757
|
}
|
|
704
758
|
} // parallel section
|
|
705
759
|
|
|
706
|
-
|
|
707
|
-
if (!exception_string.empty()) {
|
|
708
|
-
FAISS_THROW_FMT(
|
|
709
|
-
"search interrupted with: %s", exception_string.c_str());
|
|
710
|
-
} else {
|
|
711
|
-
FAISS_THROW_MSG("computation interrupted");
|
|
712
|
-
}
|
|
713
|
-
}
|
|
760
|
+
omp_rethrow_if_exception(ex);
|
|
714
761
|
|
|
715
762
|
if (ivf_stats == nullptr) {
|
|
716
763
|
ivf_stats = &indexIVF_stats;
|
|
@@ -727,6 +774,8 @@ void IndexIVF::range_search(
|
|
|
727
774
|
float radius,
|
|
728
775
|
RangeSearchResult* result,
|
|
729
776
|
const SearchParameters* params_in) const {
|
|
777
|
+
FAISS_THROW_IF_NOT_MSG(quantizer, "IVF quantizer must not be null");
|
|
778
|
+
FAISS_THROW_IF_NOT_MSG(is_trained, "IVF index is not trained");
|
|
730
779
|
const IVFSearchParameters* params = nullptr;
|
|
731
780
|
const SearchParameters* quantizer_params = nullptr;
|
|
732
781
|
if (params_in) {
|
|
@@ -734,18 +783,18 @@ void IndexIVF::range_search(
|
|
|
734
783
|
FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
|
|
735
784
|
quantizer_params = params->quantizer_params;
|
|
736
785
|
}
|
|
737
|
-
const size_t
|
|
786
|
+
const size_t cur_nprobe =
|
|
738
787
|
std::min(nlist, params ? params->nprobe : this->nprobe);
|
|
739
|
-
std::unique_ptr<idx_t[]> keys(new idx_t[nx *
|
|
740
|
-
std::unique_ptr<float[]> coarse_dis(new float[nx *
|
|
788
|
+
std::unique_ptr<idx_t[]> keys(new idx_t[nx * cur_nprobe]);
|
|
789
|
+
std::unique_ptr<float[]> coarse_dis(new float[nx * cur_nprobe]);
|
|
741
790
|
|
|
742
791
|
double t0 = getmillisecs();
|
|
743
792
|
quantizer->search(
|
|
744
|
-
nx, x,
|
|
793
|
+
nx, x, cur_nprobe, coarse_dis.get(), keys.get(), quantizer_params);
|
|
745
794
|
indexIVF_stats.quantization_time += getmillisecs() - t0;
|
|
746
795
|
|
|
747
796
|
t0 = getmillisecs();
|
|
748
|
-
invlists->prefetch_lists(keys.get(), nx *
|
|
797
|
+
invlists->prefetch_lists(keys.get(), static_cast<int>(nx * cur_nprobe));
|
|
749
798
|
|
|
750
799
|
range_search_preassigned(
|
|
751
800
|
nx,
|
|
@@ -771,22 +820,29 @@ void IndexIVF::range_search_preassigned(
|
|
|
771
820
|
bool store_pairs,
|
|
772
821
|
const IVFSearchParameters* params,
|
|
773
822
|
IndexIVFStats* stats) const {
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
823
|
+
FAISS_THROW_IF_NOT_MSG(is_trained, "IVF index is not trained");
|
|
824
|
+
idx_t cur_nprobe = params ? params->nprobe : this->nprobe;
|
|
825
|
+
cur_nprobe = std::min((idx_t)nlist, cur_nprobe);
|
|
826
|
+
FAISS_THROW_IF_NOT(cur_nprobe > 0);
|
|
827
|
+
|
|
828
|
+
idx_t cur_max_codes = params ? params->max_codes : this->max_codes;
|
|
829
|
+
// Range-search early-stop budget. 0 disables the empty-bucket stop.
|
|
830
|
+
const size_t max_empty_result_buckets =
|
|
831
|
+
params ? params->max_empty_result_buckets : 0;
|
|
779
832
|
IDSelector* sel = params ? params->sel : nullptr;
|
|
780
833
|
|
|
781
834
|
FAISS_THROW_IF_NOT_MSG(
|
|
782
|
-
!invlists->use_iterator ||
|
|
835
|
+
!invlists->use_iterator ||
|
|
836
|
+
(cur_max_codes == 0 && store_pairs == false),
|
|
783
837
|
"iterable inverted lists don't support max_codes and store_pairs");
|
|
784
838
|
|
|
839
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
840
|
+
max_empty_result_buckets == 0 || parallel_mode == 0,
|
|
841
|
+
"max_empty_result_buckets supported only for parallel_mode = 0");
|
|
842
|
+
|
|
785
843
|
size_t nlistv = 0, ndis = 0;
|
|
786
844
|
|
|
787
|
-
|
|
788
|
-
std::mutex exception_mutex;
|
|
789
|
-
std::string exception_string;
|
|
845
|
+
std::exception_ptr ex;
|
|
790
846
|
|
|
791
847
|
std::vector<RangeSearchPartialResult*> all_pres(omp_get_max_threads());
|
|
792
848
|
|
|
@@ -795,122 +851,142 @@ void IndexIVF::range_search_preassigned(
|
|
|
795
851
|
[[maybe_unused]] bool do_parallel = omp_get_max_threads() >= 2 &&
|
|
796
852
|
(pmode == 3 ? false
|
|
797
853
|
: pmode == 0 ? nx > 1
|
|
798
|
-
: pmode == 1 ?
|
|
799
|
-
:
|
|
854
|
+
: pmode == 1 ? cur_nprobe > 1
|
|
855
|
+
: cur_nprobe * nx > 1);
|
|
800
856
|
|
|
801
857
|
void* inverted_list_context =
|
|
802
858
|
params ? params->inverted_list_context : nullptr;
|
|
803
859
|
|
|
804
860
|
#pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis)
|
|
805
861
|
{
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
key,
|
|
823
|
-
ik,
|
|
824
|
-
nlist);
|
|
825
|
-
|
|
826
|
-
if (invlists->is_empty(key, inverted_list_context)) {
|
|
827
|
-
return;
|
|
828
|
-
}
|
|
862
|
+
try {
|
|
863
|
+
RangeSearchPartialResult pres(result);
|
|
864
|
+
std::unique_ptr<InvertedListScanner> scanner(
|
|
865
|
+
get_InvertedListScanner(store_pairs, sel, params));
|
|
866
|
+
FAISS_THROW_IF_NOT(scanner.get());
|
|
867
|
+
all_pres[omp_get_thread_num()] = &pres;
|
|
868
|
+
|
|
869
|
+
// prepare the list scanning function
|
|
870
|
+
|
|
871
|
+
auto scan_list_func = [&](size_t i,
|
|
872
|
+
size_t ik,
|
|
873
|
+
RangeQueryResult& qres) {
|
|
874
|
+
idx_t key = keys[i * cur_nprobe + ik]; /* select the list */
|
|
875
|
+
if (key < 0) {
|
|
876
|
+
return;
|
|
877
|
+
}
|
|
829
878
|
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
879
|
+
FAISS_THROW_IF_NOT_FMT(
|
|
880
|
+
key < (idx_t)nlist,
|
|
881
|
+
"Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
|
|
882
|
+
key,
|
|
883
|
+
ik,
|
|
884
|
+
nlist);
|
|
885
|
+
|
|
886
|
+
if (invlists->is_empty(key, inverted_list_context)) {
|
|
887
|
+
return;
|
|
888
|
+
}
|
|
889
|
+
|
|
890
|
+
scanner->set_list(key, coarse_dis[i * cur_nprobe + ik]);
|
|
891
|
+
const size_t scan_cnt0 = qres.stats.scan_cnt;
|
|
833
892
|
if (invlists->use_iterator) {
|
|
893
|
+
size_t list_size = 0;
|
|
834
894
|
std::unique_ptr<InvertedListsIterator> it(
|
|
835
895
|
invlists->get_iterator(key, inverted_list_context));
|
|
836
896
|
|
|
837
897
|
scanner->iterate_codes_range(
|
|
838
898
|
it.get(), radius, qres, list_size);
|
|
899
|
+
qres.stats.scan_cnt += list_size;
|
|
839
900
|
} else {
|
|
840
901
|
InvertedLists::ScopedCodes scodes(invlists, key);
|
|
841
902
|
InvertedLists::ScopedIds ids(invlists, key);
|
|
842
|
-
list_size = invlists->list_size(key);
|
|
903
|
+
size_t list_size = invlists->list_size(key);
|
|
843
904
|
|
|
844
905
|
scanner->scan_codes_range(
|
|
845
906
|
list_size, scodes.get(), ids.get(), radius, qres);
|
|
846
907
|
}
|
|
847
908
|
nlistv++;
|
|
848
|
-
ndis +=
|
|
849
|
-
}
|
|
850
|
-
std::lock_guard<std::mutex> lock(exception_mutex);
|
|
851
|
-
exception_string =
|
|
852
|
-
demangle_cpp_symbol(typeid(e).name()) + " " + e.what();
|
|
853
|
-
interrupt = true;
|
|
854
|
-
}
|
|
855
|
-
};
|
|
909
|
+
ndis += qres.stats.scan_cnt - scan_cnt0;
|
|
910
|
+
};
|
|
856
911
|
|
|
857
|
-
|
|
912
|
+
if (parallel_mode == 0) {
|
|
858
913
|
#pragma omp for
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
914
|
+
for (idx_t i = 0; i < nx; i++) {
|
|
915
|
+
try {
|
|
916
|
+
scanner->set_query(x + i * d);
|
|
917
|
+
RangeQueryResult& qres = pres.new_result(i);
|
|
918
|
+
|
|
919
|
+
// Stop after enough consecutive probes add no range
|
|
920
|
+
// results. A hit resets the counter.
|
|
921
|
+
size_t prev_nres = qres.nres;
|
|
922
|
+
size_t ndup = 0;
|
|
923
|
+
for (idx_t ik = 0; ik < cur_nprobe; ik++) {
|
|
924
|
+
scan_list_func(i, ik, qres);
|
|
925
|
+
if (max_empty_result_buckets > 0) {
|
|
926
|
+
// Early-stop check: stop range search after
|
|
927
|
+
// enough consecutive empty probes.
|
|
928
|
+
ndup = (qres.nres == prev_nres) ? ndup + 1 : 0;
|
|
929
|
+
if (ndup >= max_empty_result_buckets) {
|
|
930
|
+
break;
|
|
931
|
+
}
|
|
932
|
+
prev_nres = qres.nres;
|
|
933
|
+
}
|
|
934
|
+
}
|
|
935
|
+
} catch (...) {
|
|
936
|
+
omp_capture_exception(ex);
|
|
937
|
+
}
|
|
866
938
|
}
|
|
867
|
-
}
|
|
868
939
|
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
940
|
+
} else if (parallel_mode == 1) {
|
|
941
|
+
for (idx_t i = 0; i < nx; i++) {
|
|
942
|
+
scanner->set_query(x + i * d);
|
|
872
943
|
|
|
873
|
-
|
|
944
|
+
RangeQueryResult& qres = pres.new_result(i);
|
|
874
945
|
|
|
875
946
|
#pragma omp for schedule(dynamic)
|
|
876
|
-
|
|
877
|
-
|
|
947
|
+
for (int64_t ik = 0; ik < cur_nprobe; ik++) {
|
|
948
|
+
try {
|
|
949
|
+
scan_list_func(i, ik, qres);
|
|
950
|
+
} catch (...) {
|
|
951
|
+
omp_capture_exception(ex);
|
|
952
|
+
}
|
|
953
|
+
}
|
|
878
954
|
}
|
|
879
|
-
}
|
|
880
|
-
|
|
881
|
-
RangeQueryResult* qres = nullptr;
|
|
955
|
+
} else if (parallel_mode == 2) {
|
|
956
|
+
RangeQueryResult* qres = nullptr;
|
|
882
957
|
|
|
883
958
|
#pragma omp for schedule(dynamic)
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
959
|
+
for (idx_t iik = 0; iik < nx * (idx_t)cur_nprobe; iik++) {
|
|
960
|
+
try {
|
|
961
|
+
idx_t i = iik / (idx_t)cur_nprobe;
|
|
962
|
+
idx_t ik = iik % (idx_t)cur_nprobe;
|
|
963
|
+
if (qres == nullptr || qres->qno != i) {
|
|
964
|
+
qres = &pres.new_result(i);
|
|
965
|
+
scanner->set_query(x + i * d);
|
|
966
|
+
}
|
|
967
|
+
scan_list_func(i, ik, *qres);
|
|
968
|
+
} catch (...) {
|
|
969
|
+
omp_capture_exception(ex);
|
|
970
|
+
}
|
|
890
971
|
}
|
|
891
|
-
|
|
972
|
+
} else {
|
|
973
|
+
FAISS_THROW_FMT(
|
|
974
|
+
"parallel_mode %d not supported\n", parallel_mode);
|
|
892
975
|
}
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
if (parallel_mode == 0) {
|
|
897
|
-
pres.finalize();
|
|
898
|
-
} else {
|
|
976
|
+
if (parallel_mode == 0) {
|
|
977
|
+
pres.finalize();
|
|
978
|
+
} else {
|
|
899
979
|
#pragma omp barrier
|
|
900
980
|
#pragma omp single
|
|
901
|
-
|
|
981
|
+
RangeSearchPartialResult::merge(all_pres, false);
|
|
902
982
|
#pragma omp barrier
|
|
983
|
+
}
|
|
984
|
+
} catch (...) {
|
|
985
|
+
omp_capture_exception(ex);
|
|
903
986
|
}
|
|
904
987
|
}
|
|
905
988
|
|
|
906
|
-
|
|
907
|
-
if (!exception_string.empty()) {
|
|
908
|
-
FAISS_THROW_FMT(
|
|
909
|
-
"search interrupted with: %s", exception_string.c_str());
|
|
910
|
-
} else {
|
|
911
|
-
FAISS_THROW_MSG("computation interrupted");
|
|
912
|
-
}
|
|
913
|
-
}
|
|
989
|
+
omp_rethrow_if_exception(ex);
|
|
914
990
|
|
|
915
991
|
if (stats == nullptr) {
|
|
916
992
|
stats = &indexIVF_stats;
|
|
@@ -920,6 +996,57 @@ void IndexIVF::range_search_preassigned(
|
|
|
920
996
|
stats->ndis += ndis;
|
|
921
997
|
}
|
|
922
998
|
|
|
999
|
+
void IndexIVF::search1(
|
|
1000
|
+
const float* x,
|
|
1001
|
+
ResultHandler& handler,
|
|
1002
|
+
SearchParameters* params_in) const {
|
|
1003
|
+
const IVFSearchParameters* params = nullptr;
|
|
1004
|
+
const SearchParameters* quantizer_params = nullptr;
|
|
1005
|
+
if (params_in) {
|
|
1006
|
+
params = dynamic_cast<const IVFSearchParameters*>(params_in);
|
|
1007
|
+
FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
|
|
1008
|
+
quantizer_params = params->quantizer_params;
|
|
1009
|
+
}
|
|
1010
|
+
const size_t cur_nprobe =
|
|
1011
|
+
std::min(nlist, params ? params->nprobe : this->nprobe);
|
|
1012
|
+
size_t nx = 1;
|
|
1013
|
+
std::unique_ptr<idx_t[]> keys(new idx_t[nx * cur_nprobe]);
|
|
1014
|
+
std::unique_ptr<float[]> coarse_dis(new float[nx * cur_nprobe]);
|
|
1015
|
+
|
|
1016
|
+
double t0 = getmillisecs();
|
|
1017
|
+
quantizer->search(
|
|
1018
|
+
nx, x, cur_nprobe, coarse_dis.get(), keys.get(), quantizer_params);
|
|
1019
|
+
indexIVF_stats.quantization_time += getmillisecs() - t0;
|
|
1020
|
+
|
|
1021
|
+
t0 = getmillisecs();
|
|
1022
|
+
invlists->prefetch_lists(keys.get(), static_cast<int>(nx * cur_nprobe));
|
|
1023
|
+
|
|
1024
|
+
std::unique_ptr<InvertedListScanner> scanner(
|
|
1025
|
+
get_InvertedListScanner(false, nullptr, params));
|
|
1026
|
+
scanner->set_query(x);
|
|
1027
|
+
|
|
1028
|
+
for (size_t i = 0; i < cur_nprobe; i++) {
|
|
1029
|
+
idx_t key = keys[i];
|
|
1030
|
+
FAISS_THROW_IF_NOT_FMT(
|
|
1031
|
+
key < (idx_t)nlist,
|
|
1032
|
+
"Invalid key=%" PRId64 " nlist=%zd\n",
|
|
1033
|
+
key,
|
|
1034
|
+
nlist);
|
|
1035
|
+
if (key < 0 || invlists->is_empty(key)) {
|
|
1036
|
+
continue;
|
|
1037
|
+
}
|
|
1038
|
+
|
|
1039
|
+
scanner->set_list(key, coarse_dis[i]);
|
|
1040
|
+
InvertedLists::ScopedCodes scodes(invlists, key);
|
|
1041
|
+
InvertedLists::ScopedIds ids(invlists, key);
|
|
1042
|
+
size_t list_size = invlists->list_size(key);
|
|
1043
|
+
|
|
1044
|
+
scanner->scan_codes(list_size, scodes.get(), ids.get(), handler);
|
|
1045
|
+
}
|
|
1046
|
+
|
|
1047
|
+
indexIVF_stats.search_time += getmillisecs() - t0;
|
|
1048
|
+
}
|
|
1049
|
+
|
|
923
1050
|
InvertedListScanner* IndexIVF::get_InvertedListScanner(
|
|
924
1051
|
bool /*store_pairs*/,
|
|
925
1052
|
const IDSelector* /* sel */,
|
|
@@ -935,11 +1062,11 @@ void IndexIVF::reconstruct(idx_t key, float* recons) const {
|
|
|
935
1062
|
void IndexIVF::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
|
|
936
1063
|
FAISS_THROW_IF_NOT(ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
|
|
937
1064
|
|
|
938
|
-
for (
|
|
1065
|
+
for (size_t list_no = 0; list_no < nlist; list_no++) {
|
|
939
1066
|
size_t list_size = invlists->list_size(list_no);
|
|
940
1067
|
ScopedIds idlist(invlists, list_no);
|
|
941
1068
|
|
|
942
|
-
for (
|
|
1069
|
+
for (size_t offset = 0; offset < list_size; offset++) {
|
|
943
1070
|
idx_t id = idlist[offset];
|
|
944
1071
|
if (!(id >= i0 && id < i0 + ni)) {
|
|
945
1072
|
continue;
|
|
@@ -1000,16 +1127,16 @@ void IndexIVF::search_and_reconstruct(
|
|
|
1000
1127
|
params = dynamic_cast<const IVFSearchParameters*>(params_in);
|
|
1001
1128
|
FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
|
|
1002
1129
|
}
|
|
1003
|
-
const size_t
|
|
1130
|
+
const size_t cur_nprobe =
|
|
1004
1131
|
std::min(nlist, params ? params->nprobe : this->nprobe);
|
|
1005
|
-
FAISS_THROW_IF_NOT(
|
|
1132
|
+
FAISS_THROW_IF_NOT(cur_nprobe > 0);
|
|
1006
1133
|
|
|
1007
|
-
std::unique_ptr<idx_t[]> idx(new idx_t[n *
|
|
1008
|
-
std::unique_ptr<float[]> coarse_dis(new float[n *
|
|
1134
|
+
std::unique_ptr<idx_t[]> idx(new idx_t[n * cur_nprobe]);
|
|
1135
|
+
std::unique_ptr<float[]> coarse_dis(new float[n * cur_nprobe]);
|
|
1009
1136
|
|
|
1010
|
-
quantizer->search(n, x,
|
|
1137
|
+
quantizer->search(n, x, cur_nprobe, coarse_dis.get(), idx.get());
|
|
1011
1138
|
|
|
1012
|
-
invlists->prefetch_lists(idx.get(), n *
|
|
1139
|
+
invlists->prefetch_lists(idx.get(), static_cast<int>(n * cur_nprobe));
|
|
1013
1140
|
|
|
1014
1141
|
// search_preassigned() with `store_pairs` enabled to obtain the list_no
|
|
1015
1142
|
// and offset into `codes` for reconstruction
|
|
@@ -1031,8 +1158,8 @@ void IndexIVF::search_and_reconstruct(
|
|
|
1031
1158
|
// Fill with NaNs
|
|
1032
1159
|
memset(reconstructed, -1, sizeof(*reconstructed) * d);
|
|
1033
1160
|
} else {
|
|
1034
|
-
|
|
1035
|
-
|
|
1161
|
+
size_t list_no = lo_listno(key);
|
|
1162
|
+
size_t offset = lo_offset(key);
|
|
1036
1163
|
|
|
1037
1164
|
// Update label to the actual id
|
|
1038
1165
|
labels[ij] = invlists->get_single_id(list_no, offset);
|
|
@@ -1056,16 +1183,16 @@ void IndexIVF::search_and_return_codes(
|
|
|
1056
1183
|
params = dynamic_cast<const IVFSearchParameters*>(params_in);
|
|
1057
1184
|
FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
|
|
1058
1185
|
}
|
|
1059
|
-
const size_t
|
|
1186
|
+
const size_t cur_nprobe =
|
|
1060
1187
|
std::min(nlist, params ? params->nprobe : this->nprobe);
|
|
1061
|
-
FAISS_THROW_IF_NOT(
|
|
1188
|
+
FAISS_THROW_IF_NOT(cur_nprobe > 0);
|
|
1062
1189
|
|
|
1063
|
-
std::unique_ptr<idx_t[]> idx(new idx_t[n *
|
|
1064
|
-
std::unique_ptr<float[]> coarse_dis(new float[n *
|
|
1190
|
+
std::unique_ptr<idx_t[]> idx(new idx_t[n * cur_nprobe]);
|
|
1191
|
+
std::unique_ptr<float[]> coarse_dis(new float[n * cur_nprobe]);
|
|
1065
1192
|
|
|
1066
|
-
quantizer->search(n, x,
|
|
1193
|
+
quantizer->search(n, x, cur_nprobe, coarse_dis.get(), idx.get());
|
|
1067
1194
|
|
|
1068
|
-
invlists->prefetch_lists(idx.get(), n *
|
|
1195
|
+
invlists->prefetch_lists(idx.get(), static_cast<int>(n * cur_nprobe));
|
|
1069
1196
|
|
|
1070
1197
|
// search_preassigned() with `store_pairs` enabled to obtain the list_no
|
|
1071
1198
|
// and offset into `codes` for reconstruction
|
|
@@ -1094,8 +1221,8 @@ void IndexIVF::search_and_return_codes(
|
|
|
1094
1221
|
// Fill with 0xff
|
|
1095
1222
|
memset(code1, -1, code_size_1);
|
|
1096
1223
|
} else {
|
|
1097
|
-
|
|
1098
|
-
|
|
1224
|
+
size_t list_no = lo_listno(key);
|
|
1225
|
+
size_t offset = lo_offset(key);
|
|
1099
1226
|
const uint8_t* cc = invlists->get_single_code(list_no, offset);
|
|
1100
1227
|
|
|
1101
1228
|
labels[ij] = invlists->get_single_id(list_no, offset);
|
|
@@ -1134,7 +1261,8 @@ void IndexIVF::update_vectors(int n, const idx_t* new_ids, const float* x) {
|
|
|
1134
1261
|
IDSelectorArray sel(n, new_ids);
|
|
1135
1262
|
size_t nremove = remove_ids(sel);
|
|
1136
1263
|
FAISS_THROW_IF_NOT_MSG(
|
|
1137
|
-
nremove == n,
|
|
1264
|
+
nremove == static_cast<size_t>(n),
|
|
1265
|
+
"did not find all entries to remove");
|
|
1138
1266
|
add_with_ids(n, x, new_ids);
|
|
1139
1267
|
return;
|
|
1140
1268
|
}
|
|
@@ -1196,7 +1324,7 @@ idx_t IndexIVF::train_encoder_num_vectors() const {
|
|
|
1196
1324
|
void IndexIVF::train_encoder(
|
|
1197
1325
|
idx_t /*n*/,
|
|
1198
1326
|
const float* /*x*/,
|
|
1199
|
-
const idx_t* assign) {
|
|
1327
|
+
const idx_t* /*assign*/) {
|
|
1200
1328
|
// does nothing by default
|
|
1201
1329
|
if (verbose) {
|
|
1202
1330
|
printf("IndexIVF: no residual training\n");
|
|
@@ -1298,6 +1426,20 @@ IndexIVFStats indexIVF_stats;
|
|
|
1298
1426
|
* InvertedListScanner
|
|
1299
1427
|
*************************************************************************/
|
|
1300
1428
|
|
|
1429
|
+
// this gets expanded in expanded_scanners
|
|
1430
|
+
|
|
1431
|
+
size_t InvertedListScanner::scan_codes(
|
|
1432
|
+
size_t list_size,
|
|
1433
|
+
const uint8_t* codes,
|
|
1434
|
+
const idx_t* ids,
|
|
1435
|
+
ResultHandler& handler) const {
|
|
1436
|
+
return run_scan_codes(*this, list_size, codes, ids, handler);
|
|
1437
|
+
}
|
|
1438
|
+
|
|
1439
|
+
void InvertedListScanner::set_list(idx_t list_no_in, float /* coarse_dis */) {
|
|
1440
|
+
this->list_no = list_no_in;
|
|
1441
|
+
}
|
|
1442
|
+
|
|
1301
1443
|
size_t InvertedListScanner::scan_codes(
|
|
1302
1444
|
size_t list_size,
|
|
1303
1445
|
const uint8_t* codes,
|
|
@@ -1305,46 +1447,15 @@ size_t InvertedListScanner::scan_codes(
|
|
|
1305
1447
|
float* simi,
|
|
1306
1448
|
idx_t* idxi,
|
|
1307
1449
|
size_t k) const {
|
|
1308
|
-
size_t nup = 0;
|
|
1309
|
-
|
|
1310
1450
|
if (!keep_max) {
|
|
1311
|
-
|
|
1312
|
-
|
|
1313
|
-
|
|
1314
|
-
if (!sel->is_member(id)) {
|
|
1315
|
-
codes += code_size;
|
|
1316
|
-
continue;
|
|
1317
|
-
}
|
|
1318
|
-
}
|
|
1319
|
-
|
|
1320
|
-
float dis = distance_to_code(codes);
|
|
1321
|
-
if (dis < simi[0]) {
|
|
1322
|
-
int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
|
|
1323
|
-
maxheap_replace_top(k, simi, idxi, dis, id);
|
|
1324
|
-
nup++;
|
|
1325
|
-
}
|
|
1326
|
-
codes += code_size;
|
|
1327
|
-
}
|
|
1451
|
+
using C = CMax<float, idx_t>;
|
|
1452
|
+
HeapResultHandler<C, false> handler(k, simi, idxi);
|
|
1453
|
+
return scan_codes(list_size, codes, ids, handler);
|
|
1328
1454
|
} else {
|
|
1329
|
-
|
|
1330
|
-
|
|
1331
|
-
|
|
1332
|
-
if (!sel->is_member(id)) {
|
|
1333
|
-
codes += code_size;
|
|
1334
|
-
continue;
|
|
1335
|
-
}
|
|
1336
|
-
}
|
|
1337
|
-
|
|
1338
|
-
float dis = distance_to_code(codes);
|
|
1339
|
-
if (dis > simi[0]) {
|
|
1340
|
-
int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
|
|
1341
|
-
minheap_replace_top(k, simi, idxi, dis, id);
|
|
1342
|
-
nup++;
|
|
1343
|
-
}
|
|
1344
|
-
codes += code_size;
|
|
1345
|
-
}
|
|
1455
|
+
using C = CMin<float, idx_t>;
|
|
1456
|
+
HeapResultHandler<C, false> handler(k, simi, idxi);
|
|
1457
|
+
return scan_codes(list_size, codes, ids, handler);
|
|
1346
1458
|
}
|
|
1347
|
-
return nup;
|
|
1348
1459
|
}
|
|
1349
1460
|
|
|
1350
1461
|
size_t InvertedListScanner::iterate_codes(
|
|
@@ -1356,11 +1467,19 @@ size_t InvertedListScanner::iterate_codes(
|
|
|
1356
1467
|
size_t nup = 0;
|
|
1357
1468
|
list_size = 0;
|
|
1358
1469
|
|
|
1470
|
+
const bool has_cb = it->has_search_callbacks_;
|
|
1471
|
+
|
|
1359
1472
|
if (!keep_max) {
|
|
1360
1473
|
for (; it->is_available(); it->next()) {
|
|
1361
1474
|
auto id_and_codes = it->get_id_and_codes();
|
|
1362
1475
|
float dis = distance_to_code(id_and_codes.second);
|
|
1476
|
+
if (has_cb) {
|
|
1477
|
+
it->on_distance_computed(id_and_codes.first, dis);
|
|
1478
|
+
}
|
|
1363
1479
|
if (dis < simi[0]) {
|
|
1480
|
+
if (has_cb) {
|
|
1481
|
+
it->on_heap_changed(id_and_codes.first, idxi[0]);
|
|
1482
|
+
}
|
|
1364
1483
|
maxheap_replace_top(k, simi, idxi, dis, id_and_codes.first);
|
|
1365
1484
|
nup++;
|
|
1366
1485
|
}
|
|
@@ -1370,7 +1489,13 @@ size_t InvertedListScanner::iterate_codes(
|
|
|
1370
1489
|
for (; it->is_available(); it->next()) {
|
|
1371
1490
|
auto id_and_codes = it->get_id_and_codes();
|
|
1372
1491
|
float dis = distance_to_code(id_and_codes.second);
|
|
1492
|
+
if (has_cb) {
|
|
1493
|
+
it->on_distance_computed(id_and_codes.first, dis);
|
|
1494
|
+
}
|
|
1373
1495
|
if (dis > simi[0]) {
|
|
1496
|
+
if (has_cb) {
|
|
1497
|
+
it->on_heap_changed(id_and_codes.first, idxi[0]);
|
|
1498
|
+
}
|
|
1374
1499
|
minheap_replace_top(k, simi, idxi, dis, id_and_codes.first);
|
|
1375
1500
|
nup++;
|
|
1376
1501
|
}
|
|
@@ -1386,16 +1511,18 @@ void InvertedListScanner::scan_codes_range(
|
|
|
1386
1511
|
const idx_t* ids,
|
|
1387
1512
|
float radius,
|
|
1388
1513
|
RangeQueryResult& res) const {
|
|
1389
|
-
|
|
1390
|
-
|
|
1391
|
-
|
|
1392
|
-
|
|
1393
|
-
|
|
1394
|
-
|
|
1395
|
-
|
|
1396
|
-
|
|
1397
|
-
|
|
1398
|
-
codes
|
|
1514
|
+
if (!keep_max) {
|
|
1515
|
+
using C = CMax<float, idx_t>;
|
|
1516
|
+
RangeResultHandler<C, false> handler(&res, radius);
|
|
1517
|
+
scan_codes(list_size, codes, ids, handler);
|
|
1518
|
+
res.stats.scan_cnt += handler.stats.scan_cnt;
|
|
1519
|
+
res.stats.nheap_updates += handler.stats.nheap_updates;
|
|
1520
|
+
} else {
|
|
1521
|
+
using C = CMin<float, idx_t>;
|
|
1522
|
+
RangeResultHandler<C, false> handler(&res, radius);
|
|
1523
|
+
scan_codes(list_size, codes, ids, handler);
|
|
1524
|
+
res.stats.scan_cnt += handler.stats.scan_cnt;
|
|
1525
|
+
res.stats.nheap_updates += handler.stats.nheap_updates;
|
|
1399
1526
|
}
|
|
1400
1527
|
}
|
|
1401
1528
|
|
|
@@ -1404,6 +1531,7 @@ void InvertedListScanner::iterate_codes_range(
|
|
|
1404
1531
|
float radius,
|
|
1405
1532
|
RangeQueryResult& res,
|
|
1406
1533
|
size_t& list_size) const {
|
|
1534
|
+
size_t nup = 0;
|
|
1407
1535
|
list_size = 0;
|
|
1408
1536
|
for (; it->is_available(); it->next()) {
|
|
1409
1537
|
auto id_and_codes = it->get_id_and_codes();
|
|
@@ -1413,9 +1541,11 @@ void InvertedListScanner::iterate_codes_range(
|
|
|
1413
1541
|
: dis > radius; // TODO templatize to remove this test
|
|
1414
1542
|
if (keep) {
|
|
1415
1543
|
res.add(dis, id_and_codes.first);
|
|
1544
|
+
nup++;
|
|
1416
1545
|
}
|
|
1417
1546
|
list_size++;
|
|
1418
1547
|
}
|
|
1548
|
+
res.stats.nheap_updates += nup;
|
|
1419
1549
|
}
|
|
1420
1550
|
|
|
1421
1551
|
} // namespace faiss
|