faiss 0.6.0 → 0.6.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/ext/faiss/extconf.rb +2 -1
- data/ext/faiss/{index_rb.cpp → index.cpp} +1 -1
- data/ext/faiss/index_binary.cpp +1 -1
- data/ext/faiss/kmeans.cpp +1 -1
- data/ext/faiss/pca_matrix.cpp +1 -1
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/ext/faiss/{utils_rb.cpp → utils.cpp} +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +93 -80
- data/vendor/faiss/faiss/Clustering.cpp +39 -240
- data/vendor/faiss/faiss/Clustering.h +6 -0
- data/vendor/faiss/faiss/IVFlib.cpp +41 -21
- data/vendor/faiss/faiss/Index.cpp +6 -5
- data/vendor/faiss/faiss/Index.h +5 -5
- data/vendor/faiss/faiss/Index2Layer.cpp +37 -53
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +49 -37
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +36 -34
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +4 -1
- data/vendor/faiss/faiss/IndexBinary.cpp +5 -3
- data/vendor/faiss/faiss/IndexBinary.h +4 -4
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +1 -1
- data/vendor/faiss/faiss/IndexBinaryFlat.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -4
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +84 -92
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +9 -3
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +45 -236
- data/vendor/faiss/faiss/IndexBinaryHash.h +6 -6
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +87 -415
- data/vendor/faiss/faiss/IndexFastScan.cpp +72 -109
- data/vendor/faiss/faiss/IndexFastScan.h +25 -23
- data/vendor/faiss/faiss/IndexFlat.cpp +27 -20
- data/vendor/faiss/faiss/IndexFlat.h +21 -18
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +42 -19
- data/vendor/faiss/faiss/IndexHNSW.cpp +283 -145
- data/vendor/faiss/faiss/IndexHNSW.h +16 -2
- data/vendor/faiss/faiss/IndexIDMap.cpp +25 -21
- data/vendor/faiss/faiss/IndexIDMap.h +9 -7
- data/vendor/faiss/faiss/IndexIVF.cpp +465 -362
- data/vendor/faiss/faiss/IndexIVF.h +33 -12
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +77 -74
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +96 -93
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -1
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +357 -238
- data/vendor/faiss/faiss/IndexIVFFastScan.h +42 -41
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +36 -68
- data/vendor/faiss/faiss/IndexIVFFlat.h +32 -0
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +53 -30
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +3 -1
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +18 -15
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +71 -843
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +151 -121
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +3 -0
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +21 -17
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +26 -39
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +475 -476
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +248 -93
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +41 -127
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
- data/vendor/faiss/faiss/IndexLSH.cpp +36 -19
- data/vendor/faiss/faiss/IndexLattice.cpp +13 -13
- data/vendor/faiss/faiss/IndexNNDescent.cpp +36 -21
- data/vendor/faiss/faiss/IndexNNDescent.h +2 -2
- data/vendor/faiss/faiss/IndexNSG.cpp +39 -23
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +31 -11
- data/vendor/faiss/faiss/IndexPQ.cpp +128 -221
- data/vendor/faiss/faiss/IndexPQ.h +3 -2
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +20 -14
- data/vendor/faiss/faiss/IndexPQFastScan.h +3 -0
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -18
- data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +11 -36
- data/vendor/faiss/faiss/IndexRaBitQ.h +2 -1
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +41 -277
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +183 -27
- data/vendor/faiss/faiss/IndexRefine.cpp +30 -25
- data/vendor/faiss/faiss/IndexRefine.h +4 -4
- data/vendor/faiss/faiss/IndexReplicas.cpp +6 -6
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +15 -14
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +1 -1
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +82 -14
- data/vendor/faiss/faiss/IndexShards.cpp +10 -9
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +21 -15
- data/vendor/faiss/faiss/MatrixStats.cpp +5 -4
- data/vendor/faiss/faiss/MetaIndexes.cpp +19 -17
- data/vendor/faiss/faiss/MetaIndexes.h +1 -1
- data/vendor/faiss/faiss/MetricType.h +14 -7
- data/vendor/faiss/faiss/SuperKMeans.cpp +656 -0
- data/vendor/faiss/faiss/SuperKMeans.h +97 -0
- data/vendor/faiss/faiss/VectorTransform.cpp +237 -149
- data/vendor/faiss/faiss/VectorTransform.h +16 -16
- data/vendor/faiss/faiss/build.cpp +23 -0
- data/vendor/faiss/faiss/build.h +15 -0
- data/vendor/faiss/faiss/clone_index.cpp +48 -47
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +47 -47
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +11 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +38 -38
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +11 -0
- data/vendor/faiss/faiss/factory_tools.cpp +5 -0
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +6 -5
- data/vendor/faiss/faiss/gpu/GpuResources.h +1 -1
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +9 -9
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +4 -3
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +46 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +56 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +78 -1
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +72 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +23 -0
- data/vendor/faiss/faiss/gpu/utils/CuvsFilterConvert.h +1 -1
- data/vendor/faiss/faiss/gpu/utils/CuvsUtils.h +21 -10
- data/vendor/faiss/faiss/gpu_metal/GpuIndexFlat.h +22 -0
- data/vendor/faiss/faiss/gpu_metal/MetalCloner.h +35 -0
- data/vendor/faiss/faiss/gpu_metal/MetalFlatKernels.h +40 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +51 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndexFlat.h +65 -0
- data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +66 -0
- data/vendor/faiss/faiss/gpu_metal/MetalResources.h +79 -0
- data/vendor/faiss/faiss/gpu_metal/StandardMetalResources.h +35 -0
- data/vendor/faiss/faiss/impl/AdSampling.cpp +103 -0
- data/vendor/faiss/faiss/impl/AdSampling.h +35 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +29 -25
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +10 -9
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +3 -0
- data/vendor/faiss/faiss/impl/ClusteringHelpers.cpp +244 -0
- data/vendor/faiss/faiss/impl/ClusteringHelpers.h +94 -0
- data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +16 -16
- data/vendor/faiss/faiss/impl/CodePacker.cpp +3 -3
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +1 -1
- data/vendor/faiss/faiss/impl/DistanceComputer.h +8 -8
- data/vendor/faiss/faiss/impl/FaissAssert.h +6 -3
- data/vendor/faiss/faiss/impl/FaissException.h +50 -3
- data/vendor/faiss/faiss/impl/HNSW.cpp +92 -317
- data/vendor/faiss/faiss/impl/HNSW.h +13 -34
- data/vendor/faiss/faiss/impl/IDSelector.cpp +15 -11
- data/vendor/faiss/faiss/impl/IDSelector.h +8 -8
- data/vendor/faiss/faiss/impl/InvertedListScannerStats.h +26 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +82 -77
- data/vendor/faiss/faiss/impl/NNDescent.cpp +62 -25
- data/vendor/faiss/faiss/impl/NNDescent.h +6 -2
- data/vendor/faiss/faiss/impl/NSG.cpp +38 -21
- data/vendor/faiss/faiss/impl/NSG.h +4 -4
- data/vendor/faiss/faiss/impl/Panorama.cpp +23 -6
- data/vendor/faiss/faiss/impl/Panorama.h +258 -87
- data/vendor/faiss/faiss/impl/PdxLayout.cpp +93 -0
- data/vendor/faiss/faiss/impl/PdxLayout.h +41 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +46 -32
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +3 -3
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +35 -35
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +21 -16
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +30 -23
- data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +55 -49
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +65 -0
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +296 -283
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +26 -23
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResultHandler.h +99 -75
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +52 -4
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +27 -1
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +14 -11
- data/vendor/faiss/faiss/impl/VisitedTable.h +7 -0
- data/vendor/faiss/faiss/impl/approx_topk/approx_topk.h +276 -0
- data/vendor/faiss/faiss/impl/approx_topk/avx2.cpp +68 -0
- data/vendor/faiss/faiss/{utils → impl}/approx_topk/generic.h +15 -8
- data/vendor/faiss/faiss/impl/approx_topk/neon.cpp +68 -0
- data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab-inl.h +169 -0
- data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab.h +117 -0
- data/vendor/faiss/faiss/impl/approx_topk/simdlib256-inl.h +146 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHNSW_impl.h +73 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHash_impl.h +270 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryIVF_impl.h +460 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexIVFSpectralHash_impl.h +159 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexPQ_impl.h +92 -0
- data/vendor/faiss/faiss/impl/binary_hamming/avx2.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/avx512.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/dispatch.h +143 -0
- data/vendor/faiss/faiss/impl/binary_hamming/neon.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/rvv.cpp +26 -0
- data/vendor/faiss/faiss/impl/expanded_scanners.h +8 -3
- data/vendor/faiss/faiss/impl/{FastScanDistancePostProcessing.h → fast_scan/FastScanDistancePostProcessing.h} +13 -6
- data/vendor/faiss/faiss/impl/{LookupTableScaler.h → fast_scan/LookupTableScaler.h} +16 -5
- data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops.h +237 -0
- data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops_512.h +185 -0
- data/vendor/faiss/faiss/impl/fast_scan/decompose_qbs.h +229 -0
- data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +268 -0
- data/vendor/faiss/faiss/impl/{pq4_fast_scan.cpp → fast_scan/fast_scan.cpp} +169 -2
- data/vendor/faiss/faiss/impl/fast_scan/fast_scan.h +341 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-avx2.cpp +36 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-avx512.cpp +40 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-neon.cpp +120 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-riscv.cpp +104 -0
- data/vendor/faiss/faiss/impl/fast_scan/kernels_simd256.h +213 -0
- data/vendor/faiss/faiss/impl/{pq4_fast_scan_search_qbs.cpp → fast_scan/kernels_simd512.h} +26 -356
- data/vendor/faiss/faiss/impl/fast_scan/rabitq_dispatching.h +90 -0
- data/vendor/faiss/faiss/impl/fast_scan/rabitq_result_handler.h +108 -0
- data/vendor/faiss/faiss/impl/{simd_result_handlers.h → fast_scan/simd_result_handlers.h} +282 -134
- data/vendor/faiss/faiss/impl/hnsw/LockVector.cpp +54 -0
- data/vendor/faiss/faiss/impl/hnsw/LockVector.h +64 -0
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +91 -0
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -0
- data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +104 -0
- data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +111 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +1132 -45
- data/vendor/faiss/faiss/impl/index_read_utils.h +1 -1
- data/vendor/faiss/faiss/impl/index_write.cpp +95 -13
- data/vendor/faiss/faiss/impl/io.cpp +6 -6
- data/vendor/faiss/faiss/impl/io_macros.h +33 -16
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +10 -10
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +37 -23
- data/vendor/faiss/faiss/impl/lattice_Zn.h +6 -6
- data/vendor/faiss/faiss/impl/mapped_io.cpp +6 -6
- data/vendor/faiss/faiss/impl/platform_macros.h +11 -4
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQScanner_impl.h +549 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.cpp +245 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.h +105 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/PQDistanceComputer_impl.h +106 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/{pq_code_distance-avx2.cpp → pq_code_distance-avx2.h} +9 -13
- data/vendor/faiss/faiss/impl/pq_code_distance/{pq_code_distance-avx512.cpp → pq_code_distance-avx512.h} +9 -57
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +29 -111
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.h +96 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +238 -5
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-sve.cpp +5 -7
- data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +68 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +311 -477
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
- data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +1 -1
- data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +3 -2
- data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +102 -11
- data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +27 -1
- data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +3 -3
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +148 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +167 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +59 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +163 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +311 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +192 -8
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +12 -0
- data/vendor/faiss/faiss/impl/simd_dispatch.h +100 -66
- data/vendor/faiss/faiss/impl/simdlib/simdlib.h +57 -0
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_avx2.h +264 -172
- data/vendor/faiss/faiss/impl/simdlib/simdlib_avx512.h +414 -0
- data/vendor/faiss/faiss/impl/simdlib/simdlib_dispatch.h +44 -0
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_emulated.h +231 -166
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_neon.h +270 -218
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_ppc64.h +201 -160
- data/vendor/faiss/faiss/impl/svs_io.cpp +12 -3
- data/vendor/faiss/faiss/impl/svs_io.h +8 -2
- data/vendor/faiss/faiss/index_factory.cpp +86 -18
- data/vendor/faiss/faiss/index_io.h +24 -0
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +66 -16
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +24 -14
- data/vendor/faiss/faiss/invlists/DirectMap.h +4 -3
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +157 -73
- data/vendor/faiss/faiss/invlists/InvertedLists.h +86 -23
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +4 -4
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +13 -13
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +2 -2
- data/vendor/faiss/faiss/svs/IndexSVSIVF.cpp +350 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVF.h +128 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.cpp +40 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.h +43 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.cpp +225 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.h +71 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +25 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +18 -2
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +12 -3
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +7 -2
- data/vendor/faiss/faiss/utils/Heap.cpp +10 -10
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +47 -36
- data/vendor/faiss/faiss/utils/NeuralNet.h +1 -1
- data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +10 -4
- data/vendor/faiss/faiss/utils/distances.cpp +390 -560
- data/vendor/faiss/faiss/utils/distances.h +20 -1
- data/vendor/faiss/faiss/utils/distances_dispatch.h +117 -37
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +8 -7
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +33 -14
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +12 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +16 -293
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based_neon.cpp +57 -0
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_kernel-inl.h +290 -0
- data/vendor/faiss/faiss/utils/distances_simd.cpp +5 -177
- data/vendor/faiss/faiss/utils/extra_distances.cpp +9 -8
- data/vendor/faiss/faiss/utils/extra_distances.h +32 -6
- data/vendor/faiss/faiss/utils/hamming-inl.h +13 -11
- data/vendor/faiss/faiss/utils/hamming.cpp +66 -517
- data/vendor/faiss/faiss/utils/hamming.h +92 -2
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +287 -10
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx2.h +142 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +234 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-generic.h +368 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-neon.h +322 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-rvv.h +39 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer.h +146 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_impl.h +481 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_neon.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_rvv.cpp +15 -0
- data/vendor/faiss/faiss/utils/partitioning.cpp +66 -987
- data/vendor/faiss/faiss/utils/partitioning.h +31 -0
- data/vendor/faiss/faiss/utils/popcount.h +29 -0
- data/vendor/faiss/faiss/utils/pq_code_distance.h +2 -2
- data/vendor/faiss/faiss/utils/prefetch.h +2 -2
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +30 -30
- data/vendor/faiss/faiss/utils/quantize_lut.h +1 -1
- data/vendor/faiss/faiss/utils/rabitq_simd.h +57 -536
- data/vendor/faiss/faiss/utils/random.cpp +6 -6
- data/vendor/faiss/faiss/utils/simd_impl/IVFFlatScanner-inl.h +51 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +5 -1
- data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +213 -4
- data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +163 -10
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +250 -4
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +7 -4
- data/vendor/faiss/faiss/utils/simd_impl/distances_rvv.cpp +189 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_simdlib256.h +195 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +2 -1
- data/vendor/faiss/faiss/utils/{distances_fused/simdlib_based.h → simd_impl/exhaustive_L2sqr_blas_cmax.h} +5 -10
- data/vendor/faiss/faiss/utils/simd_impl/hamming_impl.h +481 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_avx2.cpp +14 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_neon.cpp +14 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +1085 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx2.cpp +355 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512.cpp +477 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_neon.cpp +55 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_rvv.cpp +55 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_dispatch.h +32 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels.h +43 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx2.cpp +57 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx512.cpp +45 -0
- data/vendor/faiss/faiss/utils/simd_levels.cpp +17 -5
- data/vendor/faiss/faiss/utils/simd_levels.h +93 -1
- data/vendor/faiss/faiss/utils/sorting.cpp +48 -36
- data/vendor/faiss/faiss/utils/utils.cpp +5 -5
- data/vendor/faiss/faiss/utils/utils.h +3 -3
- metadata +119 -34
- data/vendor/faiss/faiss/impl/RaBitQStats.cpp +0 -29
- data/vendor/faiss/faiss/impl/RaBitQStats.h +0 -56
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +0 -224
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +0 -230
- data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +0 -84
- data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +0 -196
- data/vendor/faiss/faiss/utils/approx_topk/mode.h +0 -34
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +0 -36
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +0 -235
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +0 -462
- data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +0 -490
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -449
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +0 -87
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +0 -524
- data/vendor/faiss/faiss/utils/simdlib.h +0 -42
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +0 -365
- /data/ext/faiss/{utils_rb.h → utils.h} +0 -0
|
@@ -13,22 +13,19 @@
|
|
|
13
13
|
#include <cstddef>
|
|
14
14
|
#include <cstdio>
|
|
15
15
|
#include <cstring>
|
|
16
|
+
#include <vector>
|
|
16
17
|
|
|
17
18
|
#include <omp.h>
|
|
18
19
|
|
|
19
|
-
#ifdef __AVX2__
|
|
20
|
-
#include <immintrin.h>
|
|
21
|
-
#elif defined(__ARM_FEATURE_SVE)
|
|
22
|
-
#include <arm_sve.h>
|
|
23
|
-
#endif
|
|
24
|
-
|
|
25
20
|
#include <faiss/impl/AuxIndexStructures.h>
|
|
26
21
|
#include <faiss/impl/FaissAssert.h>
|
|
27
22
|
#include <faiss/impl/IDSelector.h>
|
|
28
23
|
#include <faiss/impl/ResultHandler.h>
|
|
29
24
|
|
|
25
|
+
#include <faiss/impl/simd_dispatch.h>
|
|
30
26
|
#include <faiss/utils/distances_dispatch.h>
|
|
31
27
|
#include <faiss/utils/distances_fused/distances_fused.h>
|
|
28
|
+
#include <faiss/utils/simd_impl/exhaustive_L2sqr_blas_cmax.h>
|
|
32
29
|
|
|
33
30
|
#ifndef FINTEGER
|
|
34
31
|
#define FINTEGER long
|
|
@@ -172,6 +169,30 @@ int fvec_madd_and_argmin(
|
|
|
172
169
|
return fvec_madd_and_argmin_dispatch(n, a, bf, b, c);
|
|
173
170
|
}
|
|
174
171
|
|
|
172
|
+
void fvec_sub(size_t d, const float* a, const float* b, float* c) {
|
|
173
|
+
fvec_sub_dispatch(d, a, b, c);
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
void fvec_add(size_t d, const float* a, const float* b, float* c) {
|
|
177
|
+
fvec_add_dispatch(d, a, b, c);
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
void fvec_add(size_t d, const float* a, float b, float* c) {
|
|
181
|
+
fvec_add_scalar_dispatch(d, a, b, c);
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
void compute_PQ_dis_tables_dsub2(
|
|
185
|
+
size_t d,
|
|
186
|
+
size_t ksub,
|
|
187
|
+
const float* all_centroids,
|
|
188
|
+
size_t nx,
|
|
189
|
+
const float* x,
|
|
190
|
+
bool is_inner_product,
|
|
191
|
+
float* dis_tables) {
|
|
192
|
+
compute_PQ_dis_tables_dsub2_dispatch(
|
|
193
|
+
d, ksub, all_centroids, nx, x, is_inner_product, dis_tables);
|
|
194
|
+
}
|
|
195
|
+
|
|
175
196
|
/***************************************************************************
|
|
176
197
|
* Matrix/vector ops
|
|
177
198
|
***************************************************************************/
|
|
@@ -182,10 +203,12 @@ void fvec_norms_L2(
|
|
|
182
203
|
const float* __restrict x,
|
|
183
204
|
size_t d,
|
|
184
205
|
size_t nx) {
|
|
206
|
+
with_simd_level([&]<SIMDLevel SL>() {
|
|
185
207
|
#pragma omp parallel for if (nx > 10000)
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
208
|
+
for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
|
|
209
|
+
nr[i] = sqrtf(fvec_norm_L2sqr<SL>(x + i * d, d));
|
|
210
|
+
}
|
|
211
|
+
});
|
|
189
212
|
}
|
|
190
213
|
|
|
191
214
|
void fvec_norms_L2sqr(
|
|
@@ -193,10 +216,12 @@ void fvec_norms_L2sqr(
|
|
|
193
216
|
const float* __restrict x,
|
|
194
217
|
size_t d,
|
|
195
218
|
size_t nx) {
|
|
219
|
+
with_simd_level([&]<SIMDLevel SL>() {
|
|
196
220
|
#pragma omp parallel for if (nx > 10000)
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
221
|
+
for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
|
|
222
|
+
nr[i] = fvec_norm_L2sqr<SL>(x + i * d, d);
|
|
223
|
+
}
|
|
224
|
+
});
|
|
200
225
|
}
|
|
201
226
|
|
|
202
227
|
// The following is a workaround to a problem
|
|
@@ -210,29 +235,35 @@ void fvec_norms_L2sqr(
|
|
|
210
235
|
// The workaround below is explicitly branching
|
|
211
236
|
// off to a codepath without omp.
|
|
212
237
|
|
|
213
|
-
#define FVEC_RENORM_L2_IMPL \
|
|
214
|
-
float* __restrict xi = x + i * d; \
|
|
215
|
-
\
|
|
216
|
-
float nr = fvec_norm_L2sqr_dispatch(xi, d); \
|
|
217
|
-
\
|
|
218
|
-
if (nr > 0) { \
|
|
219
|
-
size_t j; \
|
|
220
|
-
const float inv_nr = 1.0 / sqrtf(nr); \
|
|
221
|
-
for (j = 0; j < d; j++) \
|
|
222
|
-
xi[j] *= inv_nr; \
|
|
223
|
-
}
|
|
224
|
-
|
|
225
238
|
void fvec_renorm_L2_noomp(size_t d, size_t nx, float* __restrict x) {
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
239
|
+
with_simd_level([&]<SIMDLevel SL>() {
|
|
240
|
+
for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
|
|
241
|
+
float* __restrict xi = x + i * d;
|
|
242
|
+
float nr = fvec_norm_L2sqr<SL>(xi, d);
|
|
243
|
+
if (nr > 0) {
|
|
244
|
+
const float inv_nr = 1.0 / sqrtf(nr);
|
|
245
|
+
for (size_t j = 0; j < d; j++) {
|
|
246
|
+
xi[j] *= inv_nr;
|
|
247
|
+
}
|
|
248
|
+
}
|
|
249
|
+
}
|
|
250
|
+
});
|
|
229
251
|
}
|
|
230
252
|
|
|
231
253
|
void fvec_renorm_L2_omp(size_t d, size_t nx, float* __restrict x) {
|
|
254
|
+
with_simd_level([&]<SIMDLevel SL>() {
|
|
232
255
|
#pragma omp parallel for if (nx > 10000)
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
256
|
+
for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
|
|
257
|
+
float* __restrict xi = x + i * d;
|
|
258
|
+
float nr = fvec_norm_L2sqr<SL>(xi, d);
|
|
259
|
+
if (nr > 0) {
|
|
260
|
+
const float inv_nr = 1.0 / sqrtf(nr);
|
|
261
|
+
for (size_t j = 0; j < d; j++) {
|
|
262
|
+
xi[j] *= inv_nr;
|
|
263
|
+
}
|
|
264
|
+
}
|
|
265
|
+
}
|
|
266
|
+
});
|
|
236
267
|
}
|
|
237
268
|
|
|
238
269
|
void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) {
|
|
@@ -265,22 +296,24 @@ void exhaustive_inner_product_seq(
|
|
|
265
296
|
#pragma omp parallel num_threads(nt)
|
|
266
297
|
{
|
|
267
298
|
SingleResultHandler resi(res);
|
|
299
|
+
with_simd_level([&]<SIMDLevel SL>() {
|
|
268
300
|
#pragma omp for
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
301
|
+
for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
|
|
302
|
+
const float* x_i = x + i * d;
|
|
303
|
+
const float* y_j = y;
|
|
272
304
|
|
|
273
|
-
|
|
305
|
+
resi.begin(i);
|
|
274
306
|
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
307
|
+
for (size_t j = 0; j < ny; j++, y_j += d) {
|
|
308
|
+
if (!res.is_in_selection(j)) {
|
|
309
|
+
continue;
|
|
310
|
+
}
|
|
311
|
+
float ip = fvec_inner_product<SL>(x_i, y_j, d);
|
|
312
|
+
resi.add_result(ip, j);
|
|
278
313
|
}
|
|
279
|
-
|
|
280
|
-
resi.add_result(ip, j);
|
|
314
|
+
resi.end();
|
|
281
315
|
}
|
|
282
|
-
|
|
283
|
-
}
|
|
316
|
+
});
|
|
284
317
|
}
|
|
285
318
|
}
|
|
286
319
|
|
|
@@ -299,20 +332,22 @@ void exhaustive_L2sqr_seq(
|
|
|
299
332
|
#pragma omp parallel num_threads(nt)
|
|
300
333
|
{
|
|
301
334
|
SingleResultHandler resi(res);
|
|
335
|
+
with_simd_level([&]<SIMDLevel SL>() {
|
|
302
336
|
#pragma omp for
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
337
|
+
for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
|
|
338
|
+
const float* x_i = x + i * d;
|
|
339
|
+
const float* y_j = y;
|
|
340
|
+
resi.begin(i);
|
|
341
|
+
for (size_t j = 0; j < ny; j++, y_j += d) {
|
|
342
|
+
if (!res.is_in_selection(j)) {
|
|
343
|
+
continue;
|
|
344
|
+
}
|
|
345
|
+
float disij = fvec_L2sqr<SL>(x_i, y_j, d);
|
|
346
|
+
resi.add_result(disij, j);
|
|
310
347
|
}
|
|
311
|
-
|
|
312
|
-
resi.add_result(disij, j);
|
|
348
|
+
resi.end();
|
|
313
349
|
}
|
|
314
|
-
|
|
315
|
-
}
|
|
350
|
+
});
|
|
316
351
|
}
|
|
317
352
|
}
|
|
318
353
|
|
|
@@ -438,7 +473,7 @@ void exhaustive_L2sqr_blas_default_impl(
|
|
|
438
473
|
ip_block.get(),
|
|
439
474
|
&nyi);
|
|
440
475
|
}
|
|
441
|
-
for (
|
|
476
|
+
for (size_t i = i0; i < i1; i++) {
|
|
442
477
|
float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);
|
|
443
478
|
|
|
444
479
|
for (size_t j = j0; j < j1; j++) {
|
|
@@ -474,396 +509,12 @@ void exhaustive_L2sqr_blas(
|
|
|
474
509
|
size_t ny,
|
|
475
510
|
BlockResultHandler& res,
|
|
476
511
|
const float* y_norms = nullptr) {
|
|
477
|
-
exhaustive_L2sqr_blas_default_impl(x, y, d, nx, ny, res);
|
|
512
|
+
exhaustive_L2sqr_blas_default_impl(x, y, d, nx, ny, res, y_norms);
|
|
478
513
|
}
|
|
479
514
|
|
|
480
|
-
|
|
481
|
-
void exhaustive_L2sqr_blas_cmax_avx2(
|
|
482
|
-
const float* x,
|
|
483
|
-
const float* y,
|
|
484
|
-
size_t d,
|
|
485
|
-
size_t nx,
|
|
486
|
-
size_t ny,
|
|
487
|
-
Top1BlockResultHandler<CMax<float, int64_t>>& res,
|
|
488
|
-
const float* y_norms) {
|
|
489
|
-
// BLAS does not like empty matrices
|
|
490
|
-
if (nx == 0 || ny == 0) {
|
|
491
|
-
return;
|
|
492
|
-
}
|
|
493
|
-
|
|
494
|
-
/* block sizes */
|
|
495
|
-
const size_t bs_x = distance_compute_blas_query_bs;
|
|
496
|
-
const size_t bs_y = distance_compute_blas_database_bs;
|
|
497
|
-
// const size_t bs_x = 16, bs_y = 16;
|
|
498
|
-
std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);
|
|
499
|
-
std::unique_ptr<float[]> x_norms(new float[nx]);
|
|
500
|
-
std::unique_ptr<float[]> del2;
|
|
501
|
-
|
|
502
|
-
fvec_norms_L2sqr(x_norms.get(), x, d, nx);
|
|
503
|
-
|
|
504
|
-
if (!y_norms) {
|
|
505
|
-
float* y_norms2 = new float[ny];
|
|
506
|
-
del2.reset(y_norms2);
|
|
507
|
-
fvec_norms_L2sqr(y_norms2, y, d, ny);
|
|
508
|
-
y_norms = y_norms2;
|
|
509
|
-
}
|
|
510
|
-
|
|
511
|
-
for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
|
|
512
|
-
size_t i1 = i0 + bs_x;
|
|
513
|
-
if (i1 > nx) {
|
|
514
|
-
i1 = nx;
|
|
515
|
-
}
|
|
516
|
-
|
|
517
|
-
res.begin_multiple(i0, i1);
|
|
518
|
-
|
|
519
|
-
for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
|
|
520
|
-
size_t j1 = j0 + bs_y;
|
|
521
|
-
if (j1 > ny) {
|
|
522
|
-
j1 = ny;
|
|
523
|
-
}
|
|
524
|
-
/* compute the actual dot products */
|
|
525
|
-
{
|
|
526
|
-
float one = 1, zero = 0;
|
|
527
|
-
FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
|
|
528
|
-
sgemm_("Transpose",
|
|
529
|
-
"Not transpose",
|
|
530
|
-
&nyi,
|
|
531
|
-
&nxi,
|
|
532
|
-
&di,
|
|
533
|
-
&one,
|
|
534
|
-
y + j0 * d,
|
|
535
|
-
&di,
|
|
536
|
-
x + i0 * d,
|
|
537
|
-
&di,
|
|
538
|
-
&zero,
|
|
539
|
-
ip_block.get(),
|
|
540
|
-
&nyi);
|
|
541
|
-
}
|
|
542
|
-
for (int64_t i = i0; i < i1; i++) {
|
|
543
|
-
float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);
|
|
544
|
-
|
|
545
|
-
_mm_prefetch((const char*)ip_line, _MM_HINT_NTA);
|
|
546
|
-
_mm_prefetch((const char*)(ip_line + 16), _MM_HINT_NTA);
|
|
547
|
-
|
|
548
|
-
// constant
|
|
549
|
-
const __m256 mul_minus2 = _mm256_set1_ps(-2);
|
|
550
|
-
|
|
551
|
-
// Track 8 min distances + 8 min indices.
|
|
552
|
-
// All the distances tracked do not take x_norms[i]
|
|
553
|
-
// into account in order to get rid of extra
|
|
554
|
-
// _mm256_add_ps(x_norms[i], ...) instructions
|
|
555
|
-
// is distance computations.
|
|
556
|
-
__m256 min_distances =
|
|
557
|
-
_mm256_set1_ps(res.dis_tab[i] - x_norms[i]);
|
|
558
|
-
|
|
559
|
-
// these indices are local and are relative to j0.
|
|
560
|
-
// so, value 0 means j0.
|
|
561
|
-
__m256i min_indices = _mm256_set1_epi32(0);
|
|
562
|
-
|
|
563
|
-
__m256i current_indices =
|
|
564
|
-
_mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
|
|
565
|
-
const __m256i indices_delta = _mm256_set1_epi32(8);
|
|
566
|
-
|
|
567
|
-
// current j index
|
|
568
|
-
size_t idx_j = 0;
|
|
569
|
-
size_t count = j1 - j0;
|
|
570
|
-
|
|
571
|
-
// process 16 elements per loop
|
|
572
|
-
for (; idx_j < (count / 16) * 16; idx_j += 16, ip_line += 16) {
|
|
573
|
-
_mm_prefetch((const char*)(ip_line + 32), _MM_HINT_NTA);
|
|
574
|
-
_mm_prefetch((const char*)(ip_line + 48), _MM_HINT_NTA);
|
|
575
|
-
|
|
576
|
-
// load values for norms
|
|
577
|
-
const __m256 y_norm_0 =
|
|
578
|
-
_mm256_loadu_ps(y_norms + idx_j + j0 + 0);
|
|
579
|
-
const __m256 y_norm_1 =
|
|
580
|
-
_mm256_loadu_ps(y_norms + idx_j + j0 + 8);
|
|
581
|
-
|
|
582
|
-
// load values for dot products
|
|
583
|
-
const __m256 ip_0 = _mm256_loadu_ps(ip_line + 0);
|
|
584
|
-
const __m256 ip_1 = _mm256_loadu_ps(ip_line + 8);
|
|
585
|
-
|
|
586
|
-
// compute dis = y_norm[j] - 2 * dot(x_norm[i], y_norm[j]).
|
|
587
|
-
// x_norm[i] was dropped off because it is a constant for a
|
|
588
|
-
// given i. We'll deal with it later.
|
|
589
|
-
__m256 distances_0 =
|
|
590
|
-
_mm256_fmadd_ps(ip_0, mul_minus2, y_norm_0);
|
|
591
|
-
__m256 distances_1 =
|
|
592
|
-
_mm256_fmadd_ps(ip_1, mul_minus2, y_norm_1);
|
|
593
|
-
|
|
594
|
-
// compare the new distances to the min distances
|
|
595
|
-
// for each of the first group of 8 AVX2 components.
|
|
596
|
-
const __m256 comparison_0 = _mm256_cmp_ps(
|
|
597
|
-
min_distances, distances_0, _CMP_LE_OS);
|
|
598
|
-
|
|
599
|
-
// update min distances and indices with closest vectors if
|
|
600
|
-
// needed.
|
|
601
|
-
min_distances = _mm256_blendv_ps(
|
|
602
|
-
distances_0, min_distances, comparison_0);
|
|
603
|
-
min_indices = _mm256_castps_si256(_mm256_blendv_ps(
|
|
604
|
-
_mm256_castsi256_ps(current_indices),
|
|
605
|
-
_mm256_castsi256_ps(min_indices),
|
|
606
|
-
comparison_0));
|
|
607
|
-
current_indices =
|
|
608
|
-
_mm256_add_epi32(current_indices, indices_delta);
|
|
609
|
-
|
|
610
|
-
// compare the new distances to the min distances
|
|
611
|
-
// for each of the second group of 8 AVX2 components.
|
|
612
|
-
const __m256 comparison_1 = _mm256_cmp_ps(
|
|
613
|
-
min_distances, distances_1, _CMP_LE_OS);
|
|
614
|
-
|
|
615
|
-
// update min distances and indices with closest vectors if
|
|
616
|
-
// needed.
|
|
617
|
-
min_distances = _mm256_blendv_ps(
|
|
618
|
-
distances_1, min_distances, comparison_1);
|
|
619
|
-
min_indices = _mm256_castps_si256(_mm256_blendv_ps(
|
|
620
|
-
_mm256_castsi256_ps(current_indices),
|
|
621
|
-
_mm256_castsi256_ps(min_indices),
|
|
622
|
-
comparison_1));
|
|
623
|
-
current_indices =
|
|
624
|
-
_mm256_add_epi32(current_indices, indices_delta);
|
|
625
|
-
}
|
|
626
|
-
|
|
627
|
-
// dump values and find the minimum distance / minimum index
|
|
628
|
-
float min_distances_scalar[8];
|
|
629
|
-
uint32_t min_indices_scalar[8];
|
|
630
|
-
_mm256_storeu_ps(min_distances_scalar, min_distances);
|
|
631
|
-
_mm256_storeu_si256(
|
|
632
|
-
(__m256i*)(min_indices_scalar), min_indices);
|
|
633
|
-
|
|
634
|
-
float current_min_distance = res.dis_tab[i];
|
|
635
|
-
uint32_t current_min_index = res.ids_tab[i];
|
|
636
|
-
|
|
637
|
-
// This unusual comparison is needed to maintain the behavior
|
|
638
|
-
// of the original implementation: if two indices are
|
|
639
|
-
// represented with equal distance values, then
|
|
640
|
-
// the index with the min value is returned.
|
|
641
|
-
for (size_t jv = 0; jv < 8; jv++) {
|
|
642
|
-
// add missing x_norms[i]
|
|
643
|
-
float distance_candidate =
|
|
644
|
-
min_distances_scalar[jv] + x_norms[i];
|
|
645
|
-
|
|
646
|
-
// negative values can occur for identical vectors
|
|
647
|
-
// due to roundoff errors.
|
|
648
|
-
if (distance_candidate < 0) {
|
|
649
|
-
distance_candidate = 0;
|
|
650
|
-
}
|
|
651
|
-
|
|
652
|
-
int64_t index_candidate = min_indices_scalar[jv] + j0;
|
|
653
|
-
|
|
654
|
-
if (current_min_distance > distance_candidate) {
|
|
655
|
-
current_min_distance = distance_candidate;
|
|
656
|
-
current_min_index = index_candidate;
|
|
657
|
-
} else if (
|
|
658
|
-
current_min_distance == distance_candidate &&
|
|
659
|
-
current_min_index > index_candidate) {
|
|
660
|
-
current_min_index = index_candidate;
|
|
661
|
-
}
|
|
662
|
-
}
|
|
663
|
-
|
|
664
|
-
// process leftovers
|
|
665
|
-
for (; idx_j < count; idx_j++, ip_line++) {
|
|
666
|
-
float ip = *ip_line;
|
|
667
|
-
float dis = x_norms[i] + y_norms[idx_j + j0] - 2 * ip;
|
|
668
|
-
// negative values can occur for identical vectors
|
|
669
|
-
// due to roundoff errors.
|
|
670
|
-
if (dis < 0) {
|
|
671
|
-
dis = 0;
|
|
672
|
-
}
|
|
673
|
-
|
|
674
|
-
if (current_min_distance > dis) {
|
|
675
|
-
current_min_distance = dis;
|
|
676
|
-
current_min_index = idx_j + j0;
|
|
677
|
-
}
|
|
678
|
-
}
|
|
679
|
-
|
|
680
|
-
//
|
|
681
|
-
res.add_result(i, current_min_distance, current_min_index);
|
|
682
|
-
}
|
|
683
|
-
}
|
|
684
|
-
// Does nothing for SingleBestResultHandler, but
|
|
685
|
-
// keeping the call for the consistency.
|
|
686
|
-
res.end_multiple();
|
|
687
|
-
InterruptCallback::check();
|
|
688
|
-
}
|
|
689
|
-
}
|
|
690
|
-
#elif defined(__ARM_FEATURE_SVE)
|
|
691
|
-
void exhaustive_L2sqr_blas_cmax_sve(
|
|
692
|
-
const float* x,
|
|
693
|
-
const float* y,
|
|
694
|
-
size_t d,
|
|
695
|
-
size_t nx,
|
|
696
|
-
size_t ny,
|
|
697
|
-
Top1BlockResultHandler<CMax<float, int64_t>>& res,
|
|
698
|
-
const float* y_norms) {
|
|
699
|
-
// BLAS does not like empty matrices
|
|
700
|
-
if (nx == 0 || ny == 0)
|
|
701
|
-
return;
|
|
702
|
-
|
|
703
|
-
/* block sizes */
|
|
704
|
-
const size_t bs_x = distance_compute_blas_query_bs;
|
|
705
|
-
const size_t bs_y = distance_compute_blas_database_bs;
|
|
706
|
-
// const size_t bs_x = 16, bs_y = 16;
|
|
707
|
-
std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);
|
|
708
|
-
std::unique_ptr<float[]> x_norms(new float[nx]);
|
|
709
|
-
std::unique_ptr<float[]> del2;
|
|
710
|
-
|
|
711
|
-
fvec_norms_L2sqr(x_norms.get(), x, d, nx);
|
|
712
|
-
|
|
713
|
-
const size_t lanes = svcntw();
|
|
714
|
-
|
|
715
|
-
if (!y_norms) {
|
|
716
|
-
float* y_norms2 = new float[ny];
|
|
717
|
-
del2.reset(y_norms2);
|
|
718
|
-
fvec_norms_L2sqr(y_norms2, y, d, ny);
|
|
719
|
-
y_norms = y_norms2;
|
|
720
|
-
}
|
|
721
|
-
|
|
722
|
-
for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
|
|
723
|
-
size_t i1 = i0 + bs_x;
|
|
724
|
-
if (i1 > nx)
|
|
725
|
-
i1 = nx;
|
|
726
|
-
|
|
727
|
-
res.begin_multiple(i0, i1);
|
|
728
|
-
|
|
729
|
-
for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
|
|
730
|
-
size_t j1 = j0 + bs_y;
|
|
731
|
-
if (j1 > ny)
|
|
732
|
-
j1 = ny;
|
|
733
|
-
/* compute the actual dot products */
|
|
734
|
-
{
|
|
735
|
-
float one = 1, zero = 0;
|
|
736
|
-
FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
|
|
737
|
-
sgemm_("Transpose",
|
|
738
|
-
"Not transpose",
|
|
739
|
-
&nyi,
|
|
740
|
-
&nxi,
|
|
741
|
-
&di,
|
|
742
|
-
&one,
|
|
743
|
-
y + j0 * d,
|
|
744
|
-
&di,
|
|
745
|
-
x + i0 * d,
|
|
746
|
-
&di,
|
|
747
|
-
&zero,
|
|
748
|
-
ip_block.get(),
|
|
749
|
-
&nyi);
|
|
750
|
-
}
|
|
751
|
-
for (int64_t i = i0; i < i1; i++) {
|
|
752
|
-
const size_t count = j1 - j0;
|
|
753
|
-
float* ip_line = ip_block.get() + (i - i0) * count;
|
|
754
|
-
|
|
755
|
-
svprfw(svwhilelt_b32_u64(0, count), ip_line, SV_PLDL1KEEP);
|
|
756
|
-
svprfw(svwhilelt_b32_u64(lanes, count),
|
|
757
|
-
ip_line + lanes,
|
|
758
|
-
SV_PLDL1KEEP);
|
|
759
|
-
|
|
760
|
-
// Track lanes min distances + lanes min indices.
|
|
761
|
-
// All the distances tracked do not take x_norms[i]
|
|
762
|
-
// into account in order to get rid of extra
|
|
763
|
-
// vaddq_f32(x_norms[i], ...) instructions
|
|
764
|
-
// is distance computations.
|
|
765
|
-
auto min_distances = svdup_n_f32(res.dis_tab[i] - x_norms[i]);
|
|
766
|
-
|
|
767
|
-
// these indices are local and are relative to j0.
|
|
768
|
-
// so, value 0 means j0.
|
|
769
|
-
auto min_indices = svdup_n_u32(0u);
|
|
770
|
-
|
|
771
|
-
auto current_indices = svindex_u32(0u, 1u);
|
|
772
|
-
|
|
773
|
-
// process lanes * 2 elements per loop
|
|
774
|
-
for (size_t idx_j = 0; idx_j < count;
|
|
775
|
-
idx_j += lanes * 2, ip_line += lanes * 2) {
|
|
776
|
-
svprfw(svwhilelt_b32_u64(idx_j + lanes * 2, count),
|
|
777
|
-
ip_line + lanes * 2,
|
|
778
|
-
SV_PLDL1KEEP);
|
|
779
|
-
svprfw(svwhilelt_b32_u64(idx_j + lanes * 3, count),
|
|
780
|
-
ip_line + lanes * 3,
|
|
781
|
-
SV_PLDL1KEEP);
|
|
782
|
-
|
|
783
|
-
// mask
|
|
784
|
-
const auto mask_0 = svwhilelt_b32_u64(idx_j, count);
|
|
785
|
-
const auto mask_1 = svwhilelt_b32_u64(idx_j + lanes, count);
|
|
786
|
-
|
|
787
|
-
// load values for norms
|
|
788
|
-
const auto y_norm_0 =
|
|
789
|
-
svld1_f32(mask_0, y_norms + idx_j + j0 + 0);
|
|
790
|
-
const auto y_norm_1 =
|
|
791
|
-
svld1_f32(mask_1, y_norms + idx_j + j0 + lanes);
|
|
792
|
-
|
|
793
|
-
// load values for dot products
|
|
794
|
-
const auto ip_0 = svld1_f32(mask_0, ip_line + 0);
|
|
795
|
-
const auto ip_1 = svld1_f32(mask_1, ip_line + lanes);
|
|
796
|
-
|
|
797
|
-
// compute dis = y_norm[j] - 2 * dot(x_norm[i], y_norm[j]).
|
|
798
|
-
// x_norm[i] was dropped off because it is a constant for a
|
|
799
|
-
// given i. We'll deal with it later.
|
|
800
|
-
const auto distances_0 =
|
|
801
|
-
svmla_n_f32_z(mask_0, y_norm_0, ip_0, -2.f);
|
|
802
|
-
const auto distances_1 =
|
|
803
|
-
svmla_n_f32_z(mask_1, y_norm_1, ip_1, -2.f);
|
|
804
|
-
|
|
805
|
-
// compare the new distances to the min distances
|
|
806
|
-
// for each of the first group of 4 ARM SIMD components.
|
|
807
|
-
auto comparison =
|
|
808
|
-
svcmpgt_f32(mask_0, min_distances, distances_0);
|
|
809
|
-
|
|
810
|
-
// update min distances and indices with closest vectors if
|
|
811
|
-
// needed.
|
|
812
|
-
min_distances =
|
|
813
|
-
svsel_f32(comparison, distances_0, min_distances);
|
|
814
|
-
min_indices =
|
|
815
|
-
svsel_u32(comparison, current_indices, min_indices);
|
|
816
|
-
current_indices = svadd_n_u32_x(
|
|
817
|
-
mask_0,
|
|
818
|
-
current_indices,
|
|
819
|
-
static_cast<uint32_t>(lanes));
|
|
820
|
-
|
|
821
|
-
// compare the new distances to the min distances
|
|
822
|
-
// for each of the second group of 4 ARM SIMD components.
|
|
823
|
-
comparison =
|
|
824
|
-
svcmpgt_f32(mask_1, min_distances, distances_1);
|
|
825
|
-
|
|
826
|
-
// update min distances and indices with closest vectors if
|
|
827
|
-
// needed.
|
|
828
|
-
min_distances =
|
|
829
|
-
svsel_f32(comparison, distances_1, min_distances);
|
|
830
|
-
min_indices =
|
|
831
|
-
svsel_u32(comparison, current_indices, min_indices);
|
|
832
|
-
current_indices = svadd_n_u32_x(
|
|
833
|
-
mask_1,
|
|
834
|
-
current_indices,
|
|
835
|
-
static_cast<uint32_t>(lanes));
|
|
836
|
-
}
|
|
515
|
+
} // anonymous namespace
|
|
837
516
|
|
|
838
|
-
|
|
839
|
-
// negative values can occur for identical vectors
|
|
840
|
-
// due to roundoff errors.
|
|
841
|
-
auto mask = svwhilelt_b32_u64(0, count);
|
|
842
|
-
min_distances = svadd_n_f32_z(
|
|
843
|
-
svcmpge_n_f32(mask, min_distances, -x_norms[i]),
|
|
844
|
-
min_distances,
|
|
845
|
-
x_norms[i]);
|
|
846
|
-
min_indices = svadd_n_u32_x(
|
|
847
|
-
mask, min_indices, static_cast<uint32_t>(j0));
|
|
848
|
-
mask = svcmple_n_f32(mask, min_distances, res.dis_tab[i]);
|
|
849
|
-
if (svcntp_b32(svptrue_b32(), mask) == 0)
|
|
850
|
-
res.add_result(i, res.dis_tab[i], res.ids_tab[i]);
|
|
851
|
-
else {
|
|
852
|
-
const auto min_distance = svminv_f32(mask, min_distances);
|
|
853
|
-
const auto min_index = svminv_u32(
|
|
854
|
-
svcmpeq_n_f32(mask, min_distances, min_distance),
|
|
855
|
-
min_indices);
|
|
856
|
-
res.add_result(i, min_distance, min_index);
|
|
857
|
-
}
|
|
858
|
-
}
|
|
859
|
-
}
|
|
860
|
-
// Does nothing for SingleBestResultHandler, but
|
|
861
|
-
// keeping the call for the consistency.
|
|
862
|
-
res.end_multiple();
|
|
863
|
-
InterruptCallback::check();
|
|
864
|
-
}
|
|
865
|
-
}
|
|
866
|
-
#endif
|
|
517
|
+
namespace {
|
|
867
518
|
|
|
868
519
|
// an override if only a single closest point is needed
|
|
869
520
|
template <>
|
|
@@ -875,43 +526,20 @@ void exhaustive_L2sqr_blas<Top1BlockResultHandler<CMax<float, int64_t>>>(
|
|
|
875
526
|
size_t ny,
|
|
876
527
|
Top1BlockResultHandler<CMax<float, int64_t>>& res,
|
|
877
528
|
const float* y_norms) {
|
|
878
|
-
#if defined(__AVX2__)
|
|
879
529
|
// use a faster fused kernel if available
|
|
880
530
|
if (exhaustive_L2sqr_fused_cmax(x, y, d, nx, ny, res, y_norms)) {
|
|
881
|
-
// the kernel is available and it is complete, we're done.
|
|
882
531
|
return;
|
|
883
532
|
}
|
|
884
533
|
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
}
|
|
894
|
-
|
|
895
|
-
// run the specialized SVE implementation
|
|
896
|
-
exhaustive_L2sqr_blas_cmax_sve(x, y, d, nx, ny, res, y_norms);
|
|
897
|
-
|
|
898
|
-
#elif defined(__aarch64__)
|
|
899
|
-
// use a faster fused kernel if available
|
|
900
|
-
if (exhaustive_L2sqr_fused_cmax(x, y, d, nx, ny, res, y_norms)) {
|
|
901
|
-
// the kernel is available and it is complete, we're done.
|
|
902
|
-
return;
|
|
903
|
-
}
|
|
904
|
-
|
|
905
|
-
// run the default implementation
|
|
906
|
-
exhaustive_L2sqr_blas_default_impl<
|
|
907
|
-
Top1BlockResultHandler<CMax<float, int64_t>>>(
|
|
908
|
-
x, y, d, nx, ny, res, y_norms);
|
|
909
|
-
#else
|
|
910
|
-
// run the default implementation
|
|
911
|
-
exhaustive_L2sqr_blas_default_impl<
|
|
912
|
-
Top1BlockResultHandler<CMax<float, int64_t>>>(
|
|
913
|
-
x, y, d, nx, ny, res, y_norms);
|
|
914
|
-
#endif
|
|
534
|
+
with_selected_simd_levels<AVAILABLE_SIMD_LEVELS_A2>([&]<SIMDLevel SL>() {
|
|
535
|
+
if constexpr (SL == SIMDLevel::AVX2 || SL == SIMDLevel::ARM_SVE) {
|
|
536
|
+
exhaustive_L2sqr_blas_cmax<SL>(x, y, d, nx, ny, res, y_norms);
|
|
537
|
+
} else {
|
|
538
|
+
exhaustive_L2sqr_blas_default_impl<
|
|
539
|
+
Top1BlockResultHandler<CMax<float, int64_t>>>(
|
|
540
|
+
x, y, d, nx, ny, res, y_norms);
|
|
541
|
+
}
|
|
542
|
+
});
|
|
915
543
|
}
|
|
916
544
|
|
|
917
545
|
struct Run_search_inner_product {
|
|
@@ -923,7 +551,8 @@ struct Run_search_inner_product {
|
|
|
923
551
|
size_t d,
|
|
924
552
|
size_t nx,
|
|
925
553
|
size_t ny) {
|
|
926
|
-
if (res.sel ||
|
|
554
|
+
if (res.sel ||
|
|
555
|
+
nx * d < static_cast<size_t>(distance_compute_blas_threshold)) {
|
|
927
556
|
exhaustive_inner_product_seq(x, y, d, nx, ny, res);
|
|
928
557
|
} else {
|
|
929
558
|
exhaustive_inner_product_blas(x, y, d, nx, ny, res);
|
|
@@ -941,7 +570,8 @@ struct Run_search_L2sqr {
|
|
|
941
570
|
size_t nx,
|
|
942
571
|
size_t ny,
|
|
943
572
|
const float* y_norm2) {
|
|
944
|
-
if (res.sel ||
|
|
573
|
+
if (res.sel ||
|
|
574
|
+
nx * d < static_cast<size_t>(distance_compute_blas_threshold)) {
|
|
945
575
|
exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
|
|
946
576
|
} else {
|
|
947
577
|
exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2);
|
|
@@ -955,11 +585,174 @@ struct Run_search_L2sqr {
|
|
|
955
585
|
* KNN driver functions
|
|
956
586
|
*******************************************************/
|
|
957
587
|
|
|
958
|
-
int distance_compute_blas_threshold =
|
|
588
|
+
int distance_compute_blas_threshold = 128000;
|
|
959
589
|
int distance_compute_blas_query_bs = 4096;
|
|
960
590
|
int distance_compute_blas_database_bs = 1024;
|
|
961
591
|
int distance_compute_min_k_reservoir = 100;
|
|
962
592
|
|
|
593
|
+
// Database-parallel KNN: parallelizes over database segments instead of
|
|
594
|
+
// queries, for the case where nx < nthreads and the database is large.
|
|
595
|
+
static constexpr size_t kDbParallelMinVectors = 10000;
|
|
596
|
+
|
|
597
|
+
template <class C>
|
|
598
|
+
static void knn_db_parallel_impl(
|
|
599
|
+
const float* x,
|
|
600
|
+
const float* y,
|
|
601
|
+
size_t d,
|
|
602
|
+
size_t nx,
|
|
603
|
+
size_t ny,
|
|
604
|
+
size_t k,
|
|
605
|
+
float* vals,
|
|
606
|
+
int64_t* ids,
|
|
607
|
+
const float* y_norms) {
|
|
608
|
+
using T = typename C::T;
|
|
609
|
+
using TI = typename C::TI;
|
|
610
|
+
|
|
611
|
+
int nt = omp_get_max_threads();
|
|
612
|
+
const size_t bs_y = distance_compute_blas_database_bs;
|
|
613
|
+
|
|
614
|
+
// Per-thread result heaps: nt threads x nx queries x k results
|
|
615
|
+
std::vector<T> all_dis(static_cast<size_t>(nt) * nx * k);
|
|
616
|
+
std::vector<TI> all_ids(static_cast<size_t>(nt) * nx * k);
|
|
617
|
+
|
|
618
|
+
std::unique_ptr<float[]> x_norms_storage;
|
|
619
|
+
std::unique_ptr<float[]> y_norms_storage;
|
|
620
|
+
const float* x_norms = nullptr;
|
|
621
|
+
// C::is_max corresponds to L2 (CMax), not IP (CMin)
|
|
622
|
+
if constexpr (C::is_max) {
|
|
623
|
+
x_norms_storage.reset(new float[nx]);
|
|
624
|
+
fvec_norms_L2sqr(x_norms_storage.get(), x, d, nx);
|
|
625
|
+
x_norms = x_norms_storage.get();
|
|
626
|
+
|
|
627
|
+
if (!y_norms) {
|
|
628
|
+
y_norms_storage.reset(new float[ny]);
|
|
629
|
+
y_norms = y_norms_storage.get();
|
|
630
|
+
}
|
|
631
|
+
}
|
|
632
|
+
|
|
633
|
+
#pragma omp parallel num_threads(nt)
|
|
634
|
+
{
|
|
635
|
+
int tid = omp_get_thread_num();
|
|
636
|
+
size_t j_begin = static_cast<size_t>(tid) * ny / nt;
|
|
637
|
+
size_t j_end = static_cast<size_t>(tid + 1) * ny / nt;
|
|
638
|
+
size_t local_ny = j_end - j_begin;
|
|
639
|
+
|
|
640
|
+
// Compute y_norms for this thread's segment (cache locality)
|
|
641
|
+
if constexpr (C::is_max) {
|
|
642
|
+
if (y_norms_storage && local_ny > 0) {
|
|
643
|
+
fvec_norms_L2sqr(
|
|
644
|
+
y_norms_storage.get() + j_begin,
|
|
645
|
+
y + j_begin * d,
|
|
646
|
+
d,
|
|
647
|
+
local_ny);
|
|
648
|
+
}
|
|
649
|
+
}
|
|
650
|
+
|
|
651
|
+
T* my_dis = all_dis.data() + tid * nx * k;
|
|
652
|
+
TI* my_ids = all_ids.data() + tid * nx * k;
|
|
653
|
+
|
|
654
|
+
// Each thread initializes its own heaps
|
|
655
|
+
for (size_t i = 0; i < nx; i++) {
|
|
656
|
+
heap_heapify<C>(k, my_dis + i * k, my_ids + i * k);
|
|
657
|
+
}
|
|
658
|
+
|
|
659
|
+
if (local_ny > 0) {
|
|
660
|
+
size_t max_block = std::min(bs_y, local_ny);
|
|
661
|
+
std::unique_ptr<float[]> ip_block(new float[nx * max_block]);
|
|
662
|
+
|
|
663
|
+
for (size_t jj0 = 0; jj0 < local_ny; jj0 += bs_y) {
|
|
664
|
+
size_t jj1 = std::min(jj0 + bs_y, local_ny);
|
|
665
|
+
size_t block_ny = jj1 - jj0;
|
|
666
|
+
|
|
667
|
+
{
|
|
668
|
+
float one = 1, zero = 0;
|
|
669
|
+
FINTEGER nyi = static_cast<FINTEGER>(block_ny);
|
|
670
|
+
FINTEGER nxi = static_cast<FINTEGER>(nx);
|
|
671
|
+
FINTEGER di = static_cast<FINTEGER>(d);
|
|
672
|
+
sgemm_("Transpose",
|
|
673
|
+
"Not transpose",
|
|
674
|
+
&nyi,
|
|
675
|
+
&nxi,
|
|
676
|
+
&di,
|
|
677
|
+
&one,
|
|
678
|
+
y + (j_begin + jj0) * d,
|
|
679
|
+
&di,
|
|
680
|
+
x,
|
|
681
|
+
&di,
|
|
682
|
+
&zero,
|
|
683
|
+
ip_block.get(),
|
|
684
|
+
&nyi);
|
|
685
|
+
}
|
|
686
|
+
|
|
687
|
+
for (size_t i = 0; i < nx; i++) {
|
|
688
|
+
T* heap_dis = my_dis + i * k;
|
|
689
|
+
TI* heap_ids = my_ids + i * k;
|
|
690
|
+
const float* ip_line = ip_block.get() + i * block_ny;
|
|
691
|
+
T thresh = heap_dis[0];
|
|
692
|
+
|
|
693
|
+
for (size_t jj = 0; jj < block_ny; jj++) {
|
|
694
|
+
size_t global_j = j_begin + jj0 + jj;
|
|
695
|
+
float ip = ip_line[jj];
|
|
696
|
+
T dis;
|
|
697
|
+
|
|
698
|
+
if constexpr (C::is_max) {
|
|
699
|
+
dis = x_norms[i] + y_norms[global_j] - 2 * ip;
|
|
700
|
+
if (dis < 0) {
|
|
701
|
+
dis = 0;
|
|
702
|
+
}
|
|
703
|
+
} else {
|
|
704
|
+
dis = ip;
|
|
705
|
+
}
|
|
706
|
+
|
|
707
|
+
if (C::cmp(thresh, dis)) {
|
|
708
|
+
heap_replace_top<C>(
|
|
709
|
+
k, heap_dis, heap_ids, dis, global_j);
|
|
710
|
+
thresh = heap_dis[0];
|
|
711
|
+
}
|
|
712
|
+
}
|
|
713
|
+
}
|
|
714
|
+
}
|
|
715
|
+
}
|
|
716
|
+
}
|
|
717
|
+
|
|
718
|
+
// Merge per-thread heaps into output, parallelized over queries
|
|
719
|
+
#pragma omp parallel for
|
|
720
|
+
for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
|
|
721
|
+
heap_heapify<C>(k, vals + i * k, ids + i * k);
|
|
722
|
+
|
|
723
|
+
for (int t = 0; t < nt; t++) {
|
|
724
|
+
T* t_dis = all_dis.data() + (t * nx + i) * k;
|
|
725
|
+
TI* t_ids = all_ids.data() + (t * nx + i) * k;
|
|
726
|
+
T* out_dis = vals + i * k;
|
|
727
|
+
TI* out_ids = ids + i * k;
|
|
728
|
+
|
|
729
|
+
for (size_t j = 0; j < k; j++) {
|
|
730
|
+
if (t_ids[j] >= 0 && C::cmp(out_dis[0], t_dis[j])) {
|
|
731
|
+
heap_replace_top<C>(
|
|
732
|
+
k, out_dis, out_ids, t_dis[j], t_ids[j]);
|
|
733
|
+
}
|
|
734
|
+
}
|
|
735
|
+
}
|
|
736
|
+
|
|
737
|
+
heap_reorder<C>(k, vals + i * k, ids + i * k);
|
|
738
|
+
}
|
|
739
|
+
}
|
|
740
|
+
|
|
741
|
+
static bool should_use_db_parallel(
|
|
742
|
+
size_t nx,
|
|
743
|
+
size_t ny,
|
|
744
|
+
const IDSelector* sel) {
|
|
745
|
+
if (sel) {
|
|
746
|
+
return false;
|
|
747
|
+
}
|
|
748
|
+
int nt = omp_get_max_threads();
|
|
749
|
+
size_t min_ny = std::max(
|
|
750
|
+
kDbParallelMinVectors,
|
|
751
|
+
static_cast<size_t>(nt) *
|
|
752
|
+
static_cast<size_t>(distance_compute_blas_database_bs));
|
|
753
|
+
return nt > 1 && nx < static_cast<size_t>(nt) && ny >= min_ny;
|
|
754
|
+
}
|
|
755
|
+
|
|
963
756
|
void knn_inner_product(
|
|
964
757
|
const float* x,
|
|
965
758
|
const float* y,
|
|
@@ -984,9 +777,26 @@ void knn_inner_product(
|
|
|
984
777
|
return;
|
|
985
778
|
}
|
|
986
779
|
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
780
|
+
if (should_use_db_parallel(nx, ny, sel)) {
|
|
781
|
+
knn_db_parallel_impl<CMin<float, int64_t>>(
|
|
782
|
+
x, y, d, nx, ny, k, vals, ids, nullptr);
|
|
783
|
+
} else {
|
|
784
|
+
Run_search_inner_product r;
|
|
785
|
+
// @lint-ignore CLANGTIDY facebook-hte-NullableDereference
|
|
786
|
+
dispatch_knn_ResultHandler(
|
|
787
|
+
nx,
|
|
788
|
+
vals,
|
|
789
|
+
ids,
|
|
790
|
+
k,
|
|
791
|
+
METRIC_INNER_PRODUCT,
|
|
792
|
+
sel,
|
|
793
|
+
r,
|
|
794
|
+
x,
|
|
795
|
+
y,
|
|
796
|
+
d,
|
|
797
|
+
nx,
|
|
798
|
+
ny);
|
|
799
|
+
}
|
|
990
800
|
|
|
991
801
|
if (imin != 0) {
|
|
992
802
|
for (size_t i = 0; i < nx * k; i++) {
|
|
@@ -1033,9 +843,15 @@ void knn_L2sqr(
|
|
|
1033
843
|
return;
|
|
1034
844
|
}
|
|
1035
845
|
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
846
|
+
if (should_use_db_parallel(nx, ny, sel)) {
|
|
847
|
+
knn_db_parallel_impl<CMax<float, int64_t>>(
|
|
848
|
+
x, y, d, nx, ny, k, vals, ids, y_norm2);
|
|
849
|
+
} else {
|
|
850
|
+
Run_search_L2sqr r;
|
|
851
|
+
// @lint-ignore CLANGTIDY facebook-hte-NullableDereference
|
|
852
|
+
dispatch_knn_ResultHandler(
|
|
853
|
+
nx, vals, ids, k, METRIC_L2, sel, r, x, y, d, nx, ny, y_norm2);
|
|
854
|
+
}
|
|
1039
855
|
|
|
1040
856
|
if (imin != 0) {
|
|
1041
857
|
for (size_t i = 0; i < nx * k; i++) {
|
|
@@ -1106,19 +922,21 @@ void fvec_inner_products_by_idx(
|
|
|
1106
922
|
size_t d,
|
|
1107
923
|
size_t nx,
|
|
1108
924
|
size_t ny) {
|
|
925
|
+
with_simd_level([&]<SIMDLevel SL>() {
|
|
1109
926
|
#pragma omp parallel for
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
|
|
1113
|
-
|
|
1114
|
-
|
|
1115
|
-
|
|
1116
|
-
|
|
1117
|
-
|
|
1118
|
-
|
|
927
|
+
for (int64_t j = 0; j < static_cast<int64_t>(nx); j++) {
|
|
928
|
+
const int64_t* __restrict idsj = ids + j * ny;
|
|
929
|
+
const float* xj = x + j * d;
|
|
930
|
+
float* __restrict ipj = ip + j * ny;
|
|
931
|
+
for (size_t i = 0; i < ny; i++) {
|
|
932
|
+
if (idsj[i] < 0) {
|
|
933
|
+
ipj[i] = -INFINITY;
|
|
934
|
+
} else {
|
|
935
|
+
ipj[i] = fvec_inner_product<SL>(xj, y + d * idsj[i], d);
|
|
936
|
+
}
|
|
1119
937
|
}
|
|
1120
938
|
}
|
|
1121
|
-
}
|
|
939
|
+
});
|
|
1122
940
|
}
|
|
1123
941
|
|
|
1124
942
|
/* compute the inner product between x and a subset y of ny vectors,
|
|
@@ -1131,19 +949,21 @@ void fvec_L2sqr_by_idx(
|
|
|
1131
949
|
size_t d,
|
|
1132
950
|
size_t nx,
|
|
1133
951
|
size_t ny) {
|
|
952
|
+
with_simd_level([&]<SIMDLevel SL>() {
|
|
1134
953
|
#pragma omp parallel for
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
|
|
1143
|
-
|
|
954
|
+
for (int64_t j = 0; j < static_cast<int64_t>(nx); j++) {
|
|
955
|
+
const int64_t* __restrict idsj = ids + j * ny;
|
|
956
|
+
const float* xj = x + j * d;
|
|
957
|
+
float* __restrict disj = dis + j * ny;
|
|
958
|
+
for (size_t i = 0; i < ny; i++) {
|
|
959
|
+
if (idsj[i] < 0) {
|
|
960
|
+
disj[i] = INFINITY;
|
|
961
|
+
} else {
|
|
962
|
+
disj[i] = fvec_L2sqr<SL>(xj, y + d * idsj[i], d);
|
|
963
|
+
}
|
|
1144
964
|
}
|
|
1145
965
|
}
|
|
1146
|
-
}
|
|
966
|
+
});
|
|
1147
967
|
}
|
|
1148
968
|
|
|
1149
969
|
void pairwise_indexed_L2sqr(
|
|
@@ -1154,14 +974,16 @@ void pairwise_indexed_L2sqr(
|
|
|
1154
974
|
const float* y,
|
|
1155
975
|
const int64_t* iy,
|
|
1156
976
|
float* dis) {
|
|
977
|
+
with_simd_level([&]<SIMDLevel SL>() {
|
|
1157
978
|
#pragma omp parallel for if (n > 1)
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
|
|
979
|
+
for (int64_t j = 0; j < static_cast<int64_t>(n); j++) {
|
|
980
|
+
if (ix[j] >= 0 && iy[j] >= 0) {
|
|
981
|
+
dis[j] = fvec_L2sqr<SL>(x + d * ix[j], y + d * iy[j], d);
|
|
982
|
+
} else {
|
|
983
|
+
dis[j] = INFINITY;
|
|
984
|
+
}
|
|
1163
985
|
}
|
|
1164
|
-
}
|
|
986
|
+
});
|
|
1165
987
|
}
|
|
1166
988
|
|
|
1167
989
|
void pairwise_indexed_inner_product(
|
|
@@ -1172,15 +994,17 @@ void pairwise_indexed_inner_product(
|
|
|
1172
994
|
const float* y,
|
|
1173
995
|
const int64_t* iy,
|
|
1174
996
|
float* dis) {
|
|
997
|
+
with_simd_level([&]<SIMDLevel SL>() {
|
|
1175
998
|
#pragma omp parallel for if (n > 1)
|
|
1176
|
-
|
|
1177
|
-
|
|
1178
|
-
|
|
1179
|
-
|
|
1180
|
-
|
|
1181
|
-
|
|
999
|
+
for (int64_t j = 0; j < static_cast<int64_t>(n); j++) {
|
|
1000
|
+
if (ix[j] >= 0 && iy[j] >= 0) {
|
|
1001
|
+
dis[j] =
|
|
1002
|
+
fvec_inner_product<SL>(x + d * ix[j], y + d * iy[j], d);
|
|
1003
|
+
} else {
|
|
1004
|
+
dis[j] = -INFINITY;
|
|
1005
|
+
}
|
|
1182
1006
|
}
|
|
1183
|
-
}
|
|
1007
|
+
});
|
|
1184
1008
|
}
|
|
1185
1009
|
|
|
1186
1010
|
/* Find the nearest neighbors for nx queries in a set of ny vectors
|
|
@@ -1201,27 +1025,29 @@ void knn_inner_products_by_idx(
|
|
|
1201
1025
|
ld_ids = ny;
|
|
1202
1026
|
}
|
|
1203
1027
|
|
|
1028
|
+
with_simd_level([&]<SIMDLevel SL>() {
|
|
1204
1029
|
#pragma omp parallel for if (nx > 100)
|
|
1205
|
-
|
|
1206
|
-
|
|
1207
|
-
|
|
1208
|
-
|
|
1209
|
-
|
|
1210
|
-
|
|
1211
|
-
|
|
1212
|
-
|
|
1213
|
-
|
|
1214
|
-
|
|
1215
|
-
|
|
1216
|
-
|
|
1217
|
-
|
|
1030
|
+
for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
|
|
1031
|
+
const float* x_ = x + i * d;
|
|
1032
|
+
const int64_t* idsi = ids + i * ld_ids;
|
|
1033
|
+
size_t j;
|
|
1034
|
+
float* __restrict simi = res_vals + i * k;
|
|
1035
|
+
int64_t* __restrict idxi = res_ids + i * k;
|
|
1036
|
+
minheap_heapify(k, simi, idxi);
|
|
1037
|
+
|
|
1038
|
+
for (j = 0; j < nsubset; j++) {
|
|
1039
|
+
if (idsi[j] < 0 || static_cast<size_t>(idsi[j]) >= ny) {
|
|
1040
|
+
break;
|
|
1041
|
+
}
|
|
1042
|
+
float ip = fvec_inner_product<SL>(x_, y + d * idsi[j], d);
|
|
1218
1043
|
|
|
1219
|
-
|
|
1220
|
-
|
|
1044
|
+
if (ip > simi[0]) {
|
|
1045
|
+
minheap_replace_top(k, simi, idxi, ip, idsi[j]);
|
|
1046
|
+
}
|
|
1221
1047
|
}
|
|
1048
|
+
minheap_reorder(k, simi, idxi);
|
|
1222
1049
|
}
|
|
1223
|
-
|
|
1224
|
-
}
|
|
1050
|
+
});
|
|
1225
1051
|
}
|
|
1226
1052
|
|
|
1227
1053
|
void knn_L2sqr_by_idx(
|
|
@@ -1239,25 +1065,27 @@ void knn_L2sqr_by_idx(
|
|
|
1239
1065
|
if (ld_ids < 0) {
|
|
1240
1066
|
ld_ids = ny;
|
|
1241
1067
|
}
|
|
1068
|
+
with_simd_level([&]<SIMDLevel SL>() {
|
|
1242
1069
|
#pragma omp parallel for if (nx > 100)
|
|
1243
|
-
|
|
1244
|
-
|
|
1245
|
-
|
|
1246
|
-
|
|
1247
|
-
|
|
1248
|
-
|
|
1249
|
-
|
|
1250
|
-
|
|
1251
|
-
|
|
1252
|
-
|
|
1253
|
-
|
|
1070
|
+
for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
|
|
1071
|
+
const float* x_ = x + i * d;
|
|
1072
|
+
const int64_t* __restrict idsi = ids + i * ld_ids;
|
|
1073
|
+
float* __restrict simi = res_vals + i * k;
|
|
1074
|
+
int64_t* __restrict idxi = res_ids + i * k;
|
|
1075
|
+
maxheap_heapify(k, simi, idxi);
|
|
1076
|
+
for (size_t j = 0; j < nsubset; j++) {
|
|
1077
|
+
if (idsi[j] < 0 || static_cast<size_t>(idsi[j]) >= ny) {
|
|
1078
|
+
break;
|
|
1079
|
+
}
|
|
1080
|
+
float disij = fvec_L2sqr<SL>(x_, y + d * idsi[j], d);
|
|
1254
1081
|
|
|
1255
|
-
|
|
1256
|
-
|
|
1082
|
+
if (disij < simi[0]) {
|
|
1083
|
+
maxheap_replace_top(k, simi, idxi, disij, idsi[j]);
|
|
1084
|
+
}
|
|
1257
1085
|
}
|
|
1086
|
+
maxheap_reorder(k, simi, idxi);
|
|
1258
1087
|
}
|
|
1259
|
-
|
|
1260
|
-
}
|
|
1088
|
+
});
|
|
1261
1089
|
}
|
|
1262
1090
|
|
|
1263
1091
|
void pairwise_L2sqr(
|
|
@@ -1286,25 +1114,27 @@ void pairwise_L2sqr(
|
|
|
1286
1114
|
// store in beginning of distance matrix to avoid malloc
|
|
1287
1115
|
float* b_norms = dis;
|
|
1288
1116
|
|
|
1117
|
+
with_simd_level([&]<SIMDLevel SL>() {
|
|
1289
1118
|
#pragma omp parallel for if (nb > 1)
|
|
1290
|
-
|
|
1291
|
-
|
|
1292
|
-
|
|
1119
|
+
for (int64_t i = 0; i < nb; i++) {
|
|
1120
|
+
b_norms[i] = fvec_norm_L2sqr<SL>(xb + i * ldb, d);
|
|
1121
|
+
}
|
|
1293
1122
|
|
|
1294
1123
|
#pragma omp parallel for
|
|
1295
|
-
|
|
1296
|
-
|
|
1297
|
-
|
|
1298
|
-
|
|
1124
|
+
for (int64_t i = 1; i < nq; i++) {
|
|
1125
|
+
float q_norm = fvec_norm_L2sqr<SL>(xq + i * ldq, d);
|
|
1126
|
+
for (int64_t j = 0; j < nb; j++) {
|
|
1127
|
+
dis[i * ldd + j] = q_norm + b_norms[j];
|
|
1128
|
+
}
|
|
1299
1129
|
}
|
|
1300
|
-
}
|
|
1301
1130
|
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
|
|
1305
|
-
|
|
1131
|
+
{
|
|
1132
|
+
float q_norm = fvec_norm_L2sqr<SL>(xq, d);
|
|
1133
|
+
for (int64_t j = 0; j < nb; j++) {
|
|
1134
|
+
dis[j] += q_norm;
|
|
1135
|
+
}
|
|
1306
1136
|
}
|
|
1307
|
-
}
|
|
1137
|
+
});
|
|
1308
1138
|
|
|
1309
1139
|
{
|
|
1310
1140
|
FINTEGER nbi = nb, nqi = nq, di = d, ldqi = ldq, ldbi = ldb, lddi = ldd;
|
|
@@ -1333,7 +1163,7 @@ void inner_product_to_L2sqr(
|
|
|
1333
1163
|
size_t n1,
|
|
1334
1164
|
size_t n2) {
|
|
1335
1165
|
#pragma omp parallel for
|
|
1336
|
-
for (int64_t j = 0; j < n1; j++) {
|
|
1166
|
+
for (int64_t j = 0; j < static_cast<int64_t>(n1); j++) {
|
|
1337
1167
|
float* disj = dis + j * n2;
|
|
1338
1168
|
for (size_t i = 0; i < n2; i++) {
|
|
1339
1169
|
disj[i] = nr1[j] + nr2[i] - 2 * disj[i];
|