faiss 0.5.3 → 0.6.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -0
- data/ext/faiss/ext.cpp +1 -1
- data/ext/faiss/extconf.rb +4 -4
- data/ext/faiss/index.cpp +63 -45
- data/ext/faiss/index_binary.cpp +37 -27
- data/ext/faiss/kmeans.cpp +9 -8
- data/ext/faiss/pca_matrix.cpp +9 -7
- data/ext/faiss/product_quantizer.cpp +13 -11
- data/ext/faiss/utils.cpp +4 -2
- data/ext/faiss/utils.h +4 -0
- data/lib/faiss/version.rb +1 -1
- data/lib/faiss.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +214 -82
- data/vendor/faiss/faiss/AutoTune.h +14 -1
- data/vendor/faiss/faiss/Clustering.cpp +97 -249
- data/vendor/faiss/faiss/Clustering.h +18 -0
- data/vendor/faiss/faiss/IVFlib.cpp +67 -44
- data/vendor/faiss/faiss/Index.cpp +25 -12
- data/vendor/faiss/faiss/Index.h +26 -4
- data/vendor/faiss/faiss/Index2Layer.cpp +37 -53
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +68 -61
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +36 -34
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +4 -1
- data/vendor/faiss/faiss/IndexBinary.cpp +6 -3
- data/vendor/faiss/faiss/IndexBinary.h +4 -4
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +1 -1
- data/vendor/faiss/faiss/IndexBinaryFlat.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -4
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +92 -95
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +9 -3
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +45 -236
- data/vendor/faiss/faiss/IndexBinaryHash.h +6 -6
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +120 -414
- data/vendor/faiss/faiss/IndexFastScan.cpp +105 -129
- data/vendor/faiss/faiss/IndexFastScan.h +35 -24
- data/vendor/faiss/faiss/IndexFlat.cpp +216 -152
- data/vendor/faiss/faiss/IndexFlat.h +32 -14
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +88 -41
- data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +299 -187
- data/vendor/faiss/faiss/IndexHNSW.h +30 -14
- data/vendor/faiss/faiss/IndexIDMap.cpp +26 -22
- data/vendor/faiss/faiss/IndexIDMap.h +9 -7
- data/vendor/faiss/faiss/IndexIVF.cpp +535 -405
- data/vendor/faiss/faiss/IndexIVF.h +47 -16
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +77 -74
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +105 -99
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +6 -3
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +379 -249
- data/vendor/faiss/faiss/IndexIVFFastScan.h +65 -60
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +41 -124
- data/vendor/faiss/faiss/IndexIVFFlat.h +32 -0
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +89 -138
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +3 -1
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +18 -15
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +77 -907
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +184 -122
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +3 -0
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +23 -18
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +59 -60
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +4 -3
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +564 -416
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +269 -111
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +41 -127
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
- data/vendor/faiss/faiss/IndexLSH.cpp +44 -25
- data/vendor/faiss/faiss/IndexLattice.cpp +41 -36
- data/vendor/faiss/faiss/IndexNNDescent.cpp +37 -21
- data/vendor/faiss/faiss/IndexNNDescent.h +2 -2
- data/vendor/faiss/faiss/IndexNSG.cpp +40 -23
- data/vendor/faiss/faiss/IndexNSG.h +0 -2
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +32 -12
- data/vendor/faiss/faiss/IndexPQ.cpp +129 -213
- data/vendor/faiss/faiss/IndexPQ.h +3 -2
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +20 -14
- data/vendor/faiss/faiss/IndexPQFastScan.h +3 -0
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -18
- data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +31 -43
- data/vendor/faiss/faiss/IndexRaBitQ.h +4 -3
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +135 -317
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +192 -34
- data/vendor/faiss/faiss/IndexRefine.cpp +30 -55
- data/vendor/faiss/faiss/IndexRefine.h +4 -4
- data/vendor/faiss/faiss/IndexReplicas.cpp +6 -6
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +15 -14
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +1 -1
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +82 -14
- data/vendor/faiss/faiss/IndexShards.cpp +13 -13
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +21 -15
- data/vendor/faiss/faiss/MatrixStats.cpp +5 -4
- data/vendor/faiss/faiss/MetaIndexes.cpp +19 -17
- data/vendor/faiss/faiss/MetaIndexes.h +1 -1
- data/vendor/faiss/faiss/MetricType.h +29 -6
- data/vendor/faiss/faiss/SuperKMeans.cpp +656 -0
- data/vendor/faiss/faiss/SuperKMeans.h +97 -0
- data/vendor/faiss/faiss/VectorTransform.cpp +349 -141
- data/vendor/faiss/faiss/VectorTransform.h +39 -16
- data/vendor/faiss/faiss/build.cpp +23 -0
- data/vendor/faiss/faiss/build.h +15 -0
- data/vendor/faiss/faiss/clone_index.cpp +55 -51
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +47 -47
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +11 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +38 -38
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +11 -0
- data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +6 -1
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +6 -5
- data/vendor/faiss/faiss/gpu/GpuResources.h +1 -1
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +9 -9
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +4 -3
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +46 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +56 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +78 -1
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +72 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +23 -0
- data/vendor/faiss/faiss/gpu/utils/CuvsFilterConvert.h +1 -1
- data/vendor/faiss/faiss/gpu/utils/CuvsUtils.h +21 -10
- data/vendor/faiss/faiss/gpu_metal/GpuIndexFlat.h +22 -0
- data/vendor/faiss/faiss/gpu_metal/MetalCloner.h +35 -0
- data/vendor/faiss/faiss/gpu_metal/MetalFlatKernels.h +40 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +51 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndexFlat.h +65 -0
- data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +66 -0
- data/vendor/faiss/faiss/gpu_metal/MetalResources.h +79 -0
- data/vendor/faiss/faiss/gpu_metal/StandardMetalResources.h +35 -0
- data/vendor/faiss/faiss/impl/AdSampling.cpp +103 -0
- data/vendor/faiss/faiss/impl/AdSampling.h +35 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +64 -34
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +10 -9
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +3 -28
- data/vendor/faiss/faiss/impl/ClusteringHelpers.cpp +244 -0
- data/vendor/faiss/faiss/impl/ClusteringHelpers.h +94 -0
- data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
- data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
- data/vendor/faiss/faiss/impl/CodePacker.cpp +7 -3
- data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
- data/vendor/faiss/faiss/impl/DistanceComputer.h +8 -8
- data/vendor/faiss/faiss/impl/FaissAssert.h +64 -3
- data/vendor/faiss/faiss/impl/FaissException.h +50 -3
- data/vendor/faiss/faiss/impl/HNSW.cpp +117 -351
- data/vendor/faiss/faiss/impl/HNSW.h +21 -40
- data/vendor/faiss/faiss/impl/IDSelector.cpp +15 -11
- data/vendor/faiss/faiss/impl/IDSelector.h +8 -8
- data/vendor/faiss/faiss/impl/InvertedListScannerStats.h +26 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +114 -102
- data/vendor/faiss/faiss/impl/NNDescent.cpp +63 -26
- data/vendor/faiss/faiss/impl/NNDescent.h +6 -2
- data/vendor/faiss/faiss/impl/NSG.cpp +44 -26
- data/vendor/faiss/faiss/impl/NSG.h +20 -10
- data/vendor/faiss/faiss/impl/Panorama.cpp +76 -52
- data/vendor/faiss/faiss/impl/Panorama.h +265 -78
- data/vendor/faiss/faiss/impl/PdxLayout.cpp +93 -0
- data/vendor/faiss/faiss/impl/PdxLayout.h +41 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +62 -37
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +3 -3
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +35 -35
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +21 -16
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +99 -80
- data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +135 -37
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +148 -21
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +298 -301
- data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +40 -32
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResultHandler.h +218 -113
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +119 -2362
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +27 -3
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +14 -11
- data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
- data/vendor/faiss/faiss/impl/VisitedTable.h +76 -0
- data/vendor/faiss/faiss/impl/approx_topk/approx_topk.h +276 -0
- data/vendor/faiss/faiss/impl/approx_topk/avx2.cpp +68 -0
- data/vendor/faiss/faiss/{utils → impl}/approx_topk/generic.h +15 -8
- data/vendor/faiss/faiss/impl/approx_topk/neon.cpp +68 -0
- data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab-inl.h +169 -0
- data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab.h +117 -0
- data/vendor/faiss/faiss/impl/approx_topk/simdlib256-inl.h +146 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHNSW_impl.h +73 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHash_impl.h +270 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryIVF_impl.h +460 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexIVFSpectralHash_impl.h +159 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexPQ_impl.h +92 -0
- data/vendor/faiss/faiss/impl/binary_hamming/avx2.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/avx512.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/dispatch.h +143 -0
- data/vendor/faiss/faiss/impl/binary_hamming/neon.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/rvv.cpp +26 -0
- data/vendor/faiss/faiss/impl/expanded_scanners.h +163 -0
- data/vendor/faiss/faiss/impl/{FastScanDistancePostProcessing.h → fast_scan/FastScanDistancePostProcessing.h} +13 -6
- data/vendor/faiss/faiss/impl/{LookupTableScaler.h → fast_scan/LookupTableScaler.h} +16 -5
- data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops.h +237 -0
- data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops_512.h +185 -0
- data/vendor/faiss/faiss/impl/fast_scan/decompose_qbs.h +229 -0
- data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +268 -0
- data/vendor/faiss/faiss/impl/{pq4_fast_scan.cpp → fast_scan/fast_scan.cpp} +176 -4
- data/vendor/faiss/faiss/impl/fast_scan/fast_scan.h +341 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-avx2.cpp +36 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-avx512.cpp +40 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-neon.cpp +120 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-riscv.cpp +104 -0
- data/vendor/faiss/faiss/impl/fast_scan/kernels_simd256.h +213 -0
- data/vendor/faiss/faiss/impl/{pq4_fast_scan_search_qbs.cpp → fast_scan/kernels_simd512.h} +26 -348
- data/vendor/faiss/faiss/impl/fast_scan/rabitq_dispatching.h +90 -0
- data/vendor/faiss/faiss/impl/fast_scan/rabitq_result_handler.h +108 -0
- data/vendor/faiss/faiss/impl/{simd_result_handlers.h → fast_scan/simd_result_handlers.h} +290 -142
- data/vendor/faiss/faiss/impl/hnsw/LockVector.cpp +54 -0
- data/vendor/faiss/faiss/impl/hnsw/LockVector.h +64 -0
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +91 -0
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -0
- data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +104 -0
- data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +111 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +1950 -505
- data/vendor/faiss/faiss/impl/index_read_utils.h +1 -2
- data/vendor/faiss/faiss/impl/index_write.cpp +112 -21
- data/vendor/faiss/faiss/impl/io.cpp +6 -6
- data/vendor/faiss/faiss/impl/io_macros.h +33 -16
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +10 -10
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +81 -40
- data/vendor/faiss/faiss/impl/lattice_Zn.h +6 -6
- data/vendor/faiss/faiss/impl/mapped_io.cpp +15 -8
- data/vendor/faiss/faiss/impl/platform_macros.h +11 -4
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQScanner_impl.h +549 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.cpp +245 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.h +105 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/PQDistanceComputer_impl.h +106 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +21 -0
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.h} +43 -220
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.h} +25 -112
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +59 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.h +96 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +256 -0
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -146
- data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +68 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +320 -483
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
- data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +137 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +371 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +190 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +603 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +597 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +388 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +630 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +311 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +387 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +54 -0
- data/vendor/faiss/faiss/impl/simd_dispatch.h +173 -0
- data/vendor/faiss/faiss/impl/simdlib/simdlib.h +57 -0
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_avx2.h +274 -171
- data/vendor/faiss/faiss/impl/simdlib/simdlib_avx512.h +414 -0
- data/vendor/faiss/faiss/impl/simdlib/simdlib_dispatch.h +44 -0
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_emulated.h +231 -166
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_neon.h +275 -217
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_ppc64.h +201 -160
- data/vendor/faiss/faiss/impl/svs_io.cpp +12 -3
- data/vendor/faiss/faiss/impl/svs_io.h +8 -2
- data/vendor/faiss/faiss/index_factory.cpp +115 -28
- data/vendor/faiss/faiss/index_io.h +53 -3
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +73 -20
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +24 -14
- data/vendor/faiss/faiss/invlists/DirectMap.h +4 -3
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +157 -73
- data/vendor/faiss/faiss/invlists/InvertedLists.h +86 -23
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +4 -4
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +14 -14
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
- data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +2 -2
- data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVF.cpp +350 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVF.h +128 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.cpp +40 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.h +43 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.cpp +225 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.h +71 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +25 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +19 -2
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +19 -2
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +14 -0
- data/vendor/faiss/faiss/utils/Heap.cpp +56 -10
- data/vendor/faiss/faiss/utils/Heap.h +21 -0
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +54 -40
- data/vendor/faiss/faiss/utils/NeuralNet.h +1 -1
- data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +10 -4
- data/vendor/faiss/faiss/utils/distances.cpp +507 -559
- data/vendor/faiss/faiss/utils/distances.h +118 -1
- data/vendor/faiss/faiss/utils/distances_dispatch.h +250 -0
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +8 -7
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +33 -14
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +12 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +16 -293
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based_neon.cpp +57 -0
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_kernel-inl.h +290 -0
- data/vendor/faiss/faiss/utils/distances_simd.cpp +72 -3681
- data/vendor/faiss/faiss/utils/extra_distances.cpp +60 -102
- data/vendor/faiss/faiss/utils/extra_distances.h +79 -7
- data/vendor/faiss/faiss/utils/hamming-inl.h +13 -11
- data/vendor/faiss/faiss/utils/hamming.cpp +66 -517
- data/vendor/faiss/faiss/utils/hamming.h +92 -2
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +287 -10
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx2.h +142 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +234 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-generic.h +368 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-neon.h +322 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-rvv.h +39 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer.h +146 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_impl.h +481 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_neon.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_rvv.cpp +15 -0
- data/vendor/faiss/faiss/utils/partitioning.cpp +66 -987
- data/vendor/faiss/faiss/utils/partitioning.h +31 -0
- data/vendor/faiss/faiss/utils/popcount.h +29 -0
- data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
- data/vendor/faiss/faiss/utils/prefetch.h +2 -2
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +30 -30
- data/vendor/faiss/faiss/utils/quantize_lut.h +1 -1
- data/vendor/faiss/faiss/utils/rabitq_simd.h +124 -343
- data/vendor/faiss/faiss/utils/random.cpp +6 -6
- data/vendor/faiss/faiss/utils/simd_impl/IVFFlatScanner-inl.h +51 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +154 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +777 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +306 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1431 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1095 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_rvv.cpp +189 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_simdlib256.h +195 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +392 -0
- data/vendor/faiss/faiss/utils/{distances_fused/simdlib_based.h → simd_impl/exhaustive_L2sqr_blas_cmax.h} +5 -10
- data/vendor/faiss/faiss/utils/simd_impl/hamming_impl.h +481 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_avx2.cpp +14 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_neon.cpp +14 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +1085 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx2.cpp +355 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512.cpp +477 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_neon.cpp +55 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_rvv.cpp +55 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_dispatch.h +32 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels.h +43 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx2.cpp +57 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx512.cpp +45 -0
- data/vendor/faiss/faiss/utils/simd_levels.cpp +334 -0
- data/vendor/faiss/faiss/utils/simd_levels.h +183 -0
- data/vendor/faiss/faiss/utils/sorting.cpp +48 -36
- data/vendor/faiss/faiss/utils/utils.cpp +21 -14
- data/vendor/faiss/faiss/utils/utils.h +3 -3
- metadata +156 -42
- data/vendor/faiss/faiss/impl/RaBitQStats.cpp +0 -29
- data/vendor/faiss/faiss/impl/RaBitQStats.h +0 -56
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +0 -216
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +0 -224
- data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +0 -84
- data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +0 -196
- data/vendor/faiss/faiss/utils/approx_topk/mode.h +0 -34
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +0 -36
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +0 -228
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +0 -462
- data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +0 -490
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -450
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +0 -87
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +0 -524
- data/vendor/faiss/faiss/utils/simdlib.h +0 -42
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +0 -296
- /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
|
@@ -10,94 +10,40 @@
|
|
|
10
10
|
#include <faiss/utils/distances.h>
|
|
11
11
|
|
|
12
12
|
#include <algorithm>
|
|
13
|
-
#include <cassert>
|
|
14
13
|
#include <cmath>
|
|
15
14
|
#include <cstdio>
|
|
16
15
|
#include <cstring>
|
|
17
16
|
|
|
18
17
|
#include <faiss/impl/FaissAssert.h>
|
|
19
|
-
#include <faiss/impl/
|
|
20
|
-
#include <faiss/utils/simdlib.h>
|
|
18
|
+
#include <faiss/impl/simdlib/simdlib_dispatch.h>
|
|
21
19
|
|
|
22
|
-
#
|
|
23
|
-
|
|
24
|
-
#
|
|
25
|
-
|
|
26
|
-
#if defined(__AVX512F__)
|
|
27
|
-
#include <faiss/utils/transpose/transpose-avx512-inl.h>
|
|
28
|
-
#elif defined(__AVX2__)
|
|
29
|
-
#include <faiss/utils/transpose/transpose-avx2-inl.h>
|
|
30
|
-
#endif
|
|
31
|
-
|
|
32
|
-
#ifdef __ARM_FEATURE_SVE
|
|
33
|
-
#include <arm_sve.h>
|
|
34
|
-
#endif
|
|
20
|
+
#define THE_SIMD_LEVEL SIMDLevel::NONE
|
|
21
|
+
// NOLINTNEXTLINE(facebook-hte-InlineHeader)
|
|
22
|
+
#include <faiss/utils/simd_impl/distances_autovec-inl.h>
|
|
35
23
|
|
|
36
|
-
|
|
37
|
-
#include <
|
|
38
|
-
#endif
|
|
24
|
+
// NOLINTNEXTLINE(facebook-hte-InlineHeader)
|
|
25
|
+
#include <faiss/utils/simd_impl/distances_simdlib256.h>
|
|
39
26
|
|
|
40
27
|
namespace faiss {
|
|
41
28
|
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
#endif
|
|
45
|
-
|
|
46
|
-
/*********************************************************
|
|
47
|
-
* Optimized distance computations
|
|
48
|
-
*********************************************************/
|
|
49
|
-
|
|
50
|
-
/* Functions to compute:
|
|
51
|
-
- L2 distance between 2 vectors
|
|
52
|
-
- inner product between 2 vectors
|
|
53
|
-
- L2 norm of a vector
|
|
54
|
-
|
|
55
|
-
The functions should probably not be invoked when a large number of
|
|
56
|
-
vectors are be processed in batch (in which case Matrix multiply
|
|
57
|
-
is faster), but may be useful for comparing vectors isolated in
|
|
58
|
-
memory.
|
|
59
|
-
|
|
60
|
-
Works with any vectors of any dimension, even unaligned (in which
|
|
61
|
-
case they are slower).
|
|
62
|
-
|
|
29
|
+
/*******
|
|
30
|
+
Functions with SIMDLevel::NONE
|
|
63
31
|
*/
|
|
64
32
|
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
float
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
for (i = 0; i <
|
|
73
|
-
|
|
74
|
-
res += fabs(tmp);
|
|
75
|
-
}
|
|
76
|
-
return res;
|
|
77
|
-
}
|
|
78
|
-
|
|
79
|
-
float fvec_Linf_ref(const float* x, const float* y, size_t d) {
|
|
80
|
-
size_t i;
|
|
81
|
-
float res = 0;
|
|
82
|
-
for (i = 0; i < d; i++) {
|
|
83
|
-
res = fmax(res, fabs(x[i] - y[i]));
|
|
84
|
-
}
|
|
85
|
-
return res;
|
|
86
|
-
}
|
|
87
|
-
|
|
88
|
-
void fvec_L2sqr_ny_ref(
|
|
89
|
-
float* dis,
|
|
90
|
-
const float* x,
|
|
91
|
-
const float* y,
|
|
92
|
-
size_t d,
|
|
93
|
-
size_t ny) {
|
|
94
|
-
for (size_t i = 0; i < ny; i++) {
|
|
95
|
-
dis[i] = fvec_L2sqr(x, y, d);
|
|
96
|
-
y += d;
|
|
33
|
+
template <>
|
|
34
|
+
void fvec_madd<SIMDLevel::NONE>(
|
|
35
|
+
size_t n,
|
|
36
|
+
const float* a,
|
|
37
|
+
float bf,
|
|
38
|
+
const float* b,
|
|
39
|
+
float* c) {
|
|
40
|
+
for (size_t i = 0; i < n; i++) {
|
|
41
|
+
c[i] = a[i] + bf * b[i];
|
|
97
42
|
}
|
|
98
43
|
}
|
|
99
44
|
|
|
100
|
-
|
|
45
|
+
template <>
|
|
46
|
+
void fvec_L2sqr_ny_transposed<SIMDLevel::NONE>(
|
|
101
47
|
float* dis,
|
|
102
48
|
const float* x,
|
|
103
49
|
const float* y,
|
|
@@ -120,66 +66,22 @@ void fvec_L2sqr_ny_y_transposed_ref(
|
|
|
120
66
|
}
|
|
121
67
|
}
|
|
122
68
|
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
const float* x,
|
|
126
|
-
const float* y,
|
|
127
|
-
size_t d,
|
|
128
|
-
size_t ny) {
|
|
129
|
-
fvec_L2sqr_ny(distances_tmp_buffer, x, y, d, ny);
|
|
130
|
-
|
|
131
|
-
size_t nearest_idx = 0;
|
|
132
|
-
float min_dis = HUGE_VALF;
|
|
133
|
-
|
|
134
|
-
for (size_t i = 0; i < ny; i++) {
|
|
135
|
-
if (distances_tmp_buffer[i] < min_dis) {
|
|
136
|
-
min_dis = distances_tmp_buffer[i];
|
|
137
|
-
nearest_idx = i;
|
|
138
|
-
}
|
|
139
|
-
}
|
|
140
|
-
|
|
141
|
-
return nearest_idx;
|
|
142
|
-
}
|
|
143
|
-
|
|
144
|
-
size_t fvec_L2sqr_ny_nearest_y_transposed_ref(
|
|
145
|
-
float* distances_tmp_buffer,
|
|
146
|
-
const float* x,
|
|
147
|
-
const float* y,
|
|
148
|
-
const float* y_sqlen,
|
|
149
|
-
size_t d,
|
|
150
|
-
size_t d_offset,
|
|
151
|
-
size_t ny) {
|
|
152
|
-
fvec_L2sqr_ny_y_transposed_ref(
|
|
153
|
-
distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
|
|
154
|
-
|
|
155
|
-
size_t nearest_idx = 0;
|
|
156
|
-
float min_dis = HUGE_VALF;
|
|
157
|
-
|
|
158
|
-
for (size_t i = 0; i < ny; i++) {
|
|
159
|
-
if (distances_tmp_buffer[i] < min_dis) {
|
|
160
|
-
min_dis = distances_tmp_buffer[i];
|
|
161
|
-
nearest_idx = i;
|
|
162
|
-
}
|
|
163
|
-
}
|
|
164
|
-
|
|
165
|
-
return nearest_idx;
|
|
166
|
-
}
|
|
167
|
-
|
|
168
|
-
void fvec_inner_products_ny_ref(
|
|
69
|
+
template <>
|
|
70
|
+
void fvec_inner_products_ny<SIMDLevel::NONE>(
|
|
169
71
|
float* ip,
|
|
170
72
|
const float* x,
|
|
171
73
|
const float* y,
|
|
172
74
|
size_t d,
|
|
173
75
|
size_t ny) {
|
|
174
|
-
|
|
76
|
+
// BLAS slower for the use cases here
|
|
175
77
|
#if 0
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
78
|
+
{
|
|
79
|
+
FINTEGER di = d;
|
|
80
|
+
FINTEGER nyi = ny;
|
|
81
|
+
float one = 1.0, zero = 0.0;
|
|
82
|
+
FINTEGER onei = 1;
|
|
83
|
+
sgemv_ ("T", &di, &nyi, &one, y, &di, x, &onei, &zero, ip, &onei);
|
|
84
|
+
}
|
|
183
85
|
#endif
|
|
184
86
|
for (size_t i = 0; i < ny; i++) {
|
|
185
87
|
ip[i] = fvec_inner_product(x, y, d);
|
|
@@ -187,3595 +89,84 @@ void fvec_inner_products_ny_ref(
|
|
|
187
89
|
}
|
|
188
90
|
}
|
|
189
91
|
|
|
190
|
-
/*********************************************************
|
|
191
|
-
* Autovectorized implementations
|
|
192
|
-
*/
|
|
193
|
-
|
|
194
|
-
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
|
|
195
|
-
float fvec_inner_product(const float* x, const float* y, size_t d) {
|
|
196
|
-
float res = 0.F;
|
|
197
|
-
FAISS_PRAGMA_IMPRECISE_LOOP
|
|
198
|
-
for (size_t i = 0; i != d; ++i) {
|
|
199
|
-
res += x[i] * y[i];
|
|
200
|
-
}
|
|
201
|
-
return res;
|
|
202
|
-
}
|
|
203
|
-
FAISS_PRAGMA_IMPRECISE_FUNCTION_END
|
|
204
|
-
|
|
205
|
-
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
|
|
206
|
-
float fvec_norm_L2sqr(const float* x, size_t d) {
|
|
207
|
-
// the double in the _ref is suspected to be a typo. Some of the manual
|
|
208
|
-
// implementations this replaces used float.
|
|
209
|
-
float res = 0;
|
|
210
|
-
FAISS_PRAGMA_IMPRECISE_LOOP
|
|
211
|
-
for (size_t i = 0; i != d; ++i) {
|
|
212
|
-
res += x[i] * x[i];
|
|
213
|
-
}
|
|
214
|
-
|
|
215
|
-
return res;
|
|
216
|
-
}
|
|
217
|
-
FAISS_PRAGMA_IMPRECISE_FUNCTION_END
|
|
218
|
-
|
|
219
|
-
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
|
|
220
|
-
float fvec_L2sqr(const float* x, const float* y, size_t d) {
|
|
221
|
-
size_t i;
|
|
222
|
-
float res = 0;
|
|
223
|
-
FAISS_PRAGMA_IMPRECISE_LOOP
|
|
224
|
-
for (i = 0; i < d; i++) {
|
|
225
|
-
const float tmp = x[i] - y[i];
|
|
226
|
-
res += tmp * tmp;
|
|
227
|
-
}
|
|
228
|
-
return res;
|
|
229
|
-
}
|
|
230
|
-
FAISS_PRAGMA_IMPRECISE_FUNCTION_END
|
|
231
|
-
|
|
232
|
-
/// Special version of inner product that computes 4 distances
|
|
233
|
-
/// between x and yi
|
|
234
|
-
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
|
|
235
|
-
void fvec_inner_product_batch_4(
|
|
236
|
-
const float* __restrict x,
|
|
237
|
-
const float* __restrict y0,
|
|
238
|
-
const float* __restrict y1,
|
|
239
|
-
const float* __restrict y2,
|
|
240
|
-
const float* __restrict y3,
|
|
241
|
-
const size_t d,
|
|
242
|
-
float& dis0,
|
|
243
|
-
float& dis1,
|
|
244
|
-
float& dis2,
|
|
245
|
-
float& dis3) {
|
|
246
|
-
float d0 = 0;
|
|
247
|
-
float d1 = 0;
|
|
248
|
-
float d2 = 0;
|
|
249
|
-
float d3 = 0;
|
|
250
|
-
FAISS_PRAGMA_IMPRECISE_LOOP
|
|
251
|
-
for (size_t i = 0; i < d; ++i) {
|
|
252
|
-
d0 += x[i] * y0[i];
|
|
253
|
-
d1 += x[i] * y1[i];
|
|
254
|
-
d2 += x[i] * y2[i];
|
|
255
|
-
d3 += x[i] * y3[i];
|
|
256
|
-
}
|
|
257
|
-
|
|
258
|
-
dis0 = d0;
|
|
259
|
-
dis1 = d1;
|
|
260
|
-
dis2 = d2;
|
|
261
|
-
dis3 = d3;
|
|
262
|
-
}
|
|
263
|
-
FAISS_PRAGMA_IMPRECISE_FUNCTION_END
|
|
264
|
-
|
|
265
|
-
/// Special version of L2sqr that computes 4 distances
|
|
266
|
-
/// between x and yi, which is performance oriented.
|
|
267
|
-
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
|
|
268
|
-
void fvec_L2sqr_batch_4(
|
|
269
|
-
const float* x,
|
|
270
|
-
const float* y0,
|
|
271
|
-
const float* y1,
|
|
272
|
-
const float* y2,
|
|
273
|
-
const float* y3,
|
|
274
|
-
const size_t d,
|
|
275
|
-
float& dis0,
|
|
276
|
-
float& dis1,
|
|
277
|
-
float& dis2,
|
|
278
|
-
float& dis3) {
|
|
279
|
-
float d0 = 0;
|
|
280
|
-
float d1 = 0;
|
|
281
|
-
float d2 = 0;
|
|
282
|
-
float d3 = 0;
|
|
283
|
-
FAISS_PRAGMA_IMPRECISE_LOOP
|
|
284
|
-
for (size_t i = 0; i < d; ++i) {
|
|
285
|
-
const float q0 = x[i] - y0[i];
|
|
286
|
-
const float q1 = x[i] - y1[i];
|
|
287
|
-
const float q2 = x[i] - y2[i];
|
|
288
|
-
const float q3 = x[i] - y3[i];
|
|
289
|
-
d0 += q0 * q0;
|
|
290
|
-
d1 += q1 * q1;
|
|
291
|
-
d2 += q2 * q2;
|
|
292
|
-
d3 += q3 * q3;
|
|
293
|
-
}
|
|
294
|
-
|
|
295
|
-
dis0 = d0;
|
|
296
|
-
dis1 = d1;
|
|
297
|
-
dis2 = d2;
|
|
298
|
-
dis3 = d3;
|
|
299
|
-
}
|
|
300
|
-
FAISS_PRAGMA_IMPRECISE_FUNCTION_END
|
|
301
|
-
|
|
302
|
-
/*********************************************************
|
|
303
|
-
* SSE and AVX implementations
|
|
304
|
-
*/
|
|
305
|
-
|
|
306
|
-
#ifdef __SSE3__
|
|
307
|
-
|
|
308
|
-
// reads 0 <= d < 4 floats as __m128
|
|
309
|
-
static inline __m128 masked_read(int d, const float* x) {
|
|
310
|
-
assert(0 <= d && d < 4);
|
|
311
|
-
ALIGNED(16) float buf[4] = {0, 0, 0, 0};
|
|
312
|
-
switch (d) {
|
|
313
|
-
case 3:
|
|
314
|
-
buf[2] = x[2];
|
|
315
|
-
[[fallthrough]];
|
|
316
|
-
case 2:
|
|
317
|
-
buf[1] = x[1];
|
|
318
|
-
[[fallthrough]];
|
|
319
|
-
case 1:
|
|
320
|
-
buf[0] = x[0];
|
|
321
|
-
}
|
|
322
|
-
return _mm_load_ps(buf);
|
|
323
|
-
// cannot use AVX2 _mm_mask_set1_epi32
|
|
324
|
-
}
|
|
325
|
-
|
|
326
|
-
namespace {
|
|
327
|
-
|
|
328
|
-
/// helper function
|
|
329
|
-
inline float horizontal_sum(const __m128 v) {
|
|
330
|
-
// say, v is [x0, x1, x2, x3]
|
|
331
|
-
|
|
332
|
-
// v0 is [x2, x3, ..., ...]
|
|
333
|
-
const __m128 v0 = _mm_shuffle_ps(v, v, _MM_SHUFFLE(0, 0, 3, 2));
|
|
334
|
-
// v1 is [x0 + x2, x1 + x3, ..., ...]
|
|
335
|
-
const __m128 v1 = _mm_add_ps(v, v0);
|
|
336
|
-
// v2 is [x1 + x3, ..., .... ,...]
|
|
337
|
-
__m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1));
|
|
338
|
-
// v3 is [x0 + x1 + x2 + x3, ..., ..., ...]
|
|
339
|
-
const __m128 v3 = _mm_add_ps(v1, v2);
|
|
340
|
-
// return v3[0]
|
|
341
|
-
return _mm_cvtss_f32(v3);
|
|
342
|
-
}
|
|
343
|
-
|
|
344
|
-
#ifdef __AVX2__
|
|
345
|
-
/// helper function for AVX2
|
|
346
|
-
inline float horizontal_sum(const __m256 v) {
|
|
347
|
-
// add high and low parts
|
|
348
|
-
const __m128 v0 =
|
|
349
|
-
_mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1));
|
|
350
|
-
// perform horizontal sum on v0
|
|
351
|
-
return horizontal_sum(v0);
|
|
352
|
-
}
|
|
353
|
-
#endif
|
|
354
|
-
|
|
355
|
-
#ifdef __AVX512F__
|
|
356
|
-
/// helper function for AVX512
|
|
357
|
-
inline float horizontal_sum(const __m512 v) {
|
|
358
|
-
// performs better than adding the high and low parts
|
|
359
|
-
return _mm512_reduce_add_ps(v);
|
|
360
|
-
}
|
|
361
|
-
#endif
|
|
362
|
-
|
|
363
|
-
/// Function that does a component-wise operation between x and y
|
|
364
|
-
/// to compute L2 distances. ElementOp can then be used in the fvec_op_ny
|
|
365
|
-
/// functions below
|
|
366
|
-
struct ElementOpL2 {
|
|
367
|
-
static float op(float x, float y) {
|
|
368
|
-
float tmp = x - y;
|
|
369
|
-
return tmp * tmp;
|
|
370
|
-
}
|
|
371
|
-
|
|
372
|
-
static __m128 op(__m128 x, __m128 y) {
|
|
373
|
-
__m128 tmp = _mm_sub_ps(x, y);
|
|
374
|
-
return _mm_mul_ps(tmp, tmp);
|
|
375
|
-
}
|
|
376
|
-
|
|
377
|
-
#ifdef __AVX2__
|
|
378
|
-
static __m256 op(__m256 x, __m256 y) {
|
|
379
|
-
__m256 tmp = _mm256_sub_ps(x, y);
|
|
380
|
-
return _mm256_mul_ps(tmp, tmp);
|
|
381
|
-
}
|
|
382
|
-
#endif
|
|
383
|
-
|
|
384
|
-
#ifdef __AVX512F__
|
|
385
|
-
static __m512 op(__m512 x, __m512 y) {
|
|
386
|
-
__m512 tmp = _mm512_sub_ps(x, y);
|
|
387
|
-
return _mm512_mul_ps(tmp, tmp);
|
|
388
|
-
}
|
|
389
|
-
#endif
|
|
390
|
-
};
|
|
391
|
-
|
|
392
|
-
/// Function that does a component-wise operation between x and y
|
|
393
|
-
/// to compute inner products
|
|
394
|
-
struct ElementOpIP {
|
|
395
|
-
static float op(float x, float y) {
|
|
396
|
-
return x * y;
|
|
397
|
-
}
|
|
398
|
-
|
|
399
|
-
static __m128 op(__m128 x, __m128 y) {
|
|
400
|
-
return _mm_mul_ps(x, y);
|
|
401
|
-
}
|
|
402
|
-
|
|
403
|
-
#ifdef __AVX2__
|
|
404
|
-
static __m256 op(__m256 x, __m256 y) {
|
|
405
|
-
return _mm256_mul_ps(x, y);
|
|
406
|
-
}
|
|
407
|
-
#endif
|
|
408
|
-
|
|
409
|
-
#ifdef __AVX512F__
|
|
410
|
-
static __m512 op(__m512 x, __m512 y) {
|
|
411
|
-
return _mm512_mul_ps(x, y);
|
|
412
|
-
}
|
|
413
|
-
#endif
|
|
414
|
-
};
|
|
415
|
-
|
|
416
|
-
template <class ElementOp>
|
|
417
|
-
void fvec_op_ny_D1(float* dis, const float* x, const float* y, size_t ny) {
|
|
418
|
-
float x0s = x[0];
|
|
419
|
-
__m128 x0 = _mm_set_ps(x0s, x0s, x0s, x0s);
|
|
420
|
-
|
|
421
|
-
size_t i;
|
|
422
|
-
for (i = 0; i + 3 < ny; i += 4) {
|
|
423
|
-
__m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
|
|
424
|
-
y += 4;
|
|
425
|
-
dis[i] = _mm_cvtss_f32(accu);
|
|
426
|
-
__m128 tmp = _mm_shuffle_ps(accu, accu, 1);
|
|
427
|
-
dis[i + 1] = _mm_cvtss_f32(tmp);
|
|
428
|
-
tmp = _mm_shuffle_ps(accu, accu, 2);
|
|
429
|
-
dis[i + 2] = _mm_cvtss_f32(tmp);
|
|
430
|
-
tmp = _mm_shuffle_ps(accu, accu, 3);
|
|
431
|
-
dis[i + 3] = _mm_cvtss_f32(tmp);
|
|
432
|
-
}
|
|
433
|
-
while (i < ny) { // handle non-multiple-of-4 case
|
|
434
|
-
dis[i++] = ElementOp::op(x0s, *y++);
|
|
435
|
-
}
|
|
436
|
-
}
|
|
437
|
-
|
|
438
|
-
template <class ElementOp>
|
|
439
|
-
void fvec_op_ny_D2(float* dis, const float* x, const float* y, size_t ny) {
|
|
440
|
-
__m128 x0 = _mm_set_ps(x[1], x[0], x[1], x[0]);
|
|
441
|
-
|
|
442
|
-
size_t i;
|
|
443
|
-
for (i = 0; i + 1 < ny; i += 2) {
|
|
444
|
-
__m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
|
|
445
|
-
y += 4;
|
|
446
|
-
accu = _mm_hadd_ps(accu, accu);
|
|
447
|
-
dis[i] = _mm_cvtss_f32(accu);
|
|
448
|
-
accu = _mm_shuffle_ps(accu, accu, 3);
|
|
449
|
-
dis[i + 1] = _mm_cvtss_f32(accu);
|
|
450
|
-
}
|
|
451
|
-
if (i < ny) { // handle odd case
|
|
452
|
-
dis[i] = ElementOp::op(x[0], y[0]) + ElementOp::op(x[1], y[1]);
|
|
453
|
-
}
|
|
454
|
-
}
|
|
455
|
-
|
|
456
|
-
#if defined(__AVX512F__)
|
|
457
|
-
|
|
458
92
|
template <>
|
|
459
|
-
void
|
|
93
|
+
void fvec_L2sqr_ny<SIMDLevel::NONE>(
|
|
460
94
|
float* dis,
|
|
461
95
|
const float* x,
|
|
462
96
|
const float* y,
|
|
97
|
+
size_t d,
|
|
463
98
|
size_t ny) {
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
if (ny16 > 0) {
|
|
468
|
-
// process 16 D2-vectors per loop.
|
|
469
|
-
_mm_prefetch((const char*)y, _MM_HINT_T0);
|
|
470
|
-
_mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
|
|
471
|
-
|
|
472
|
-
const __m512 m0 = _mm512_set1_ps(x[0]);
|
|
473
|
-
const __m512 m1 = _mm512_set1_ps(x[1]);
|
|
474
|
-
|
|
475
|
-
for (i = 0; i < ny16 * 16; i += 16) {
|
|
476
|
-
_mm_prefetch((const char*)(y + 64), _MM_HINT_T0);
|
|
477
|
-
|
|
478
|
-
// load 16x2 matrix and transpose it in registers.
|
|
479
|
-
// the typical bottleneck is memory access, so
|
|
480
|
-
// let's trade instructions for the bandwidth.
|
|
481
|
-
|
|
482
|
-
__m512 v0;
|
|
483
|
-
__m512 v1;
|
|
484
|
-
|
|
485
|
-
transpose_16x2(
|
|
486
|
-
_mm512_loadu_ps(y + 0 * 16),
|
|
487
|
-
_mm512_loadu_ps(y + 1 * 16),
|
|
488
|
-
v0,
|
|
489
|
-
v1);
|
|
490
|
-
|
|
491
|
-
// compute distances (dot product)
|
|
492
|
-
__m512 distances = _mm512_mul_ps(m0, v0);
|
|
493
|
-
distances = _mm512_fmadd_ps(m1, v1, distances);
|
|
494
|
-
|
|
495
|
-
// store
|
|
496
|
-
_mm512_storeu_ps(dis + i, distances);
|
|
497
|
-
|
|
498
|
-
y += 32; // move to the next set of 16x2 elements
|
|
499
|
-
}
|
|
500
|
-
}
|
|
501
|
-
|
|
502
|
-
if (i < ny) {
|
|
503
|
-
// process leftovers
|
|
504
|
-
float x0 = x[0];
|
|
505
|
-
float x1 = x[1];
|
|
506
|
-
|
|
507
|
-
for (; i < ny; i++) {
|
|
508
|
-
float distance = x0 * y[0] + x1 * y[1];
|
|
509
|
-
y += 2;
|
|
510
|
-
dis[i] = distance;
|
|
511
|
-
}
|
|
99
|
+
for (size_t i = 0; i < ny; i++) {
|
|
100
|
+
dis[i] = fvec_L2sqr(x, y, d);
|
|
101
|
+
y += d;
|
|
512
102
|
}
|
|
513
103
|
}
|
|
514
104
|
|
|
515
105
|
template <>
|
|
516
|
-
|
|
517
|
-
float*
|
|
106
|
+
size_t fvec_L2sqr_ny_nearest<SIMDLevel::NONE>(
|
|
107
|
+
float* distances_tmp_buffer,
|
|
518
108
|
const float* x,
|
|
519
109
|
const float* y,
|
|
110
|
+
size_t d,
|
|
520
111
|
size_t ny) {
|
|
521
|
-
|
|
522
|
-
size_t i = 0;
|
|
523
|
-
|
|
524
|
-
if (ny16 > 0) {
|
|
525
|
-
// process 16 D2-vectors per loop.
|
|
526
|
-
_mm_prefetch((const char*)y, _MM_HINT_T0);
|
|
527
|
-
_mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
|
|
528
|
-
|
|
529
|
-
const __m512 m0 = _mm512_set1_ps(x[0]);
|
|
530
|
-
const __m512 m1 = _mm512_set1_ps(x[1]);
|
|
531
|
-
|
|
532
|
-
for (i = 0; i < ny16 * 16; i += 16) {
|
|
533
|
-
_mm_prefetch((const char*)(y + 64), _MM_HINT_T0);
|
|
534
|
-
|
|
535
|
-
// load 16x2 matrix and transpose it in registers.
|
|
536
|
-
// the typical bottleneck is memory access, so
|
|
537
|
-
// let's trade instructions for the bandwidth.
|
|
538
|
-
|
|
539
|
-
__m512 v0;
|
|
540
|
-
__m512 v1;
|
|
541
|
-
|
|
542
|
-
transpose_16x2(
|
|
543
|
-
_mm512_loadu_ps(y + 0 * 16),
|
|
544
|
-
_mm512_loadu_ps(y + 1 * 16),
|
|
545
|
-
v0,
|
|
546
|
-
v1);
|
|
112
|
+
fvec_L2sqr_ny<SIMDLevel::NONE>(distances_tmp_buffer, x, y, d, ny);
|
|
547
113
|
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
const __m512 d1 = _mm512_sub_ps(m1, v1);
|
|
551
|
-
|
|
552
|
-
// compute squares of differences
|
|
553
|
-
__m512 distances = _mm512_mul_ps(d0, d0);
|
|
554
|
-
distances = _mm512_fmadd_ps(d1, d1, distances);
|
|
555
|
-
|
|
556
|
-
// store
|
|
557
|
-
_mm512_storeu_ps(dis + i, distances);
|
|
114
|
+
size_t nearest_idx = 0;
|
|
115
|
+
float min_dis = HUGE_VALF;
|
|
558
116
|
|
|
559
|
-
|
|
117
|
+
for (size_t i = 0; i < ny; i++) {
|
|
118
|
+
if (distances_tmp_buffer[i] < min_dis) {
|
|
119
|
+
min_dis = distances_tmp_buffer[i];
|
|
120
|
+
nearest_idx = i;
|
|
560
121
|
}
|
|
561
122
|
}
|
|
562
123
|
|
|
563
|
-
|
|
564
|
-
// process leftovers
|
|
565
|
-
float x0 = x[0];
|
|
566
|
-
float x1 = x[1];
|
|
567
|
-
|
|
568
|
-
for (; i < ny; i++) {
|
|
569
|
-
float sub0 = x0 - y[0];
|
|
570
|
-
float sub1 = x1 - y[1];
|
|
571
|
-
float distance = sub0 * sub0 + sub1 * sub1;
|
|
572
|
-
|
|
573
|
-
y += 2;
|
|
574
|
-
dis[i] = distance;
|
|
575
|
-
}
|
|
576
|
-
}
|
|
124
|
+
return nearest_idx;
|
|
577
125
|
}
|
|
578
126
|
|
|
579
|
-
#elif defined(__AVX2__)
|
|
580
|
-
|
|
581
127
|
template <>
|
|
582
|
-
|
|
583
|
-
float*
|
|
128
|
+
size_t fvec_L2sqr_ny_nearest_y_transposed<SIMDLevel::NONE>(
|
|
129
|
+
float* distances_tmp_buffer,
|
|
584
130
|
const float* x,
|
|
585
131
|
const float* y,
|
|
132
|
+
const float* y_sqlen,
|
|
133
|
+
size_t d,
|
|
134
|
+
size_t d_offset,
|
|
586
135
|
size_t ny) {
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
if (ny8 > 0) {
|
|
591
|
-
// process 8 D2-vectors per loop.
|
|
592
|
-
_mm_prefetch((const char*)y, _MM_HINT_T0);
|
|
593
|
-
_mm_prefetch((const char*)(y + 16), _MM_HINT_T0);
|
|
594
|
-
|
|
595
|
-
const __m256 m0 = _mm256_set1_ps(x[0]);
|
|
596
|
-
const __m256 m1 = _mm256_set1_ps(x[1]);
|
|
597
|
-
|
|
598
|
-
for (i = 0; i < ny8 * 8; i += 8) {
|
|
599
|
-
_mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
|
|
600
|
-
|
|
601
|
-
// load 8x2 matrix and transpose it in registers.
|
|
602
|
-
// the typical bottleneck is memory access, so
|
|
603
|
-
// let's trade instructions for the bandwidth.
|
|
604
|
-
|
|
605
|
-
__m256 v0;
|
|
606
|
-
__m256 v1;
|
|
607
|
-
|
|
608
|
-
transpose_8x2(
|
|
609
|
-
_mm256_loadu_ps(y + 0 * 8),
|
|
610
|
-
_mm256_loadu_ps(y + 1 * 8),
|
|
611
|
-
v0,
|
|
612
|
-
v1);
|
|
613
|
-
|
|
614
|
-
// compute distances
|
|
615
|
-
__m256 distances = _mm256_mul_ps(m0, v0);
|
|
616
|
-
distances = _mm256_fmadd_ps(m1, v1, distances);
|
|
136
|
+
fvec_L2sqr_ny_transposed<SIMDLevel::NONE>(
|
|
137
|
+
distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
|
|
617
138
|
|
|
618
|
-
|
|
619
|
-
|
|
139
|
+
size_t nearest_idx = 0;
|
|
140
|
+
float min_dis = HUGE_VALF;
|
|
620
141
|
|
|
621
|
-
|
|
142
|
+
for (size_t i = 0; i < ny; i++) {
|
|
143
|
+
if (distances_tmp_buffer[i] < min_dis) {
|
|
144
|
+
min_dis = distances_tmp_buffer[i];
|
|
145
|
+
nearest_idx = i;
|
|
622
146
|
}
|
|
623
147
|
}
|
|
624
148
|
|
|
625
|
-
|
|
626
|
-
// process leftovers
|
|
627
|
-
float x0 = x[0];
|
|
628
|
-
float x1 = x[1];
|
|
629
|
-
|
|
630
|
-
for (; i < ny; i++) {
|
|
631
|
-
float distance = x0 * y[0] + x1 * y[1];
|
|
632
|
-
y += 2;
|
|
633
|
-
dis[i] = distance;
|
|
634
|
-
}
|
|
635
|
-
}
|
|
149
|
+
return nearest_idx;
|
|
636
150
|
}
|
|
637
151
|
|
|
638
152
|
template <>
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
const float*
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
if (ny8 > 0) {
|
|
648
|
-
// process 8 D2-vectors per loop.
|
|
649
|
-
_mm_prefetch((const char*)y, _MM_HINT_T0);
|
|
650
|
-
_mm_prefetch((const char*)(y + 16), _MM_HINT_T0);
|
|
651
|
-
|
|
652
|
-
const __m256 m0 = _mm256_set1_ps(x[0]);
|
|
653
|
-
const __m256 m1 = _mm256_set1_ps(x[1]);
|
|
654
|
-
|
|
655
|
-
for (i = 0; i < ny8 * 8; i += 8) {
|
|
656
|
-
_mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
|
|
657
|
-
|
|
658
|
-
// load 8x2 matrix and transpose it in registers.
|
|
659
|
-
// the typical bottleneck is memory access, so
|
|
660
|
-
// let's trade instructions for the bandwidth.
|
|
661
|
-
|
|
662
|
-
__m256 v0;
|
|
663
|
-
__m256 v1;
|
|
664
|
-
|
|
665
|
-
transpose_8x2(
|
|
666
|
-
_mm256_loadu_ps(y + 0 * 8),
|
|
667
|
-
_mm256_loadu_ps(y + 1 * 8),
|
|
668
|
-
v0,
|
|
669
|
-
v1);
|
|
670
|
-
|
|
671
|
-
// compute differences
|
|
672
|
-
const __m256 d0 = _mm256_sub_ps(m0, v0);
|
|
673
|
-
const __m256 d1 = _mm256_sub_ps(m1, v1);
|
|
674
|
-
|
|
675
|
-
// compute squares of differences
|
|
676
|
-
__m256 distances = _mm256_mul_ps(d0, d0);
|
|
677
|
-
distances = _mm256_fmadd_ps(d1, d1, distances);
|
|
678
|
-
|
|
679
|
-
// store
|
|
680
|
-
_mm256_storeu_ps(dis + i, distances);
|
|
681
|
-
|
|
682
|
-
y += 16;
|
|
683
|
-
}
|
|
684
|
-
}
|
|
685
|
-
|
|
686
|
-
if (i < ny) {
|
|
687
|
-
// process leftovers
|
|
688
|
-
float x0 = x[0];
|
|
689
|
-
float x1 = x[1];
|
|
690
|
-
|
|
691
|
-
for (; i < ny; i++) {
|
|
692
|
-
float sub0 = x0 - y[0];
|
|
693
|
-
float sub1 = x1 - y[1];
|
|
694
|
-
float distance = sub0 * sub0 + sub1 * sub1;
|
|
153
|
+
int fvec_madd_and_argmin<SIMDLevel::NONE>(
|
|
154
|
+
size_t n,
|
|
155
|
+
const float* a,
|
|
156
|
+
float bf,
|
|
157
|
+
const float* b,
|
|
158
|
+
float* c) {
|
|
159
|
+
float vmin = 1e20;
|
|
160
|
+
int imin = -1;
|
|
695
161
|
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
#endif
|
|
703
|
-
|
|
704
|
-
template <class ElementOp>
|
|
705
|
-
void fvec_op_ny_D4(float* dis, const float* x, const float* y, size_t ny) {
|
|
706
|
-
__m128 x0 = _mm_loadu_ps(x);
|
|
707
|
-
|
|
708
|
-
for (size_t i = 0; i < ny; i++) {
|
|
709
|
-
__m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
|
|
710
|
-
y += 4;
|
|
711
|
-
dis[i] = horizontal_sum(accu);
|
|
712
|
-
}
|
|
713
|
-
}
|
|
714
|
-
|
|
715
|
-
#if defined(__AVX512F__)
|
|
716
|
-
|
|
717
|
-
template <>
|
|
718
|
-
void fvec_op_ny_D4<ElementOpIP>(
|
|
719
|
-
float* dis,
|
|
720
|
-
const float* x,
|
|
721
|
-
const float* y,
|
|
722
|
-
size_t ny) {
|
|
723
|
-
const size_t ny16 = ny / 16;
|
|
724
|
-
size_t i = 0;
|
|
725
|
-
|
|
726
|
-
if (ny16 > 0) {
|
|
727
|
-
// process 16 D4-vectors per loop.
|
|
728
|
-
const __m512 m0 = _mm512_set1_ps(x[0]);
|
|
729
|
-
const __m512 m1 = _mm512_set1_ps(x[1]);
|
|
730
|
-
const __m512 m2 = _mm512_set1_ps(x[2]);
|
|
731
|
-
const __m512 m3 = _mm512_set1_ps(x[3]);
|
|
732
|
-
|
|
733
|
-
for (i = 0; i < ny16 * 16; i += 16) {
|
|
734
|
-
// load 16x4 matrix and transpose it in registers.
|
|
735
|
-
// the typical bottleneck is memory access, so
|
|
736
|
-
// let's trade instructions for the bandwidth.
|
|
737
|
-
|
|
738
|
-
__m512 v0;
|
|
739
|
-
__m512 v1;
|
|
740
|
-
__m512 v2;
|
|
741
|
-
__m512 v3;
|
|
742
|
-
|
|
743
|
-
transpose_16x4(
|
|
744
|
-
_mm512_loadu_ps(y + 0 * 16),
|
|
745
|
-
_mm512_loadu_ps(y + 1 * 16),
|
|
746
|
-
_mm512_loadu_ps(y + 2 * 16),
|
|
747
|
-
_mm512_loadu_ps(y + 3 * 16),
|
|
748
|
-
v0,
|
|
749
|
-
v1,
|
|
750
|
-
v2,
|
|
751
|
-
v3);
|
|
752
|
-
|
|
753
|
-
// compute distances
|
|
754
|
-
__m512 distances = _mm512_mul_ps(m0, v0);
|
|
755
|
-
distances = _mm512_fmadd_ps(m1, v1, distances);
|
|
756
|
-
distances = _mm512_fmadd_ps(m2, v2, distances);
|
|
757
|
-
distances = _mm512_fmadd_ps(m3, v3, distances);
|
|
758
|
-
|
|
759
|
-
// store
|
|
760
|
-
_mm512_storeu_ps(dis + i, distances);
|
|
761
|
-
|
|
762
|
-
y += 64; // move to the next set of 16x4 elements
|
|
763
|
-
}
|
|
764
|
-
}
|
|
765
|
-
|
|
766
|
-
if (i < ny) {
|
|
767
|
-
// process leftovers
|
|
768
|
-
__m128 x0 = _mm_loadu_ps(x);
|
|
769
|
-
|
|
770
|
-
for (; i < ny; i++) {
|
|
771
|
-
__m128 accu = ElementOpIP::op(x0, _mm_loadu_ps(y));
|
|
772
|
-
y += 4;
|
|
773
|
-
dis[i] = horizontal_sum(accu);
|
|
774
|
-
}
|
|
775
|
-
}
|
|
776
|
-
}
|
|
777
|
-
|
|
778
|
-
template <>
|
|
779
|
-
void fvec_op_ny_D4<ElementOpL2>(
|
|
780
|
-
float* dis,
|
|
781
|
-
const float* x,
|
|
782
|
-
const float* y,
|
|
783
|
-
size_t ny) {
|
|
784
|
-
const size_t ny16 = ny / 16;
|
|
785
|
-
size_t i = 0;
|
|
786
|
-
|
|
787
|
-
if (ny16 > 0) {
|
|
788
|
-
// process 16 D4-vectors per loop.
|
|
789
|
-
const __m512 m0 = _mm512_set1_ps(x[0]);
|
|
790
|
-
const __m512 m1 = _mm512_set1_ps(x[1]);
|
|
791
|
-
const __m512 m2 = _mm512_set1_ps(x[2]);
|
|
792
|
-
const __m512 m3 = _mm512_set1_ps(x[3]);
|
|
793
|
-
|
|
794
|
-
for (i = 0; i < ny16 * 16; i += 16) {
|
|
795
|
-
// load 16x4 matrix and transpose it in registers.
|
|
796
|
-
// the typical bottleneck is memory access, so
|
|
797
|
-
// let's trade instructions for the bandwidth.
|
|
798
|
-
|
|
799
|
-
__m512 v0;
|
|
800
|
-
__m512 v1;
|
|
801
|
-
__m512 v2;
|
|
802
|
-
__m512 v3;
|
|
803
|
-
|
|
804
|
-
transpose_16x4(
|
|
805
|
-
_mm512_loadu_ps(y + 0 * 16),
|
|
806
|
-
_mm512_loadu_ps(y + 1 * 16),
|
|
807
|
-
_mm512_loadu_ps(y + 2 * 16),
|
|
808
|
-
_mm512_loadu_ps(y + 3 * 16),
|
|
809
|
-
v0,
|
|
810
|
-
v1,
|
|
811
|
-
v2,
|
|
812
|
-
v3);
|
|
813
|
-
|
|
814
|
-
// compute differences
|
|
815
|
-
const __m512 d0 = _mm512_sub_ps(m0, v0);
|
|
816
|
-
const __m512 d1 = _mm512_sub_ps(m1, v1);
|
|
817
|
-
const __m512 d2 = _mm512_sub_ps(m2, v2);
|
|
818
|
-
const __m512 d3 = _mm512_sub_ps(m3, v3);
|
|
819
|
-
|
|
820
|
-
// compute squares of differences
|
|
821
|
-
__m512 distances = _mm512_mul_ps(d0, d0);
|
|
822
|
-
distances = _mm512_fmadd_ps(d1, d1, distances);
|
|
823
|
-
distances = _mm512_fmadd_ps(d2, d2, distances);
|
|
824
|
-
distances = _mm512_fmadd_ps(d3, d3, distances);
|
|
825
|
-
|
|
826
|
-
// store
|
|
827
|
-
_mm512_storeu_ps(dis + i, distances);
|
|
828
|
-
|
|
829
|
-
y += 64; // move to the next set of 16x4 elements
|
|
830
|
-
}
|
|
831
|
-
}
|
|
832
|
-
|
|
833
|
-
if (i < ny) {
|
|
834
|
-
// process leftovers
|
|
835
|
-
__m128 x0 = _mm_loadu_ps(x);
|
|
836
|
-
|
|
837
|
-
for (; i < ny; i++) {
|
|
838
|
-
__m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
|
|
839
|
-
y += 4;
|
|
840
|
-
dis[i] = horizontal_sum(accu);
|
|
841
|
-
}
|
|
842
|
-
}
|
|
843
|
-
}
|
|
844
|
-
|
|
845
|
-
#elif defined(__AVX2__)
|
|
846
|
-
|
|
847
|
-
template <>
|
|
848
|
-
void fvec_op_ny_D4<ElementOpIP>(
|
|
849
|
-
float* dis,
|
|
850
|
-
const float* x,
|
|
851
|
-
const float* y,
|
|
852
|
-
size_t ny) {
|
|
853
|
-
const size_t ny8 = ny / 8;
|
|
854
|
-
size_t i = 0;
|
|
855
|
-
|
|
856
|
-
if (ny8 > 0) {
|
|
857
|
-
// process 8 D4-vectors per loop.
|
|
858
|
-
const __m256 m0 = _mm256_set1_ps(x[0]);
|
|
859
|
-
const __m256 m1 = _mm256_set1_ps(x[1]);
|
|
860
|
-
const __m256 m2 = _mm256_set1_ps(x[2]);
|
|
861
|
-
const __m256 m3 = _mm256_set1_ps(x[3]);
|
|
862
|
-
|
|
863
|
-
for (i = 0; i < ny8 * 8; i += 8) {
|
|
864
|
-
// load 8x4 matrix and transpose it in registers.
|
|
865
|
-
// the typical bottleneck is memory access, so
|
|
866
|
-
// let's trade instructions for the bandwidth.
|
|
867
|
-
|
|
868
|
-
__m256 v0;
|
|
869
|
-
__m256 v1;
|
|
870
|
-
__m256 v2;
|
|
871
|
-
__m256 v3;
|
|
872
|
-
|
|
873
|
-
transpose_8x4(
|
|
874
|
-
_mm256_loadu_ps(y + 0 * 8),
|
|
875
|
-
_mm256_loadu_ps(y + 1 * 8),
|
|
876
|
-
_mm256_loadu_ps(y + 2 * 8),
|
|
877
|
-
_mm256_loadu_ps(y + 3 * 8),
|
|
878
|
-
v0,
|
|
879
|
-
v1,
|
|
880
|
-
v2,
|
|
881
|
-
v3);
|
|
882
|
-
|
|
883
|
-
// compute distances
|
|
884
|
-
__m256 distances = _mm256_mul_ps(m0, v0);
|
|
885
|
-
distances = _mm256_fmadd_ps(m1, v1, distances);
|
|
886
|
-
distances = _mm256_fmadd_ps(m2, v2, distances);
|
|
887
|
-
distances = _mm256_fmadd_ps(m3, v3, distances);
|
|
888
|
-
|
|
889
|
-
// store
|
|
890
|
-
_mm256_storeu_ps(dis + i, distances);
|
|
891
|
-
|
|
892
|
-
y += 32;
|
|
893
|
-
}
|
|
894
|
-
}
|
|
895
|
-
|
|
896
|
-
if (i < ny) {
|
|
897
|
-
// process leftovers
|
|
898
|
-
__m128 x0 = _mm_loadu_ps(x);
|
|
899
|
-
|
|
900
|
-
for (; i < ny; i++) {
|
|
901
|
-
__m128 accu = ElementOpIP::op(x0, _mm_loadu_ps(y));
|
|
902
|
-
y += 4;
|
|
903
|
-
dis[i] = horizontal_sum(accu);
|
|
904
|
-
}
|
|
905
|
-
}
|
|
906
|
-
}
|
|
907
|
-
|
|
908
|
-
template <>
|
|
909
|
-
void fvec_op_ny_D4<ElementOpL2>(
|
|
910
|
-
float* dis,
|
|
911
|
-
const float* x,
|
|
912
|
-
const float* y,
|
|
913
|
-
size_t ny) {
|
|
914
|
-
const size_t ny8 = ny / 8;
|
|
915
|
-
size_t i = 0;
|
|
916
|
-
|
|
917
|
-
if (ny8 > 0) {
|
|
918
|
-
// process 8 D4-vectors per loop.
|
|
919
|
-
const __m256 m0 = _mm256_set1_ps(x[0]);
|
|
920
|
-
const __m256 m1 = _mm256_set1_ps(x[1]);
|
|
921
|
-
const __m256 m2 = _mm256_set1_ps(x[2]);
|
|
922
|
-
const __m256 m3 = _mm256_set1_ps(x[3]);
|
|
923
|
-
|
|
924
|
-
for (i = 0; i < ny8 * 8; i += 8) {
|
|
925
|
-
// load 8x4 matrix and transpose it in registers.
|
|
926
|
-
// the typical bottleneck is memory access, so
|
|
927
|
-
// let's trade instructions for the bandwidth.
|
|
928
|
-
|
|
929
|
-
__m256 v0;
|
|
930
|
-
__m256 v1;
|
|
931
|
-
__m256 v2;
|
|
932
|
-
__m256 v3;
|
|
933
|
-
|
|
934
|
-
transpose_8x4(
|
|
935
|
-
_mm256_loadu_ps(y + 0 * 8),
|
|
936
|
-
_mm256_loadu_ps(y + 1 * 8),
|
|
937
|
-
_mm256_loadu_ps(y + 2 * 8),
|
|
938
|
-
_mm256_loadu_ps(y + 3 * 8),
|
|
939
|
-
v0,
|
|
940
|
-
v1,
|
|
941
|
-
v2,
|
|
942
|
-
v3);
|
|
943
|
-
|
|
944
|
-
// compute differences
|
|
945
|
-
const __m256 d0 = _mm256_sub_ps(m0, v0);
|
|
946
|
-
const __m256 d1 = _mm256_sub_ps(m1, v1);
|
|
947
|
-
const __m256 d2 = _mm256_sub_ps(m2, v2);
|
|
948
|
-
const __m256 d3 = _mm256_sub_ps(m3, v3);
|
|
949
|
-
|
|
950
|
-
// compute squares of differences
|
|
951
|
-
__m256 distances = _mm256_mul_ps(d0, d0);
|
|
952
|
-
distances = _mm256_fmadd_ps(d1, d1, distances);
|
|
953
|
-
distances = _mm256_fmadd_ps(d2, d2, distances);
|
|
954
|
-
distances = _mm256_fmadd_ps(d3, d3, distances);
|
|
955
|
-
|
|
956
|
-
// store
|
|
957
|
-
_mm256_storeu_ps(dis + i, distances);
|
|
958
|
-
|
|
959
|
-
y += 32;
|
|
960
|
-
}
|
|
961
|
-
}
|
|
962
|
-
|
|
963
|
-
if (i < ny) {
|
|
964
|
-
// process leftovers
|
|
965
|
-
__m128 x0 = _mm_loadu_ps(x);
|
|
966
|
-
|
|
967
|
-
for (; i < ny; i++) {
|
|
968
|
-
__m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
|
|
969
|
-
y += 4;
|
|
970
|
-
dis[i] = horizontal_sum(accu);
|
|
971
|
-
}
|
|
972
|
-
}
|
|
973
|
-
}
|
|
974
|
-
|
|
975
|
-
#endif
|
|
976
|
-
|
|
977
|
-
template <class ElementOp>
|
|
978
|
-
void fvec_op_ny_D8(float* dis, const float* x, const float* y, size_t ny) {
|
|
979
|
-
__m128 x0 = _mm_loadu_ps(x);
|
|
980
|
-
__m128 x1 = _mm_loadu_ps(x + 4);
|
|
981
|
-
|
|
982
|
-
for (size_t i = 0; i < ny; i++) {
|
|
983
|
-
__m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
|
|
984
|
-
y += 4;
|
|
985
|
-
accu = _mm_add_ps(accu, ElementOp::op(x1, _mm_loadu_ps(y)));
|
|
986
|
-
y += 4;
|
|
987
|
-
accu = _mm_hadd_ps(accu, accu);
|
|
988
|
-
accu = _mm_hadd_ps(accu, accu);
|
|
989
|
-
dis[i] = _mm_cvtss_f32(accu);
|
|
990
|
-
}
|
|
991
|
-
}
|
|
992
|
-
|
|
993
|
-
#if defined(__AVX512F__)
|
|
994
|
-
|
|
995
|
-
template <>
|
|
996
|
-
void fvec_op_ny_D8<ElementOpIP>(
|
|
997
|
-
float* dis,
|
|
998
|
-
const float* x,
|
|
999
|
-
const float* y,
|
|
1000
|
-
size_t ny) {
|
|
1001
|
-
const size_t ny16 = ny / 16;
|
|
1002
|
-
size_t i = 0;
|
|
1003
|
-
|
|
1004
|
-
if (ny16 > 0) {
|
|
1005
|
-
// process 16 D16-vectors per loop.
|
|
1006
|
-
const __m512 m0 = _mm512_set1_ps(x[0]);
|
|
1007
|
-
const __m512 m1 = _mm512_set1_ps(x[1]);
|
|
1008
|
-
const __m512 m2 = _mm512_set1_ps(x[2]);
|
|
1009
|
-
const __m512 m3 = _mm512_set1_ps(x[3]);
|
|
1010
|
-
const __m512 m4 = _mm512_set1_ps(x[4]);
|
|
1011
|
-
const __m512 m5 = _mm512_set1_ps(x[5]);
|
|
1012
|
-
const __m512 m6 = _mm512_set1_ps(x[6]);
|
|
1013
|
-
const __m512 m7 = _mm512_set1_ps(x[7]);
|
|
1014
|
-
|
|
1015
|
-
for (i = 0; i < ny16 * 16; i += 16) {
|
|
1016
|
-
// load 16x8 matrix and transpose it in registers.
|
|
1017
|
-
// the typical bottleneck is memory access, so
|
|
1018
|
-
// let's trade instructions for the bandwidth.
|
|
1019
|
-
|
|
1020
|
-
__m512 v0;
|
|
1021
|
-
__m512 v1;
|
|
1022
|
-
__m512 v2;
|
|
1023
|
-
__m512 v3;
|
|
1024
|
-
__m512 v4;
|
|
1025
|
-
__m512 v5;
|
|
1026
|
-
__m512 v6;
|
|
1027
|
-
__m512 v7;
|
|
1028
|
-
|
|
1029
|
-
transpose_16x8(
|
|
1030
|
-
_mm512_loadu_ps(y + 0 * 16),
|
|
1031
|
-
_mm512_loadu_ps(y + 1 * 16),
|
|
1032
|
-
_mm512_loadu_ps(y + 2 * 16),
|
|
1033
|
-
_mm512_loadu_ps(y + 3 * 16),
|
|
1034
|
-
_mm512_loadu_ps(y + 4 * 16),
|
|
1035
|
-
_mm512_loadu_ps(y + 5 * 16),
|
|
1036
|
-
_mm512_loadu_ps(y + 6 * 16),
|
|
1037
|
-
_mm512_loadu_ps(y + 7 * 16),
|
|
1038
|
-
v0,
|
|
1039
|
-
v1,
|
|
1040
|
-
v2,
|
|
1041
|
-
v3,
|
|
1042
|
-
v4,
|
|
1043
|
-
v5,
|
|
1044
|
-
v6,
|
|
1045
|
-
v7);
|
|
1046
|
-
|
|
1047
|
-
// compute distances
|
|
1048
|
-
__m512 distances = _mm512_mul_ps(m0, v0);
|
|
1049
|
-
distances = _mm512_fmadd_ps(m1, v1, distances);
|
|
1050
|
-
distances = _mm512_fmadd_ps(m2, v2, distances);
|
|
1051
|
-
distances = _mm512_fmadd_ps(m3, v3, distances);
|
|
1052
|
-
distances = _mm512_fmadd_ps(m4, v4, distances);
|
|
1053
|
-
distances = _mm512_fmadd_ps(m5, v5, distances);
|
|
1054
|
-
distances = _mm512_fmadd_ps(m6, v6, distances);
|
|
1055
|
-
distances = _mm512_fmadd_ps(m7, v7, distances);
|
|
1056
|
-
|
|
1057
|
-
// store
|
|
1058
|
-
_mm512_storeu_ps(dis + i, distances);
|
|
1059
|
-
|
|
1060
|
-
y += 128; // 16 floats * 8 rows
|
|
1061
|
-
}
|
|
1062
|
-
}
|
|
1063
|
-
|
|
1064
|
-
if (i < ny) {
|
|
1065
|
-
// process leftovers
|
|
1066
|
-
__m256 x0 = _mm256_loadu_ps(x);
|
|
1067
|
-
|
|
1068
|
-
for (; i < ny; i++) {
|
|
1069
|
-
__m256 accu = ElementOpIP::op(x0, _mm256_loadu_ps(y));
|
|
1070
|
-
y += 8;
|
|
1071
|
-
dis[i] = horizontal_sum(accu);
|
|
1072
|
-
}
|
|
1073
|
-
}
|
|
1074
|
-
}
|
|
1075
|
-
|
|
1076
|
-
template <>
|
|
1077
|
-
void fvec_op_ny_D8<ElementOpL2>(
|
|
1078
|
-
float* dis,
|
|
1079
|
-
const float* x,
|
|
1080
|
-
const float* y,
|
|
1081
|
-
size_t ny) {
|
|
1082
|
-
const size_t ny16 = ny / 16;
|
|
1083
|
-
size_t i = 0;
|
|
1084
|
-
|
|
1085
|
-
if (ny16 > 0) {
|
|
1086
|
-
// process 16 D16-vectors per loop.
|
|
1087
|
-
const __m512 m0 = _mm512_set1_ps(x[0]);
|
|
1088
|
-
const __m512 m1 = _mm512_set1_ps(x[1]);
|
|
1089
|
-
const __m512 m2 = _mm512_set1_ps(x[2]);
|
|
1090
|
-
const __m512 m3 = _mm512_set1_ps(x[3]);
|
|
1091
|
-
const __m512 m4 = _mm512_set1_ps(x[4]);
|
|
1092
|
-
const __m512 m5 = _mm512_set1_ps(x[5]);
|
|
1093
|
-
const __m512 m6 = _mm512_set1_ps(x[6]);
|
|
1094
|
-
const __m512 m7 = _mm512_set1_ps(x[7]);
|
|
1095
|
-
|
|
1096
|
-
for (i = 0; i < ny16 * 16; i += 16) {
|
|
1097
|
-
// load 16x8 matrix and transpose it in registers.
|
|
1098
|
-
// the typical bottleneck is memory access, so
|
|
1099
|
-
// let's trade instructions for the bandwidth.
|
|
1100
|
-
|
|
1101
|
-
__m512 v0;
|
|
1102
|
-
__m512 v1;
|
|
1103
|
-
__m512 v2;
|
|
1104
|
-
__m512 v3;
|
|
1105
|
-
__m512 v4;
|
|
1106
|
-
__m512 v5;
|
|
1107
|
-
__m512 v6;
|
|
1108
|
-
__m512 v7;
|
|
1109
|
-
|
|
1110
|
-
transpose_16x8(
|
|
1111
|
-
_mm512_loadu_ps(y + 0 * 16),
|
|
1112
|
-
_mm512_loadu_ps(y + 1 * 16),
|
|
1113
|
-
_mm512_loadu_ps(y + 2 * 16),
|
|
1114
|
-
_mm512_loadu_ps(y + 3 * 16),
|
|
1115
|
-
_mm512_loadu_ps(y + 4 * 16),
|
|
1116
|
-
_mm512_loadu_ps(y + 5 * 16),
|
|
1117
|
-
_mm512_loadu_ps(y + 6 * 16),
|
|
1118
|
-
_mm512_loadu_ps(y + 7 * 16),
|
|
1119
|
-
v0,
|
|
1120
|
-
v1,
|
|
1121
|
-
v2,
|
|
1122
|
-
v3,
|
|
1123
|
-
v4,
|
|
1124
|
-
v5,
|
|
1125
|
-
v6,
|
|
1126
|
-
v7);
|
|
1127
|
-
|
|
1128
|
-
// compute differences
|
|
1129
|
-
const __m512 d0 = _mm512_sub_ps(m0, v0);
|
|
1130
|
-
const __m512 d1 = _mm512_sub_ps(m1, v1);
|
|
1131
|
-
const __m512 d2 = _mm512_sub_ps(m2, v2);
|
|
1132
|
-
const __m512 d3 = _mm512_sub_ps(m3, v3);
|
|
1133
|
-
const __m512 d4 = _mm512_sub_ps(m4, v4);
|
|
1134
|
-
const __m512 d5 = _mm512_sub_ps(m5, v5);
|
|
1135
|
-
const __m512 d6 = _mm512_sub_ps(m6, v6);
|
|
1136
|
-
const __m512 d7 = _mm512_sub_ps(m7, v7);
|
|
1137
|
-
|
|
1138
|
-
// compute squares of differences
|
|
1139
|
-
__m512 distances = _mm512_mul_ps(d0, d0);
|
|
1140
|
-
distances = _mm512_fmadd_ps(d1, d1, distances);
|
|
1141
|
-
distances = _mm512_fmadd_ps(d2, d2, distances);
|
|
1142
|
-
distances = _mm512_fmadd_ps(d3, d3, distances);
|
|
1143
|
-
distances = _mm512_fmadd_ps(d4, d4, distances);
|
|
1144
|
-
distances = _mm512_fmadd_ps(d5, d5, distances);
|
|
1145
|
-
distances = _mm512_fmadd_ps(d6, d6, distances);
|
|
1146
|
-
distances = _mm512_fmadd_ps(d7, d7, distances);
|
|
1147
|
-
|
|
1148
|
-
// store
|
|
1149
|
-
_mm512_storeu_ps(dis + i, distances);
|
|
1150
|
-
|
|
1151
|
-
y += 128; // 16 floats * 8 rows
|
|
1152
|
-
}
|
|
1153
|
-
}
|
|
1154
|
-
|
|
1155
|
-
if (i < ny) {
|
|
1156
|
-
// process leftovers
|
|
1157
|
-
__m256 x0 = _mm256_loadu_ps(x);
|
|
1158
|
-
|
|
1159
|
-
for (; i < ny; i++) {
|
|
1160
|
-
__m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y));
|
|
1161
|
-
y += 8;
|
|
1162
|
-
dis[i] = horizontal_sum(accu);
|
|
1163
|
-
}
|
|
1164
|
-
}
|
|
1165
|
-
}
|
|
1166
|
-
|
|
1167
|
-
#elif defined(__AVX2__)
|
|
1168
|
-
|
|
1169
|
-
template <>
|
|
1170
|
-
void fvec_op_ny_D8<ElementOpIP>(
|
|
1171
|
-
float* dis,
|
|
1172
|
-
const float* x,
|
|
1173
|
-
const float* y,
|
|
1174
|
-
size_t ny) {
|
|
1175
|
-
const size_t ny8 = ny / 8;
|
|
1176
|
-
size_t i = 0;
|
|
1177
|
-
|
|
1178
|
-
if (ny8 > 0) {
|
|
1179
|
-
// process 8 D8-vectors per loop.
|
|
1180
|
-
const __m256 m0 = _mm256_set1_ps(x[0]);
|
|
1181
|
-
const __m256 m1 = _mm256_set1_ps(x[1]);
|
|
1182
|
-
const __m256 m2 = _mm256_set1_ps(x[2]);
|
|
1183
|
-
const __m256 m3 = _mm256_set1_ps(x[3]);
|
|
1184
|
-
const __m256 m4 = _mm256_set1_ps(x[4]);
|
|
1185
|
-
const __m256 m5 = _mm256_set1_ps(x[5]);
|
|
1186
|
-
const __m256 m6 = _mm256_set1_ps(x[6]);
|
|
1187
|
-
const __m256 m7 = _mm256_set1_ps(x[7]);
|
|
1188
|
-
|
|
1189
|
-
for (i = 0; i < ny8 * 8; i += 8) {
|
|
1190
|
-
// load 8x8 matrix and transpose it in registers.
|
|
1191
|
-
// the typical bottleneck is memory access, so
|
|
1192
|
-
// let's trade instructions for the bandwidth.
|
|
1193
|
-
|
|
1194
|
-
__m256 v0;
|
|
1195
|
-
__m256 v1;
|
|
1196
|
-
__m256 v2;
|
|
1197
|
-
__m256 v3;
|
|
1198
|
-
__m256 v4;
|
|
1199
|
-
__m256 v5;
|
|
1200
|
-
__m256 v6;
|
|
1201
|
-
__m256 v7;
|
|
1202
|
-
|
|
1203
|
-
transpose_8x8(
|
|
1204
|
-
_mm256_loadu_ps(y + 0 * 8),
|
|
1205
|
-
_mm256_loadu_ps(y + 1 * 8),
|
|
1206
|
-
_mm256_loadu_ps(y + 2 * 8),
|
|
1207
|
-
_mm256_loadu_ps(y + 3 * 8),
|
|
1208
|
-
_mm256_loadu_ps(y + 4 * 8),
|
|
1209
|
-
_mm256_loadu_ps(y + 5 * 8),
|
|
1210
|
-
_mm256_loadu_ps(y + 6 * 8),
|
|
1211
|
-
_mm256_loadu_ps(y + 7 * 8),
|
|
1212
|
-
v0,
|
|
1213
|
-
v1,
|
|
1214
|
-
v2,
|
|
1215
|
-
v3,
|
|
1216
|
-
v4,
|
|
1217
|
-
v5,
|
|
1218
|
-
v6,
|
|
1219
|
-
v7);
|
|
1220
|
-
|
|
1221
|
-
// compute distances
|
|
1222
|
-
__m256 distances = _mm256_mul_ps(m0, v0);
|
|
1223
|
-
distances = _mm256_fmadd_ps(m1, v1, distances);
|
|
1224
|
-
distances = _mm256_fmadd_ps(m2, v2, distances);
|
|
1225
|
-
distances = _mm256_fmadd_ps(m3, v3, distances);
|
|
1226
|
-
distances = _mm256_fmadd_ps(m4, v4, distances);
|
|
1227
|
-
distances = _mm256_fmadd_ps(m5, v5, distances);
|
|
1228
|
-
distances = _mm256_fmadd_ps(m6, v6, distances);
|
|
1229
|
-
distances = _mm256_fmadd_ps(m7, v7, distances);
|
|
1230
|
-
|
|
1231
|
-
// store
|
|
1232
|
-
_mm256_storeu_ps(dis + i, distances);
|
|
1233
|
-
|
|
1234
|
-
y += 64;
|
|
1235
|
-
}
|
|
1236
|
-
}
|
|
1237
|
-
|
|
1238
|
-
if (i < ny) {
|
|
1239
|
-
// process leftovers
|
|
1240
|
-
__m256 x0 = _mm256_loadu_ps(x);
|
|
1241
|
-
|
|
1242
|
-
for (; i < ny; i++) {
|
|
1243
|
-
__m256 accu = ElementOpIP::op(x0, _mm256_loadu_ps(y));
|
|
1244
|
-
y += 8;
|
|
1245
|
-
dis[i] = horizontal_sum(accu);
|
|
1246
|
-
}
|
|
1247
|
-
}
|
|
1248
|
-
}
|
|
1249
|
-
|
|
1250
|
-
template <>
|
|
1251
|
-
void fvec_op_ny_D8<ElementOpL2>(
|
|
1252
|
-
float* dis,
|
|
1253
|
-
const float* x,
|
|
1254
|
-
const float* y,
|
|
1255
|
-
size_t ny) {
|
|
1256
|
-
const size_t ny8 = ny / 8;
|
|
1257
|
-
size_t i = 0;
|
|
1258
|
-
|
|
1259
|
-
if (ny8 > 0) {
|
|
1260
|
-
// process 8 D8-vectors per loop.
|
|
1261
|
-
const __m256 m0 = _mm256_set1_ps(x[0]);
|
|
1262
|
-
const __m256 m1 = _mm256_set1_ps(x[1]);
|
|
1263
|
-
const __m256 m2 = _mm256_set1_ps(x[2]);
|
|
1264
|
-
const __m256 m3 = _mm256_set1_ps(x[3]);
|
|
1265
|
-
const __m256 m4 = _mm256_set1_ps(x[4]);
|
|
1266
|
-
const __m256 m5 = _mm256_set1_ps(x[5]);
|
|
1267
|
-
const __m256 m6 = _mm256_set1_ps(x[6]);
|
|
1268
|
-
const __m256 m7 = _mm256_set1_ps(x[7]);
|
|
1269
|
-
|
|
1270
|
-
for (i = 0; i < ny8 * 8; i += 8) {
|
|
1271
|
-
// load 8x8 matrix and transpose it in registers.
|
|
1272
|
-
// the typical bottleneck is memory access, so
|
|
1273
|
-
// let's trade instructions for the bandwidth.
|
|
1274
|
-
|
|
1275
|
-
__m256 v0;
|
|
1276
|
-
__m256 v1;
|
|
1277
|
-
__m256 v2;
|
|
1278
|
-
__m256 v3;
|
|
1279
|
-
__m256 v4;
|
|
1280
|
-
__m256 v5;
|
|
1281
|
-
__m256 v6;
|
|
1282
|
-
__m256 v7;
|
|
1283
|
-
|
|
1284
|
-
transpose_8x8(
|
|
1285
|
-
_mm256_loadu_ps(y + 0 * 8),
|
|
1286
|
-
_mm256_loadu_ps(y + 1 * 8),
|
|
1287
|
-
_mm256_loadu_ps(y + 2 * 8),
|
|
1288
|
-
_mm256_loadu_ps(y + 3 * 8),
|
|
1289
|
-
_mm256_loadu_ps(y + 4 * 8),
|
|
1290
|
-
_mm256_loadu_ps(y + 5 * 8),
|
|
1291
|
-
_mm256_loadu_ps(y + 6 * 8),
|
|
1292
|
-
_mm256_loadu_ps(y + 7 * 8),
|
|
1293
|
-
v0,
|
|
1294
|
-
v1,
|
|
1295
|
-
v2,
|
|
1296
|
-
v3,
|
|
1297
|
-
v4,
|
|
1298
|
-
v5,
|
|
1299
|
-
v6,
|
|
1300
|
-
v7);
|
|
1301
|
-
|
|
1302
|
-
// compute differences
|
|
1303
|
-
const __m256 d0 = _mm256_sub_ps(m0, v0);
|
|
1304
|
-
const __m256 d1 = _mm256_sub_ps(m1, v1);
|
|
1305
|
-
const __m256 d2 = _mm256_sub_ps(m2, v2);
|
|
1306
|
-
const __m256 d3 = _mm256_sub_ps(m3, v3);
|
|
1307
|
-
const __m256 d4 = _mm256_sub_ps(m4, v4);
|
|
1308
|
-
const __m256 d5 = _mm256_sub_ps(m5, v5);
|
|
1309
|
-
const __m256 d6 = _mm256_sub_ps(m6, v6);
|
|
1310
|
-
const __m256 d7 = _mm256_sub_ps(m7, v7);
|
|
1311
|
-
|
|
1312
|
-
// compute squares of differences
|
|
1313
|
-
__m256 distances = _mm256_mul_ps(d0, d0);
|
|
1314
|
-
distances = _mm256_fmadd_ps(d1, d1, distances);
|
|
1315
|
-
distances = _mm256_fmadd_ps(d2, d2, distances);
|
|
1316
|
-
distances = _mm256_fmadd_ps(d3, d3, distances);
|
|
1317
|
-
distances = _mm256_fmadd_ps(d4, d4, distances);
|
|
1318
|
-
distances = _mm256_fmadd_ps(d5, d5, distances);
|
|
1319
|
-
distances = _mm256_fmadd_ps(d6, d6, distances);
|
|
1320
|
-
distances = _mm256_fmadd_ps(d7, d7, distances);
|
|
1321
|
-
|
|
1322
|
-
// store
|
|
1323
|
-
_mm256_storeu_ps(dis + i, distances);
|
|
1324
|
-
|
|
1325
|
-
y += 64;
|
|
1326
|
-
}
|
|
1327
|
-
}
|
|
1328
|
-
|
|
1329
|
-
if (i < ny) {
|
|
1330
|
-
// process leftovers
|
|
1331
|
-
__m256 x0 = _mm256_loadu_ps(x);
|
|
1332
|
-
|
|
1333
|
-
for (; i < ny; i++) {
|
|
1334
|
-
__m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y));
|
|
1335
|
-
y += 8;
|
|
1336
|
-
dis[i] = horizontal_sum(accu);
|
|
1337
|
-
}
|
|
1338
|
-
}
|
|
1339
|
-
}
|
|
1340
|
-
|
|
1341
|
-
#endif
|
|
1342
|
-
|
|
1343
|
-
template <class ElementOp>
|
|
1344
|
-
void fvec_op_ny_D12(float* dis, const float* x, const float* y, size_t ny) {
|
|
1345
|
-
__m128 x0 = _mm_loadu_ps(x);
|
|
1346
|
-
__m128 x1 = _mm_loadu_ps(x + 4);
|
|
1347
|
-
__m128 x2 = _mm_loadu_ps(x + 8);
|
|
1348
|
-
|
|
1349
|
-
for (size_t i = 0; i < ny; i++) {
|
|
1350
|
-
__m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
|
|
1351
|
-
y += 4;
|
|
1352
|
-
accu = _mm_add_ps(accu, ElementOp::op(x1, _mm_loadu_ps(y)));
|
|
1353
|
-
y += 4;
|
|
1354
|
-
accu = _mm_add_ps(accu, ElementOp::op(x2, _mm_loadu_ps(y)));
|
|
1355
|
-
y += 4;
|
|
1356
|
-
dis[i] = horizontal_sum(accu);
|
|
1357
|
-
}
|
|
1358
|
-
}
|
|
1359
|
-
|
|
1360
|
-
} // anonymous namespace
|
|
1361
|
-
|
|
1362
|
-
void fvec_L2sqr_ny(
|
|
1363
|
-
float* dis,
|
|
1364
|
-
const float* x,
|
|
1365
|
-
const float* y,
|
|
1366
|
-
size_t d,
|
|
1367
|
-
size_t ny) {
|
|
1368
|
-
// optimized for a few special cases
|
|
1369
|
-
|
|
1370
|
-
#define DISPATCH(dval) \
|
|
1371
|
-
case dval: \
|
|
1372
|
-
fvec_op_ny_D##dval<ElementOpL2>(dis, x, y, ny); \
|
|
1373
|
-
return;
|
|
1374
|
-
|
|
1375
|
-
switch (d) {
|
|
1376
|
-
DISPATCH(1)
|
|
1377
|
-
DISPATCH(2)
|
|
1378
|
-
DISPATCH(4)
|
|
1379
|
-
DISPATCH(8)
|
|
1380
|
-
DISPATCH(12)
|
|
1381
|
-
default:
|
|
1382
|
-
fvec_L2sqr_ny_ref(dis, x, y, d, ny);
|
|
1383
|
-
return;
|
|
1384
|
-
}
|
|
1385
|
-
#undef DISPATCH
|
|
1386
|
-
}
|
|
1387
|
-
|
|
1388
|
-
void fvec_inner_products_ny(
|
|
1389
|
-
float* dis,
|
|
1390
|
-
const float* x,
|
|
1391
|
-
const float* y,
|
|
1392
|
-
size_t d,
|
|
1393
|
-
size_t ny) {
|
|
1394
|
-
#define DISPATCH(dval) \
|
|
1395
|
-
case dval: \
|
|
1396
|
-
fvec_op_ny_D##dval<ElementOpIP>(dis, x, y, ny); \
|
|
1397
|
-
return;
|
|
1398
|
-
|
|
1399
|
-
switch (d) {
|
|
1400
|
-
DISPATCH(1)
|
|
1401
|
-
DISPATCH(2)
|
|
1402
|
-
DISPATCH(4)
|
|
1403
|
-
DISPATCH(8)
|
|
1404
|
-
DISPATCH(12)
|
|
1405
|
-
default:
|
|
1406
|
-
fvec_inner_products_ny_ref(dis, x, y, d, ny);
|
|
1407
|
-
return;
|
|
1408
|
-
}
|
|
1409
|
-
#undef DISPATCH
|
|
1410
|
-
}
|
|
1411
|
-
|
|
1412
|
-
#if defined(__AVX512F__)
|
|
1413
|
-
|
|
1414
|
-
template <size_t DIM>
|
|
1415
|
-
void fvec_L2sqr_ny_y_transposed_D(
|
|
1416
|
-
float* distances,
|
|
1417
|
-
const float* x,
|
|
1418
|
-
const float* y,
|
|
1419
|
-
const float* y_sqlen,
|
|
1420
|
-
const size_t d_offset,
|
|
1421
|
-
size_t ny) {
|
|
1422
|
-
// current index being processed
|
|
1423
|
-
size_t i = 0;
|
|
1424
|
-
|
|
1425
|
-
// squared length of x
|
|
1426
|
-
float x_sqlen = 0;
|
|
1427
|
-
for (size_t j = 0; j < DIM; j++) {
|
|
1428
|
-
x_sqlen += x[j] * x[j];
|
|
1429
|
-
}
|
|
1430
|
-
|
|
1431
|
-
// process 16 vectors per loop
|
|
1432
|
-
const size_t ny16 = ny / 16;
|
|
1433
|
-
|
|
1434
|
-
if (ny16 > 0) {
|
|
1435
|
-
// m[i] = (2 * x[i], ... 2 * x[i])
|
|
1436
|
-
__m512 m[DIM];
|
|
1437
|
-
for (size_t j = 0; j < DIM; j++) {
|
|
1438
|
-
m[j] = _mm512_set1_ps(x[j]);
|
|
1439
|
-
m[j] = _mm512_add_ps(m[j], m[j]); // m[j] = 2 * x[j]
|
|
1440
|
-
}
|
|
1441
|
-
|
|
1442
|
-
__m512 x_sqlen_ymm = _mm512_set1_ps(x_sqlen);
|
|
1443
|
-
|
|
1444
|
-
for (; i < ny16 * 16; i += 16) {
|
|
1445
|
-
// Load vectors for 16 dimensions
|
|
1446
|
-
__m512 v[DIM];
|
|
1447
|
-
for (size_t j = 0; j < DIM; j++) {
|
|
1448
|
-
v[j] = _mm512_loadu_ps(y + j * d_offset);
|
|
1449
|
-
}
|
|
1450
|
-
|
|
1451
|
-
// Compute dot products
|
|
1452
|
-
__m512 dp = _mm512_fnmadd_ps(m[0], v[0], x_sqlen_ymm);
|
|
1453
|
-
for (size_t j = 1; j < DIM; j++) {
|
|
1454
|
-
dp = _mm512_fnmadd_ps(m[j], v[j], dp);
|
|
1455
|
-
}
|
|
1456
|
-
|
|
1457
|
-
// Compute y^2 - (2 * x, y) + x^2
|
|
1458
|
-
__m512 distances_v = _mm512_add_ps(_mm512_loadu_ps(y_sqlen), dp);
|
|
1459
|
-
|
|
1460
|
-
_mm512_storeu_ps(distances + i, distances_v);
|
|
1461
|
-
|
|
1462
|
-
// Scroll y and y_sqlen forward
|
|
1463
|
-
y += 16;
|
|
1464
|
-
y_sqlen += 16;
|
|
1465
|
-
}
|
|
1466
|
-
}
|
|
1467
|
-
|
|
1468
|
-
if (i < ny) {
|
|
1469
|
-
// Process leftovers
|
|
1470
|
-
for (; i < ny; i++) {
|
|
1471
|
-
float dp = 0;
|
|
1472
|
-
for (size_t j = 0; j < DIM; j++) {
|
|
1473
|
-
dp += x[j] * y[j * d_offset];
|
|
1474
|
-
}
|
|
1475
|
-
|
|
1476
|
-
// Compute y^2 - 2 * (x, y), which is sufficient for looking for the
|
|
1477
|
-
// lowest distance.
|
|
1478
|
-
const float distance = y_sqlen[0] - 2 * dp + x_sqlen;
|
|
1479
|
-
distances[i] = distance;
|
|
1480
|
-
|
|
1481
|
-
y += 1;
|
|
1482
|
-
y_sqlen += 1;
|
|
1483
|
-
}
|
|
1484
|
-
}
|
|
1485
|
-
}
|
|
1486
|
-
|
|
1487
|
-
#elif defined(__AVX2__)
|
|
1488
|
-
|
|
1489
|
-
template <size_t DIM>
|
|
1490
|
-
void fvec_L2sqr_ny_y_transposed_D(
|
|
1491
|
-
float* distances,
|
|
1492
|
-
const float* x,
|
|
1493
|
-
const float* y,
|
|
1494
|
-
const float* y_sqlen,
|
|
1495
|
-
const size_t d_offset,
|
|
1496
|
-
size_t ny) {
|
|
1497
|
-
// current index being processed
|
|
1498
|
-
size_t i = 0;
|
|
1499
|
-
|
|
1500
|
-
// squared length of x
|
|
1501
|
-
float x_sqlen = 0;
|
|
1502
|
-
for (size_t j = 0; j < DIM; j++) {
|
|
1503
|
-
x_sqlen += x[j] * x[j];
|
|
1504
|
-
}
|
|
1505
|
-
|
|
1506
|
-
// process 8 vectors per loop.
|
|
1507
|
-
const size_t ny8 = ny / 8;
|
|
1508
|
-
|
|
1509
|
-
if (ny8 > 0) {
|
|
1510
|
-
// m[i] = (2 * x[i], ... 2 * x[i])
|
|
1511
|
-
__m256 m[DIM];
|
|
1512
|
-
for (size_t j = 0; j < DIM; j++) {
|
|
1513
|
-
m[j] = _mm256_set1_ps(x[j]);
|
|
1514
|
-
m[j] = _mm256_add_ps(m[j], m[j]);
|
|
1515
|
-
}
|
|
1516
|
-
|
|
1517
|
-
__m256 x_sqlen_ymm = _mm256_set1_ps(x_sqlen);
|
|
1518
|
-
|
|
1519
|
-
for (; i < ny8 * 8; i += 8) {
|
|
1520
|
-
// collect dim 0 for 8 D4-vectors.
|
|
1521
|
-
const __m256 v0 = _mm256_loadu_ps(y + 0 * d_offset);
|
|
1522
|
-
|
|
1523
|
-
// compute dot products
|
|
1524
|
-
// this is x^2 - 2x[0]*y[0]
|
|
1525
|
-
__m256 dp = _mm256_fnmadd_ps(m[0], v0, x_sqlen_ymm);
|
|
1526
|
-
|
|
1527
|
-
for (size_t j = 1; j < DIM; j++) {
|
|
1528
|
-
// collect dim j for 8 D4-vectors.
|
|
1529
|
-
const __m256 vj = _mm256_loadu_ps(y + j * d_offset);
|
|
1530
|
-
dp = _mm256_fnmadd_ps(m[j], vj, dp);
|
|
1531
|
-
}
|
|
1532
|
-
|
|
1533
|
-
// we've got x^2 - (2x, y) at this point
|
|
1534
|
-
|
|
1535
|
-
// y^2 - (2x, y) + x^2
|
|
1536
|
-
__m256 distances_v = _mm256_add_ps(_mm256_loadu_ps(y_sqlen), dp);
|
|
1537
|
-
|
|
1538
|
-
_mm256_storeu_ps(distances + i, distances_v);
|
|
1539
|
-
|
|
1540
|
-
// scroll y and y_sqlen forward.
|
|
1541
|
-
y += 8;
|
|
1542
|
-
y_sqlen += 8;
|
|
1543
|
-
}
|
|
1544
|
-
}
|
|
1545
|
-
|
|
1546
|
-
if (i < ny) {
|
|
1547
|
-
// process leftovers
|
|
1548
|
-
for (; i < ny; i++) {
|
|
1549
|
-
float dp = 0;
|
|
1550
|
-
for (size_t j = 0; j < DIM; j++) {
|
|
1551
|
-
dp += x[j] * y[j * d_offset];
|
|
1552
|
-
}
|
|
1553
|
-
|
|
1554
|
-
// compute y^2 - 2 * (x, y), which is sufficient for looking for the
|
|
1555
|
-
// lowest distance.
|
|
1556
|
-
const float distance = y_sqlen[0] - 2 * dp + x_sqlen;
|
|
1557
|
-
distances[i] = distance;
|
|
1558
|
-
|
|
1559
|
-
y += 1;
|
|
1560
|
-
y_sqlen += 1;
|
|
1561
|
-
}
|
|
1562
|
-
}
|
|
1563
|
-
}
|
|
1564
|
-
|
|
1565
|
-
#endif
|
|
1566
|
-
|
|
1567
|
-
void fvec_L2sqr_ny_transposed(
|
|
1568
|
-
float* dis,
|
|
1569
|
-
const float* x,
|
|
1570
|
-
const float* y,
|
|
1571
|
-
const float* y_sqlen,
|
|
1572
|
-
size_t d,
|
|
1573
|
-
size_t d_offset,
|
|
1574
|
-
size_t ny) {
|
|
1575
|
-
// optimized for a few special cases
|
|
1576
|
-
|
|
1577
|
-
#ifdef __AVX2__
|
|
1578
|
-
#define DISPATCH(dval) \
|
|
1579
|
-
case dval: \
|
|
1580
|
-
return fvec_L2sqr_ny_y_transposed_D<dval>( \
|
|
1581
|
-
dis, x, y, y_sqlen, d_offset, ny);
|
|
1582
|
-
|
|
1583
|
-
switch (d) {
|
|
1584
|
-
DISPATCH(1)
|
|
1585
|
-
DISPATCH(2)
|
|
1586
|
-
DISPATCH(4)
|
|
1587
|
-
DISPATCH(8)
|
|
1588
|
-
default:
|
|
1589
|
-
return fvec_L2sqr_ny_y_transposed_ref(
|
|
1590
|
-
dis, x, y, y_sqlen, d, d_offset, ny);
|
|
1591
|
-
}
|
|
1592
|
-
#undef DISPATCH
|
|
1593
|
-
#else
|
|
1594
|
-
// non-AVX2 case
|
|
1595
|
-
return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny);
|
|
1596
|
-
#endif
|
|
1597
|
-
}
|
|
1598
|
-
|
|
1599
|
-
#if defined(__AVX512F__)
|
|
1600
|
-
|
|
1601
|
-
size_t fvec_L2sqr_ny_nearest_D2(
|
|
1602
|
-
float* distances_tmp_buffer,
|
|
1603
|
-
const float* x,
|
|
1604
|
-
const float* y,
|
|
1605
|
-
size_t ny) {
|
|
1606
|
-
// this implementation does not use distances_tmp_buffer.
|
|
1607
|
-
|
|
1608
|
-
size_t i = 0;
|
|
1609
|
-
float current_min_distance = HUGE_VALF;
|
|
1610
|
-
size_t current_min_index = 0;
|
|
1611
|
-
|
|
1612
|
-
const size_t ny16 = ny / 16;
|
|
1613
|
-
if (ny16 > 0) {
|
|
1614
|
-
_mm_prefetch((const char*)y, _MM_HINT_T0);
|
|
1615
|
-
_mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
|
|
1616
|
-
|
|
1617
|
-
__m512 min_distances = _mm512_set1_ps(HUGE_VALF);
|
|
1618
|
-
__m512i min_indices = _mm512_set1_epi32(0);
|
|
1619
|
-
|
|
1620
|
-
__m512i current_indices = _mm512_setr_epi32(
|
|
1621
|
-
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
|
|
1622
|
-
const __m512i indices_increment = _mm512_set1_epi32(16);
|
|
1623
|
-
|
|
1624
|
-
const __m512 m0 = _mm512_set1_ps(x[0]);
|
|
1625
|
-
const __m512 m1 = _mm512_set1_ps(x[1]);
|
|
1626
|
-
|
|
1627
|
-
for (; i < ny16 * 16; i += 16) {
|
|
1628
|
-
_mm_prefetch((const char*)(y + 64), _MM_HINT_T0);
|
|
1629
|
-
|
|
1630
|
-
__m512 v0;
|
|
1631
|
-
__m512 v1;
|
|
1632
|
-
|
|
1633
|
-
transpose_16x2(
|
|
1634
|
-
_mm512_loadu_ps(y + 0 * 16),
|
|
1635
|
-
_mm512_loadu_ps(y + 1 * 16),
|
|
1636
|
-
v0,
|
|
1637
|
-
v1);
|
|
1638
|
-
|
|
1639
|
-
const __m512 d0 = _mm512_sub_ps(m0, v0);
|
|
1640
|
-
const __m512 d1 = _mm512_sub_ps(m1, v1);
|
|
1641
|
-
|
|
1642
|
-
__m512 distances = _mm512_mul_ps(d0, d0);
|
|
1643
|
-
distances = _mm512_fmadd_ps(d1, d1, distances);
|
|
1644
|
-
|
|
1645
|
-
__mmask16 comparison =
|
|
1646
|
-
_mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS);
|
|
1647
|
-
|
|
1648
|
-
min_distances = _mm512_min_ps(distances, min_distances);
|
|
1649
|
-
min_indices = _mm512_mask_blend_epi32(
|
|
1650
|
-
comparison, min_indices, current_indices);
|
|
1651
|
-
|
|
1652
|
-
current_indices =
|
|
1653
|
-
_mm512_add_epi32(current_indices, indices_increment);
|
|
1654
|
-
|
|
1655
|
-
y += 32;
|
|
1656
|
-
}
|
|
1657
|
-
|
|
1658
|
-
alignas(64) float min_distances_scalar[16];
|
|
1659
|
-
alignas(64) uint32_t min_indices_scalar[16];
|
|
1660
|
-
_mm512_store_ps(min_distances_scalar, min_distances);
|
|
1661
|
-
_mm512_store_epi32(min_indices_scalar, min_indices);
|
|
1662
|
-
|
|
1663
|
-
for (size_t j = 0; j < 16; j++) {
|
|
1664
|
-
if (current_min_distance > min_distances_scalar[j]) {
|
|
1665
|
-
current_min_distance = min_distances_scalar[j];
|
|
1666
|
-
current_min_index = min_indices_scalar[j];
|
|
1667
|
-
}
|
|
1668
|
-
}
|
|
1669
|
-
}
|
|
1670
|
-
|
|
1671
|
-
if (i < ny) {
|
|
1672
|
-
float x0 = x[0];
|
|
1673
|
-
float x1 = x[1];
|
|
1674
|
-
|
|
1675
|
-
for (; i < ny; i++) {
|
|
1676
|
-
float sub0 = x0 - y[0];
|
|
1677
|
-
float sub1 = x1 - y[1];
|
|
1678
|
-
float distance = sub0 * sub0 + sub1 * sub1;
|
|
1679
|
-
|
|
1680
|
-
y += 2;
|
|
1681
|
-
|
|
1682
|
-
if (current_min_distance > distance) {
|
|
1683
|
-
current_min_distance = distance;
|
|
1684
|
-
current_min_index = i;
|
|
1685
|
-
}
|
|
1686
|
-
}
|
|
1687
|
-
}
|
|
1688
|
-
|
|
1689
|
-
return current_min_index;
|
|
1690
|
-
}
|
|
1691
|
-
|
|
1692
|
-
size_t fvec_L2sqr_ny_nearest_D4(
|
|
1693
|
-
float* distances_tmp_buffer,
|
|
1694
|
-
const float* x,
|
|
1695
|
-
const float* y,
|
|
1696
|
-
size_t ny) {
|
|
1697
|
-
// this implementation does not use distances_tmp_buffer.
|
|
1698
|
-
|
|
1699
|
-
size_t i = 0;
|
|
1700
|
-
float current_min_distance = HUGE_VALF;
|
|
1701
|
-
size_t current_min_index = 0;
|
|
1702
|
-
|
|
1703
|
-
const size_t ny16 = ny / 16;
|
|
1704
|
-
|
|
1705
|
-
if (ny16 > 0) {
|
|
1706
|
-
__m512 min_distances = _mm512_set1_ps(HUGE_VALF);
|
|
1707
|
-
__m512i min_indices = _mm512_set1_epi32(0);
|
|
1708
|
-
|
|
1709
|
-
__m512i current_indices = _mm512_setr_epi32(
|
|
1710
|
-
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
|
|
1711
|
-
const __m512i indices_increment = _mm512_set1_epi32(16);
|
|
1712
|
-
|
|
1713
|
-
const __m512 m0 = _mm512_set1_ps(x[0]);
|
|
1714
|
-
const __m512 m1 = _mm512_set1_ps(x[1]);
|
|
1715
|
-
const __m512 m2 = _mm512_set1_ps(x[2]);
|
|
1716
|
-
const __m512 m3 = _mm512_set1_ps(x[3]);
|
|
1717
|
-
|
|
1718
|
-
for (; i < ny16 * 16; i += 16) {
|
|
1719
|
-
__m512 v0;
|
|
1720
|
-
__m512 v1;
|
|
1721
|
-
__m512 v2;
|
|
1722
|
-
__m512 v3;
|
|
1723
|
-
|
|
1724
|
-
transpose_16x4(
|
|
1725
|
-
_mm512_loadu_ps(y + 0 * 16),
|
|
1726
|
-
_mm512_loadu_ps(y + 1 * 16),
|
|
1727
|
-
_mm512_loadu_ps(y + 2 * 16),
|
|
1728
|
-
_mm512_loadu_ps(y + 3 * 16),
|
|
1729
|
-
v0,
|
|
1730
|
-
v1,
|
|
1731
|
-
v2,
|
|
1732
|
-
v3);
|
|
1733
|
-
|
|
1734
|
-
const __m512 d0 = _mm512_sub_ps(m0, v0);
|
|
1735
|
-
const __m512 d1 = _mm512_sub_ps(m1, v1);
|
|
1736
|
-
const __m512 d2 = _mm512_sub_ps(m2, v2);
|
|
1737
|
-
const __m512 d3 = _mm512_sub_ps(m3, v3);
|
|
1738
|
-
|
|
1739
|
-
__m512 distances = _mm512_mul_ps(d0, d0);
|
|
1740
|
-
distances = _mm512_fmadd_ps(d1, d1, distances);
|
|
1741
|
-
distances = _mm512_fmadd_ps(d2, d2, distances);
|
|
1742
|
-
distances = _mm512_fmadd_ps(d3, d3, distances);
|
|
1743
|
-
|
|
1744
|
-
__mmask16 comparison =
|
|
1745
|
-
_mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS);
|
|
1746
|
-
|
|
1747
|
-
min_distances = _mm512_min_ps(distances, min_distances);
|
|
1748
|
-
min_indices = _mm512_mask_blend_epi32(
|
|
1749
|
-
comparison, min_indices, current_indices);
|
|
1750
|
-
|
|
1751
|
-
current_indices =
|
|
1752
|
-
_mm512_add_epi32(current_indices, indices_increment);
|
|
1753
|
-
|
|
1754
|
-
y += 64;
|
|
1755
|
-
}
|
|
1756
|
-
|
|
1757
|
-
alignas(64) float min_distances_scalar[16];
|
|
1758
|
-
alignas(64) uint32_t min_indices_scalar[16];
|
|
1759
|
-
_mm512_store_ps(min_distances_scalar, min_distances);
|
|
1760
|
-
_mm512_store_epi32(min_indices_scalar, min_indices);
|
|
1761
|
-
|
|
1762
|
-
for (size_t j = 0; j < 16; j++) {
|
|
1763
|
-
if (current_min_distance > min_distances_scalar[j]) {
|
|
1764
|
-
current_min_distance = min_distances_scalar[j];
|
|
1765
|
-
current_min_index = min_indices_scalar[j];
|
|
1766
|
-
}
|
|
1767
|
-
}
|
|
1768
|
-
}
|
|
1769
|
-
|
|
1770
|
-
if (i < ny) {
|
|
1771
|
-
__m128 x0 = _mm_loadu_ps(x);
|
|
1772
|
-
|
|
1773
|
-
for (; i < ny; i++) {
|
|
1774
|
-
__m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
|
|
1775
|
-
y += 4;
|
|
1776
|
-
const float distance = horizontal_sum(accu);
|
|
1777
|
-
|
|
1778
|
-
if (current_min_distance > distance) {
|
|
1779
|
-
current_min_distance = distance;
|
|
1780
|
-
current_min_index = i;
|
|
1781
|
-
}
|
|
1782
|
-
}
|
|
1783
|
-
}
|
|
1784
|
-
|
|
1785
|
-
return current_min_index;
|
|
1786
|
-
}
|
|
1787
|
-
|
|
1788
|
-
size_t fvec_L2sqr_ny_nearest_D8(
|
|
1789
|
-
float* distances_tmp_buffer,
|
|
1790
|
-
const float* x,
|
|
1791
|
-
const float* y,
|
|
1792
|
-
size_t ny) {
|
|
1793
|
-
// this implementation does not use distances_tmp_buffer.
|
|
1794
|
-
|
|
1795
|
-
size_t i = 0;
|
|
1796
|
-
float current_min_distance = HUGE_VALF;
|
|
1797
|
-
size_t current_min_index = 0;
|
|
1798
|
-
|
|
1799
|
-
const size_t ny16 = ny / 16;
|
|
1800
|
-
if (ny16 > 0) {
|
|
1801
|
-
__m512 min_distances = _mm512_set1_ps(HUGE_VALF);
|
|
1802
|
-
__m512i min_indices = _mm512_set1_epi32(0);
|
|
1803
|
-
|
|
1804
|
-
__m512i current_indices = _mm512_setr_epi32(
|
|
1805
|
-
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
|
|
1806
|
-
const __m512i indices_increment = _mm512_set1_epi32(16);
|
|
1807
|
-
|
|
1808
|
-
const __m512 m0 = _mm512_set1_ps(x[0]);
|
|
1809
|
-
const __m512 m1 = _mm512_set1_ps(x[1]);
|
|
1810
|
-
const __m512 m2 = _mm512_set1_ps(x[2]);
|
|
1811
|
-
const __m512 m3 = _mm512_set1_ps(x[3]);
|
|
1812
|
-
|
|
1813
|
-
const __m512 m4 = _mm512_set1_ps(x[4]);
|
|
1814
|
-
const __m512 m5 = _mm512_set1_ps(x[5]);
|
|
1815
|
-
const __m512 m6 = _mm512_set1_ps(x[6]);
|
|
1816
|
-
const __m512 m7 = _mm512_set1_ps(x[7]);
|
|
1817
|
-
|
|
1818
|
-
for (; i < ny16 * 16; i += 16) {
|
|
1819
|
-
__m512 v0;
|
|
1820
|
-
__m512 v1;
|
|
1821
|
-
__m512 v2;
|
|
1822
|
-
__m512 v3;
|
|
1823
|
-
__m512 v4;
|
|
1824
|
-
__m512 v5;
|
|
1825
|
-
__m512 v6;
|
|
1826
|
-
__m512 v7;
|
|
1827
|
-
|
|
1828
|
-
transpose_16x8(
|
|
1829
|
-
_mm512_loadu_ps(y + 0 * 16),
|
|
1830
|
-
_mm512_loadu_ps(y + 1 * 16),
|
|
1831
|
-
_mm512_loadu_ps(y + 2 * 16),
|
|
1832
|
-
_mm512_loadu_ps(y + 3 * 16),
|
|
1833
|
-
_mm512_loadu_ps(y + 4 * 16),
|
|
1834
|
-
_mm512_loadu_ps(y + 5 * 16),
|
|
1835
|
-
_mm512_loadu_ps(y + 6 * 16),
|
|
1836
|
-
_mm512_loadu_ps(y + 7 * 16),
|
|
1837
|
-
v0,
|
|
1838
|
-
v1,
|
|
1839
|
-
v2,
|
|
1840
|
-
v3,
|
|
1841
|
-
v4,
|
|
1842
|
-
v5,
|
|
1843
|
-
v6,
|
|
1844
|
-
v7);
|
|
1845
|
-
|
|
1846
|
-
const __m512 d0 = _mm512_sub_ps(m0, v0);
|
|
1847
|
-
const __m512 d1 = _mm512_sub_ps(m1, v1);
|
|
1848
|
-
const __m512 d2 = _mm512_sub_ps(m2, v2);
|
|
1849
|
-
const __m512 d3 = _mm512_sub_ps(m3, v3);
|
|
1850
|
-
const __m512 d4 = _mm512_sub_ps(m4, v4);
|
|
1851
|
-
const __m512 d5 = _mm512_sub_ps(m5, v5);
|
|
1852
|
-
const __m512 d6 = _mm512_sub_ps(m6, v6);
|
|
1853
|
-
const __m512 d7 = _mm512_sub_ps(m7, v7);
|
|
1854
|
-
|
|
1855
|
-
__m512 distances = _mm512_mul_ps(d0, d0);
|
|
1856
|
-
distances = _mm512_fmadd_ps(d1, d1, distances);
|
|
1857
|
-
distances = _mm512_fmadd_ps(d2, d2, distances);
|
|
1858
|
-
distances = _mm512_fmadd_ps(d3, d3, distances);
|
|
1859
|
-
distances = _mm512_fmadd_ps(d4, d4, distances);
|
|
1860
|
-
distances = _mm512_fmadd_ps(d5, d5, distances);
|
|
1861
|
-
distances = _mm512_fmadd_ps(d6, d6, distances);
|
|
1862
|
-
distances = _mm512_fmadd_ps(d7, d7, distances);
|
|
1863
|
-
|
|
1864
|
-
__mmask16 comparison =
|
|
1865
|
-
_mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS);
|
|
1866
|
-
|
|
1867
|
-
min_distances = _mm512_min_ps(distances, min_distances);
|
|
1868
|
-
min_indices = _mm512_mask_blend_epi32(
|
|
1869
|
-
comparison, min_indices, current_indices);
|
|
1870
|
-
|
|
1871
|
-
current_indices =
|
|
1872
|
-
_mm512_add_epi32(current_indices, indices_increment);
|
|
1873
|
-
|
|
1874
|
-
y += 128;
|
|
1875
|
-
}
|
|
1876
|
-
|
|
1877
|
-
alignas(64) float min_distances_scalar[16];
|
|
1878
|
-
alignas(64) uint32_t min_indices_scalar[16];
|
|
1879
|
-
_mm512_store_ps(min_distances_scalar, min_distances);
|
|
1880
|
-
_mm512_store_epi32(min_indices_scalar, min_indices);
|
|
1881
|
-
|
|
1882
|
-
for (size_t j = 0; j < 16; j++) {
|
|
1883
|
-
if (current_min_distance > min_distances_scalar[j]) {
|
|
1884
|
-
current_min_distance = min_distances_scalar[j];
|
|
1885
|
-
current_min_index = min_indices_scalar[j];
|
|
1886
|
-
}
|
|
1887
|
-
}
|
|
1888
|
-
}
|
|
1889
|
-
|
|
1890
|
-
if (i < ny) {
|
|
1891
|
-
__m256 x0 = _mm256_loadu_ps(x);
|
|
1892
|
-
|
|
1893
|
-
for (; i < ny; i++) {
|
|
1894
|
-
__m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y));
|
|
1895
|
-
y += 8;
|
|
1896
|
-
const float distance = horizontal_sum(accu);
|
|
1897
|
-
|
|
1898
|
-
if (current_min_distance > distance) {
|
|
1899
|
-
current_min_distance = distance;
|
|
1900
|
-
current_min_index = i;
|
|
1901
|
-
}
|
|
1902
|
-
}
|
|
1903
|
-
}
|
|
1904
|
-
|
|
1905
|
-
return current_min_index;
|
|
1906
|
-
}
|
|
1907
|
-
|
|
1908
|
-
#elif defined(__AVX2__)
|
|
1909
|
-
|
|
1910
|
-
size_t fvec_L2sqr_ny_nearest_D2(
|
|
1911
|
-
float* distances_tmp_buffer,
|
|
1912
|
-
const float* x,
|
|
1913
|
-
const float* y,
|
|
1914
|
-
size_t ny) {
|
|
1915
|
-
// this implementation does not use distances_tmp_buffer.
|
|
1916
|
-
|
|
1917
|
-
// current index being processed
|
|
1918
|
-
size_t i = 0;
|
|
1919
|
-
|
|
1920
|
-
// min distance and the index of the closest vector so far
|
|
1921
|
-
float current_min_distance = HUGE_VALF;
|
|
1922
|
-
size_t current_min_index = 0;
|
|
1923
|
-
|
|
1924
|
-
// process 8 D2-vectors per loop.
|
|
1925
|
-
const size_t ny8 = ny / 8;
|
|
1926
|
-
if (ny8 > 0) {
|
|
1927
|
-
_mm_prefetch((const char*)y, _MM_HINT_T0);
|
|
1928
|
-
_mm_prefetch((const char*)(y + 16), _MM_HINT_T0);
|
|
1929
|
-
|
|
1930
|
-
// track min distance and the closest vector independently
|
|
1931
|
-
// for each of 8 AVX2 components.
|
|
1932
|
-
__m256 min_distances = _mm256_set1_ps(HUGE_VALF);
|
|
1933
|
-
__m256i min_indices = _mm256_set1_epi32(0);
|
|
1934
|
-
|
|
1935
|
-
__m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
|
|
1936
|
-
const __m256i indices_increment = _mm256_set1_epi32(8);
|
|
1937
|
-
|
|
1938
|
-
// 1 value per register
|
|
1939
|
-
const __m256 m0 = _mm256_set1_ps(x[0]);
|
|
1940
|
-
const __m256 m1 = _mm256_set1_ps(x[1]);
|
|
1941
|
-
|
|
1942
|
-
for (; i < ny8 * 8; i += 8) {
|
|
1943
|
-
_mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
|
|
1944
|
-
|
|
1945
|
-
__m256 v0;
|
|
1946
|
-
__m256 v1;
|
|
1947
|
-
|
|
1948
|
-
transpose_8x2(
|
|
1949
|
-
_mm256_loadu_ps(y + 0 * 8),
|
|
1950
|
-
_mm256_loadu_ps(y + 1 * 8),
|
|
1951
|
-
v0,
|
|
1952
|
-
v1);
|
|
1953
|
-
|
|
1954
|
-
// compute differences
|
|
1955
|
-
const __m256 d0 = _mm256_sub_ps(m0, v0);
|
|
1956
|
-
const __m256 d1 = _mm256_sub_ps(m1, v1);
|
|
1957
|
-
|
|
1958
|
-
// compute squares of differences
|
|
1959
|
-
__m256 distances = _mm256_mul_ps(d0, d0);
|
|
1960
|
-
distances = _mm256_fmadd_ps(d1, d1, distances);
|
|
1961
|
-
|
|
1962
|
-
// compare the new distances to the min distances
|
|
1963
|
-
// for each of 8 AVX2 components.
|
|
1964
|
-
__m256 comparison =
|
|
1965
|
-
_mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
|
|
1966
|
-
|
|
1967
|
-
// update min distances and indices with closest vectors if needed.
|
|
1968
|
-
min_distances = _mm256_min_ps(distances, min_distances);
|
|
1969
|
-
min_indices = _mm256_castps_si256(_mm256_blendv_ps(
|
|
1970
|
-
_mm256_castsi256_ps(current_indices),
|
|
1971
|
-
_mm256_castsi256_ps(min_indices),
|
|
1972
|
-
comparison));
|
|
1973
|
-
|
|
1974
|
-
// update current indices values. Basically, +8 to each of the
|
|
1975
|
-
// 8 AVX2 components.
|
|
1976
|
-
current_indices =
|
|
1977
|
-
_mm256_add_epi32(current_indices, indices_increment);
|
|
1978
|
-
|
|
1979
|
-
// scroll y forward (8 vectors 2 DIM each).
|
|
1980
|
-
y += 16;
|
|
1981
|
-
}
|
|
1982
|
-
|
|
1983
|
-
// dump values and find the minimum distance / minimum index
|
|
1984
|
-
float min_distances_scalar[8];
|
|
1985
|
-
uint32_t min_indices_scalar[8];
|
|
1986
|
-
_mm256_storeu_ps(min_distances_scalar, min_distances);
|
|
1987
|
-
_mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
|
|
1988
|
-
|
|
1989
|
-
for (size_t j = 0; j < 8; j++) {
|
|
1990
|
-
if (current_min_distance > min_distances_scalar[j]) {
|
|
1991
|
-
current_min_distance = min_distances_scalar[j];
|
|
1992
|
-
current_min_index = min_indices_scalar[j];
|
|
1993
|
-
}
|
|
1994
|
-
}
|
|
1995
|
-
}
|
|
1996
|
-
|
|
1997
|
-
if (i < ny) {
|
|
1998
|
-
// process leftovers.
|
|
1999
|
-
// the following code is not optimal, but it is rarely invoked.
|
|
2000
|
-
float x0 = x[0];
|
|
2001
|
-
float x1 = x[1];
|
|
2002
|
-
|
|
2003
|
-
for (; i < ny; i++) {
|
|
2004
|
-
float sub0 = x0 - y[0];
|
|
2005
|
-
float sub1 = x1 - y[1];
|
|
2006
|
-
float distance = sub0 * sub0 + sub1 * sub1;
|
|
2007
|
-
|
|
2008
|
-
y += 2;
|
|
2009
|
-
|
|
2010
|
-
if (current_min_distance > distance) {
|
|
2011
|
-
current_min_distance = distance;
|
|
2012
|
-
current_min_index = i;
|
|
2013
|
-
}
|
|
2014
|
-
}
|
|
2015
|
-
}
|
|
2016
|
-
|
|
2017
|
-
return current_min_index;
|
|
2018
|
-
}
|
|
2019
|
-
|
|
2020
|
-
size_t fvec_L2sqr_ny_nearest_D4(
|
|
2021
|
-
float* distances_tmp_buffer,
|
|
2022
|
-
const float* x,
|
|
2023
|
-
const float* y,
|
|
2024
|
-
size_t ny) {
|
|
2025
|
-
// this implementation does not use distances_tmp_buffer.
|
|
2026
|
-
|
|
2027
|
-
// current index being processed
|
|
2028
|
-
size_t i = 0;
|
|
2029
|
-
|
|
2030
|
-
// min distance and the index of the closest vector so far
|
|
2031
|
-
float current_min_distance = HUGE_VALF;
|
|
2032
|
-
size_t current_min_index = 0;
|
|
2033
|
-
|
|
2034
|
-
// process 8 D4-vectors per loop.
|
|
2035
|
-
const size_t ny8 = ny / 8;
|
|
2036
|
-
|
|
2037
|
-
if (ny8 > 0) {
|
|
2038
|
-
// track min distance and the closest vector independently
|
|
2039
|
-
// for each of 8 AVX2 components.
|
|
2040
|
-
__m256 min_distances = _mm256_set1_ps(HUGE_VALF);
|
|
2041
|
-
__m256i min_indices = _mm256_set1_epi32(0);
|
|
2042
|
-
|
|
2043
|
-
__m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
|
|
2044
|
-
const __m256i indices_increment = _mm256_set1_epi32(8);
|
|
2045
|
-
|
|
2046
|
-
// 1 value per register
|
|
2047
|
-
const __m256 m0 = _mm256_set1_ps(x[0]);
|
|
2048
|
-
const __m256 m1 = _mm256_set1_ps(x[1]);
|
|
2049
|
-
const __m256 m2 = _mm256_set1_ps(x[2]);
|
|
2050
|
-
const __m256 m3 = _mm256_set1_ps(x[3]);
|
|
2051
|
-
|
|
2052
|
-
for (; i < ny8 * 8; i += 8) {
|
|
2053
|
-
__m256 v0;
|
|
2054
|
-
__m256 v1;
|
|
2055
|
-
__m256 v2;
|
|
2056
|
-
__m256 v3;
|
|
2057
|
-
|
|
2058
|
-
transpose_8x4(
|
|
2059
|
-
_mm256_loadu_ps(y + 0 * 8),
|
|
2060
|
-
_mm256_loadu_ps(y + 1 * 8),
|
|
2061
|
-
_mm256_loadu_ps(y + 2 * 8),
|
|
2062
|
-
_mm256_loadu_ps(y + 3 * 8),
|
|
2063
|
-
v0,
|
|
2064
|
-
v1,
|
|
2065
|
-
v2,
|
|
2066
|
-
v3);
|
|
2067
|
-
|
|
2068
|
-
// compute differences
|
|
2069
|
-
const __m256 d0 = _mm256_sub_ps(m0, v0);
|
|
2070
|
-
const __m256 d1 = _mm256_sub_ps(m1, v1);
|
|
2071
|
-
const __m256 d2 = _mm256_sub_ps(m2, v2);
|
|
2072
|
-
const __m256 d3 = _mm256_sub_ps(m3, v3);
|
|
2073
|
-
|
|
2074
|
-
// compute squares of differences
|
|
2075
|
-
__m256 distances = _mm256_mul_ps(d0, d0);
|
|
2076
|
-
distances = _mm256_fmadd_ps(d1, d1, distances);
|
|
2077
|
-
distances = _mm256_fmadd_ps(d2, d2, distances);
|
|
2078
|
-
distances = _mm256_fmadd_ps(d3, d3, distances);
|
|
2079
|
-
|
|
2080
|
-
// compare the new distances to the min distances
|
|
2081
|
-
// for each of 8 AVX2 components.
|
|
2082
|
-
__m256 comparison =
|
|
2083
|
-
_mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
|
|
2084
|
-
|
|
2085
|
-
// update min distances and indices with closest vectors if needed.
|
|
2086
|
-
min_distances = _mm256_min_ps(distances, min_distances);
|
|
2087
|
-
min_indices = _mm256_castps_si256(_mm256_blendv_ps(
|
|
2088
|
-
_mm256_castsi256_ps(current_indices),
|
|
2089
|
-
_mm256_castsi256_ps(min_indices),
|
|
2090
|
-
comparison));
|
|
2091
|
-
|
|
2092
|
-
// update current indices values. Basically, +8 to each of the
|
|
2093
|
-
// 8 AVX2 components.
|
|
2094
|
-
current_indices =
|
|
2095
|
-
_mm256_add_epi32(current_indices, indices_increment);
|
|
2096
|
-
|
|
2097
|
-
// scroll y forward (8 vectors 4 DIM each).
|
|
2098
|
-
y += 32;
|
|
2099
|
-
}
|
|
2100
|
-
|
|
2101
|
-
// dump values and find the minimum distance / minimum index
|
|
2102
|
-
float min_distances_scalar[8];
|
|
2103
|
-
uint32_t min_indices_scalar[8];
|
|
2104
|
-
_mm256_storeu_ps(min_distances_scalar, min_distances);
|
|
2105
|
-
_mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
|
|
2106
|
-
|
|
2107
|
-
for (size_t j = 0; j < 8; j++) {
|
|
2108
|
-
if (current_min_distance > min_distances_scalar[j]) {
|
|
2109
|
-
current_min_distance = min_distances_scalar[j];
|
|
2110
|
-
current_min_index = min_indices_scalar[j];
|
|
2111
|
-
}
|
|
2112
|
-
}
|
|
2113
|
-
}
|
|
2114
|
-
|
|
2115
|
-
if (i < ny) {
|
|
2116
|
-
// process leftovers
|
|
2117
|
-
__m128 x0 = _mm_loadu_ps(x);
|
|
2118
|
-
|
|
2119
|
-
for (; i < ny; i++) {
|
|
2120
|
-
__m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
|
|
2121
|
-
y += 4;
|
|
2122
|
-
const float distance = horizontal_sum(accu);
|
|
2123
|
-
|
|
2124
|
-
if (current_min_distance > distance) {
|
|
2125
|
-
current_min_distance = distance;
|
|
2126
|
-
current_min_index = i;
|
|
2127
|
-
}
|
|
2128
|
-
}
|
|
2129
|
-
}
|
|
2130
|
-
|
|
2131
|
-
return current_min_index;
|
|
2132
|
-
}
|
|
2133
|
-
|
|
2134
|
-
size_t fvec_L2sqr_ny_nearest_D8(
|
|
2135
|
-
float* distances_tmp_buffer,
|
|
2136
|
-
const float* x,
|
|
2137
|
-
const float* y,
|
|
2138
|
-
size_t ny) {
|
|
2139
|
-
// this implementation does not use distances_tmp_buffer.
|
|
2140
|
-
|
|
2141
|
-
// current index being processed
|
|
2142
|
-
size_t i = 0;
|
|
2143
|
-
|
|
2144
|
-
// min distance and the index of the closest vector so far
|
|
2145
|
-
float current_min_distance = HUGE_VALF;
|
|
2146
|
-
size_t current_min_index = 0;
|
|
2147
|
-
|
|
2148
|
-
// process 8 D8-vectors per loop.
|
|
2149
|
-
const size_t ny8 = ny / 8;
|
|
2150
|
-
if (ny8 > 0) {
|
|
2151
|
-
// track min distance and the closest vector independently
|
|
2152
|
-
// for each of 8 AVX2 components.
|
|
2153
|
-
__m256 min_distances = _mm256_set1_ps(HUGE_VALF);
|
|
2154
|
-
__m256i min_indices = _mm256_set1_epi32(0);
|
|
2155
|
-
|
|
2156
|
-
__m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
|
|
2157
|
-
const __m256i indices_increment = _mm256_set1_epi32(8);
|
|
2158
|
-
|
|
2159
|
-
// 1 value per register
|
|
2160
|
-
const __m256 m0 = _mm256_set1_ps(x[0]);
|
|
2161
|
-
const __m256 m1 = _mm256_set1_ps(x[1]);
|
|
2162
|
-
const __m256 m2 = _mm256_set1_ps(x[2]);
|
|
2163
|
-
const __m256 m3 = _mm256_set1_ps(x[3]);
|
|
2164
|
-
|
|
2165
|
-
const __m256 m4 = _mm256_set1_ps(x[4]);
|
|
2166
|
-
const __m256 m5 = _mm256_set1_ps(x[5]);
|
|
2167
|
-
const __m256 m6 = _mm256_set1_ps(x[6]);
|
|
2168
|
-
const __m256 m7 = _mm256_set1_ps(x[7]);
|
|
2169
|
-
|
|
2170
|
-
for (; i < ny8 * 8; i += 8) {
|
|
2171
|
-
__m256 v0;
|
|
2172
|
-
__m256 v1;
|
|
2173
|
-
__m256 v2;
|
|
2174
|
-
__m256 v3;
|
|
2175
|
-
__m256 v4;
|
|
2176
|
-
__m256 v5;
|
|
2177
|
-
__m256 v6;
|
|
2178
|
-
__m256 v7;
|
|
2179
|
-
|
|
2180
|
-
transpose_8x8(
|
|
2181
|
-
_mm256_loadu_ps(y + 0 * 8),
|
|
2182
|
-
_mm256_loadu_ps(y + 1 * 8),
|
|
2183
|
-
_mm256_loadu_ps(y + 2 * 8),
|
|
2184
|
-
_mm256_loadu_ps(y + 3 * 8),
|
|
2185
|
-
_mm256_loadu_ps(y + 4 * 8),
|
|
2186
|
-
_mm256_loadu_ps(y + 5 * 8),
|
|
2187
|
-
_mm256_loadu_ps(y + 6 * 8),
|
|
2188
|
-
_mm256_loadu_ps(y + 7 * 8),
|
|
2189
|
-
v0,
|
|
2190
|
-
v1,
|
|
2191
|
-
v2,
|
|
2192
|
-
v3,
|
|
2193
|
-
v4,
|
|
2194
|
-
v5,
|
|
2195
|
-
v6,
|
|
2196
|
-
v7);
|
|
2197
|
-
|
|
2198
|
-
// compute differences
|
|
2199
|
-
const __m256 d0 = _mm256_sub_ps(m0, v0);
|
|
2200
|
-
const __m256 d1 = _mm256_sub_ps(m1, v1);
|
|
2201
|
-
const __m256 d2 = _mm256_sub_ps(m2, v2);
|
|
2202
|
-
const __m256 d3 = _mm256_sub_ps(m3, v3);
|
|
2203
|
-
const __m256 d4 = _mm256_sub_ps(m4, v4);
|
|
2204
|
-
const __m256 d5 = _mm256_sub_ps(m5, v5);
|
|
2205
|
-
const __m256 d6 = _mm256_sub_ps(m6, v6);
|
|
2206
|
-
const __m256 d7 = _mm256_sub_ps(m7, v7);
|
|
2207
|
-
|
|
2208
|
-
// compute squares of differences
|
|
2209
|
-
__m256 distances = _mm256_mul_ps(d0, d0);
|
|
2210
|
-
distances = _mm256_fmadd_ps(d1, d1, distances);
|
|
2211
|
-
distances = _mm256_fmadd_ps(d2, d2, distances);
|
|
2212
|
-
distances = _mm256_fmadd_ps(d3, d3, distances);
|
|
2213
|
-
distances = _mm256_fmadd_ps(d4, d4, distances);
|
|
2214
|
-
distances = _mm256_fmadd_ps(d5, d5, distances);
|
|
2215
|
-
distances = _mm256_fmadd_ps(d6, d6, distances);
|
|
2216
|
-
distances = _mm256_fmadd_ps(d7, d7, distances);
|
|
2217
|
-
|
|
2218
|
-
// compare the new distances to the min distances
|
|
2219
|
-
// for each of 8 AVX2 components.
|
|
2220
|
-
__m256 comparison =
|
|
2221
|
-
_mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
|
|
2222
|
-
|
|
2223
|
-
// update min distances and indices with closest vectors if needed.
|
|
2224
|
-
min_distances = _mm256_min_ps(distances, min_distances);
|
|
2225
|
-
min_indices = _mm256_castps_si256(_mm256_blendv_ps(
|
|
2226
|
-
_mm256_castsi256_ps(current_indices),
|
|
2227
|
-
_mm256_castsi256_ps(min_indices),
|
|
2228
|
-
comparison));
|
|
2229
|
-
|
|
2230
|
-
// update current indices values. Basically, +8 to each of the
|
|
2231
|
-
// 8 AVX2 components.
|
|
2232
|
-
current_indices =
|
|
2233
|
-
_mm256_add_epi32(current_indices, indices_increment);
|
|
2234
|
-
|
|
2235
|
-
// scroll y forward (8 vectors 8 DIM each).
|
|
2236
|
-
y += 64;
|
|
2237
|
-
}
|
|
2238
|
-
|
|
2239
|
-
// dump values and find the minimum distance / minimum index
|
|
2240
|
-
float min_distances_scalar[8];
|
|
2241
|
-
uint32_t min_indices_scalar[8];
|
|
2242
|
-
_mm256_storeu_ps(min_distances_scalar, min_distances);
|
|
2243
|
-
_mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
|
|
2244
|
-
|
|
2245
|
-
for (size_t j = 0; j < 8; j++) {
|
|
2246
|
-
if (current_min_distance > min_distances_scalar[j]) {
|
|
2247
|
-
current_min_distance = min_distances_scalar[j];
|
|
2248
|
-
current_min_index = min_indices_scalar[j];
|
|
2249
|
-
}
|
|
2250
|
-
}
|
|
2251
|
-
}
|
|
2252
|
-
|
|
2253
|
-
if (i < ny) {
|
|
2254
|
-
// process leftovers
|
|
2255
|
-
__m256 x0 = _mm256_loadu_ps(x);
|
|
2256
|
-
|
|
2257
|
-
for (; i < ny; i++) {
|
|
2258
|
-
__m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y));
|
|
2259
|
-
y += 8;
|
|
2260
|
-
const float distance = horizontal_sum(accu);
|
|
2261
|
-
|
|
2262
|
-
if (current_min_distance > distance) {
|
|
2263
|
-
current_min_distance = distance;
|
|
2264
|
-
current_min_index = i;
|
|
2265
|
-
}
|
|
2266
|
-
}
|
|
2267
|
-
}
|
|
2268
|
-
|
|
2269
|
-
return current_min_index;
|
|
2270
|
-
}
|
|
2271
|
-
|
|
2272
|
-
#else
|
|
2273
|
-
size_t fvec_L2sqr_ny_nearest_D2(
|
|
2274
|
-
float* distances_tmp_buffer,
|
|
2275
|
-
const float* x,
|
|
2276
|
-
const float* y,
|
|
2277
|
-
size_t ny) {
|
|
2278
|
-
return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, 2, ny);
|
|
2279
|
-
}
|
|
2280
|
-
|
|
2281
|
-
size_t fvec_L2sqr_ny_nearest_D4(
|
|
2282
|
-
float* distances_tmp_buffer,
|
|
2283
|
-
const float* x,
|
|
2284
|
-
const float* y,
|
|
2285
|
-
size_t ny) {
|
|
2286
|
-
return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, 4, ny);
|
|
2287
|
-
}
|
|
2288
|
-
|
|
2289
|
-
size_t fvec_L2sqr_ny_nearest_D8(
|
|
2290
|
-
float* distances_tmp_buffer,
|
|
2291
|
-
const float* x,
|
|
2292
|
-
const float* y,
|
|
2293
|
-
size_t ny) {
|
|
2294
|
-
return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, 8, ny);
|
|
2295
|
-
}
|
|
2296
|
-
#endif
|
|
2297
|
-
|
|
2298
|
-
size_t fvec_L2sqr_ny_nearest(
|
|
2299
|
-
float* distances_tmp_buffer,
|
|
2300
|
-
const float* x,
|
|
2301
|
-
const float* y,
|
|
2302
|
-
size_t d,
|
|
2303
|
-
size_t ny) {
|
|
2304
|
-
// optimized for a few special cases
|
|
2305
|
-
#define DISPATCH(dval) \
|
|
2306
|
-
case dval: \
|
|
2307
|
-
return fvec_L2sqr_ny_nearest_D##dval(distances_tmp_buffer, x, y, ny);
|
|
2308
|
-
|
|
2309
|
-
switch (d) {
|
|
2310
|
-
DISPATCH(2)
|
|
2311
|
-
DISPATCH(4)
|
|
2312
|
-
DISPATCH(8)
|
|
2313
|
-
default:
|
|
2314
|
-
return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
|
|
2315
|
-
}
|
|
2316
|
-
#undef DISPATCH
|
|
2317
|
-
}
|
|
2318
|
-
|
|
2319
|
-
#if defined(__AVX512F__)
|
|
2320
|
-
|
|
2321
|
-
template <size_t DIM>
|
|
2322
|
-
size_t fvec_L2sqr_ny_nearest_y_transposed_D(
|
|
2323
|
-
float* distances_tmp_buffer,
|
|
2324
|
-
const float* x,
|
|
2325
|
-
const float* y,
|
|
2326
|
-
const float* y_sqlen,
|
|
2327
|
-
const size_t d_offset,
|
|
2328
|
-
size_t ny) {
|
|
2329
|
-
// This implementation does not use distances_tmp_buffer.
|
|
2330
|
-
|
|
2331
|
-
// Current index being processed
|
|
2332
|
-
size_t i = 0;
|
|
2333
|
-
|
|
2334
|
-
// Min distance and the index of the closest vector so far
|
|
2335
|
-
float current_min_distance = HUGE_VALF;
|
|
2336
|
-
size_t current_min_index = 0;
|
|
2337
|
-
|
|
2338
|
-
// Process 16 vectors per loop
|
|
2339
|
-
const size_t ny16 = ny / 16;
|
|
2340
|
-
|
|
2341
|
-
if (ny16 > 0) {
|
|
2342
|
-
// Track min distance and the closest vector independently
|
|
2343
|
-
// for each of 16 AVX-512 components.
|
|
2344
|
-
__m512 min_distances = _mm512_set1_ps(HUGE_VALF);
|
|
2345
|
-
__m512i min_indices = _mm512_set1_epi32(0);
|
|
2346
|
-
|
|
2347
|
-
__m512i current_indices = _mm512_setr_epi32(
|
|
2348
|
-
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
|
|
2349
|
-
const __m512i indices_increment = _mm512_set1_epi32(16);
|
|
2350
|
-
|
|
2351
|
-
// m[i] = (2 * x[i], ... 2 * x[i])
|
|
2352
|
-
__m512 m[DIM];
|
|
2353
|
-
for (size_t j = 0; j < DIM; j++) {
|
|
2354
|
-
m[j] = _mm512_set1_ps(x[j]);
|
|
2355
|
-
m[j] = _mm512_add_ps(m[j], m[j]);
|
|
2356
|
-
}
|
|
2357
|
-
|
|
2358
|
-
for (; i < ny16 * 16; i += 16) {
|
|
2359
|
-
// Compute dot products
|
|
2360
|
-
const __m512 v0 = _mm512_loadu_ps(y + 0 * d_offset);
|
|
2361
|
-
__m512 dp = _mm512_mul_ps(m[0], v0);
|
|
2362
|
-
for (size_t j = 1; j < DIM; j++) {
|
|
2363
|
-
const __m512 vj = _mm512_loadu_ps(y + j * d_offset);
|
|
2364
|
-
dp = _mm512_fmadd_ps(m[j], vj, dp);
|
|
2365
|
-
}
|
|
2366
|
-
|
|
2367
|
-
// Compute y^2 - (2 * x, y), which is sufficient for looking for the
|
|
2368
|
-
// lowest distance.
|
|
2369
|
-
// x^2 is the constant that can be avoided.
|
|
2370
|
-
const __m512 distances =
|
|
2371
|
-
_mm512_sub_ps(_mm512_loadu_ps(y_sqlen), dp);
|
|
2372
|
-
|
|
2373
|
-
// Compare the new distances to the min distances
|
|
2374
|
-
__mmask16 comparison =
|
|
2375
|
-
_mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS);
|
|
2376
|
-
|
|
2377
|
-
// Update min distances and indices with closest vectors if needed
|
|
2378
|
-
min_distances =
|
|
2379
|
-
_mm512_mask_blend_ps(comparison, distances, min_distances);
|
|
2380
|
-
min_indices = _mm512_castps_si512(_mm512_mask_blend_ps(
|
|
2381
|
-
comparison,
|
|
2382
|
-
_mm512_castsi512_ps(current_indices),
|
|
2383
|
-
_mm512_castsi512_ps(min_indices)));
|
|
2384
|
-
|
|
2385
|
-
// Update current indices values. Basically, +16 to each of the 16
|
|
2386
|
-
// AVX-512 components.
|
|
2387
|
-
current_indices =
|
|
2388
|
-
_mm512_add_epi32(current_indices, indices_increment);
|
|
2389
|
-
|
|
2390
|
-
// Scroll y and y_sqlen forward.
|
|
2391
|
-
y += 16;
|
|
2392
|
-
y_sqlen += 16;
|
|
2393
|
-
}
|
|
2394
|
-
|
|
2395
|
-
// Dump values and find the minimum distance / minimum index
|
|
2396
|
-
float min_distances_scalar[16];
|
|
2397
|
-
uint32_t min_indices_scalar[16];
|
|
2398
|
-
_mm512_storeu_ps(min_distances_scalar, min_distances);
|
|
2399
|
-
_mm512_storeu_si512((__m512i*)(min_indices_scalar), min_indices);
|
|
2400
|
-
|
|
2401
|
-
for (size_t j = 0; j < 16; j++) {
|
|
2402
|
-
if (current_min_distance > min_distances_scalar[j]) {
|
|
2403
|
-
current_min_distance = min_distances_scalar[j];
|
|
2404
|
-
current_min_index = min_indices_scalar[j];
|
|
2405
|
-
}
|
|
2406
|
-
}
|
|
2407
|
-
}
|
|
2408
|
-
|
|
2409
|
-
if (i < ny) {
|
|
2410
|
-
// Process leftovers
|
|
2411
|
-
for (; i < ny; i++) {
|
|
2412
|
-
float dp = 0;
|
|
2413
|
-
for (size_t j = 0; j < DIM; j++) {
|
|
2414
|
-
dp += x[j] * y[j * d_offset];
|
|
2415
|
-
}
|
|
2416
|
-
|
|
2417
|
-
// Compute y^2 - 2 * (x, y), which is sufficient for looking for the
|
|
2418
|
-
// lowest distance.
|
|
2419
|
-
const float distance = y_sqlen[0] - 2 * dp;
|
|
2420
|
-
|
|
2421
|
-
if (current_min_distance > distance) {
|
|
2422
|
-
current_min_distance = distance;
|
|
2423
|
-
current_min_index = i;
|
|
2424
|
-
}
|
|
2425
|
-
|
|
2426
|
-
y += 1;
|
|
2427
|
-
y_sqlen += 1;
|
|
2428
|
-
}
|
|
2429
|
-
}
|
|
2430
|
-
|
|
2431
|
-
return current_min_index;
|
|
2432
|
-
}
|
|
2433
|
-
|
|
2434
|
-
#elif defined(__AVX2__)
|
|
2435
|
-
|
|
2436
|
-
template <size_t DIM>
|
|
2437
|
-
size_t fvec_L2sqr_ny_nearest_y_transposed_D(
|
|
2438
|
-
float* distances_tmp_buffer,
|
|
2439
|
-
const float* x,
|
|
2440
|
-
const float* y,
|
|
2441
|
-
const float* y_sqlen,
|
|
2442
|
-
const size_t d_offset,
|
|
2443
|
-
size_t ny) {
|
|
2444
|
-
// this implementation does not use distances_tmp_buffer.
|
|
2445
|
-
|
|
2446
|
-
// current index being processed
|
|
2447
|
-
size_t i = 0;
|
|
2448
|
-
|
|
2449
|
-
// min distance and the index of the closest vector so far
|
|
2450
|
-
float current_min_distance = HUGE_VALF;
|
|
2451
|
-
size_t current_min_index = 0;
|
|
2452
|
-
|
|
2453
|
-
// process 8 vectors per loop.
|
|
2454
|
-
const size_t ny8 = ny / 8;
|
|
2455
|
-
|
|
2456
|
-
if (ny8 > 0) {
|
|
2457
|
-
// track min distance and the closest vector independently
|
|
2458
|
-
// for each of 8 AVX2 components.
|
|
2459
|
-
__m256 min_distances = _mm256_set1_ps(HUGE_VALF);
|
|
2460
|
-
__m256i min_indices = _mm256_set1_epi32(0);
|
|
2461
|
-
|
|
2462
|
-
__m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
|
|
2463
|
-
const __m256i indices_increment = _mm256_set1_epi32(8);
|
|
2464
|
-
|
|
2465
|
-
// m[i] = (2 * x[i], ... 2 * x[i])
|
|
2466
|
-
__m256 m[DIM];
|
|
2467
|
-
for (size_t j = 0; j < DIM; j++) {
|
|
2468
|
-
m[j] = _mm256_set1_ps(x[j]);
|
|
2469
|
-
m[j] = _mm256_add_ps(m[j], m[j]);
|
|
2470
|
-
}
|
|
2471
|
-
|
|
2472
|
-
for (; i < ny8 * 8; i += 8) {
|
|
2473
|
-
// collect dim 0 for 8 D4-vectors.
|
|
2474
|
-
const __m256 v0 = _mm256_loadu_ps(y + 0 * d_offset);
|
|
2475
|
-
// compute dot products
|
|
2476
|
-
__m256 dp = _mm256_mul_ps(m[0], v0);
|
|
2477
|
-
|
|
2478
|
-
for (size_t j = 1; j < DIM; j++) {
|
|
2479
|
-
// collect dim j for 8 D4-vectors.
|
|
2480
|
-
const __m256 vj = _mm256_loadu_ps(y + j * d_offset);
|
|
2481
|
-
dp = _mm256_fmadd_ps(m[j], vj, dp);
|
|
2482
|
-
}
|
|
2483
|
-
|
|
2484
|
-
// compute y^2 - (2 * x, y), which is sufficient for looking for the
|
|
2485
|
-
// lowest distance.
|
|
2486
|
-
// x^2 is the constant that can be avoided.
|
|
2487
|
-
const __m256 distances =
|
|
2488
|
-
_mm256_sub_ps(_mm256_loadu_ps(y_sqlen), dp);
|
|
2489
|
-
|
|
2490
|
-
// compare the new distances to the min distances
|
|
2491
|
-
// for each of 8 AVX2 components.
|
|
2492
|
-
const __m256 comparison =
|
|
2493
|
-
_mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
|
|
2494
|
-
|
|
2495
|
-
// update min distances and indices with closest vectors if needed.
|
|
2496
|
-
min_distances =
|
|
2497
|
-
_mm256_blendv_ps(distances, min_distances, comparison);
|
|
2498
|
-
min_indices = _mm256_castps_si256(_mm256_blendv_ps(
|
|
2499
|
-
_mm256_castsi256_ps(current_indices),
|
|
2500
|
-
_mm256_castsi256_ps(min_indices),
|
|
2501
|
-
comparison));
|
|
2502
|
-
|
|
2503
|
-
// update current indices values. Basically, +8 to each of the
|
|
2504
|
-
// 8 AVX2 components.
|
|
2505
|
-
current_indices =
|
|
2506
|
-
_mm256_add_epi32(current_indices, indices_increment);
|
|
2507
|
-
|
|
2508
|
-
// scroll y and y_sqlen forward.
|
|
2509
|
-
y += 8;
|
|
2510
|
-
y_sqlen += 8;
|
|
2511
|
-
}
|
|
2512
|
-
|
|
2513
|
-
// dump values and find the minimum distance / minimum index
|
|
2514
|
-
float min_distances_scalar[8];
|
|
2515
|
-
uint32_t min_indices_scalar[8];
|
|
2516
|
-
_mm256_storeu_ps(min_distances_scalar, min_distances);
|
|
2517
|
-
_mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
|
|
2518
|
-
|
|
2519
|
-
for (size_t j = 0; j < 8; j++) {
|
|
2520
|
-
if (current_min_distance > min_distances_scalar[j]) {
|
|
2521
|
-
current_min_distance = min_distances_scalar[j];
|
|
2522
|
-
current_min_index = min_indices_scalar[j];
|
|
2523
|
-
}
|
|
2524
|
-
}
|
|
2525
|
-
}
|
|
2526
|
-
|
|
2527
|
-
if (i < ny) {
|
|
2528
|
-
// process leftovers
|
|
2529
|
-
for (; i < ny; i++) {
|
|
2530
|
-
float dp = 0;
|
|
2531
|
-
for (size_t j = 0; j < DIM; j++) {
|
|
2532
|
-
dp += x[j] * y[j * d_offset];
|
|
2533
|
-
}
|
|
2534
|
-
|
|
2535
|
-
// compute y^2 - 2 * (x, y), which is sufficient for looking for the
|
|
2536
|
-
// lowest distance.
|
|
2537
|
-
const float distance = y_sqlen[0] - 2 * dp;
|
|
2538
|
-
|
|
2539
|
-
if (current_min_distance > distance) {
|
|
2540
|
-
current_min_distance = distance;
|
|
2541
|
-
current_min_index = i;
|
|
2542
|
-
}
|
|
2543
|
-
|
|
2544
|
-
y += 1;
|
|
2545
|
-
y_sqlen += 1;
|
|
2546
|
-
}
|
|
2547
|
-
}
|
|
2548
|
-
|
|
2549
|
-
return current_min_index;
|
|
2550
|
-
}
|
|
2551
|
-
|
|
2552
|
-
#endif
|
|
2553
|
-
|
|
2554
|
-
size_t fvec_L2sqr_ny_nearest_y_transposed(
|
|
2555
|
-
float* distances_tmp_buffer,
|
|
2556
|
-
const float* x,
|
|
2557
|
-
const float* y,
|
|
2558
|
-
const float* y_sqlen,
|
|
2559
|
-
size_t d,
|
|
2560
|
-
size_t d_offset,
|
|
2561
|
-
size_t ny) {
|
|
2562
|
-
// optimized for a few special cases
|
|
2563
|
-
#ifdef __AVX2__
|
|
2564
|
-
#define DISPATCH(dval) \
|
|
2565
|
-
case dval: \
|
|
2566
|
-
return fvec_L2sqr_ny_nearest_y_transposed_D<dval>( \
|
|
2567
|
-
distances_tmp_buffer, x, y, y_sqlen, d_offset, ny);
|
|
2568
|
-
|
|
2569
|
-
switch (d) {
|
|
2570
|
-
DISPATCH(1)
|
|
2571
|
-
DISPATCH(2)
|
|
2572
|
-
DISPATCH(4)
|
|
2573
|
-
DISPATCH(8)
|
|
2574
|
-
default:
|
|
2575
|
-
return fvec_L2sqr_ny_nearest_y_transposed_ref(
|
|
2576
|
-
distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
|
|
2577
|
-
}
|
|
2578
|
-
#undef DISPATCH
|
|
2579
|
-
#else
|
|
2580
|
-
// non-AVX2 case
|
|
2581
|
-
return fvec_L2sqr_ny_nearest_y_transposed_ref(
|
|
2582
|
-
distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
|
|
2583
|
-
#endif
|
|
2584
|
-
}
|
|
2585
|
-
|
|
2586
|
-
#endif
|
|
2587
|
-
|
|
2588
|
-
#ifdef USE_AVX
|
|
2589
|
-
|
|
2590
|
-
float fvec_L1(const float* x, const float* y, size_t d) {
|
|
2591
|
-
__m256 msum1 = _mm256_setzero_ps();
|
|
2592
|
-
// signmask used for absolute value
|
|
2593
|
-
__m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL));
|
|
2594
|
-
|
|
2595
|
-
while (d >= 8) {
|
|
2596
|
-
__m256 mx = _mm256_loadu_ps(x);
|
|
2597
|
-
x += 8;
|
|
2598
|
-
__m256 my = _mm256_loadu_ps(y);
|
|
2599
|
-
y += 8;
|
|
2600
|
-
// subtract
|
|
2601
|
-
const __m256 a_m_b = _mm256_sub_ps(mx, my);
|
|
2602
|
-
// find sum of absolute value of distances (manhattan distance)
|
|
2603
|
-
msum1 = _mm256_add_ps(msum1, _mm256_and_ps(signmask, a_m_b));
|
|
2604
|
-
d -= 8;
|
|
2605
|
-
}
|
|
2606
|
-
|
|
2607
|
-
__m128 msum2 = _mm256_extractf128_ps(msum1, 1);
|
|
2608
|
-
msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0));
|
|
2609
|
-
__m128 signmask2 = _mm_castsi128_ps(_mm_set1_epi32(0x7fffffffUL));
|
|
2610
|
-
|
|
2611
|
-
if (d >= 4) {
|
|
2612
|
-
__m128 mx = _mm_loadu_ps(x);
|
|
2613
|
-
x += 4;
|
|
2614
|
-
__m128 my = _mm_loadu_ps(y);
|
|
2615
|
-
y += 4;
|
|
2616
|
-
const __m128 a_m_b = _mm_sub_ps(mx, my);
|
|
2617
|
-
msum2 = _mm_add_ps(msum2, _mm_and_ps(signmask2, a_m_b));
|
|
2618
|
-
d -= 4;
|
|
2619
|
-
}
|
|
2620
|
-
|
|
2621
|
-
if (d > 0) {
|
|
2622
|
-
__m128 mx = masked_read(d, x);
|
|
2623
|
-
__m128 my = masked_read(d, y);
|
|
2624
|
-
__m128 a_m_b = _mm_sub_ps(mx, my);
|
|
2625
|
-
msum2 = _mm_add_ps(msum2, _mm_and_ps(signmask2, a_m_b));
|
|
2626
|
-
}
|
|
2627
|
-
|
|
2628
|
-
msum2 = _mm_hadd_ps(msum2, msum2);
|
|
2629
|
-
msum2 = _mm_hadd_ps(msum2, msum2);
|
|
2630
|
-
return _mm_cvtss_f32(msum2);
|
|
2631
|
-
}
|
|
2632
|
-
|
|
2633
|
-
float fvec_Linf(const float* x, const float* y, size_t d) {
|
|
2634
|
-
__m256 msum1 = _mm256_setzero_ps();
|
|
2635
|
-
// signmask used for absolute value
|
|
2636
|
-
__m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL));
|
|
2637
|
-
|
|
2638
|
-
while (d >= 8) {
|
|
2639
|
-
__m256 mx = _mm256_loadu_ps(x);
|
|
2640
|
-
x += 8;
|
|
2641
|
-
__m256 my = _mm256_loadu_ps(y);
|
|
2642
|
-
y += 8;
|
|
2643
|
-
// subtract
|
|
2644
|
-
const __m256 a_m_b = _mm256_sub_ps(mx, my);
|
|
2645
|
-
// find max of absolute value of distances (chebyshev distance)
|
|
2646
|
-
msum1 = _mm256_max_ps(msum1, _mm256_and_ps(signmask, a_m_b));
|
|
2647
|
-
d -= 8;
|
|
2648
|
-
}
|
|
2649
|
-
|
|
2650
|
-
__m128 msum2 = _mm256_extractf128_ps(msum1, 1);
|
|
2651
|
-
msum2 = _mm_max_ps(msum2, _mm256_extractf128_ps(msum1, 0));
|
|
2652
|
-
__m128 signmask2 = _mm_castsi128_ps(_mm_set1_epi32(0x7fffffffUL));
|
|
2653
|
-
|
|
2654
|
-
if (d >= 4) {
|
|
2655
|
-
__m128 mx = _mm_loadu_ps(x);
|
|
2656
|
-
x += 4;
|
|
2657
|
-
__m128 my = _mm_loadu_ps(y);
|
|
2658
|
-
y += 4;
|
|
2659
|
-
const __m128 a_m_b = _mm_sub_ps(mx, my);
|
|
2660
|
-
msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b));
|
|
2661
|
-
d -= 4;
|
|
2662
|
-
}
|
|
2663
|
-
|
|
2664
|
-
if (d > 0) {
|
|
2665
|
-
__m128 mx = masked_read(d, x);
|
|
2666
|
-
__m128 my = masked_read(d, y);
|
|
2667
|
-
__m128 a_m_b = _mm_sub_ps(mx, my);
|
|
2668
|
-
msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b));
|
|
2669
|
-
}
|
|
2670
|
-
|
|
2671
|
-
msum2 = _mm_max_ps(_mm_movehl_ps(msum2, msum2), msum2);
|
|
2672
|
-
msum2 = _mm_max_ps(msum2, _mm_shuffle_ps(msum2, msum2, 1));
|
|
2673
|
-
return _mm_cvtss_f32(msum2);
|
|
2674
|
-
}
|
|
2675
|
-
|
|
2676
|
-
#elif defined(__SSE3__) // But not AVX
|
|
2677
|
-
|
|
2678
|
-
float fvec_L1(const float* x, const float* y, size_t d) {
|
|
2679
|
-
return fvec_L1_ref(x, y, d);
|
|
2680
|
-
}
|
|
2681
|
-
|
|
2682
|
-
float fvec_Linf(const float* x, const float* y, size_t d) {
|
|
2683
|
-
return fvec_Linf_ref(x, y, d);
|
|
2684
|
-
}
|
|
2685
|
-
|
|
2686
|
-
#elif defined(__ARM_FEATURE_SVE)
|
|
2687
|
-
|
|
2688
|
-
struct ElementOpIP {
|
|
2689
|
-
static svfloat32_t op(svbool_t pg, svfloat32_t x, svfloat32_t y) {
|
|
2690
|
-
return svmul_f32_x(pg, x, y);
|
|
2691
|
-
}
|
|
2692
|
-
static svfloat32_t merge(
|
|
2693
|
-
svbool_t pg,
|
|
2694
|
-
svfloat32_t z,
|
|
2695
|
-
svfloat32_t x,
|
|
2696
|
-
svfloat32_t y) {
|
|
2697
|
-
return svmla_f32_x(pg, z, x, y);
|
|
2698
|
-
}
|
|
2699
|
-
};
|
|
2700
|
-
|
|
2701
|
-
template <typename ElementOp>
|
|
2702
|
-
void fvec_op_ny_sve_d1(float* dis, const float* x, const float* y, size_t ny) {
|
|
2703
|
-
const size_t lanes = svcntw();
|
|
2704
|
-
const size_t lanes2 = lanes * 2;
|
|
2705
|
-
const size_t lanes3 = lanes * 3;
|
|
2706
|
-
const size_t lanes4 = lanes * 4;
|
|
2707
|
-
const svbool_t pg = svptrue_b32();
|
|
2708
|
-
const svfloat32_t x0 = svdup_n_f32(x[0]);
|
|
2709
|
-
size_t i = 0;
|
|
2710
|
-
for (; i + lanes4 < ny; i += lanes4) {
|
|
2711
|
-
svfloat32_t y0 = svld1_f32(pg, y);
|
|
2712
|
-
svfloat32_t y1 = svld1_f32(pg, y + lanes);
|
|
2713
|
-
svfloat32_t y2 = svld1_f32(pg, y + lanes2);
|
|
2714
|
-
svfloat32_t y3 = svld1_f32(pg, y + lanes3);
|
|
2715
|
-
y0 = ElementOp::op(pg, x0, y0);
|
|
2716
|
-
y1 = ElementOp::op(pg, x0, y1);
|
|
2717
|
-
y2 = ElementOp::op(pg, x0, y2);
|
|
2718
|
-
y3 = ElementOp::op(pg, x0, y3);
|
|
2719
|
-
svst1_f32(pg, dis, y0);
|
|
2720
|
-
svst1_f32(pg, dis + lanes, y1);
|
|
2721
|
-
svst1_f32(pg, dis + lanes2, y2);
|
|
2722
|
-
svst1_f32(pg, dis + lanes3, y3);
|
|
2723
|
-
y += lanes4;
|
|
2724
|
-
dis += lanes4;
|
|
2725
|
-
}
|
|
2726
|
-
const svbool_t pg0 = svwhilelt_b32_u64(i, ny);
|
|
2727
|
-
const svbool_t pg1 = svwhilelt_b32_u64(i + lanes, ny);
|
|
2728
|
-
const svbool_t pg2 = svwhilelt_b32_u64(i + lanes2, ny);
|
|
2729
|
-
const svbool_t pg3 = svwhilelt_b32_u64(i + lanes3, ny);
|
|
2730
|
-
svfloat32_t y0 = svld1_f32(pg0, y);
|
|
2731
|
-
svfloat32_t y1 = svld1_f32(pg1, y + lanes);
|
|
2732
|
-
svfloat32_t y2 = svld1_f32(pg2, y + lanes2);
|
|
2733
|
-
svfloat32_t y3 = svld1_f32(pg3, y + lanes3);
|
|
2734
|
-
y0 = ElementOp::op(pg0, x0, y0);
|
|
2735
|
-
y1 = ElementOp::op(pg1, x0, y1);
|
|
2736
|
-
y2 = ElementOp::op(pg2, x0, y2);
|
|
2737
|
-
y3 = ElementOp::op(pg3, x0, y3);
|
|
2738
|
-
svst1_f32(pg0, dis, y0);
|
|
2739
|
-
svst1_f32(pg1, dis + lanes, y1);
|
|
2740
|
-
svst1_f32(pg2, dis + lanes2, y2);
|
|
2741
|
-
svst1_f32(pg3, dis + lanes3, y3);
|
|
2742
|
-
}
|
|
2743
|
-
|
|
2744
|
-
template <typename ElementOp>
|
|
2745
|
-
void fvec_op_ny_sve_d2(float* dis, const float* x, const float* y, size_t ny) {
|
|
2746
|
-
const size_t lanes = svcntw();
|
|
2747
|
-
const size_t lanes2 = lanes * 2;
|
|
2748
|
-
const size_t lanes4 = lanes * 4;
|
|
2749
|
-
const svbool_t pg = svptrue_b32();
|
|
2750
|
-
const svfloat32_t x0 = svdup_n_f32(x[0]);
|
|
2751
|
-
const svfloat32_t x1 = svdup_n_f32(x[1]);
|
|
2752
|
-
size_t i = 0;
|
|
2753
|
-
for (; i + lanes2 < ny; i += lanes2) {
|
|
2754
|
-
const svfloat32x2_t y0 = svld2_f32(pg, y);
|
|
2755
|
-
const svfloat32x2_t y1 = svld2_f32(pg, y + lanes2);
|
|
2756
|
-
svfloat32_t y00 = svget2_f32(y0, 0);
|
|
2757
|
-
const svfloat32_t y01 = svget2_f32(y0, 1);
|
|
2758
|
-
svfloat32_t y10 = svget2_f32(y1, 0);
|
|
2759
|
-
const svfloat32_t y11 = svget2_f32(y1, 1);
|
|
2760
|
-
y00 = ElementOp::op(pg, x0, y00);
|
|
2761
|
-
y10 = ElementOp::op(pg, x0, y10);
|
|
2762
|
-
y00 = ElementOp::merge(pg, y00, x1, y01);
|
|
2763
|
-
y10 = ElementOp::merge(pg, y10, x1, y11);
|
|
2764
|
-
svst1_f32(pg, dis, y00);
|
|
2765
|
-
svst1_f32(pg, dis + lanes, y10);
|
|
2766
|
-
y += lanes4;
|
|
2767
|
-
dis += lanes2;
|
|
2768
|
-
}
|
|
2769
|
-
const svbool_t pg0 = svwhilelt_b32_u64(i, ny);
|
|
2770
|
-
const svbool_t pg1 = svwhilelt_b32_u64(i + lanes, ny);
|
|
2771
|
-
const svfloat32x2_t y0 = svld2_f32(pg0, y);
|
|
2772
|
-
const svfloat32x2_t y1 = svld2_f32(pg1, y + lanes2);
|
|
2773
|
-
svfloat32_t y00 = svget2_f32(y0, 0);
|
|
2774
|
-
const svfloat32_t y01 = svget2_f32(y0, 1);
|
|
2775
|
-
svfloat32_t y10 = svget2_f32(y1, 0);
|
|
2776
|
-
const svfloat32_t y11 = svget2_f32(y1, 1);
|
|
2777
|
-
y00 = ElementOp::op(pg0, x0, y00);
|
|
2778
|
-
y10 = ElementOp::op(pg1, x0, y10);
|
|
2779
|
-
y00 = ElementOp::merge(pg0, y00, x1, y01);
|
|
2780
|
-
y10 = ElementOp::merge(pg1, y10, x1, y11);
|
|
2781
|
-
svst1_f32(pg0, dis, y00);
|
|
2782
|
-
svst1_f32(pg1, dis + lanes, y10);
|
|
2783
|
-
}
|
|
2784
|
-
|
|
2785
|
-
template <typename ElementOp>
|
|
2786
|
-
void fvec_op_ny_sve_d4(float* dis, const float* x, const float* y, size_t ny) {
|
|
2787
|
-
const size_t lanes = svcntw();
|
|
2788
|
-
const size_t lanes4 = lanes * 4;
|
|
2789
|
-
const svbool_t pg = svptrue_b32();
|
|
2790
|
-
const svfloat32_t x0 = svdup_n_f32(x[0]);
|
|
2791
|
-
const svfloat32_t x1 = svdup_n_f32(x[1]);
|
|
2792
|
-
const svfloat32_t x2 = svdup_n_f32(x[2]);
|
|
2793
|
-
const svfloat32_t x3 = svdup_n_f32(x[3]);
|
|
2794
|
-
size_t i = 0;
|
|
2795
|
-
for (; i + lanes < ny; i += lanes) {
|
|
2796
|
-
const svfloat32x4_t y0 = svld4_f32(pg, y);
|
|
2797
|
-
svfloat32_t y00 = svget4_f32(y0, 0);
|
|
2798
|
-
const svfloat32_t y01 = svget4_f32(y0, 1);
|
|
2799
|
-
svfloat32_t y02 = svget4_f32(y0, 2);
|
|
2800
|
-
const svfloat32_t y03 = svget4_f32(y0, 3);
|
|
2801
|
-
y00 = ElementOp::op(pg, x0, y00);
|
|
2802
|
-
y02 = ElementOp::op(pg, x2, y02);
|
|
2803
|
-
y00 = ElementOp::merge(pg, y00, x1, y01);
|
|
2804
|
-
y02 = ElementOp::merge(pg, y02, x3, y03);
|
|
2805
|
-
y00 = svadd_f32_x(pg, y00, y02);
|
|
2806
|
-
svst1_f32(pg, dis, y00);
|
|
2807
|
-
y += lanes4;
|
|
2808
|
-
dis += lanes;
|
|
2809
|
-
}
|
|
2810
|
-
const svbool_t pg0 = svwhilelt_b32_u64(i, ny);
|
|
2811
|
-
const svfloat32x4_t y0 = svld4_f32(pg0, y);
|
|
2812
|
-
svfloat32_t y00 = svget4_f32(y0, 0);
|
|
2813
|
-
const svfloat32_t y01 = svget4_f32(y0, 1);
|
|
2814
|
-
svfloat32_t y02 = svget4_f32(y0, 2);
|
|
2815
|
-
const svfloat32_t y03 = svget4_f32(y0, 3);
|
|
2816
|
-
y00 = ElementOp::op(pg0, x0, y00);
|
|
2817
|
-
y02 = ElementOp::op(pg0, x2, y02);
|
|
2818
|
-
y00 = ElementOp::merge(pg0, y00, x1, y01);
|
|
2819
|
-
y02 = ElementOp::merge(pg0, y02, x3, y03);
|
|
2820
|
-
y00 = svadd_f32_x(pg0, y00, y02);
|
|
2821
|
-
svst1_f32(pg0, dis, y00);
|
|
2822
|
-
}
|
|
2823
|
-
|
|
2824
|
-
template <typename ElementOp>
|
|
2825
|
-
void fvec_op_ny_sve_d8(float* dis, const float* x, const float* y, size_t ny) {
|
|
2826
|
-
const size_t lanes = svcntw();
|
|
2827
|
-
const size_t lanes4 = lanes * 4;
|
|
2828
|
-
const size_t lanes8 = lanes * 8;
|
|
2829
|
-
const svbool_t pg = svptrue_b32();
|
|
2830
|
-
const svfloat32_t x0 = svdup_n_f32(x[0]);
|
|
2831
|
-
const svfloat32_t x1 = svdup_n_f32(x[1]);
|
|
2832
|
-
const svfloat32_t x2 = svdup_n_f32(x[2]);
|
|
2833
|
-
const svfloat32_t x3 = svdup_n_f32(x[3]);
|
|
2834
|
-
const svfloat32_t x4 = svdup_n_f32(x[4]);
|
|
2835
|
-
const svfloat32_t x5 = svdup_n_f32(x[5]);
|
|
2836
|
-
const svfloat32_t x6 = svdup_n_f32(x[6]);
|
|
2837
|
-
const svfloat32_t x7 = svdup_n_f32(x[7]);
|
|
2838
|
-
size_t i = 0;
|
|
2839
|
-
for (; i + lanes < ny; i += lanes) {
|
|
2840
|
-
const svfloat32x4_t ya = svld4_f32(pg, y);
|
|
2841
|
-
const svfloat32x4_t yb = svld4_f32(pg, y + lanes4);
|
|
2842
|
-
const svfloat32_t ya0 = svget4_f32(ya, 0);
|
|
2843
|
-
const svfloat32_t ya1 = svget4_f32(ya, 1);
|
|
2844
|
-
const svfloat32_t ya2 = svget4_f32(ya, 2);
|
|
2845
|
-
const svfloat32_t ya3 = svget4_f32(ya, 3);
|
|
2846
|
-
const svfloat32_t yb0 = svget4_f32(yb, 0);
|
|
2847
|
-
const svfloat32_t yb1 = svget4_f32(yb, 1);
|
|
2848
|
-
const svfloat32_t yb2 = svget4_f32(yb, 2);
|
|
2849
|
-
const svfloat32_t yb3 = svget4_f32(yb, 3);
|
|
2850
|
-
svfloat32_t y0 = svuzp1(ya0, yb0);
|
|
2851
|
-
const svfloat32_t y1 = svuzp1(ya1, yb1);
|
|
2852
|
-
svfloat32_t y2 = svuzp1(ya2, yb2);
|
|
2853
|
-
const svfloat32_t y3 = svuzp1(ya3, yb3);
|
|
2854
|
-
svfloat32_t y4 = svuzp2(ya0, yb0);
|
|
2855
|
-
const svfloat32_t y5 = svuzp2(ya1, yb1);
|
|
2856
|
-
svfloat32_t y6 = svuzp2(ya2, yb2);
|
|
2857
|
-
const svfloat32_t y7 = svuzp2(ya3, yb3);
|
|
2858
|
-
y0 = ElementOp::op(pg, x0, y0);
|
|
2859
|
-
y2 = ElementOp::op(pg, x2, y2);
|
|
2860
|
-
y4 = ElementOp::op(pg, x4, y4);
|
|
2861
|
-
y6 = ElementOp::op(pg, x6, y6);
|
|
2862
|
-
y0 = ElementOp::merge(pg, y0, x1, y1);
|
|
2863
|
-
y2 = ElementOp::merge(pg, y2, x3, y3);
|
|
2864
|
-
y4 = ElementOp::merge(pg, y4, x5, y5);
|
|
2865
|
-
y6 = ElementOp::merge(pg, y6, x7, y7);
|
|
2866
|
-
y0 = svadd_f32_x(pg, y0, y2);
|
|
2867
|
-
y4 = svadd_f32_x(pg, y4, y6);
|
|
2868
|
-
y0 = svadd_f32_x(pg, y0, y4);
|
|
2869
|
-
svst1_f32(pg, dis, y0);
|
|
2870
|
-
y += lanes8;
|
|
2871
|
-
dis += lanes;
|
|
2872
|
-
}
|
|
2873
|
-
const svbool_t pg0 = svwhilelt_b32_u64(i, ny);
|
|
2874
|
-
const svbool_t pga = svwhilelt_b32_u64(i * 2, ny * 2);
|
|
2875
|
-
const svbool_t pgb = svwhilelt_b32_u64(i * 2 + lanes, ny * 2);
|
|
2876
|
-
const svfloat32x4_t ya = svld4_f32(pga, y);
|
|
2877
|
-
const svfloat32x4_t yb = svld4_f32(pgb, y + lanes4);
|
|
2878
|
-
const svfloat32_t ya0 = svget4_f32(ya, 0);
|
|
2879
|
-
const svfloat32_t ya1 = svget4_f32(ya, 1);
|
|
2880
|
-
const svfloat32_t ya2 = svget4_f32(ya, 2);
|
|
2881
|
-
const svfloat32_t ya3 = svget4_f32(ya, 3);
|
|
2882
|
-
const svfloat32_t yb0 = svget4_f32(yb, 0);
|
|
2883
|
-
const svfloat32_t yb1 = svget4_f32(yb, 1);
|
|
2884
|
-
const svfloat32_t yb2 = svget4_f32(yb, 2);
|
|
2885
|
-
const svfloat32_t yb3 = svget4_f32(yb, 3);
|
|
2886
|
-
svfloat32_t y0 = svuzp1(ya0, yb0);
|
|
2887
|
-
const svfloat32_t y1 = svuzp1(ya1, yb1);
|
|
2888
|
-
svfloat32_t y2 = svuzp1(ya2, yb2);
|
|
2889
|
-
const svfloat32_t y3 = svuzp1(ya3, yb3);
|
|
2890
|
-
svfloat32_t y4 = svuzp2(ya0, yb0);
|
|
2891
|
-
const svfloat32_t y5 = svuzp2(ya1, yb1);
|
|
2892
|
-
svfloat32_t y6 = svuzp2(ya2, yb2);
|
|
2893
|
-
const svfloat32_t y7 = svuzp2(ya3, yb3);
|
|
2894
|
-
y0 = ElementOp::op(pg0, x0, y0);
|
|
2895
|
-
y2 = ElementOp::op(pg0, x2, y2);
|
|
2896
|
-
y4 = ElementOp::op(pg0, x4, y4);
|
|
2897
|
-
y6 = ElementOp::op(pg0, x6, y6);
|
|
2898
|
-
y0 = ElementOp::merge(pg0, y0, x1, y1);
|
|
2899
|
-
y2 = ElementOp::merge(pg0, y2, x3, y3);
|
|
2900
|
-
y4 = ElementOp::merge(pg0, y4, x5, y5);
|
|
2901
|
-
y6 = ElementOp::merge(pg0, y6, x7, y7);
|
|
2902
|
-
y0 = svadd_f32_x(pg0, y0, y2);
|
|
2903
|
-
y4 = svadd_f32_x(pg0, y4, y6);
|
|
2904
|
-
y0 = svadd_f32_x(pg0, y0, y4);
|
|
2905
|
-
svst1_f32(pg0, dis, y0);
|
|
2906
|
-
y += lanes8;
|
|
2907
|
-
dis += lanes;
|
|
2908
|
-
}
|
|
2909
|
-
|
|
2910
|
-
template <typename ElementOp>
|
|
2911
|
-
void fvec_op_ny_sve_lanes1(
|
|
2912
|
-
float* dis,
|
|
2913
|
-
const float* x,
|
|
2914
|
-
const float* y,
|
|
2915
|
-
size_t ny) {
|
|
2916
|
-
const size_t lanes = svcntw();
|
|
2917
|
-
const size_t lanes2 = lanes * 2;
|
|
2918
|
-
const size_t lanes3 = lanes * 3;
|
|
2919
|
-
const size_t lanes4 = lanes * 4;
|
|
2920
|
-
const svbool_t pg = svptrue_b32();
|
|
2921
|
-
const svfloat32_t x0 = svld1_f32(pg, x);
|
|
2922
|
-
size_t i = 0;
|
|
2923
|
-
for (; i + 3 < ny; i += 4) {
|
|
2924
|
-
svfloat32_t y0 = svld1_f32(pg, y);
|
|
2925
|
-
svfloat32_t y1 = svld1_f32(pg, y + lanes);
|
|
2926
|
-
svfloat32_t y2 = svld1_f32(pg, y + lanes2);
|
|
2927
|
-
svfloat32_t y3 = svld1_f32(pg, y + lanes3);
|
|
2928
|
-
y += lanes4;
|
|
2929
|
-
y0 = ElementOp::op(pg, x0, y0);
|
|
2930
|
-
y1 = ElementOp::op(pg, x0, y1);
|
|
2931
|
-
y2 = ElementOp::op(pg, x0, y2);
|
|
2932
|
-
y3 = ElementOp::op(pg, x0, y3);
|
|
2933
|
-
dis[i] = svaddv_f32(pg, y0);
|
|
2934
|
-
dis[i + 1] = svaddv_f32(pg, y1);
|
|
2935
|
-
dis[i + 2] = svaddv_f32(pg, y2);
|
|
2936
|
-
dis[i + 3] = svaddv_f32(pg, y3);
|
|
2937
|
-
}
|
|
2938
|
-
for (; i < ny; ++i) {
|
|
2939
|
-
svfloat32_t y0 = svld1_f32(pg, y);
|
|
2940
|
-
y += lanes;
|
|
2941
|
-
y0 = ElementOp::op(pg, x0, y0);
|
|
2942
|
-
dis[i] = svaddv_f32(pg, y0);
|
|
2943
|
-
}
|
|
2944
|
-
}
|
|
2945
|
-
|
|
2946
|
-
template <typename ElementOp>
|
|
2947
|
-
void fvec_op_ny_sve_lanes2(
|
|
2948
|
-
float* dis,
|
|
2949
|
-
const float* x,
|
|
2950
|
-
const float* y,
|
|
2951
|
-
size_t ny) {
|
|
2952
|
-
const size_t lanes = svcntw();
|
|
2953
|
-
const size_t lanes2 = lanes * 2;
|
|
2954
|
-
const size_t lanes3 = lanes * 3;
|
|
2955
|
-
const size_t lanes4 = lanes * 4;
|
|
2956
|
-
const svbool_t pg = svptrue_b32();
|
|
2957
|
-
const svfloat32_t x0 = svld1_f32(pg, x);
|
|
2958
|
-
const svfloat32_t x1 = svld1_f32(pg, x + lanes);
|
|
2959
|
-
size_t i = 0;
|
|
2960
|
-
for (; i + 1 < ny; i += 2) {
|
|
2961
|
-
svfloat32_t y00 = svld1_f32(pg, y);
|
|
2962
|
-
const svfloat32_t y01 = svld1_f32(pg, y + lanes);
|
|
2963
|
-
svfloat32_t y10 = svld1_f32(pg, y + lanes2);
|
|
2964
|
-
const svfloat32_t y11 = svld1_f32(pg, y + lanes3);
|
|
2965
|
-
y += lanes4;
|
|
2966
|
-
y00 = ElementOp::op(pg, x0, y00);
|
|
2967
|
-
y10 = ElementOp::op(pg, x0, y10);
|
|
2968
|
-
y00 = ElementOp::merge(pg, y00, x1, y01);
|
|
2969
|
-
y10 = ElementOp::merge(pg, y10, x1, y11);
|
|
2970
|
-
dis[i] = svaddv_f32(pg, y00);
|
|
2971
|
-
dis[i + 1] = svaddv_f32(pg, y10);
|
|
2972
|
-
}
|
|
2973
|
-
if (i < ny) {
|
|
2974
|
-
svfloat32_t y0 = svld1_f32(pg, y);
|
|
2975
|
-
const svfloat32_t y1 = svld1_f32(pg, y + lanes);
|
|
2976
|
-
y0 = ElementOp::op(pg, x0, y0);
|
|
2977
|
-
y0 = ElementOp::merge(pg, y0, x1, y1);
|
|
2978
|
-
dis[i] = svaddv_f32(pg, y0);
|
|
2979
|
-
}
|
|
2980
|
-
}
|
|
2981
|
-
|
|
2982
|
-
template <typename ElementOp>
|
|
2983
|
-
void fvec_op_ny_sve_lanes3(
|
|
2984
|
-
float* dis,
|
|
2985
|
-
const float* x,
|
|
2986
|
-
const float* y,
|
|
2987
|
-
size_t ny) {
|
|
2988
|
-
const size_t lanes = svcntw();
|
|
2989
|
-
const size_t lanes2 = lanes * 2;
|
|
2990
|
-
const size_t lanes3 = lanes * 3;
|
|
2991
|
-
const svbool_t pg = svptrue_b32();
|
|
2992
|
-
const svfloat32_t x0 = svld1_f32(pg, x);
|
|
2993
|
-
const svfloat32_t x1 = svld1_f32(pg, x + lanes);
|
|
2994
|
-
const svfloat32_t x2 = svld1_f32(pg, x + lanes2);
|
|
2995
|
-
for (size_t i = 0; i < ny; ++i) {
|
|
2996
|
-
svfloat32_t y0 = svld1_f32(pg, y);
|
|
2997
|
-
const svfloat32_t y1 = svld1_f32(pg, y + lanes);
|
|
2998
|
-
svfloat32_t y2 = svld1_f32(pg, y + lanes2);
|
|
2999
|
-
y += lanes3;
|
|
3000
|
-
y0 = ElementOp::op(pg, x0, y0);
|
|
3001
|
-
y0 = ElementOp::merge(pg, y0, x1, y1);
|
|
3002
|
-
y0 = ElementOp::merge(pg, y0, x2, y2);
|
|
3003
|
-
dis[i] = svaddv_f32(pg, y0);
|
|
3004
|
-
}
|
|
3005
|
-
}
|
|
3006
|
-
|
|
3007
|
-
template <typename ElementOp>
|
|
3008
|
-
void fvec_op_ny_sve_lanes4(
|
|
3009
|
-
float* dis,
|
|
3010
|
-
const float* x,
|
|
3011
|
-
const float* y,
|
|
3012
|
-
size_t ny) {
|
|
3013
|
-
const size_t lanes = svcntw();
|
|
3014
|
-
const size_t lanes2 = lanes * 2;
|
|
3015
|
-
const size_t lanes3 = lanes * 3;
|
|
3016
|
-
const size_t lanes4 = lanes * 4;
|
|
3017
|
-
const svbool_t pg = svptrue_b32();
|
|
3018
|
-
const svfloat32_t x0 = svld1_f32(pg, x);
|
|
3019
|
-
const svfloat32_t x1 = svld1_f32(pg, x + lanes);
|
|
3020
|
-
const svfloat32_t x2 = svld1_f32(pg, x + lanes2);
|
|
3021
|
-
const svfloat32_t x3 = svld1_f32(pg, x + lanes3);
|
|
3022
|
-
for (size_t i = 0; i < ny; ++i) {
|
|
3023
|
-
svfloat32_t y0 = svld1_f32(pg, y);
|
|
3024
|
-
const svfloat32_t y1 = svld1_f32(pg, y + lanes);
|
|
3025
|
-
svfloat32_t y2 = svld1_f32(pg, y + lanes2);
|
|
3026
|
-
const svfloat32_t y3 = svld1_f32(pg, y + lanes3);
|
|
3027
|
-
y += lanes4;
|
|
3028
|
-
y0 = ElementOp::op(pg, x0, y0);
|
|
3029
|
-
y2 = ElementOp::op(pg, x2, y2);
|
|
3030
|
-
y0 = ElementOp::merge(pg, y0, x1, y1);
|
|
3031
|
-
y2 = ElementOp::merge(pg, y2, x3, y3);
|
|
3032
|
-
y0 = svadd_f32_x(pg, y0, y2);
|
|
3033
|
-
dis[i] = svaddv_f32(pg, y0);
|
|
3034
|
-
}
|
|
3035
|
-
}
|
|
3036
|
-
|
|
3037
|
-
void fvec_L2sqr_ny(
|
|
3038
|
-
float* dis,
|
|
3039
|
-
const float* x,
|
|
3040
|
-
const float* y,
|
|
3041
|
-
size_t d,
|
|
3042
|
-
size_t ny) {
|
|
3043
|
-
fvec_L2sqr_ny_ref(dis, x, y, d, ny);
|
|
3044
|
-
}
|
|
3045
|
-
|
|
3046
|
-
void fvec_L2sqr_ny_transposed(
|
|
3047
|
-
float* dis,
|
|
3048
|
-
const float* x,
|
|
3049
|
-
const float* y,
|
|
3050
|
-
const float* y_sqlen,
|
|
3051
|
-
size_t d,
|
|
3052
|
-
size_t d_offset,
|
|
3053
|
-
size_t ny) {
|
|
3054
|
-
return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny);
|
|
3055
|
-
}
|
|
3056
|
-
|
|
3057
|
-
size_t fvec_L2sqr_ny_nearest(
|
|
3058
|
-
float* distances_tmp_buffer,
|
|
3059
|
-
const float* x,
|
|
3060
|
-
const float* y,
|
|
3061
|
-
size_t d,
|
|
3062
|
-
size_t ny) {
|
|
3063
|
-
return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
|
|
3064
|
-
}
|
|
3065
|
-
|
|
3066
|
-
size_t fvec_L2sqr_ny_nearest_y_transposed(
|
|
3067
|
-
float* distances_tmp_buffer,
|
|
3068
|
-
const float* x,
|
|
3069
|
-
const float* y,
|
|
3070
|
-
const float* y_sqlen,
|
|
3071
|
-
size_t d,
|
|
3072
|
-
size_t d_offset,
|
|
3073
|
-
size_t ny) {
|
|
3074
|
-
return fvec_L2sqr_ny_nearest_y_transposed_ref(
|
|
3075
|
-
distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
|
|
3076
|
-
}
|
|
3077
|
-
|
|
3078
|
-
float fvec_L1(const float* x, const float* y, size_t d) {
|
|
3079
|
-
return fvec_L1_ref(x, y, d);
|
|
3080
|
-
}
|
|
3081
|
-
|
|
3082
|
-
float fvec_Linf(const float* x, const float* y, size_t d) {
|
|
3083
|
-
return fvec_Linf_ref(x, y, d);
|
|
3084
|
-
}
|
|
3085
|
-
|
|
3086
|
-
void fvec_inner_products_ny(
|
|
3087
|
-
float* dis,
|
|
3088
|
-
const float* x,
|
|
3089
|
-
const float* y,
|
|
3090
|
-
size_t d,
|
|
3091
|
-
size_t ny) {
|
|
3092
|
-
const size_t lanes = svcntw();
|
|
3093
|
-
switch (d) {
|
|
3094
|
-
case 1:
|
|
3095
|
-
fvec_op_ny_sve_d1<ElementOpIP>(dis, x, y, ny);
|
|
3096
|
-
break;
|
|
3097
|
-
case 2:
|
|
3098
|
-
fvec_op_ny_sve_d2<ElementOpIP>(dis, x, y, ny);
|
|
3099
|
-
break;
|
|
3100
|
-
case 4:
|
|
3101
|
-
fvec_op_ny_sve_d4<ElementOpIP>(dis, x, y, ny);
|
|
3102
|
-
break;
|
|
3103
|
-
case 8:
|
|
3104
|
-
fvec_op_ny_sve_d8<ElementOpIP>(dis, x, y, ny);
|
|
3105
|
-
break;
|
|
3106
|
-
default:
|
|
3107
|
-
if (d == lanes)
|
|
3108
|
-
fvec_op_ny_sve_lanes1<ElementOpIP>(dis, x, y, ny);
|
|
3109
|
-
else if (d == lanes * 2)
|
|
3110
|
-
fvec_op_ny_sve_lanes2<ElementOpIP>(dis, x, y, ny);
|
|
3111
|
-
else if (d == lanes * 3)
|
|
3112
|
-
fvec_op_ny_sve_lanes3<ElementOpIP>(dis, x, y, ny);
|
|
3113
|
-
else if (d == lanes * 4)
|
|
3114
|
-
fvec_op_ny_sve_lanes4<ElementOpIP>(dis, x, y, ny);
|
|
3115
|
-
else
|
|
3116
|
-
fvec_inner_products_ny_ref(dis, x, y, d, ny);
|
|
3117
|
-
break;
|
|
3118
|
-
}
|
|
3119
|
-
}
|
|
3120
|
-
|
|
3121
|
-
#elif defined(__aarch64__)
|
|
3122
|
-
|
|
3123
|
-
// not optimized for ARM
|
|
3124
|
-
void fvec_L2sqr_ny(
|
|
3125
|
-
float* dis,
|
|
3126
|
-
const float* x,
|
|
3127
|
-
const float* y,
|
|
3128
|
-
size_t d,
|
|
3129
|
-
size_t ny) {
|
|
3130
|
-
fvec_L2sqr_ny_ref(dis, x, y, d, ny);
|
|
3131
|
-
}
|
|
3132
|
-
|
|
3133
|
-
void fvec_L2sqr_ny_transposed(
|
|
3134
|
-
float* dis,
|
|
3135
|
-
const float* x,
|
|
3136
|
-
const float* y,
|
|
3137
|
-
const float* y_sqlen,
|
|
3138
|
-
size_t d,
|
|
3139
|
-
size_t d_offset,
|
|
3140
|
-
size_t ny) {
|
|
3141
|
-
return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny);
|
|
3142
|
-
}
|
|
3143
|
-
|
|
3144
|
-
size_t fvec_L2sqr_ny_nearest(
|
|
3145
|
-
float* distances_tmp_buffer,
|
|
3146
|
-
const float* x,
|
|
3147
|
-
const float* y,
|
|
3148
|
-
size_t d,
|
|
3149
|
-
size_t ny) {
|
|
3150
|
-
return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
|
|
3151
|
-
}
|
|
3152
|
-
|
|
3153
|
-
size_t fvec_L2sqr_ny_nearest_y_transposed(
|
|
3154
|
-
float* distances_tmp_buffer,
|
|
3155
|
-
const float* x,
|
|
3156
|
-
const float* y,
|
|
3157
|
-
const float* y_sqlen,
|
|
3158
|
-
size_t d,
|
|
3159
|
-
size_t d_offset,
|
|
3160
|
-
size_t ny) {
|
|
3161
|
-
return fvec_L2sqr_ny_nearest_y_transposed_ref(
|
|
3162
|
-
distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
|
|
3163
|
-
}
|
|
3164
|
-
|
|
3165
|
-
float fvec_L1(const float* x, const float* y, size_t d) {
|
|
3166
|
-
return fvec_L1_ref(x, y, d);
|
|
3167
|
-
}
|
|
3168
|
-
|
|
3169
|
-
float fvec_Linf(const float* x, const float* y, size_t d) {
|
|
3170
|
-
return fvec_Linf_ref(x, y, d);
|
|
3171
|
-
}
|
|
3172
|
-
|
|
3173
|
-
void fvec_inner_products_ny(
|
|
3174
|
-
float* dis,
|
|
3175
|
-
const float* x,
|
|
3176
|
-
const float* y,
|
|
3177
|
-
size_t d,
|
|
3178
|
-
size_t ny) {
|
|
3179
|
-
fvec_inner_products_ny_ref(dis, x, y, d, ny);
|
|
3180
|
-
}
|
|
3181
|
-
|
|
3182
|
-
#else
|
|
3183
|
-
// scalar implementation
|
|
3184
|
-
|
|
3185
|
-
float fvec_L1(const float* x, const float* y, size_t d) {
|
|
3186
|
-
return fvec_L1_ref(x, y, d);
|
|
3187
|
-
}
|
|
3188
|
-
|
|
3189
|
-
float fvec_Linf(const float* x, const float* y, size_t d) {
|
|
3190
|
-
return fvec_Linf_ref(x, y, d);
|
|
3191
|
-
}
|
|
3192
|
-
|
|
3193
|
-
void fvec_L2sqr_ny(
|
|
3194
|
-
float* dis,
|
|
3195
|
-
const float* x,
|
|
3196
|
-
const float* y,
|
|
3197
|
-
size_t d,
|
|
3198
|
-
size_t ny) {
|
|
3199
|
-
fvec_L2sqr_ny_ref(dis, x, y, d, ny);
|
|
3200
|
-
}
|
|
3201
|
-
|
|
3202
|
-
void fvec_L2sqr_ny_transposed(
|
|
3203
|
-
float* dis,
|
|
3204
|
-
const float* x,
|
|
3205
|
-
const float* y,
|
|
3206
|
-
const float* y_sqlen,
|
|
3207
|
-
size_t d,
|
|
3208
|
-
size_t d_offset,
|
|
3209
|
-
size_t ny) {
|
|
3210
|
-
return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny);
|
|
3211
|
-
}
|
|
3212
|
-
|
|
3213
|
-
size_t fvec_L2sqr_ny_nearest(
|
|
3214
|
-
float* distances_tmp_buffer,
|
|
3215
|
-
const float* x,
|
|
3216
|
-
const float* y,
|
|
3217
|
-
size_t d,
|
|
3218
|
-
size_t ny) {
|
|
3219
|
-
return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
|
|
3220
|
-
}
|
|
3221
|
-
|
|
3222
|
-
size_t fvec_L2sqr_ny_nearest_y_transposed(
|
|
3223
|
-
float* distances_tmp_buffer,
|
|
3224
|
-
const float* x,
|
|
3225
|
-
const float* y,
|
|
3226
|
-
const float* y_sqlen,
|
|
3227
|
-
size_t d,
|
|
3228
|
-
size_t d_offset,
|
|
3229
|
-
size_t ny) {
|
|
3230
|
-
return fvec_L2sqr_ny_nearest_y_transposed_ref(
|
|
3231
|
-
distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
|
|
3232
|
-
}
|
|
3233
|
-
|
|
3234
|
-
void fvec_inner_products_ny(
|
|
3235
|
-
float* dis,
|
|
3236
|
-
const float* x,
|
|
3237
|
-
const float* y,
|
|
3238
|
-
size_t d,
|
|
3239
|
-
size_t ny) {
|
|
3240
|
-
fvec_inner_products_ny_ref(dis, x, y, d, ny);
|
|
3241
|
-
}
|
|
3242
|
-
|
|
3243
|
-
#endif
|
|
3244
|
-
|
|
3245
|
-
/***************************************************************************
|
|
3246
|
-
* heavily optimized table computations
|
|
3247
|
-
***************************************************************************/
|
|
3248
|
-
|
|
3249
|
-
[[maybe_unused]] static inline void fvec_madd_ref(
|
|
3250
|
-
size_t n,
|
|
3251
|
-
const float* a,
|
|
3252
|
-
float bf,
|
|
3253
|
-
const float* b,
|
|
3254
|
-
float* c) {
|
|
3255
|
-
for (size_t i = 0; i < n; i++) {
|
|
3256
|
-
c[i] = a[i] + bf * b[i];
|
|
3257
|
-
}
|
|
3258
|
-
}
|
|
3259
|
-
|
|
3260
|
-
#if defined(__AVX512F__)
|
|
3261
|
-
|
|
3262
|
-
static inline void fvec_madd_avx512(
|
|
3263
|
-
const size_t n,
|
|
3264
|
-
const float* __restrict a,
|
|
3265
|
-
const float bf,
|
|
3266
|
-
const float* __restrict b,
|
|
3267
|
-
float* __restrict c) {
|
|
3268
|
-
const size_t n16 = n / 16;
|
|
3269
|
-
const size_t n_for_masking = n % 16;
|
|
3270
|
-
|
|
3271
|
-
const __m512 bfmm = _mm512_set1_ps(bf);
|
|
3272
|
-
|
|
3273
|
-
size_t idx = 0;
|
|
3274
|
-
for (idx = 0; idx < n16 * 16; idx += 16) {
|
|
3275
|
-
const __m512 ax = _mm512_loadu_ps(a + idx);
|
|
3276
|
-
const __m512 bx = _mm512_loadu_ps(b + idx);
|
|
3277
|
-
const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax);
|
|
3278
|
-
_mm512_storeu_ps(c + idx, abmul);
|
|
3279
|
-
}
|
|
3280
|
-
|
|
3281
|
-
if (n_for_masking > 0) {
|
|
3282
|
-
const __mmask16 mask = (1 << n_for_masking) - 1;
|
|
3283
|
-
|
|
3284
|
-
const __m512 ax = _mm512_maskz_loadu_ps(mask, a + idx);
|
|
3285
|
-
const __m512 bx = _mm512_maskz_loadu_ps(mask, b + idx);
|
|
3286
|
-
const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax);
|
|
3287
|
-
_mm512_mask_storeu_ps(c + idx, mask, abmul);
|
|
3288
|
-
}
|
|
3289
|
-
}
|
|
3290
|
-
|
|
3291
|
-
#elif defined(__AVX2__)
|
|
3292
|
-
|
|
3293
|
-
static inline void fvec_madd_avx2(
|
|
3294
|
-
const size_t n,
|
|
3295
|
-
const float* __restrict a,
|
|
3296
|
-
const float bf,
|
|
3297
|
-
const float* __restrict b,
|
|
3298
|
-
float* __restrict c) {
|
|
3299
|
-
//
|
|
3300
|
-
const size_t n8 = n / 8;
|
|
3301
|
-
const size_t n_for_masking = n % 8;
|
|
3302
|
-
|
|
3303
|
-
const __m256 bfmm = _mm256_set1_ps(bf);
|
|
3304
|
-
|
|
3305
|
-
size_t idx = 0;
|
|
3306
|
-
for (idx = 0; idx < n8 * 8; idx += 8) {
|
|
3307
|
-
const __m256 ax = _mm256_loadu_ps(a + idx);
|
|
3308
|
-
const __m256 bx = _mm256_loadu_ps(b + idx);
|
|
3309
|
-
const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax);
|
|
3310
|
-
_mm256_storeu_ps(c + idx, abmul);
|
|
3311
|
-
}
|
|
3312
|
-
|
|
3313
|
-
if (n_for_masking > 0) {
|
|
3314
|
-
__m256i mask;
|
|
3315
|
-
switch (n_for_masking) {
|
|
3316
|
-
case 1:
|
|
3317
|
-
mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, -1);
|
|
3318
|
-
break;
|
|
3319
|
-
case 2:
|
|
3320
|
-
mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, -1, -1);
|
|
3321
|
-
break;
|
|
3322
|
-
case 3:
|
|
3323
|
-
mask = _mm256_set_epi32(0, 0, 0, 0, 0, -1, -1, -1);
|
|
3324
|
-
break;
|
|
3325
|
-
case 4:
|
|
3326
|
-
mask = _mm256_set_epi32(0, 0, 0, 0, -1, -1, -1, -1);
|
|
3327
|
-
break;
|
|
3328
|
-
case 5:
|
|
3329
|
-
mask = _mm256_set_epi32(0, 0, 0, -1, -1, -1, -1, -1);
|
|
3330
|
-
break;
|
|
3331
|
-
case 6:
|
|
3332
|
-
mask = _mm256_set_epi32(0, 0, -1, -1, -1, -1, -1, -1);
|
|
3333
|
-
break;
|
|
3334
|
-
case 7:
|
|
3335
|
-
mask = _mm256_set_epi32(0, -1, -1, -1, -1, -1, -1, -1);
|
|
3336
|
-
break;
|
|
3337
|
-
}
|
|
3338
|
-
|
|
3339
|
-
const __m256 ax = _mm256_maskload_ps(a + idx, mask);
|
|
3340
|
-
const __m256 bx = _mm256_maskload_ps(b + idx, mask);
|
|
3341
|
-
const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax);
|
|
3342
|
-
_mm256_maskstore_ps(c + idx, mask, abmul);
|
|
3343
|
-
}
|
|
3344
|
-
}
|
|
3345
|
-
|
|
3346
|
-
#endif
|
|
3347
|
-
|
|
3348
|
-
#ifdef __SSE3__
|
|
3349
|
-
|
|
3350
|
-
[[maybe_unused]] static inline void fvec_madd_sse(
|
|
3351
|
-
size_t n,
|
|
3352
|
-
const float* a,
|
|
3353
|
-
float bf,
|
|
3354
|
-
const float* b,
|
|
3355
|
-
float* c) {
|
|
3356
|
-
n >>= 2;
|
|
3357
|
-
__m128 bf4 = _mm_set_ps1(bf);
|
|
3358
|
-
__m128* a4 = (__m128*)a;
|
|
3359
|
-
__m128* b4 = (__m128*)b;
|
|
3360
|
-
__m128* c4 = (__m128*)c;
|
|
3361
|
-
|
|
3362
|
-
while (n--) {
|
|
3363
|
-
*c4 = _mm_add_ps(*a4, _mm_mul_ps(bf4, *b4));
|
|
3364
|
-
b4++;
|
|
3365
|
-
a4++;
|
|
3366
|
-
c4++;
|
|
3367
|
-
}
|
|
3368
|
-
}
|
|
3369
|
-
|
|
3370
|
-
void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
|
|
3371
|
-
#ifdef __AVX512F__
|
|
3372
|
-
fvec_madd_avx512(n, a, bf, b, c);
|
|
3373
|
-
#elif __AVX2__
|
|
3374
|
-
fvec_madd_avx2(n, a, bf, b, c);
|
|
3375
|
-
#else
|
|
3376
|
-
if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
|
|
3377
|
-
fvec_madd_sse(n, a, bf, b, c);
|
|
3378
|
-
else
|
|
3379
|
-
fvec_madd_ref(n, a, bf, b, c);
|
|
3380
|
-
#endif
|
|
3381
|
-
}
|
|
3382
|
-
|
|
3383
|
-
#elif defined(__ARM_FEATURE_SVE)
|
|
3384
|
-
|
|
3385
|
-
void fvec_madd(
|
|
3386
|
-
const size_t n,
|
|
3387
|
-
const float* __restrict a,
|
|
3388
|
-
const float bf,
|
|
3389
|
-
const float* __restrict b,
|
|
3390
|
-
float* __restrict c) {
|
|
3391
|
-
const size_t lanes = static_cast<size_t>(svcntw());
|
|
3392
|
-
const size_t lanes2 = lanes * 2;
|
|
3393
|
-
const size_t lanes3 = lanes * 3;
|
|
3394
|
-
const size_t lanes4 = lanes * 4;
|
|
3395
|
-
size_t i = 0;
|
|
3396
|
-
for (; i + lanes4 < n; i += lanes4) {
|
|
3397
|
-
const auto mask = svptrue_b32();
|
|
3398
|
-
const auto ai0 = svld1_f32(mask, a + i);
|
|
3399
|
-
const auto ai1 = svld1_f32(mask, a + i + lanes);
|
|
3400
|
-
const auto ai2 = svld1_f32(mask, a + i + lanes2);
|
|
3401
|
-
const auto ai3 = svld1_f32(mask, a + i + lanes3);
|
|
3402
|
-
const auto bi0 = svld1_f32(mask, b + i);
|
|
3403
|
-
const auto bi1 = svld1_f32(mask, b + i + lanes);
|
|
3404
|
-
const auto bi2 = svld1_f32(mask, b + i + lanes2);
|
|
3405
|
-
const auto bi3 = svld1_f32(mask, b + i + lanes3);
|
|
3406
|
-
const auto ci0 = svmla_n_f32_x(mask, ai0, bi0, bf);
|
|
3407
|
-
const auto ci1 = svmla_n_f32_x(mask, ai1, bi1, bf);
|
|
3408
|
-
const auto ci2 = svmla_n_f32_x(mask, ai2, bi2, bf);
|
|
3409
|
-
const auto ci3 = svmla_n_f32_x(mask, ai3, bi3, bf);
|
|
3410
|
-
svst1_f32(mask, c + i, ci0);
|
|
3411
|
-
svst1_f32(mask, c + i + lanes, ci1);
|
|
3412
|
-
svst1_f32(mask, c + i + lanes2, ci2);
|
|
3413
|
-
svst1_f32(mask, c + i + lanes3, ci3);
|
|
3414
|
-
}
|
|
3415
|
-
const auto mask0 = svwhilelt_b32_u64(i, n);
|
|
3416
|
-
const auto mask1 = svwhilelt_b32_u64(i + lanes, n);
|
|
3417
|
-
const auto mask2 = svwhilelt_b32_u64(i + lanes2, n);
|
|
3418
|
-
const auto mask3 = svwhilelt_b32_u64(i + lanes3, n);
|
|
3419
|
-
const auto ai0 = svld1_f32(mask0, a + i);
|
|
3420
|
-
const auto ai1 = svld1_f32(mask1, a + i + lanes);
|
|
3421
|
-
const auto ai2 = svld1_f32(mask2, a + i + lanes2);
|
|
3422
|
-
const auto ai3 = svld1_f32(mask3, a + i + lanes3);
|
|
3423
|
-
const auto bi0 = svld1_f32(mask0, b + i);
|
|
3424
|
-
const auto bi1 = svld1_f32(mask1, b + i + lanes);
|
|
3425
|
-
const auto bi2 = svld1_f32(mask2, b + i + lanes2);
|
|
3426
|
-
const auto bi3 = svld1_f32(mask3, b + i + lanes3);
|
|
3427
|
-
const auto ci0 = svmla_n_f32_x(mask0, ai0, bi0, bf);
|
|
3428
|
-
const auto ci1 = svmla_n_f32_x(mask1, ai1, bi1, bf);
|
|
3429
|
-
const auto ci2 = svmla_n_f32_x(mask2, ai2, bi2, bf);
|
|
3430
|
-
const auto ci3 = svmla_n_f32_x(mask3, ai3, bi3, bf);
|
|
3431
|
-
svst1_f32(mask0, c + i, ci0);
|
|
3432
|
-
svst1_f32(mask1, c + i + lanes, ci1);
|
|
3433
|
-
svst1_f32(mask2, c + i + lanes2, ci2);
|
|
3434
|
-
svst1_f32(mask3, c + i + lanes3, ci3);
|
|
3435
|
-
}
|
|
3436
|
-
|
|
3437
|
-
#elif defined(__aarch64__)
|
|
3438
|
-
|
|
3439
|
-
void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
|
|
3440
|
-
const size_t n_simd = n - (n & 3);
|
|
3441
|
-
const float32x4_t bfv = vdupq_n_f32(bf);
|
|
3442
|
-
size_t i;
|
|
3443
|
-
for (i = 0; i < n_simd; i += 4) {
|
|
3444
|
-
const float32x4_t ai = vld1q_f32(a + i);
|
|
3445
|
-
const float32x4_t bi = vld1q_f32(b + i);
|
|
3446
|
-
const float32x4_t ci = vfmaq_f32(ai, bfv, bi);
|
|
3447
|
-
vst1q_f32(c + i, ci);
|
|
3448
|
-
}
|
|
3449
|
-
for (; i < n; ++i)
|
|
3450
|
-
c[i] = a[i] + bf * b[i];
|
|
3451
|
-
}
|
|
3452
|
-
|
|
3453
|
-
#else
|
|
3454
|
-
|
|
3455
|
-
void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
|
|
3456
|
-
fvec_madd_ref(n, a, bf, b, c);
|
|
3457
|
-
}
|
|
3458
|
-
|
|
3459
|
-
#endif
|
|
3460
|
-
|
|
3461
|
-
static inline int fvec_madd_and_argmin_ref(
|
|
3462
|
-
size_t n,
|
|
3463
|
-
const float* a,
|
|
3464
|
-
float bf,
|
|
3465
|
-
const float* b,
|
|
3466
|
-
float* c) {
|
|
3467
|
-
float vmin = 1e20;
|
|
3468
|
-
int imin = -1;
|
|
3469
|
-
|
|
3470
|
-
for (size_t i = 0; i < n; i++) {
|
|
3471
|
-
c[i] = a[i] + bf * b[i];
|
|
3472
|
-
if (c[i] < vmin) {
|
|
3473
|
-
vmin = c[i];
|
|
3474
|
-
imin = i;
|
|
162
|
+
for (size_t i = 0; i < n; i++) {
|
|
163
|
+
c[i] = a[i] + bf * b[i];
|
|
164
|
+
if (c[i] < vmin) {
|
|
165
|
+
vmin = c[i];
|
|
166
|
+
imin = i;
|
|
3475
167
|
}
|
|
3476
168
|
}
|
|
3477
169
|
return imin;
|
|
3478
170
|
}
|
|
3479
171
|
|
|
3480
|
-
#ifdef __SSE3__
|
|
3481
|
-
|
|
3482
|
-
static inline int fvec_madd_and_argmin_sse(
|
|
3483
|
-
size_t n,
|
|
3484
|
-
const float* a,
|
|
3485
|
-
float bf,
|
|
3486
|
-
const float* b,
|
|
3487
|
-
float* c) {
|
|
3488
|
-
n >>= 2;
|
|
3489
|
-
__m128 bf4 = _mm_set_ps1(bf);
|
|
3490
|
-
__m128 vmin4 = _mm_set_ps1(1e20);
|
|
3491
|
-
__m128i imin4 = _mm_set1_epi32(-1);
|
|
3492
|
-
__m128i idx4 = _mm_set_epi32(3, 2, 1, 0);
|
|
3493
|
-
__m128i inc4 = _mm_set1_epi32(4);
|
|
3494
|
-
__m128* a4 = (__m128*)a;
|
|
3495
|
-
__m128* b4 = (__m128*)b;
|
|
3496
|
-
__m128* c4 = (__m128*)c;
|
|
3497
|
-
|
|
3498
|
-
while (n--) {
|
|
3499
|
-
__m128 vc4 = _mm_add_ps(*a4, _mm_mul_ps(bf4, *b4));
|
|
3500
|
-
*c4 = vc4;
|
|
3501
|
-
__m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4));
|
|
3502
|
-
// imin4 = _mm_blendv_epi8 (imin4, idx4, mask); // slower!
|
|
3503
|
-
|
|
3504
|
-
imin4 = _mm_or_si128(
|
|
3505
|
-
_mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4));
|
|
3506
|
-
vmin4 = _mm_min_ps(vmin4, vc4);
|
|
3507
|
-
b4++;
|
|
3508
|
-
a4++;
|
|
3509
|
-
c4++;
|
|
3510
|
-
idx4 = _mm_add_epi32(idx4, inc4);
|
|
3511
|
-
}
|
|
3512
|
-
|
|
3513
|
-
// 4 values -> 2
|
|
3514
|
-
{
|
|
3515
|
-
idx4 = _mm_shuffle_epi32(imin4, 3 << 2 | 2);
|
|
3516
|
-
__m128 vc4 = _mm_shuffle_ps(vmin4, vmin4, 3 << 2 | 2);
|
|
3517
|
-
__m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4));
|
|
3518
|
-
imin4 = _mm_or_si128(
|
|
3519
|
-
_mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4));
|
|
3520
|
-
vmin4 = _mm_min_ps(vmin4, vc4);
|
|
3521
|
-
}
|
|
3522
|
-
// 2 values -> 1
|
|
3523
|
-
{
|
|
3524
|
-
idx4 = _mm_shuffle_epi32(imin4, 1);
|
|
3525
|
-
__m128 vc4 = _mm_shuffle_ps(vmin4, vmin4, 1);
|
|
3526
|
-
__m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4));
|
|
3527
|
-
imin4 = _mm_or_si128(
|
|
3528
|
-
_mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4));
|
|
3529
|
-
// vmin4 = _mm_min_ps (vmin4, vc4);
|
|
3530
|
-
}
|
|
3531
|
-
return _mm_cvtsi128_si32(imin4);
|
|
3532
|
-
}
|
|
3533
|
-
|
|
3534
|
-
int fvec_madd_and_argmin(
|
|
3535
|
-
size_t n,
|
|
3536
|
-
const float* a,
|
|
3537
|
-
float bf,
|
|
3538
|
-
const float* b,
|
|
3539
|
-
float* c) {
|
|
3540
|
-
if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0) {
|
|
3541
|
-
return fvec_madd_and_argmin_sse(n, a, bf, b, c);
|
|
3542
|
-
} else {
|
|
3543
|
-
return fvec_madd_and_argmin_ref(n, a, bf, b, c);
|
|
3544
|
-
}
|
|
3545
|
-
}
|
|
3546
|
-
|
|
3547
|
-
#elif defined(__aarch64__)
|
|
3548
|
-
|
|
3549
|
-
int fvec_madd_and_argmin(
|
|
3550
|
-
size_t n,
|
|
3551
|
-
const float* a,
|
|
3552
|
-
float bf,
|
|
3553
|
-
const float* b,
|
|
3554
|
-
float* c) {
|
|
3555
|
-
float32x4_t vminv = vdupq_n_f32(1e20);
|
|
3556
|
-
uint32x4_t iminv = vdupq_n_u32(static_cast<uint32_t>(-1));
|
|
3557
|
-
size_t i;
|
|
3558
|
-
{
|
|
3559
|
-
const size_t n_simd = n - (n & 3);
|
|
3560
|
-
const uint32_t iota[] = {0, 1, 2, 3};
|
|
3561
|
-
uint32x4_t iv = vld1q_u32(iota);
|
|
3562
|
-
const uint32x4_t incv = vdupq_n_u32(4);
|
|
3563
|
-
const float32x4_t bfv = vdupq_n_f32(bf);
|
|
3564
|
-
for (i = 0; i < n_simd; i += 4) {
|
|
3565
|
-
const float32x4_t ai = vld1q_f32(a + i);
|
|
3566
|
-
const float32x4_t bi = vld1q_f32(b + i);
|
|
3567
|
-
const float32x4_t ci = vfmaq_f32(ai, bfv, bi);
|
|
3568
|
-
vst1q_f32(c + i, ci);
|
|
3569
|
-
const uint32x4_t less_than = vcltq_f32(ci, vminv);
|
|
3570
|
-
vminv = vminq_f32(ci, vminv);
|
|
3571
|
-
iminv = vorrq_u32(
|
|
3572
|
-
vandq_u32(less_than, iv),
|
|
3573
|
-
vandq_u32(vmvnq_u32(less_than), iminv));
|
|
3574
|
-
iv = vaddq_u32(iv, incv);
|
|
3575
|
-
}
|
|
3576
|
-
}
|
|
3577
|
-
float vmin = vminvq_f32(vminv);
|
|
3578
|
-
uint32_t imin;
|
|
3579
|
-
{
|
|
3580
|
-
const float32x4_t vminy = vdupq_n_f32(vmin);
|
|
3581
|
-
const uint32x4_t equals = vceqq_f32(vminv, vminy);
|
|
3582
|
-
imin = vminvq_u32(vorrq_u32(
|
|
3583
|
-
vandq_u32(equals, iminv),
|
|
3584
|
-
vandq_u32(
|
|
3585
|
-
vmvnq_u32(equals),
|
|
3586
|
-
vdupq_n_u32(std::numeric_limits<uint32_t>::max()))));
|
|
3587
|
-
}
|
|
3588
|
-
for (; i < n; ++i) {
|
|
3589
|
-
c[i] = a[i] + bf * b[i];
|
|
3590
|
-
if (c[i] < vmin) {
|
|
3591
|
-
vmin = c[i];
|
|
3592
|
-
imin = static_cast<uint32_t>(i);
|
|
3593
|
-
}
|
|
3594
|
-
}
|
|
3595
|
-
return static_cast<int>(imin);
|
|
3596
|
-
}
|
|
3597
|
-
|
|
3598
|
-
#else
|
|
3599
|
-
|
|
3600
|
-
int fvec_madd_and_argmin(
|
|
3601
|
-
size_t n,
|
|
3602
|
-
const float* a,
|
|
3603
|
-
float bf,
|
|
3604
|
-
const float* b,
|
|
3605
|
-
float* c) {
|
|
3606
|
-
return fvec_madd_and_argmin_ref(n, a, bf, b, c);
|
|
3607
|
-
}
|
|
3608
|
-
|
|
3609
|
-
#endif
|
|
3610
|
-
|
|
3611
|
-
/***************************************************************************
|
|
3612
|
-
* PQ tables computations
|
|
3613
|
-
***************************************************************************/
|
|
3614
|
-
|
|
3615
|
-
namespace {
|
|
3616
|
-
|
|
3617
|
-
/// compute the IP for dsub = 2 for 8 centroids and 4 sub-vectors at a time
|
|
3618
|
-
template <bool is_inner_product>
|
|
3619
|
-
void pq2_8cents_table(
|
|
3620
|
-
const simd8float32 centroids[8],
|
|
3621
|
-
const simd8float32 x,
|
|
3622
|
-
float* out,
|
|
3623
|
-
size_t ldo,
|
|
3624
|
-
size_t nout = 4) {
|
|
3625
|
-
simd8float32 ips[4];
|
|
3626
|
-
|
|
3627
|
-
for (int i = 0; i < 4; i++) {
|
|
3628
|
-
simd8float32 p1, p2;
|
|
3629
|
-
if (is_inner_product) {
|
|
3630
|
-
p1 = x * centroids[2 * i];
|
|
3631
|
-
p2 = x * centroids[2 * i + 1];
|
|
3632
|
-
} else {
|
|
3633
|
-
p1 = (x - centroids[2 * i]);
|
|
3634
|
-
p1 = p1 * p1;
|
|
3635
|
-
p2 = (x - centroids[2 * i + 1]);
|
|
3636
|
-
p2 = p2 * p2;
|
|
3637
|
-
}
|
|
3638
|
-
ips[i] = hadd(p1, p2);
|
|
3639
|
-
}
|
|
3640
|
-
|
|
3641
|
-
simd8float32 ip02a = geteven(ips[0], ips[1]);
|
|
3642
|
-
simd8float32 ip02b = geteven(ips[2], ips[3]);
|
|
3643
|
-
simd8float32 ip0 = getlow128(ip02a, ip02b);
|
|
3644
|
-
simd8float32 ip2 = gethigh128(ip02a, ip02b);
|
|
3645
|
-
|
|
3646
|
-
simd8float32 ip13a = getodd(ips[0], ips[1]);
|
|
3647
|
-
simd8float32 ip13b = getodd(ips[2], ips[3]);
|
|
3648
|
-
simd8float32 ip1 = getlow128(ip13a, ip13b);
|
|
3649
|
-
simd8float32 ip3 = gethigh128(ip13a, ip13b);
|
|
3650
|
-
|
|
3651
|
-
switch (nout) {
|
|
3652
|
-
case 4:
|
|
3653
|
-
ip3.storeu(out + 3 * ldo);
|
|
3654
|
-
[[fallthrough]];
|
|
3655
|
-
case 3:
|
|
3656
|
-
ip2.storeu(out + 2 * ldo);
|
|
3657
|
-
[[fallthrough]];
|
|
3658
|
-
case 2:
|
|
3659
|
-
ip1.storeu(out + 1 * ldo);
|
|
3660
|
-
[[fallthrough]];
|
|
3661
|
-
case 1:
|
|
3662
|
-
ip0.storeu(out);
|
|
3663
|
-
}
|
|
3664
|
-
}
|
|
3665
|
-
|
|
3666
|
-
simd8float32 load_simd8float32_partial(const float* x, int n) {
|
|
3667
|
-
ALIGNED(32) float tmp[8] = {0, 0, 0, 0, 0, 0, 0, 0};
|
|
3668
|
-
float* wp = tmp;
|
|
3669
|
-
for (int i = 0; i < n; i++) {
|
|
3670
|
-
*wp++ = *x++;
|
|
3671
|
-
}
|
|
3672
|
-
return simd8float32(tmp);
|
|
3673
|
-
}
|
|
3674
|
-
|
|
3675
|
-
} // anonymous namespace
|
|
3676
|
-
|
|
3677
|
-
void compute_PQ_dis_tables_dsub2(
|
|
3678
|
-
size_t d,
|
|
3679
|
-
size_t ksub,
|
|
3680
|
-
const float* all_centroids,
|
|
3681
|
-
size_t nx,
|
|
3682
|
-
const float* x,
|
|
3683
|
-
bool is_inner_product,
|
|
3684
|
-
float* dis_tables) {
|
|
3685
|
-
size_t M = d / 2;
|
|
3686
|
-
FAISS_THROW_IF_NOT(ksub % 8 == 0);
|
|
3687
|
-
|
|
3688
|
-
for (size_t m0 = 0; m0 < M; m0 += 4) {
|
|
3689
|
-
int m1 = std::min(M, m0 + 4);
|
|
3690
|
-
for (int k0 = 0; k0 < ksub; k0 += 8) {
|
|
3691
|
-
simd8float32 centroids[8];
|
|
3692
|
-
for (int k = 0; k < 8; k++) {
|
|
3693
|
-
ALIGNED(32) float centroid[8];
|
|
3694
|
-
size_t wp = 0;
|
|
3695
|
-
size_t rp = (m0 * ksub + k + k0) * 2;
|
|
3696
|
-
for (int m = m0; m < m1; m++) {
|
|
3697
|
-
centroid[wp++] = all_centroids[rp];
|
|
3698
|
-
centroid[wp++] = all_centroids[rp + 1];
|
|
3699
|
-
rp += 2 * ksub;
|
|
3700
|
-
}
|
|
3701
|
-
centroids[k] = simd8float32(centroid);
|
|
3702
|
-
}
|
|
3703
|
-
for (size_t i = 0; i < nx; i++) {
|
|
3704
|
-
simd8float32 xi;
|
|
3705
|
-
if (m1 == m0 + 4) {
|
|
3706
|
-
xi.loadu(x + i * d + m0 * 2);
|
|
3707
|
-
} else {
|
|
3708
|
-
xi = load_simd8float32_partial(
|
|
3709
|
-
x + i * d + m0 * 2, 2 * (m1 - m0));
|
|
3710
|
-
}
|
|
3711
|
-
|
|
3712
|
-
if (is_inner_product) {
|
|
3713
|
-
pq2_8cents_table<true>(
|
|
3714
|
-
centroids,
|
|
3715
|
-
xi,
|
|
3716
|
-
dis_tables + (i * M + m0) * ksub + k0,
|
|
3717
|
-
ksub,
|
|
3718
|
-
m1 - m0);
|
|
3719
|
-
} else {
|
|
3720
|
-
pq2_8cents_table<false>(
|
|
3721
|
-
centroids,
|
|
3722
|
-
xi,
|
|
3723
|
-
dis_tables + (i * M + m0) * ksub + k0,
|
|
3724
|
-
ksub,
|
|
3725
|
-
m1 - m0);
|
|
3726
|
-
}
|
|
3727
|
-
}
|
|
3728
|
-
}
|
|
3729
|
-
}
|
|
3730
|
-
}
|
|
3731
|
-
|
|
3732
|
-
/*********************************************************
|
|
3733
|
-
* Vector to vector functions
|
|
3734
|
-
*********************************************************/
|
|
3735
|
-
|
|
3736
|
-
void fvec_sub(size_t d, const float* a, const float* b, float* c) {
|
|
3737
|
-
size_t i;
|
|
3738
|
-
for (i = 0; i + 7 < d; i += 8) {
|
|
3739
|
-
simd8float32 ci, ai, bi;
|
|
3740
|
-
ai.loadu(a + i);
|
|
3741
|
-
bi.loadu(b + i);
|
|
3742
|
-
ci = ai - bi;
|
|
3743
|
-
ci.storeu(c + i);
|
|
3744
|
-
}
|
|
3745
|
-
// finish non-multiple of 8 remainder
|
|
3746
|
-
for (; i < d; i++) {
|
|
3747
|
-
c[i] = a[i] - b[i];
|
|
3748
|
-
}
|
|
3749
|
-
}
|
|
3750
|
-
|
|
3751
|
-
void fvec_add(size_t d, const float* a, const float* b, float* c) {
|
|
3752
|
-
size_t i;
|
|
3753
|
-
for (i = 0; i + 7 < d; i += 8) {
|
|
3754
|
-
simd8float32 ci, ai, bi;
|
|
3755
|
-
ai.loadu(a + i);
|
|
3756
|
-
bi.loadu(b + i);
|
|
3757
|
-
ci = ai + bi;
|
|
3758
|
-
ci.storeu(c + i);
|
|
3759
|
-
}
|
|
3760
|
-
// finish non-multiple of 8 remainder
|
|
3761
|
-
for (; i < d; i++) {
|
|
3762
|
-
c[i] = a[i] + b[i];
|
|
3763
|
-
}
|
|
3764
|
-
}
|
|
3765
|
-
|
|
3766
|
-
void fvec_add(size_t d, const float* a, float b, float* c) {
|
|
3767
|
-
size_t i;
|
|
3768
|
-
simd8float32 bv(b);
|
|
3769
|
-
for (i = 0; i + 7 < d; i += 8) {
|
|
3770
|
-
simd8float32 ci, ai;
|
|
3771
|
-
ai.loadu(a + i);
|
|
3772
|
-
ci = ai + bv;
|
|
3773
|
-
ci.storeu(c + i);
|
|
3774
|
-
}
|
|
3775
|
-
// finish non-multiple of 8 remainder
|
|
3776
|
-
for (; i < d; i++) {
|
|
3777
|
-
c[i] = a[i] + b;
|
|
3778
|
-
}
|
|
3779
|
-
}
|
|
3780
|
-
|
|
3781
172
|
} // namespace faiss
|