faiss 0.3.1 → 0.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +8 -0
- data/LICENSE.txt +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +2 -2
- data/vendor/faiss/faiss/AutoTune.h +3 -3
- data/vendor/faiss/faiss/Clustering.cpp +37 -6
- data/vendor/faiss/faiss/Clustering.h +12 -3
- data/vendor/faiss/faiss/IVFlib.cpp +6 -3
- data/vendor/faiss/faiss/IVFlib.h +2 -2
- data/vendor/faiss/faiss/Index.cpp +6 -2
- data/vendor/faiss/faiss/Index.h +30 -8
- data/vendor/faiss/faiss/Index2Layer.cpp +2 -2
- data/vendor/faiss/faiss/Index2Layer.h +2 -2
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +7 -7
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +2 -2
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +14 -16
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +2 -2
- data/vendor/faiss/faiss/IndexBinary.cpp +13 -2
- data/vendor/faiss/faiss/IndexBinary.h +8 -2
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -3
- data/vendor/faiss/faiss/IndexBinaryFlat.h +2 -2
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -2
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -2
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +2 -7
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +3 -3
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -3
- data/vendor/faiss/faiss/IndexBinaryHash.h +2 -2
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +3 -3
- data/vendor/faiss/faiss/IndexBinaryIVF.h +2 -2
- data/vendor/faiss/faiss/IndexFastScan.cpp +32 -18
- data/vendor/faiss/faiss/IndexFastScan.h +11 -2
- data/vendor/faiss/faiss/IndexFlat.cpp +13 -10
- data/vendor/faiss/faiss/IndexFlat.h +2 -2
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -7
- data/vendor/faiss/faiss/IndexFlatCodes.h +25 -5
- data/vendor/faiss/faiss/IndexHNSW.cpp +156 -96
- data/vendor/faiss/faiss/IndexHNSW.h +54 -5
- data/vendor/faiss/faiss/IndexIDMap.cpp +19 -3
- data/vendor/faiss/faiss/IndexIDMap.h +5 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +5 -6
- data/vendor/faiss/faiss/IndexIVF.h +13 -4
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +21 -7
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +5 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +3 -14
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -4
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +201 -91
- data/vendor/faiss/faiss/IndexIVFFastScan.h +33 -9
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +2 -2
- data/vendor/faiss/faiss/IndexIVFFlat.h +2 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +2 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +2 -2
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -6
- data/vendor/faiss/faiss/IndexIVFPQ.h +2 -2
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +7 -14
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +2 -4
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -2
- data/vendor/faiss/faiss/IndexIVFPQR.h +2 -2
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +2 -3
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -2
- data/vendor/faiss/faiss/IndexLSH.cpp +2 -3
- data/vendor/faiss/faiss/IndexLSH.h +2 -2
- data/vendor/faiss/faiss/IndexLattice.cpp +3 -21
- data/vendor/faiss/faiss/IndexLattice.h +5 -24
- data/vendor/faiss/faiss/IndexNNDescent.cpp +2 -31
- data/vendor/faiss/faiss/IndexNNDescent.h +3 -3
- data/vendor/faiss/faiss/IndexNSG.cpp +2 -5
- data/vendor/faiss/faiss/IndexNSG.h +3 -3
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
- data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +26 -26
- data/vendor/faiss/faiss/IndexPQ.h +2 -2
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +2 -5
- data/vendor/faiss/faiss/IndexPQFastScan.h +2 -11
- data/vendor/faiss/faiss/IndexPreTransform.cpp +2 -2
- data/vendor/faiss/faiss/IndexPreTransform.h +3 -3
- data/vendor/faiss/faiss/IndexRefine.cpp +46 -9
- data/vendor/faiss/faiss/IndexRefine.h +9 -2
- data/vendor/faiss/faiss/IndexReplicas.cpp +2 -2
- data/vendor/faiss/faiss/IndexReplicas.h +2 -2
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +2 -2
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +2 -2
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +5 -4
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -2
- data/vendor/faiss/faiss/IndexShards.cpp +2 -2
- data/vendor/faiss/faiss/IndexShards.h +2 -2
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +2 -2
- data/vendor/faiss/faiss/IndexShardsIVF.h +2 -2
- data/vendor/faiss/faiss/MatrixStats.cpp +2 -2
- data/vendor/faiss/faiss/MatrixStats.h +2 -2
- data/vendor/faiss/faiss/MetaIndexes.cpp +2 -3
- data/vendor/faiss/faiss/MetaIndexes.h +2 -2
- data/vendor/faiss/faiss/MetricType.h +9 -4
- data/vendor/faiss/faiss/VectorTransform.cpp +2 -2
- data/vendor/faiss/faiss/VectorTransform.h +2 -2
- data/vendor/faiss/faiss/clone_index.cpp +2 -2
- data/vendor/faiss/faiss/clone_index.h +2 -2
- data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +2 -2
- data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +2 -2
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +97 -19
- data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +192 -0
- data/vendor/faiss/faiss/cppcontrib/factory_tools.h +29 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2 -2
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +85 -32
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2 -2
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +2 -2
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +2 -2
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +2 -2
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +2 -2
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +2 -5
- data/vendor/faiss/faiss/gpu/GpuAutoTune.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +45 -13
- data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +12 -6
- data/vendor/faiss/faiss/gpu/GpuDistance.h +11 -7
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +3 -3
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuIndex.h +10 -15
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +285 -0
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -2
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +4 -2
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +3 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +7 -2
- data/vendor/faiss/faiss/gpu/GpuResources.h +11 -4
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +66 -11
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +15 -5
- data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -2
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +28 -23
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +2 -2
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +2 -2
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +2 -2
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +2 -2
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +8 -2
- data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +2 -3
- data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +2 -2
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +10 -7
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +2 -2
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +54 -54
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +144 -77
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +51 -51
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +2 -2
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +3 -3
- data/vendor/faiss/faiss/gpu/test/TestGpuResidualQuantizer.cpp +70 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +74 -4
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +2 -2
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +3 -3
- data/vendor/faiss/faiss/gpu/utils/{RaftUtils.h → CuvsUtils.h} +12 -11
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +8 -2
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +2 -2
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +2 -2
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +2 -2
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +6 -3
- data/vendor/faiss/faiss/gpu/utils/Timer.h +3 -3
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +79 -11
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +17 -5
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +27 -2
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +11 -3
- data/vendor/faiss/faiss/impl/CodePacker.cpp +2 -2
- data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
- data/vendor/faiss/faiss/impl/DistanceComputer.h +48 -2
- data/vendor/faiss/faiss/impl/FaissAssert.h +6 -4
- data/vendor/faiss/faiss/impl/FaissException.cpp +2 -2
- data/vendor/faiss/faiss/impl/FaissException.h +2 -3
- data/vendor/faiss/faiss/impl/HNSW.cpp +378 -205
- data/vendor/faiss/faiss/impl/HNSW.h +55 -24
- data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
- data/vendor/faiss/faiss/impl/IDSelector.h +2 -2
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +10 -10
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +2 -2
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +36 -2
- data/vendor/faiss/faiss/impl/NNDescent.cpp +15 -10
- data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
- data/vendor/faiss/faiss/impl/NSG.cpp +26 -49
- data/vendor/faiss/faiss/impl/NSG.h +20 -8
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +2 -2
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +2 -2
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +2 -4
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +2 -2
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -2
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +3 -2
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +7 -3
- data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +2 -36
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +3 -13
- data/vendor/faiss/faiss/impl/ResultHandler.h +153 -34
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +721 -104
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +5 -2
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +2 -2
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +2 -2
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +7 -2
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +2 -2
- data/vendor/faiss/faiss/impl/code_distance/code_distance-sve.h +440 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +55 -2
- data/vendor/faiss/faiss/impl/index_read.cpp +31 -20
- data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
- data/vendor/faiss/faiss/impl/index_write.cpp +30 -16
- data/vendor/faiss/faiss/impl/io.cpp +15 -7
- data/vendor/faiss/faiss/impl/io.h +6 -6
- data/vendor/faiss/faiss/impl/io_macros.h +8 -9
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +2 -3
- data/vendor/faiss/faiss/impl/kmeans1d.h +2 -2
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +2 -3
- data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
- data/vendor/faiss/faiss/impl/platform_macros.h +34 -2
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +13 -2
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +20 -2
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +3 -3
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +450 -3
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +8 -8
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +3 -3
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +151 -67
- data/vendor/faiss/faiss/index_factory.cpp +51 -34
- data/vendor/faiss/faiss/index_factory.h +2 -2
- data/vendor/faiss/faiss/index_io.h +14 -7
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +30 -10
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +5 -2
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +11 -3
- data/vendor/faiss/faiss/invlists/DirectMap.h +2 -2
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +57 -19
- data/vendor/faiss/faiss/invlists/InvertedLists.h +20 -11
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +2 -2
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +2 -2
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +23 -9
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +4 -3
- data/vendor/faiss/faiss/python/python_callbacks.cpp +5 -5
- data/vendor/faiss/faiss/python/python_callbacks.h +2 -2
- data/vendor/faiss/faiss/utils/AlignedTable.h +5 -3
- data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
- data/vendor/faiss/faiss/utils/Heap.h +107 -2
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +346 -0
- data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
- data/vendor/faiss/faiss/utils/WorkerThread.cpp +2 -2
- data/vendor/faiss/faiss/utils/WorkerThread.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk/generic.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk/mode.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +2 -2
- data/vendor/faiss/faiss/utils/bf16.h +36 -0
- data/vendor/faiss/faiss/utils/distances.cpp +249 -90
- data/vendor/faiss/faiss/utils/distances.h +8 -8
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +2 -2
- data/vendor/faiss/faiss/utils/distances_simd.cpp +1543 -56
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +72 -2
- data/vendor/faiss/faiss/utils/extra_distances.cpp +87 -140
- data/vendor/faiss/faiss/utils/extra_distances.h +5 -4
- data/vendor/faiss/faiss/utils/fp16-arm.h +2 -2
- data/vendor/faiss/faiss/utils/fp16-fp16c.h +2 -2
- data/vendor/faiss/faiss/utils/fp16-inl.h +2 -2
- data/vendor/faiss/faiss/utils/fp16.h +2 -2
- data/vendor/faiss/faiss/utils/hamming-inl.h +2 -2
- data/vendor/faiss/faiss/utils/hamming.cpp +3 -4
- data/vendor/faiss/faiss/utils/hamming.h +2 -2
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +2 -2
- data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +490 -0
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +2 -2
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +6 -3
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +7 -3
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +5 -5
- data/vendor/faiss/faiss/utils/ordered_key_value.h +2 -2
- data/vendor/faiss/faiss/utils/partitioning.cpp +2 -2
- data/vendor/faiss/faiss/utils/partitioning.h +2 -2
- data/vendor/faiss/faiss/utils/prefetch.h +2 -2
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +2 -2
- data/vendor/faiss/faiss/utils/quantize_lut.h +2 -2
- data/vendor/faiss/faiss/utils/random.cpp +45 -2
- data/vendor/faiss/faiss/utils/random.h +27 -2
- data/vendor/faiss/faiss/utils/simdlib.h +12 -3
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +2 -2
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +2 -2
- data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -4
- data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
- data/vendor/faiss/faiss/utils/sorting.cpp +2 -2
- data/vendor/faiss/faiss/utils/sorting.h +2 -2
- data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +2 -2
- data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
- data/vendor/faiss/faiss/utils/utils.cpp +17 -10
- data/vendor/faiss/faiss/utils/utils.h +7 -3
- metadata +22 -11
- data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -1,5 +1,5 @@
|
|
1
|
-
|
2
|
-
* Copyright (c)
|
1
|
+
/*
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
3
3
|
*
|
4
4
|
* This source code is licensed under the MIT license found in the
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
@@ -23,10 +23,16 @@
|
|
23
23
|
#include <immintrin.h>
|
24
24
|
#endif
|
25
25
|
|
26
|
-
#
|
26
|
+
#if defined(__AVX512F__)
|
27
|
+
#include <faiss/utils/transpose/transpose-avx512-inl.h>
|
28
|
+
#elif defined(__AVX2__)
|
27
29
|
#include <faiss/utils/transpose/transpose-avx2-inl.h>
|
28
30
|
#endif
|
29
31
|
|
32
|
+
#ifdef __ARM_FEATURE_SVE
|
33
|
+
#include <arm_sve.h>
|
34
|
+
#endif
|
35
|
+
|
30
36
|
#ifdef __aarch64__
|
31
37
|
#include <arm_neon.h>
|
32
38
|
#endif
|
@@ -346,6 +352,14 @@ inline float horizontal_sum(const __m256 v) {
|
|
346
352
|
}
|
347
353
|
#endif
|
348
354
|
|
355
|
+
#ifdef __AVX512F__
|
356
|
+
/// helper function for AVX512
|
357
|
+
inline float horizontal_sum(const __m512 v) {
|
358
|
+
// performs better than adding the high and low parts
|
359
|
+
return _mm512_reduce_add_ps(v);
|
360
|
+
}
|
361
|
+
#endif
|
362
|
+
|
349
363
|
/// Function that does a component-wise operation between x and y
|
350
364
|
/// to compute L2 distances. ElementOp can then be used in the fvec_op_ny
|
351
365
|
/// functions below
|
@@ -366,6 +380,13 @@ struct ElementOpL2 {
|
|
366
380
|
return _mm256_mul_ps(tmp, tmp);
|
367
381
|
}
|
368
382
|
#endif
|
383
|
+
|
384
|
+
#ifdef __AVX512F__
|
385
|
+
static __m512 op(__m512 x, __m512 y) {
|
386
|
+
__m512 tmp = _mm512_sub_ps(x, y);
|
387
|
+
return _mm512_mul_ps(tmp, tmp);
|
388
|
+
}
|
389
|
+
#endif
|
369
390
|
};
|
370
391
|
|
371
392
|
/// Function that does a component-wise operation between x and y
|
@@ -384,6 +405,12 @@ struct ElementOpIP {
|
|
384
405
|
return _mm256_mul_ps(x, y);
|
385
406
|
}
|
386
407
|
#endif
|
408
|
+
|
409
|
+
#ifdef __AVX512F__
|
410
|
+
static __m512 op(__m512 x, __m512 y) {
|
411
|
+
return _mm512_mul_ps(x, y);
|
412
|
+
}
|
413
|
+
#endif
|
387
414
|
};
|
388
415
|
|
389
416
|
template <class ElementOp>
|
@@ -426,7 +453,130 @@ void fvec_op_ny_D2(float* dis, const float* x, const float* y, size_t ny) {
|
|
426
453
|
}
|
427
454
|
}
|
428
455
|
|
429
|
-
#
|
456
|
+
#if defined(__AVX512F__)
|
457
|
+
|
458
|
+
template <>
|
459
|
+
void fvec_op_ny_D2<ElementOpIP>(
|
460
|
+
float* dis,
|
461
|
+
const float* x,
|
462
|
+
const float* y,
|
463
|
+
size_t ny) {
|
464
|
+
const size_t ny16 = ny / 16;
|
465
|
+
size_t i = 0;
|
466
|
+
|
467
|
+
if (ny16 > 0) {
|
468
|
+
// process 16 D2-vectors per loop.
|
469
|
+
_mm_prefetch((const char*)y, _MM_HINT_T0);
|
470
|
+
_mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
|
471
|
+
|
472
|
+
const __m512 m0 = _mm512_set1_ps(x[0]);
|
473
|
+
const __m512 m1 = _mm512_set1_ps(x[1]);
|
474
|
+
|
475
|
+
for (i = 0; i < ny16 * 16; i += 16) {
|
476
|
+
_mm_prefetch((const char*)(y + 64), _MM_HINT_T0);
|
477
|
+
|
478
|
+
// load 16x2 matrix and transpose it in registers.
|
479
|
+
// the typical bottleneck is memory access, so
|
480
|
+
// let's trade instructions for the bandwidth.
|
481
|
+
|
482
|
+
__m512 v0;
|
483
|
+
__m512 v1;
|
484
|
+
|
485
|
+
transpose_16x2(
|
486
|
+
_mm512_loadu_ps(y + 0 * 16),
|
487
|
+
_mm512_loadu_ps(y + 1 * 16),
|
488
|
+
v0,
|
489
|
+
v1);
|
490
|
+
|
491
|
+
// compute distances (dot product)
|
492
|
+
__m512 distances = _mm512_mul_ps(m0, v0);
|
493
|
+
distances = _mm512_fmadd_ps(m1, v1, distances);
|
494
|
+
|
495
|
+
// store
|
496
|
+
_mm512_storeu_ps(dis + i, distances);
|
497
|
+
|
498
|
+
y += 32; // move to the next set of 16x2 elements
|
499
|
+
}
|
500
|
+
}
|
501
|
+
|
502
|
+
if (i < ny) {
|
503
|
+
// process leftovers
|
504
|
+
float x0 = x[0];
|
505
|
+
float x1 = x[1];
|
506
|
+
|
507
|
+
for (; i < ny; i++) {
|
508
|
+
float distance = x0 * y[0] + x1 * y[1];
|
509
|
+
y += 2;
|
510
|
+
dis[i] = distance;
|
511
|
+
}
|
512
|
+
}
|
513
|
+
}
|
514
|
+
|
515
|
+
template <>
|
516
|
+
void fvec_op_ny_D2<ElementOpL2>(
|
517
|
+
float* dis,
|
518
|
+
const float* x,
|
519
|
+
const float* y,
|
520
|
+
size_t ny) {
|
521
|
+
const size_t ny16 = ny / 16;
|
522
|
+
size_t i = 0;
|
523
|
+
|
524
|
+
if (ny16 > 0) {
|
525
|
+
// process 16 D2-vectors per loop.
|
526
|
+
_mm_prefetch((const char*)y, _MM_HINT_T0);
|
527
|
+
_mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
|
528
|
+
|
529
|
+
const __m512 m0 = _mm512_set1_ps(x[0]);
|
530
|
+
const __m512 m1 = _mm512_set1_ps(x[1]);
|
531
|
+
|
532
|
+
for (i = 0; i < ny16 * 16; i += 16) {
|
533
|
+
_mm_prefetch((const char*)(y + 64), _MM_HINT_T0);
|
534
|
+
|
535
|
+
// load 16x2 matrix and transpose it in registers.
|
536
|
+
// the typical bottleneck is memory access, so
|
537
|
+
// let's trade instructions for the bandwidth.
|
538
|
+
|
539
|
+
__m512 v0;
|
540
|
+
__m512 v1;
|
541
|
+
|
542
|
+
transpose_16x2(
|
543
|
+
_mm512_loadu_ps(y + 0 * 16),
|
544
|
+
_mm512_loadu_ps(y + 1 * 16),
|
545
|
+
v0,
|
546
|
+
v1);
|
547
|
+
|
548
|
+
// compute differences
|
549
|
+
const __m512 d0 = _mm512_sub_ps(m0, v0);
|
550
|
+
const __m512 d1 = _mm512_sub_ps(m1, v1);
|
551
|
+
|
552
|
+
// compute squares of differences
|
553
|
+
__m512 distances = _mm512_mul_ps(d0, d0);
|
554
|
+
distances = _mm512_fmadd_ps(d1, d1, distances);
|
555
|
+
|
556
|
+
// store
|
557
|
+
_mm512_storeu_ps(dis + i, distances);
|
558
|
+
|
559
|
+
y += 32; // move to the next set of 16x2 elements
|
560
|
+
}
|
561
|
+
}
|
562
|
+
|
563
|
+
if (i < ny) {
|
564
|
+
// process leftovers
|
565
|
+
float x0 = x[0];
|
566
|
+
float x1 = x[1];
|
567
|
+
|
568
|
+
for (; i < ny; i++) {
|
569
|
+
float sub0 = x0 - y[0];
|
570
|
+
float sub1 = x1 - y[1];
|
571
|
+
float distance = sub0 * sub0 + sub1 * sub1;
|
572
|
+
|
573
|
+
y += 2;
|
574
|
+
dis[i] = distance;
|
575
|
+
}
|
576
|
+
}
|
577
|
+
}
|
578
|
+
|
579
|
+
#elif defined(__AVX2__)
|
430
580
|
|
431
581
|
template <>
|
432
582
|
void fvec_op_ny_D2<ElementOpIP>(
|
@@ -562,7 +712,137 @@ void fvec_op_ny_D4(float* dis, const float* x, const float* y, size_t ny) {
|
|
562
712
|
}
|
563
713
|
}
|
564
714
|
|
565
|
-
#
|
715
|
+
#if defined(__AVX512F__)
|
716
|
+
|
717
|
+
template <>
|
718
|
+
void fvec_op_ny_D4<ElementOpIP>(
|
719
|
+
float* dis,
|
720
|
+
const float* x,
|
721
|
+
const float* y,
|
722
|
+
size_t ny) {
|
723
|
+
const size_t ny16 = ny / 16;
|
724
|
+
size_t i = 0;
|
725
|
+
|
726
|
+
if (ny16 > 0) {
|
727
|
+
// process 16 D4-vectors per loop.
|
728
|
+
const __m512 m0 = _mm512_set1_ps(x[0]);
|
729
|
+
const __m512 m1 = _mm512_set1_ps(x[1]);
|
730
|
+
const __m512 m2 = _mm512_set1_ps(x[2]);
|
731
|
+
const __m512 m3 = _mm512_set1_ps(x[3]);
|
732
|
+
|
733
|
+
for (i = 0; i < ny16 * 16; i += 16) {
|
734
|
+
// load 16x4 matrix and transpose it in registers.
|
735
|
+
// the typical bottleneck is memory access, so
|
736
|
+
// let's trade instructions for the bandwidth.
|
737
|
+
|
738
|
+
__m512 v0;
|
739
|
+
__m512 v1;
|
740
|
+
__m512 v2;
|
741
|
+
__m512 v3;
|
742
|
+
|
743
|
+
transpose_16x4(
|
744
|
+
_mm512_loadu_ps(y + 0 * 16),
|
745
|
+
_mm512_loadu_ps(y + 1 * 16),
|
746
|
+
_mm512_loadu_ps(y + 2 * 16),
|
747
|
+
_mm512_loadu_ps(y + 3 * 16),
|
748
|
+
v0,
|
749
|
+
v1,
|
750
|
+
v2,
|
751
|
+
v3);
|
752
|
+
|
753
|
+
// compute distances
|
754
|
+
__m512 distances = _mm512_mul_ps(m0, v0);
|
755
|
+
distances = _mm512_fmadd_ps(m1, v1, distances);
|
756
|
+
distances = _mm512_fmadd_ps(m2, v2, distances);
|
757
|
+
distances = _mm512_fmadd_ps(m3, v3, distances);
|
758
|
+
|
759
|
+
// store
|
760
|
+
_mm512_storeu_ps(dis + i, distances);
|
761
|
+
|
762
|
+
y += 64; // move to the next set of 16x4 elements
|
763
|
+
}
|
764
|
+
}
|
765
|
+
|
766
|
+
if (i < ny) {
|
767
|
+
// process leftovers
|
768
|
+
__m128 x0 = _mm_loadu_ps(x);
|
769
|
+
|
770
|
+
for (; i < ny; i++) {
|
771
|
+
__m128 accu = ElementOpIP::op(x0, _mm_loadu_ps(y));
|
772
|
+
y += 4;
|
773
|
+
dis[i] = horizontal_sum(accu);
|
774
|
+
}
|
775
|
+
}
|
776
|
+
}
|
777
|
+
|
778
|
+
template <>
|
779
|
+
void fvec_op_ny_D4<ElementOpL2>(
|
780
|
+
float* dis,
|
781
|
+
const float* x,
|
782
|
+
const float* y,
|
783
|
+
size_t ny) {
|
784
|
+
const size_t ny16 = ny / 16;
|
785
|
+
size_t i = 0;
|
786
|
+
|
787
|
+
if (ny16 > 0) {
|
788
|
+
// process 16 D4-vectors per loop.
|
789
|
+
const __m512 m0 = _mm512_set1_ps(x[0]);
|
790
|
+
const __m512 m1 = _mm512_set1_ps(x[1]);
|
791
|
+
const __m512 m2 = _mm512_set1_ps(x[2]);
|
792
|
+
const __m512 m3 = _mm512_set1_ps(x[3]);
|
793
|
+
|
794
|
+
for (i = 0; i < ny16 * 16; i += 16) {
|
795
|
+
// load 16x4 matrix and transpose it in registers.
|
796
|
+
// the typical bottleneck is memory access, so
|
797
|
+
// let's trade instructions for the bandwidth.
|
798
|
+
|
799
|
+
__m512 v0;
|
800
|
+
__m512 v1;
|
801
|
+
__m512 v2;
|
802
|
+
__m512 v3;
|
803
|
+
|
804
|
+
transpose_16x4(
|
805
|
+
_mm512_loadu_ps(y + 0 * 16),
|
806
|
+
_mm512_loadu_ps(y + 1 * 16),
|
807
|
+
_mm512_loadu_ps(y + 2 * 16),
|
808
|
+
_mm512_loadu_ps(y + 3 * 16),
|
809
|
+
v0,
|
810
|
+
v1,
|
811
|
+
v2,
|
812
|
+
v3);
|
813
|
+
|
814
|
+
// compute differences
|
815
|
+
const __m512 d0 = _mm512_sub_ps(m0, v0);
|
816
|
+
const __m512 d1 = _mm512_sub_ps(m1, v1);
|
817
|
+
const __m512 d2 = _mm512_sub_ps(m2, v2);
|
818
|
+
const __m512 d3 = _mm512_sub_ps(m3, v3);
|
819
|
+
|
820
|
+
// compute squares of differences
|
821
|
+
__m512 distances = _mm512_mul_ps(d0, d0);
|
822
|
+
distances = _mm512_fmadd_ps(d1, d1, distances);
|
823
|
+
distances = _mm512_fmadd_ps(d2, d2, distances);
|
824
|
+
distances = _mm512_fmadd_ps(d3, d3, distances);
|
825
|
+
|
826
|
+
// store
|
827
|
+
_mm512_storeu_ps(dis + i, distances);
|
828
|
+
|
829
|
+
y += 64; // move to the next set of 16x4 elements
|
830
|
+
}
|
831
|
+
}
|
832
|
+
|
833
|
+
if (i < ny) {
|
834
|
+
// process leftovers
|
835
|
+
__m128 x0 = _mm_loadu_ps(x);
|
836
|
+
|
837
|
+
for (; i < ny; i++) {
|
838
|
+
__m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
|
839
|
+
y += 4;
|
840
|
+
dis[i] = horizontal_sum(accu);
|
841
|
+
}
|
842
|
+
}
|
843
|
+
}
|
844
|
+
|
845
|
+
#elif defined(__AVX2__)
|
566
846
|
|
567
847
|
template <>
|
568
848
|
void fvec_op_ny_D4<ElementOpIP>(
|
@@ -710,7 +990,181 @@ void fvec_op_ny_D8(float* dis, const float* x, const float* y, size_t ny) {
|
|
710
990
|
}
|
711
991
|
}
|
712
992
|
|
713
|
-
#
|
993
|
+
#if defined(__AVX512F__)
|
994
|
+
|
995
|
+
template <>
|
996
|
+
void fvec_op_ny_D8<ElementOpIP>(
|
997
|
+
float* dis,
|
998
|
+
const float* x,
|
999
|
+
const float* y,
|
1000
|
+
size_t ny) {
|
1001
|
+
const size_t ny16 = ny / 16;
|
1002
|
+
size_t i = 0;
|
1003
|
+
|
1004
|
+
if (ny16 > 0) {
|
1005
|
+
// process 16 D16-vectors per loop.
|
1006
|
+
const __m512 m0 = _mm512_set1_ps(x[0]);
|
1007
|
+
const __m512 m1 = _mm512_set1_ps(x[1]);
|
1008
|
+
const __m512 m2 = _mm512_set1_ps(x[2]);
|
1009
|
+
const __m512 m3 = _mm512_set1_ps(x[3]);
|
1010
|
+
const __m512 m4 = _mm512_set1_ps(x[4]);
|
1011
|
+
const __m512 m5 = _mm512_set1_ps(x[5]);
|
1012
|
+
const __m512 m6 = _mm512_set1_ps(x[6]);
|
1013
|
+
const __m512 m7 = _mm512_set1_ps(x[7]);
|
1014
|
+
|
1015
|
+
for (i = 0; i < ny16 * 16; i += 16) {
|
1016
|
+
// load 16x8 matrix and transpose it in registers.
|
1017
|
+
// the typical bottleneck is memory access, so
|
1018
|
+
// let's trade instructions for the bandwidth.
|
1019
|
+
|
1020
|
+
__m512 v0;
|
1021
|
+
__m512 v1;
|
1022
|
+
__m512 v2;
|
1023
|
+
__m512 v3;
|
1024
|
+
__m512 v4;
|
1025
|
+
__m512 v5;
|
1026
|
+
__m512 v6;
|
1027
|
+
__m512 v7;
|
1028
|
+
|
1029
|
+
transpose_16x8(
|
1030
|
+
_mm512_loadu_ps(y + 0 * 16),
|
1031
|
+
_mm512_loadu_ps(y + 1 * 16),
|
1032
|
+
_mm512_loadu_ps(y + 2 * 16),
|
1033
|
+
_mm512_loadu_ps(y + 3 * 16),
|
1034
|
+
_mm512_loadu_ps(y + 4 * 16),
|
1035
|
+
_mm512_loadu_ps(y + 5 * 16),
|
1036
|
+
_mm512_loadu_ps(y + 6 * 16),
|
1037
|
+
_mm512_loadu_ps(y + 7 * 16),
|
1038
|
+
v0,
|
1039
|
+
v1,
|
1040
|
+
v2,
|
1041
|
+
v3,
|
1042
|
+
v4,
|
1043
|
+
v5,
|
1044
|
+
v6,
|
1045
|
+
v7);
|
1046
|
+
|
1047
|
+
// compute distances
|
1048
|
+
__m512 distances = _mm512_mul_ps(m0, v0);
|
1049
|
+
distances = _mm512_fmadd_ps(m1, v1, distances);
|
1050
|
+
distances = _mm512_fmadd_ps(m2, v2, distances);
|
1051
|
+
distances = _mm512_fmadd_ps(m3, v3, distances);
|
1052
|
+
distances = _mm512_fmadd_ps(m4, v4, distances);
|
1053
|
+
distances = _mm512_fmadd_ps(m5, v5, distances);
|
1054
|
+
distances = _mm512_fmadd_ps(m6, v6, distances);
|
1055
|
+
distances = _mm512_fmadd_ps(m7, v7, distances);
|
1056
|
+
|
1057
|
+
// store
|
1058
|
+
_mm512_storeu_ps(dis + i, distances);
|
1059
|
+
|
1060
|
+
y += 128; // 16 floats * 8 rows
|
1061
|
+
}
|
1062
|
+
}
|
1063
|
+
|
1064
|
+
if (i < ny) {
|
1065
|
+
// process leftovers
|
1066
|
+
__m256 x0 = _mm256_loadu_ps(x);
|
1067
|
+
|
1068
|
+
for (; i < ny; i++) {
|
1069
|
+
__m256 accu = ElementOpIP::op(x0, _mm256_loadu_ps(y));
|
1070
|
+
y += 8;
|
1071
|
+
dis[i] = horizontal_sum(accu);
|
1072
|
+
}
|
1073
|
+
}
|
1074
|
+
}
|
1075
|
+
|
1076
|
+
template <>
|
1077
|
+
void fvec_op_ny_D8<ElementOpL2>(
|
1078
|
+
float* dis,
|
1079
|
+
const float* x,
|
1080
|
+
const float* y,
|
1081
|
+
size_t ny) {
|
1082
|
+
const size_t ny16 = ny / 16;
|
1083
|
+
size_t i = 0;
|
1084
|
+
|
1085
|
+
if (ny16 > 0) {
|
1086
|
+
// process 16 D16-vectors per loop.
|
1087
|
+
const __m512 m0 = _mm512_set1_ps(x[0]);
|
1088
|
+
const __m512 m1 = _mm512_set1_ps(x[1]);
|
1089
|
+
const __m512 m2 = _mm512_set1_ps(x[2]);
|
1090
|
+
const __m512 m3 = _mm512_set1_ps(x[3]);
|
1091
|
+
const __m512 m4 = _mm512_set1_ps(x[4]);
|
1092
|
+
const __m512 m5 = _mm512_set1_ps(x[5]);
|
1093
|
+
const __m512 m6 = _mm512_set1_ps(x[6]);
|
1094
|
+
const __m512 m7 = _mm512_set1_ps(x[7]);
|
1095
|
+
|
1096
|
+
for (i = 0; i < ny16 * 16; i += 16) {
|
1097
|
+
// load 16x8 matrix and transpose it in registers.
|
1098
|
+
// the typical bottleneck is memory access, so
|
1099
|
+
// let's trade instructions for the bandwidth.
|
1100
|
+
|
1101
|
+
__m512 v0;
|
1102
|
+
__m512 v1;
|
1103
|
+
__m512 v2;
|
1104
|
+
__m512 v3;
|
1105
|
+
__m512 v4;
|
1106
|
+
__m512 v5;
|
1107
|
+
__m512 v6;
|
1108
|
+
__m512 v7;
|
1109
|
+
|
1110
|
+
transpose_16x8(
|
1111
|
+
_mm512_loadu_ps(y + 0 * 16),
|
1112
|
+
_mm512_loadu_ps(y + 1 * 16),
|
1113
|
+
_mm512_loadu_ps(y + 2 * 16),
|
1114
|
+
_mm512_loadu_ps(y + 3 * 16),
|
1115
|
+
_mm512_loadu_ps(y + 4 * 16),
|
1116
|
+
_mm512_loadu_ps(y + 5 * 16),
|
1117
|
+
_mm512_loadu_ps(y + 6 * 16),
|
1118
|
+
_mm512_loadu_ps(y + 7 * 16),
|
1119
|
+
v0,
|
1120
|
+
v1,
|
1121
|
+
v2,
|
1122
|
+
v3,
|
1123
|
+
v4,
|
1124
|
+
v5,
|
1125
|
+
v6,
|
1126
|
+
v7);
|
1127
|
+
|
1128
|
+
// compute differences
|
1129
|
+
const __m512 d0 = _mm512_sub_ps(m0, v0);
|
1130
|
+
const __m512 d1 = _mm512_sub_ps(m1, v1);
|
1131
|
+
const __m512 d2 = _mm512_sub_ps(m2, v2);
|
1132
|
+
const __m512 d3 = _mm512_sub_ps(m3, v3);
|
1133
|
+
const __m512 d4 = _mm512_sub_ps(m4, v4);
|
1134
|
+
const __m512 d5 = _mm512_sub_ps(m5, v5);
|
1135
|
+
const __m512 d6 = _mm512_sub_ps(m6, v6);
|
1136
|
+
const __m512 d7 = _mm512_sub_ps(m7, v7);
|
1137
|
+
|
1138
|
+
// compute squares of differences
|
1139
|
+
__m512 distances = _mm512_mul_ps(d0, d0);
|
1140
|
+
distances = _mm512_fmadd_ps(d1, d1, distances);
|
1141
|
+
distances = _mm512_fmadd_ps(d2, d2, distances);
|
1142
|
+
distances = _mm512_fmadd_ps(d3, d3, distances);
|
1143
|
+
distances = _mm512_fmadd_ps(d4, d4, distances);
|
1144
|
+
distances = _mm512_fmadd_ps(d5, d5, distances);
|
1145
|
+
distances = _mm512_fmadd_ps(d6, d6, distances);
|
1146
|
+
distances = _mm512_fmadd_ps(d7, d7, distances);
|
1147
|
+
|
1148
|
+
// store
|
1149
|
+
_mm512_storeu_ps(dis + i, distances);
|
1150
|
+
|
1151
|
+
y += 128; // 16 floats * 8 rows
|
1152
|
+
}
|
1153
|
+
}
|
1154
|
+
|
1155
|
+
if (i < ny) {
|
1156
|
+
// process leftovers
|
1157
|
+
__m256 x0 = _mm256_loadu_ps(x);
|
1158
|
+
|
1159
|
+
for (; i < ny; i++) {
|
1160
|
+
__m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y));
|
1161
|
+
y += 8;
|
1162
|
+
dis[i] = horizontal_sum(accu);
|
1163
|
+
}
|
1164
|
+
}
|
1165
|
+
}
|
1166
|
+
|
1167
|
+
#elif defined(__AVX2__)
|
714
1168
|
|
715
1169
|
template <>
|
716
1170
|
void fvec_op_ny_D8<ElementOpIP>(
|
@@ -955,7 +1409,83 @@ void fvec_inner_products_ny(
|
|
955
1409
|
#undef DISPATCH
|
956
1410
|
}
|
957
1411
|
|
958
|
-
#
|
1412
|
+
#if defined(__AVX512F__)
|
1413
|
+
|
1414
|
+
template <size_t DIM>
|
1415
|
+
void fvec_L2sqr_ny_y_transposed_D(
|
1416
|
+
float* distances,
|
1417
|
+
const float* x,
|
1418
|
+
const float* y,
|
1419
|
+
const float* y_sqlen,
|
1420
|
+
const size_t d_offset,
|
1421
|
+
size_t ny) {
|
1422
|
+
// current index being processed
|
1423
|
+
size_t i = 0;
|
1424
|
+
|
1425
|
+
// squared length of x
|
1426
|
+
float x_sqlen = 0;
|
1427
|
+
for (size_t j = 0; j < DIM; j++) {
|
1428
|
+
x_sqlen += x[j] * x[j];
|
1429
|
+
}
|
1430
|
+
|
1431
|
+
// process 16 vectors per loop
|
1432
|
+
const size_t ny16 = ny / 16;
|
1433
|
+
|
1434
|
+
if (ny16 > 0) {
|
1435
|
+
// m[i] = (2 * x[i], ... 2 * x[i])
|
1436
|
+
__m512 m[DIM];
|
1437
|
+
for (size_t j = 0; j < DIM; j++) {
|
1438
|
+
m[j] = _mm512_set1_ps(x[j]);
|
1439
|
+
m[j] = _mm512_add_ps(m[j], m[j]); // m[j] = 2 * x[j]
|
1440
|
+
}
|
1441
|
+
|
1442
|
+
__m512 x_sqlen_ymm = _mm512_set1_ps(x_sqlen);
|
1443
|
+
|
1444
|
+
for (; i < ny16 * 16; i += 16) {
|
1445
|
+
// Load vectors for 16 dimensions
|
1446
|
+
__m512 v[DIM];
|
1447
|
+
for (size_t j = 0; j < DIM; j++) {
|
1448
|
+
v[j] = _mm512_loadu_ps(y + j * d_offset);
|
1449
|
+
}
|
1450
|
+
|
1451
|
+
// Compute dot products
|
1452
|
+
__m512 dp = _mm512_fnmadd_ps(m[0], v[0], x_sqlen_ymm);
|
1453
|
+
for (size_t j = 1; j < DIM; j++) {
|
1454
|
+
dp = _mm512_fnmadd_ps(m[j], v[j], dp);
|
1455
|
+
}
|
1456
|
+
|
1457
|
+
// Compute y^2 - (2 * x, y) + x^2
|
1458
|
+
__m512 distances_v = _mm512_add_ps(_mm512_loadu_ps(y_sqlen), dp);
|
1459
|
+
|
1460
|
+
_mm512_storeu_ps(distances + i, distances_v);
|
1461
|
+
|
1462
|
+
// Scroll y and y_sqlen forward
|
1463
|
+
y += 16;
|
1464
|
+
y_sqlen += 16;
|
1465
|
+
}
|
1466
|
+
}
|
1467
|
+
|
1468
|
+
if (i < ny) {
|
1469
|
+
// Process leftovers
|
1470
|
+
for (; i < ny; i++) {
|
1471
|
+
float dp = 0;
|
1472
|
+
for (size_t j = 0; j < DIM; j++) {
|
1473
|
+
dp += x[j] * y[j * d_offset];
|
1474
|
+
}
|
1475
|
+
|
1476
|
+
// Compute y^2 - 2 * (x, y), which is sufficient for looking for the
|
1477
|
+
// lowest distance.
|
1478
|
+
const float distance = y_sqlen[0] - 2 * dp + x_sqlen;
|
1479
|
+
distances[i] = distance;
|
1480
|
+
|
1481
|
+
y += 1;
|
1482
|
+
y_sqlen += 1;
|
1483
|
+
}
|
1484
|
+
}
|
1485
|
+
}
|
1486
|
+
|
1487
|
+
#elif defined(__AVX2__)
|
1488
|
+
|
959
1489
|
template <size_t DIM>
|
960
1490
|
void fvec_L2sqr_ny_y_transposed_D(
|
961
1491
|
float* distances,
|
@@ -1014,58 +1544,368 @@ void fvec_L2sqr_ny_y_transposed_D(
|
|
1014
1544
|
}
|
1015
1545
|
|
1016
1546
|
if (i < ny) {
|
1017
|
-
// process leftovers
|
1018
|
-
for (; i < ny; i++) {
|
1019
|
-
float dp = 0;
|
1020
|
-
for (size_t j = 0; j < DIM; j++) {
|
1021
|
-
dp += x[j] * y[j * d_offset];
|
1022
|
-
}
|
1547
|
+
// process leftovers
|
1548
|
+
for (; i < ny; i++) {
|
1549
|
+
float dp = 0;
|
1550
|
+
for (size_t j = 0; j < DIM; j++) {
|
1551
|
+
dp += x[j] * y[j * d_offset];
|
1552
|
+
}
|
1553
|
+
|
1554
|
+
// compute y^2 - 2 * (x, y), which is sufficient for looking for the
|
1555
|
+
// lowest distance.
|
1556
|
+
const float distance = y_sqlen[0] - 2 * dp + x_sqlen;
|
1557
|
+
distances[i] = distance;
|
1558
|
+
|
1559
|
+
y += 1;
|
1560
|
+
y_sqlen += 1;
|
1561
|
+
}
|
1562
|
+
}
|
1563
|
+
}
|
1564
|
+
|
1565
|
+
#endif
|
1566
|
+
|
1567
|
+
void fvec_L2sqr_ny_transposed(
|
1568
|
+
float* dis,
|
1569
|
+
const float* x,
|
1570
|
+
const float* y,
|
1571
|
+
const float* y_sqlen,
|
1572
|
+
size_t d,
|
1573
|
+
size_t d_offset,
|
1574
|
+
size_t ny) {
|
1575
|
+
// optimized for a few special cases
|
1576
|
+
|
1577
|
+
#ifdef __AVX2__
|
1578
|
+
#define DISPATCH(dval) \
|
1579
|
+
case dval: \
|
1580
|
+
return fvec_L2sqr_ny_y_transposed_D<dval>( \
|
1581
|
+
dis, x, y, y_sqlen, d_offset, ny);
|
1582
|
+
|
1583
|
+
switch (d) {
|
1584
|
+
DISPATCH(1)
|
1585
|
+
DISPATCH(2)
|
1586
|
+
DISPATCH(4)
|
1587
|
+
DISPATCH(8)
|
1588
|
+
default:
|
1589
|
+
return fvec_L2sqr_ny_y_transposed_ref(
|
1590
|
+
dis, x, y, y_sqlen, d, d_offset, ny);
|
1591
|
+
}
|
1592
|
+
#undef DISPATCH
|
1593
|
+
#else
|
1594
|
+
// non-AVX2 case
|
1595
|
+
return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny);
|
1596
|
+
#endif
|
1597
|
+
}
|
1598
|
+
|
1599
|
+
#if defined(__AVX512F__)
|
1600
|
+
|
1601
|
+
size_t fvec_L2sqr_ny_nearest_D2(
|
1602
|
+
float* distances_tmp_buffer,
|
1603
|
+
const float* x,
|
1604
|
+
const float* y,
|
1605
|
+
size_t ny) {
|
1606
|
+
// this implementation does not use distances_tmp_buffer.
|
1607
|
+
|
1608
|
+
size_t i = 0;
|
1609
|
+
float current_min_distance = HUGE_VALF;
|
1610
|
+
size_t current_min_index = 0;
|
1611
|
+
|
1612
|
+
const size_t ny16 = ny / 16;
|
1613
|
+
if (ny16 > 0) {
|
1614
|
+
_mm_prefetch((const char*)y, _MM_HINT_T0);
|
1615
|
+
_mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
|
1616
|
+
|
1617
|
+
__m512 min_distances = _mm512_set1_ps(HUGE_VALF);
|
1618
|
+
__m512i min_indices = _mm512_set1_epi32(0);
|
1619
|
+
|
1620
|
+
__m512i current_indices = _mm512_setr_epi32(
|
1621
|
+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
|
1622
|
+
const __m512i indices_increment = _mm512_set1_epi32(16);
|
1623
|
+
|
1624
|
+
const __m512 m0 = _mm512_set1_ps(x[0]);
|
1625
|
+
const __m512 m1 = _mm512_set1_ps(x[1]);
|
1626
|
+
|
1627
|
+
for (; i < ny16 * 16; i += 16) {
|
1628
|
+
_mm_prefetch((const char*)(y + 64), _MM_HINT_T0);
|
1629
|
+
|
1630
|
+
__m512 v0;
|
1631
|
+
__m512 v1;
|
1632
|
+
|
1633
|
+
transpose_16x2(
|
1634
|
+
_mm512_loadu_ps(y + 0 * 16),
|
1635
|
+
_mm512_loadu_ps(y + 1 * 16),
|
1636
|
+
v0,
|
1637
|
+
v1);
|
1638
|
+
|
1639
|
+
const __m512 d0 = _mm512_sub_ps(m0, v0);
|
1640
|
+
const __m512 d1 = _mm512_sub_ps(m1, v1);
|
1641
|
+
|
1642
|
+
__m512 distances = _mm512_mul_ps(d0, d0);
|
1643
|
+
distances = _mm512_fmadd_ps(d1, d1, distances);
|
1644
|
+
|
1645
|
+
__mmask16 comparison =
|
1646
|
+
_mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS);
|
1647
|
+
|
1648
|
+
min_distances = _mm512_min_ps(distances, min_distances);
|
1649
|
+
min_indices = _mm512_mask_blend_epi32(
|
1650
|
+
comparison, min_indices, current_indices);
|
1651
|
+
|
1652
|
+
current_indices =
|
1653
|
+
_mm512_add_epi32(current_indices, indices_increment);
|
1654
|
+
|
1655
|
+
y += 32;
|
1656
|
+
}
|
1657
|
+
|
1658
|
+
alignas(64) float min_distances_scalar[16];
|
1659
|
+
alignas(64) uint32_t min_indices_scalar[16];
|
1660
|
+
_mm512_store_ps(min_distances_scalar, min_distances);
|
1661
|
+
_mm512_store_epi32(min_indices_scalar, min_indices);
|
1662
|
+
|
1663
|
+
for (size_t j = 0; j < 16; j++) {
|
1664
|
+
if (current_min_distance > min_distances_scalar[j]) {
|
1665
|
+
current_min_distance = min_distances_scalar[j];
|
1666
|
+
current_min_index = min_indices_scalar[j];
|
1667
|
+
}
|
1668
|
+
}
|
1669
|
+
}
|
1670
|
+
|
1671
|
+
if (i < ny) {
|
1672
|
+
float x0 = x[0];
|
1673
|
+
float x1 = x[1];
|
1674
|
+
|
1675
|
+
for (; i < ny; i++) {
|
1676
|
+
float sub0 = x0 - y[0];
|
1677
|
+
float sub1 = x1 - y[1];
|
1678
|
+
float distance = sub0 * sub0 + sub1 * sub1;
|
1679
|
+
|
1680
|
+
y += 2;
|
1681
|
+
|
1682
|
+
if (current_min_distance > distance) {
|
1683
|
+
current_min_distance = distance;
|
1684
|
+
current_min_index = i;
|
1685
|
+
}
|
1686
|
+
}
|
1687
|
+
}
|
1688
|
+
|
1689
|
+
return current_min_index;
|
1690
|
+
}
|
1691
|
+
|
1692
|
+
size_t fvec_L2sqr_ny_nearest_D4(
|
1693
|
+
float* distances_tmp_buffer,
|
1694
|
+
const float* x,
|
1695
|
+
const float* y,
|
1696
|
+
size_t ny) {
|
1697
|
+
// this implementation does not use distances_tmp_buffer.
|
1698
|
+
|
1699
|
+
size_t i = 0;
|
1700
|
+
float current_min_distance = HUGE_VALF;
|
1701
|
+
size_t current_min_index = 0;
|
1702
|
+
|
1703
|
+
const size_t ny16 = ny / 16;
|
1704
|
+
|
1705
|
+
if (ny16 > 0) {
|
1706
|
+
__m512 min_distances = _mm512_set1_ps(HUGE_VALF);
|
1707
|
+
__m512i min_indices = _mm512_set1_epi32(0);
|
1708
|
+
|
1709
|
+
__m512i current_indices = _mm512_setr_epi32(
|
1710
|
+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
|
1711
|
+
const __m512i indices_increment = _mm512_set1_epi32(16);
|
1712
|
+
|
1713
|
+
const __m512 m0 = _mm512_set1_ps(x[0]);
|
1714
|
+
const __m512 m1 = _mm512_set1_ps(x[1]);
|
1715
|
+
const __m512 m2 = _mm512_set1_ps(x[2]);
|
1716
|
+
const __m512 m3 = _mm512_set1_ps(x[3]);
|
1717
|
+
|
1718
|
+
for (; i < ny16 * 16; i += 16) {
|
1719
|
+
__m512 v0;
|
1720
|
+
__m512 v1;
|
1721
|
+
__m512 v2;
|
1722
|
+
__m512 v3;
|
1723
|
+
|
1724
|
+
transpose_16x4(
|
1725
|
+
_mm512_loadu_ps(y + 0 * 16),
|
1726
|
+
_mm512_loadu_ps(y + 1 * 16),
|
1727
|
+
_mm512_loadu_ps(y + 2 * 16),
|
1728
|
+
_mm512_loadu_ps(y + 3 * 16),
|
1729
|
+
v0,
|
1730
|
+
v1,
|
1731
|
+
v2,
|
1732
|
+
v3);
|
1733
|
+
|
1734
|
+
const __m512 d0 = _mm512_sub_ps(m0, v0);
|
1735
|
+
const __m512 d1 = _mm512_sub_ps(m1, v1);
|
1736
|
+
const __m512 d2 = _mm512_sub_ps(m2, v2);
|
1737
|
+
const __m512 d3 = _mm512_sub_ps(m3, v3);
|
1738
|
+
|
1739
|
+
__m512 distances = _mm512_mul_ps(d0, d0);
|
1740
|
+
distances = _mm512_fmadd_ps(d1, d1, distances);
|
1741
|
+
distances = _mm512_fmadd_ps(d2, d2, distances);
|
1742
|
+
distances = _mm512_fmadd_ps(d3, d3, distances);
|
1743
|
+
|
1744
|
+
__mmask16 comparison =
|
1745
|
+
_mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS);
|
1746
|
+
|
1747
|
+
min_distances = _mm512_min_ps(distances, min_distances);
|
1748
|
+
min_indices = _mm512_mask_blend_epi32(
|
1749
|
+
comparison, min_indices, current_indices);
|
1750
|
+
|
1751
|
+
current_indices =
|
1752
|
+
_mm512_add_epi32(current_indices, indices_increment);
|
1753
|
+
|
1754
|
+
y += 64;
|
1755
|
+
}
|
1756
|
+
|
1757
|
+
alignas(64) float min_distances_scalar[16];
|
1758
|
+
alignas(64) uint32_t min_indices_scalar[16];
|
1759
|
+
_mm512_store_ps(min_distances_scalar, min_distances);
|
1760
|
+
_mm512_store_epi32(min_indices_scalar, min_indices);
|
1761
|
+
|
1762
|
+
for (size_t j = 0; j < 16; j++) {
|
1763
|
+
if (current_min_distance > min_distances_scalar[j]) {
|
1764
|
+
current_min_distance = min_distances_scalar[j];
|
1765
|
+
current_min_index = min_indices_scalar[j];
|
1766
|
+
}
|
1767
|
+
}
|
1768
|
+
}
|
1769
|
+
|
1770
|
+
if (i < ny) {
|
1771
|
+
__m128 x0 = _mm_loadu_ps(x);
|
1772
|
+
|
1773
|
+
for (; i < ny; i++) {
|
1774
|
+
__m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
|
1775
|
+
y += 4;
|
1776
|
+
const float distance = horizontal_sum(accu);
|
1777
|
+
|
1778
|
+
if (current_min_distance > distance) {
|
1779
|
+
current_min_distance = distance;
|
1780
|
+
current_min_index = i;
|
1781
|
+
}
|
1782
|
+
}
|
1783
|
+
}
|
1784
|
+
|
1785
|
+
return current_min_index;
|
1786
|
+
}
|
1787
|
+
|
1788
|
+
size_t fvec_L2sqr_ny_nearest_D8(
|
1789
|
+
float* distances_tmp_buffer,
|
1790
|
+
const float* x,
|
1791
|
+
const float* y,
|
1792
|
+
size_t ny) {
|
1793
|
+
// this implementation does not use distances_tmp_buffer.
|
1794
|
+
|
1795
|
+
size_t i = 0;
|
1796
|
+
float current_min_distance = HUGE_VALF;
|
1797
|
+
size_t current_min_index = 0;
|
1798
|
+
|
1799
|
+
const size_t ny16 = ny / 16;
|
1800
|
+
if (ny16 > 0) {
|
1801
|
+
__m512 min_distances = _mm512_set1_ps(HUGE_VALF);
|
1802
|
+
__m512i min_indices = _mm512_set1_epi32(0);
|
1803
|
+
|
1804
|
+
__m512i current_indices = _mm512_setr_epi32(
|
1805
|
+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
|
1806
|
+
const __m512i indices_increment = _mm512_set1_epi32(16);
|
1807
|
+
|
1808
|
+
const __m512 m0 = _mm512_set1_ps(x[0]);
|
1809
|
+
const __m512 m1 = _mm512_set1_ps(x[1]);
|
1810
|
+
const __m512 m2 = _mm512_set1_ps(x[2]);
|
1811
|
+
const __m512 m3 = _mm512_set1_ps(x[3]);
|
1812
|
+
|
1813
|
+
const __m512 m4 = _mm512_set1_ps(x[4]);
|
1814
|
+
const __m512 m5 = _mm512_set1_ps(x[5]);
|
1815
|
+
const __m512 m6 = _mm512_set1_ps(x[6]);
|
1816
|
+
const __m512 m7 = _mm512_set1_ps(x[7]);
|
1817
|
+
|
1818
|
+
for (; i < ny16 * 16; i += 16) {
|
1819
|
+
__m512 v0;
|
1820
|
+
__m512 v1;
|
1821
|
+
__m512 v2;
|
1822
|
+
__m512 v3;
|
1823
|
+
__m512 v4;
|
1824
|
+
__m512 v5;
|
1825
|
+
__m512 v6;
|
1826
|
+
__m512 v7;
|
1827
|
+
|
1828
|
+
transpose_16x8(
|
1829
|
+
_mm512_loadu_ps(y + 0 * 16),
|
1830
|
+
_mm512_loadu_ps(y + 1 * 16),
|
1831
|
+
_mm512_loadu_ps(y + 2 * 16),
|
1832
|
+
_mm512_loadu_ps(y + 3 * 16),
|
1833
|
+
_mm512_loadu_ps(y + 4 * 16),
|
1834
|
+
_mm512_loadu_ps(y + 5 * 16),
|
1835
|
+
_mm512_loadu_ps(y + 6 * 16),
|
1836
|
+
_mm512_loadu_ps(y + 7 * 16),
|
1837
|
+
v0,
|
1838
|
+
v1,
|
1839
|
+
v2,
|
1840
|
+
v3,
|
1841
|
+
v4,
|
1842
|
+
v5,
|
1843
|
+
v6,
|
1844
|
+
v7);
|
1845
|
+
|
1846
|
+
const __m512 d0 = _mm512_sub_ps(m0, v0);
|
1847
|
+
const __m512 d1 = _mm512_sub_ps(m1, v1);
|
1848
|
+
const __m512 d2 = _mm512_sub_ps(m2, v2);
|
1849
|
+
const __m512 d3 = _mm512_sub_ps(m3, v3);
|
1850
|
+
const __m512 d4 = _mm512_sub_ps(m4, v4);
|
1851
|
+
const __m512 d5 = _mm512_sub_ps(m5, v5);
|
1852
|
+
const __m512 d6 = _mm512_sub_ps(m6, v6);
|
1853
|
+
const __m512 d7 = _mm512_sub_ps(m7, v7);
|
1854
|
+
|
1855
|
+
__m512 distances = _mm512_mul_ps(d0, d0);
|
1856
|
+
distances = _mm512_fmadd_ps(d1, d1, distances);
|
1857
|
+
distances = _mm512_fmadd_ps(d2, d2, distances);
|
1858
|
+
distances = _mm512_fmadd_ps(d3, d3, distances);
|
1859
|
+
distances = _mm512_fmadd_ps(d4, d4, distances);
|
1860
|
+
distances = _mm512_fmadd_ps(d5, d5, distances);
|
1861
|
+
distances = _mm512_fmadd_ps(d6, d6, distances);
|
1862
|
+
distances = _mm512_fmadd_ps(d7, d7, distances);
|
1863
|
+
|
1864
|
+
__mmask16 comparison =
|
1865
|
+
_mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS);
|
1866
|
+
|
1867
|
+
min_distances = _mm512_min_ps(distances, min_distances);
|
1868
|
+
min_indices = _mm512_mask_blend_epi32(
|
1869
|
+
comparison, min_indices, current_indices);
|
1870
|
+
|
1871
|
+
current_indices =
|
1872
|
+
_mm512_add_epi32(current_indices, indices_increment);
|
1873
|
+
|
1874
|
+
y += 128;
|
1875
|
+
}
|
1876
|
+
|
1877
|
+
alignas(64) float min_distances_scalar[16];
|
1878
|
+
alignas(64) uint32_t min_indices_scalar[16];
|
1879
|
+
_mm512_store_ps(min_distances_scalar, min_distances);
|
1880
|
+
_mm512_store_epi32(min_indices_scalar, min_indices);
|
1881
|
+
|
1882
|
+
for (size_t j = 0; j < 16; j++) {
|
1883
|
+
if (current_min_distance > min_distances_scalar[j]) {
|
1884
|
+
current_min_distance = min_distances_scalar[j];
|
1885
|
+
current_min_index = min_indices_scalar[j];
|
1886
|
+
}
|
1887
|
+
}
|
1888
|
+
}
|
1889
|
+
|
1890
|
+
if (i < ny) {
|
1891
|
+
__m256 x0 = _mm256_loadu_ps(x);
|
1023
1892
|
|
1024
|
-
|
1025
|
-
|
1026
|
-
|
1027
|
-
|
1893
|
+
for (; i < ny; i++) {
|
1894
|
+
__m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y));
|
1895
|
+
y += 8;
|
1896
|
+
const float distance = horizontal_sum(accu);
|
1028
1897
|
|
1029
|
-
|
1030
|
-
|
1898
|
+
if (current_min_distance > distance) {
|
1899
|
+
current_min_distance = distance;
|
1900
|
+
current_min_index = i;
|
1901
|
+
}
|
1031
1902
|
}
|
1032
1903
|
}
|
1033
|
-
}
|
1034
|
-
#endif
|
1035
|
-
|
1036
|
-
void fvec_L2sqr_ny_transposed(
|
1037
|
-
float* dis,
|
1038
|
-
const float* x,
|
1039
|
-
const float* y,
|
1040
|
-
const float* y_sqlen,
|
1041
|
-
size_t d,
|
1042
|
-
size_t d_offset,
|
1043
|
-
size_t ny) {
|
1044
|
-
// optimized for a few special cases
|
1045
|
-
|
1046
|
-
#ifdef __AVX2__
|
1047
|
-
#define DISPATCH(dval) \
|
1048
|
-
case dval: \
|
1049
|
-
return fvec_L2sqr_ny_y_transposed_D<dval>( \
|
1050
|
-
dis, x, y, y_sqlen, d_offset, ny);
|
1051
1904
|
|
1052
|
-
|
1053
|
-
DISPATCH(1)
|
1054
|
-
DISPATCH(2)
|
1055
|
-
DISPATCH(4)
|
1056
|
-
DISPATCH(8)
|
1057
|
-
default:
|
1058
|
-
return fvec_L2sqr_ny_y_transposed_ref(
|
1059
|
-
dis, x, y, y_sqlen, d, d_offset, ny);
|
1060
|
-
}
|
1061
|
-
#undef DISPATCH
|
1062
|
-
#else
|
1063
|
-
// non-AVX2 case
|
1064
|
-
return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny);
|
1065
|
-
#endif
|
1905
|
+
return current_min_index;
|
1066
1906
|
}
|
1067
1907
|
|
1068
|
-
#
|
1908
|
+
#elif defined(__AVX2__)
|
1069
1909
|
|
1070
1910
|
size_t fvec_L2sqr_ny_nearest_D2(
|
1071
1911
|
float* distances_tmp_buffer,
|
@@ -1476,7 +2316,123 @@ size_t fvec_L2sqr_ny_nearest(
|
|
1476
2316
|
#undef DISPATCH
|
1477
2317
|
}
|
1478
2318
|
|
1479
|
-
#
|
2319
|
+
#if defined(__AVX512F__)
|
2320
|
+
|
2321
|
+
template <size_t DIM>
|
2322
|
+
size_t fvec_L2sqr_ny_nearest_y_transposed_D(
|
2323
|
+
float* distances_tmp_buffer,
|
2324
|
+
const float* x,
|
2325
|
+
const float* y,
|
2326
|
+
const float* y_sqlen,
|
2327
|
+
const size_t d_offset,
|
2328
|
+
size_t ny) {
|
2329
|
+
// This implementation does not use distances_tmp_buffer.
|
2330
|
+
|
2331
|
+
// Current index being processed
|
2332
|
+
size_t i = 0;
|
2333
|
+
|
2334
|
+
// Min distance and the index of the closest vector so far
|
2335
|
+
float current_min_distance = HUGE_VALF;
|
2336
|
+
size_t current_min_index = 0;
|
2337
|
+
|
2338
|
+
// Process 16 vectors per loop
|
2339
|
+
const size_t ny16 = ny / 16;
|
2340
|
+
|
2341
|
+
if (ny16 > 0) {
|
2342
|
+
// Track min distance and the closest vector independently
|
2343
|
+
// for each of 16 AVX-512 components.
|
2344
|
+
__m512 min_distances = _mm512_set1_ps(HUGE_VALF);
|
2345
|
+
__m512i min_indices = _mm512_set1_epi32(0);
|
2346
|
+
|
2347
|
+
__m512i current_indices = _mm512_setr_epi32(
|
2348
|
+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
|
2349
|
+
const __m512i indices_increment = _mm512_set1_epi32(16);
|
2350
|
+
|
2351
|
+
// m[i] = (2 * x[i], ... 2 * x[i])
|
2352
|
+
__m512 m[DIM];
|
2353
|
+
for (size_t j = 0; j < DIM; j++) {
|
2354
|
+
m[j] = _mm512_set1_ps(x[j]);
|
2355
|
+
m[j] = _mm512_add_ps(m[j], m[j]);
|
2356
|
+
}
|
2357
|
+
|
2358
|
+
for (; i < ny16 * 16; i += 16) {
|
2359
|
+
// Compute dot products
|
2360
|
+
const __m512 v0 = _mm512_loadu_ps(y + 0 * d_offset);
|
2361
|
+
__m512 dp = _mm512_mul_ps(m[0], v0);
|
2362
|
+
for (size_t j = 1; j < DIM; j++) {
|
2363
|
+
const __m512 vj = _mm512_loadu_ps(y + j * d_offset);
|
2364
|
+
dp = _mm512_fmadd_ps(m[j], vj, dp);
|
2365
|
+
}
|
2366
|
+
|
2367
|
+
// Compute y^2 - (2 * x, y), which is sufficient for looking for the
|
2368
|
+
// lowest distance.
|
2369
|
+
// x^2 is the constant that can be avoided.
|
2370
|
+
const __m512 distances =
|
2371
|
+
_mm512_sub_ps(_mm512_loadu_ps(y_sqlen), dp);
|
2372
|
+
|
2373
|
+
// Compare the new distances to the min distances
|
2374
|
+
__mmask16 comparison =
|
2375
|
+
_mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS);
|
2376
|
+
|
2377
|
+
// Update min distances and indices with closest vectors if needed
|
2378
|
+
min_distances =
|
2379
|
+
_mm512_mask_blend_ps(comparison, distances, min_distances);
|
2380
|
+
min_indices = _mm512_castps_si512(_mm512_mask_blend_ps(
|
2381
|
+
comparison,
|
2382
|
+
_mm512_castsi512_ps(current_indices),
|
2383
|
+
_mm512_castsi512_ps(min_indices)));
|
2384
|
+
|
2385
|
+
// Update current indices values. Basically, +16 to each of the 16
|
2386
|
+
// AVX-512 components.
|
2387
|
+
current_indices =
|
2388
|
+
_mm512_add_epi32(current_indices, indices_increment);
|
2389
|
+
|
2390
|
+
// Scroll y and y_sqlen forward.
|
2391
|
+
y += 16;
|
2392
|
+
y_sqlen += 16;
|
2393
|
+
}
|
2394
|
+
|
2395
|
+
// Dump values and find the minimum distance / minimum index
|
2396
|
+
float min_distances_scalar[16];
|
2397
|
+
uint32_t min_indices_scalar[16];
|
2398
|
+
_mm512_storeu_ps(min_distances_scalar, min_distances);
|
2399
|
+
_mm512_storeu_si512((__m512i*)(min_indices_scalar), min_indices);
|
2400
|
+
|
2401
|
+
for (size_t j = 0; j < 16; j++) {
|
2402
|
+
if (current_min_distance > min_distances_scalar[j]) {
|
2403
|
+
current_min_distance = min_distances_scalar[j];
|
2404
|
+
current_min_index = min_indices_scalar[j];
|
2405
|
+
}
|
2406
|
+
}
|
2407
|
+
}
|
2408
|
+
|
2409
|
+
if (i < ny) {
|
2410
|
+
// Process leftovers
|
2411
|
+
for (; i < ny; i++) {
|
2412
|
+
float dp = 0;
|
2413
|
+
for (size_t j = 0; j < DIM; j++) {
|
2414
|
+
dp += x[j] * y[j * d_offset];
|
2415
|
+
}
|
2416
|
+
|
2417
|
+
// Compute y^2 - 2 * (x, y), which is sufficient for looking for the
|
2418
|
+
// lowest distance.
|
2419
|
+
const float distance = y_sqlen[0] - 2 * dp;
|
2420
|
+
|
2421
|
+
if (current_min_distance > distance) {
|
2422
|
+
current_min_distance = distance;
|
2423
|
+
current_min_index = i;
|
2424
|
+
}
|
2425
|
+
|
2426
|
+
y += 1;
|
2427
|
+
y_sqlen += 1;
|
2428
|
+
}
|
2429
|
+
}
|
2430
|
+
|
2431
|
+
return current_min_index;
|
2432
|
+
}
|
2433
|
+
|
2434
|
+
#elif defined(__AVX2__)
|
2435
|
+
|
1480
2436
|
template <size_t DIM>
|
1481
2437
|
size_t fvec_L2sqr_ny_nearest_y_transposed_D(
|
1482
2438
|
float* distances_tmp_buffer,
|
@@ -1592,6 +2548,7 @@ size_t fvec_L2sqr_ny_nearest_y_transposed_D(
|
|
1592
2548
|
|
1593
2549
|
return current_min_index;
|
1594
2550
|
}
|
2551
|
+
|
1595
2552
|
#endif
|
1596
2553
|
|
1597
2554
|
size_t fvec_L2sqr_ny_nearest_y_transposed(
|
@@ -1632,6 +2589,7 @@ size_t fvec_L2sqr_ny_nearest_y_transposed(
|
|
1632
2589
|
|
1633
2590
|
float fvec_L1(const float* x, const float* y, size_t d) {
|
1634
2591
|
__m256 msum1 = _mm256_setzero_ps();
|
2592
|
+
// signmask used for absolute value
|
1635
2593
|
__m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL));
|
1636
2594
|
|
1637
2595
|
while (d >= 8) {
|
@@ -1639,7 +2597,9 @@ float fvec_L1(const float* x, const float* y, size_t d) {
|
|
1639
2597
|
x += 8;
|
1640
2598
|
__m256 my = _mm256_loadu_ps(y);
|
1641
2599
|
y += 8;
|
2600
|
+
// subtract
|
1642
2601
|
const __m256 a_m_b = _mm256_sub_ps(mx, my);
|
2602
|
+
// find sum of absolute value of distances (manhattan distance)
|
1643
2603
|
msum1 = _mm256_add_ps(msum1, _mm256_and_ps(signmask, a_m_b));
|
1644
2604
|
d -= 8;
|
1645
2605
|
}
|
@@ -1672,6 +2632,7 @@ float fvec_L1(const float* x, const float* y, size_t d) {
|
|
1672
2632
|
|
1673
2633
|
float fvec_Linf(const float* x, const float* y, size_t d) {
|
1674
2634
|
__m256 msum1 = _mm256_setzero_ps();
|
2635
|
+
// signmask used for absolute value
|
1675
2636
|
__m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL));
|
1676
2637
|
|
1677
2638
|
while (d >= 8) {
|
@@ -1679,7 +2640,9 @@ float fvec_Linf(const float* x, const float* y, size_t d) {
|
|
1679
2640
|
x += 8;
|
1680
2641
|
__m256 my = _mm256_loadu_ps(y);
|
1681
2642
|
y += 8;
|
2643
|
+
// subtract
|
1682
2644
|
const __m256 a_m_b = _mm256_sub_ps(mx, my);
|
2645
|
+
// find max of absolute value of distances (chebyshev distance)
|
1683
2646
|
msum1 = _mm256_max_ps(msum1, _mm256_and_ps(signmask, a_m_b));
|
1684
2647
|
d -= 8;
|
1685
2648
|
}
|
@@ -1720,6 +2683,441 @@ float fvec_Linf(const float* x, const float* y, size_t d) {
|
|
1720
2683
|
return fvec_Linf_ref(x, y, d);
|
1721
2684
|
}
|
1722
2685
|
|
2686
|
+
#elif defined(__ARM_FEATURE_SVE)
|
2687
|
+
|
2688
|
+
struct ElementOpIP {
|
2689
|
+
static svfloat32_t op(svbool_t pg, svfloat32_t x, svfloat32_t y) {
|
2690
|
+
return svmul_f32_x(pg, x, y);
|
2691
|
+
}
|
2692
|
+
static svfloat32_t merge(
|
2693
|
+
svbool_t pg,
|
2694
|
+
svfloat32_t z,
|
2695
|
+
svfloat32_t x,
|
2696
|
+
svfloat32_t y) {
|
2697
|
+
return svmla_f32_x(pg, z, x, y);
|
2698
|
+
}
|
2699
|
+
};
|
2700
|
+
|
2701
|
+
template <typename ElementOp>
|
2702
|
+
void fvec_op_ny_sve_d1(float* dis, const float* x, const float* y, size_t ny) {
|
2703
|
+
const size_t lanes = svcntw();
|
2704
|
+
const size_t lanes2 = lanes * 2;
|
2705
|
+
const size_t lanes3 = lanes * 3;
|
2706
|
+
const size_t lanes4 = lanes * 4;
|
2707
|
+
const svbool_t pg = svptrue_b32();
|
2708
|
+
const svfloat32_t x0 = svdup_n_f32(x[0]);
|
2709
|
+
size_t i = 0;
|
2710
|
+
for (; i + lanes4 < ny; i += lanes4) {
|
2711
|
+
svfloat32_t y0 = svld1_f32(pg, y);
|
2712
|
+
svfloat32_t y1 = svld1_f32(pg, y + lanes);
|
2713
|
+
svfloat32_t y2 = svld1_f32(pg, y + lanes2);
|
2714
|
+
svfloat32_t y3 = svld1_f32(pg, y + lanes3);
|
2715
|
+
y0 = ElementOp::op(pg, x0, y0);
|
2716
|
+
y1 = ElementOp::op(pg, x0, y1);
|
2717
|
+
y2 = ElementOp::op(pg, x0, y2);
|
2718
|
+
y3 = ElementOp::op(pg, x0, y3);
|
2719
|
+
svst1_f32(pg, dis, y0);
|
2720
|
+
svst1_f32(pg, dis + lanes, y1);
|
2721
|
+
svst1_f32(pg, dis + lanes2, y2);
|
2722
|
+
svst1_f32(pg, dis + lanes3, y3);
|
2723
|
+
y += lanes4;
|
2724
|
+
dis += lanes4;
|
2725
|
+
}
|
2726
|
+
const svbool_t pg0 = svwhilelt_b32_u64(i, ny);
|
2727
|
+
const svbool_t pg1 = svwhilelt_b32_u64(i + lanes, ny);
|
2728
|
+
const svbool_t pg2 = svwhilelt_b32_u64(i + lanes2, ny);
|
2729
|
+
const svbool_t pg3 = svwhilelt_b32_u64(i + lanes3, ny);
|
2730
|
+
svfloat32_t y0 = svld1_f32(pg0, y);
|
2731
|
+
svfloat32_t y1 = svld1_f32(pg1, y + lanes);
|
2732
|
+
svfloat32_t y2 = svld1_f32(pg2, y + lanes2);
|
2733
|
+
svfloat32_t y3 = svld1_f32(pg3, y + lanes3);
|
2734
|
+
y0 = ElementOp::op(pg0, x0, y0);
|
2735
|
+
y1 = ElementOp::op(pg1, x0, y1);
|
2736
|
+
y2 = ElementOp::op(pg2, x0, y2);
|
2737
|
+
y3 = ElementOp::op(pg3, x0, y3);
|
2738
|
+
svst1_f32(pg0, dis, y0);
|
2739
|
+
svst1_f32(pg1, dis + lanes, y1);
|
2740
|
+
svst1_f32(pg2, dis + lanes2, y2);
|
2741
|
+
svst1_f32(pg3, dis + lanes3, y3);
|
2742
|
+
}
|
2743
|
+
|
2744
|
+
template <typename ElementOp>
|
2745
|
+
void fvec_op_ny_sve_d2(float* dis, const float* x, const float* y, size_t ny) {
|
2746
|
+
const size_t lanes = svcntw();
|
2747
|
+
const size_t lanes2 = lanes * 2;
|
2748
|
+
const size_t lanes4 = lanes * 4;
|
2749
|
+
const svbool_t pg = svptrue_b32();
|
2750
|
+
const svfloat32_t x0 = svdup_n_f32(x[0]);
|
2751
|
+
const svfloat32_t x1 = svdup_n_f32(x[1]);
|
2752
|
+
size_t i = 0;
|
2753
|
+
for (; i + lanes2 < ny; i += lanes2) {
|
2754
|
+
const svfloat32x2_t y0 = svld2_f32(pg, y);
|
2755
|
+
const svfloat32x2_t y1 = svld2_f32(pg, y + lanes2);
|
2756
|
+
svfloat32_t y00 = svget2_f32(y0, 0);
|
2757
|
+
const svfloat32_t y01 = svget2_f32(y0, 1);
|
2758
|
+
svfloat32_t y10 = svget2_f32(y1, 0);
|
2759
|
+
const svfloat32_t y11 = svget2_f32(y1, 1);
|
2760
|
+
y00 = ElementOp::op(pg, x0, y00);
|
2761
|
+
y10 = ElementOp::op(pg, x0, y10);
|
2762
|
+
y00 = ElementOp::merge(pg, y00, x1, y01);
|
2763
|
+
y10 = ElementOp::merge(pg, y10, x1, y11);
|
2764
|
+
svst1_f32(pg, dis, y00);
|
2765
|
+
svst1_f32(pg, dis + lanes, y10);
|
2766
|
+
y += lanes4;
|
2767
|
+
dis += lanes2;
|
2768
|
+
}
|
2769
|
+
const svbool_t pg0 = svwhilelt_b32_u64(i, ny);
|
2770
|
+
const svbool_t pg1 = svwhilelt_b32_u64(i + lanes, ny);
|
2771
|
+
const svfloat32x2_t y0 = svld2_f32(pg0, y);
|
2772
|
+
const svfloat32x2_t y1 = svld2_f32(pg1, y + lanes2);
|
2773
|
+
svfloat32_t y00 = svget2_f32(y0, 0);
|
2774
|
+
const svfloat32_t y01 = svget2_f32(y0, 1);
|
2775
|
+
svfloat32_t y10 = svget2_f32(y1, 0);
|
2776
|
+
const svfloat32_t y11 = svget2_f32(y1, 1);
|
2777
|
+
y00 = ElementOp::op(pg0, x0, y00);
|
2778
|
+
y10 = ElementOp::op(pg1, x0, y10);
|
2779
|
+
y00 = ElementOp::merge(pg0, y00, x1, y01);
|
2780
|
+
y10 = ElementOp::merge(pg1, y10, x1, y11);
|
2781
|
+
svst1_f32(pg0, dis, y00);
|
2782
|
+
svst1_f32(pg1, dis + lanes, y10);
|
2783
|
+
}
|
2784
|
+
|
2785
|
+
template <typename ElementOp>
|
2786
|
+
void fvec_op_ny_sve_d4(float* dis, const float* x, const float* y, size_t ny) {
|
2787
|
+
const size_t lanes = svcntw();
|
2788
|
+
const size_t lanes4 = lanes * 4;
|
2789
|
+
const svbool_t pg = svptrue_b32();
|
2790
|
+
const svfloat32_t x0 = svdup_n_f32(x[0]);
|
2791
|
+
const svfloat32_t x1 = svdup_n_f32(x[1]);
|
2792
|
+
const svfloat32_t x2 = svdup_n_f32(x[2]);
|
2793
|
+
const svfloat32_t x3 = svdup_n_f32(x[3]);
|
2794
|
+
size_t i = 0;
|
2795
|
+
for (; i + lanes < ny; i += lanes) {
|
2796
|
+
const svfloat32x4_t y0 = svld4_f32(pg, y);
|
2797
|
+
svfloat32_t y00 = svget4_f32(y0, 0);
|
2798
|
+
const svfloat32_t y01 = svget4_f32(y0, 1);
|
2799
|
+
svfloat32_t y02 = svget4_f32(y0, 2);
|
2800
|
+
const svfloat32_t y03 = svget4_f32(y0, 3);
|
2801
|
+
y00 = ElementOp::op(pg, x0, y00);
|
2802
|
+
y02 = ElementOp::op(pg, x2, y02);
|
2803
|
+
y00 = ElementOp::merge(pg, y00, x1, y01);
|
2804
|
+
y02 = ElementOp::merge(pg, y02, x3, y03);
|
2805
|
+
y00 = svadd_f32_x(pg, y00, y02);
|
2806
|
+
svst1_f32(pg, dis, y00);
|
2807
|
+
y += lanes4;
|
2808
|
+
dis += lanes;
|
2809
|
+
}
|
2810
|
+
const svbool_t pg0 = svwhilelt_b32_u64(i, ny);
|
2811
|
+
const svfloat32x4_t y0 = svld4_f32(pg0, y);
|
2812
|
+
svfloat32_t y00 = svget4_f32(y0, 0);
|
2813
|
+
const svfloat32_t y01 = svget4_f32(y0, 1);
|
2814
|
+
svfloat32_t y02 = svget4_f32(y0, 2);
|
2815
|
+
const svfloat32_t y03 = svget4_f32(y0, 3);
|
2816
|
+
y00 = ElementOp::op(pg0, x0, y00);
|
2817
|
+
y02 = ElementOp::op(pg0, x2, y02);
|
2818
|
+
y00 = ElementOp::merge(pg0, y00, x1, y01);
|
2819
|
+
y02 = ElementOp::merge(pg0, y02, x3, y03);
|
2820
|
+
y00 = svadd_f32_x(pg0, y00, y02);
|
2821
|
+
svst1_f32(pg0, dis, y00);
|
2822
|
+
}
|
2823
|
+
|
2824
|
+
template <typename ElementOp>
|
2825
|
+
void fvec_op_ny_sve_d8(float* dis, const float* x, const float* y, size_t ny) {
|
2826
|
+
const size_t lanes = svcntw();
|
2827
|
+
const size_t lanes4 = lanes * 4;
|
2828
|
+
const size_t lanes8 = lanes * 8;
|
2829
|
+
const svbool_t pg = svptrue_b32();
|
2830
|
+
const svfloat32_t x0 = svdup_n_f32(x[0]);
|
2831
|
+
const svfloat32_t x1 = svdup_n_f32(x[1]);
|
2832
|
+
const svfloat32_t x2 = svdup_n_f32(x[2]);
|
2833
|
+
const svfloat32_t x3 = svdup_n_f32(x[3]);
|
2834
|
+
const svfloat32_t x4 = svdup_n_f32(x[4]);
|
2835
|
+
const svfloat32_t x5 = svdup_n_f32(x[5]);
|
2836
|
+
const svfloat32_t x6 = svdup_n_f32(x[6]);
|
2837
|
+
const svfloat32_t x7 = svdup_n_f32(x[7]);
|
2838
|
+
size_t i = 0;
|
2839
|
+
for (; i + lanes < ny; i += lanes) {
|
2840
|
+
const svfloat32x4_t ya = svld4_f32(pg, y);
|
2841
|
+
const svfloat32x4_t yb = svld4_f32(pg, y + lanes4);
|
2842
|
+
const svfloat32_t ya0 = svget4_f32(ya, 0);
|
2843
|
+
const svfloat32_t ya1 = svget4_f32(ya, 1);
|
2844
|
+
const svfloat32_t ya2 = svget4_f32(ya, 2);
|
2845
|
+
const svfloat32_t ya3 = svget4_f32(ya, 3);
|
2846
|
+
const svfloat32_t yb0 = svget4_f32(yb, 0);
|
2847
|
+
const svfloat32_t yb1 = svget4_f32(yb, 1);
|
2848
|
+
const svfloat32_t yb2 = svget4_f32(yb, 2);
|
2849
|
+
const svfloat32_t yb3 = svget4_f32(yb, 3);
|
2850
|
+
svfloat32_t y0 = svuzp1(ya0, yb0);
|
2851
|
+
const svfloat32_t y1 = svuzp1(ya1, yb1);
|
2852
|
+
svfloat32_t y2 = svuzp1(ya2, yb2);
|
2853
|
+
const svfloat32_t y3 = svuzp1(ya3, yb3);
|
2854
|
+
svfloat32_t y4 = svuzp2(ya0, yb0);
|
2855
|
+
const svfloat32_t y5 = svuzp2(ya1, yb1);
|
2856
|
+
svfloat32_t y6 = svuzp2(ya2, yb2);
|
2857
|
+
const svfloat32_t y7 = svuzp2(ya3, yb3);
|
2858
|
+
y0 = ElementOp::op(pg, x0, y0);
|
2859
|
+
y2 = ElementOp::op(pg, x2, y2);
|
2860
|
+
y4 = ElementOp::op(pg, x4, y4);
|
2861
|
+
y6 = ElementOp::op(pg, x6, y6);
|
2862
|
+
y0 = ElementOp::merge(pg, y0, x1, y1);
|
2863
|
+
y2 = ElementOp::merge(pg, y2, x3, y3);
|
2864
|
+
y4 = ElementOp::merge(pg, y4, x5, y5);
|
2865
|
+
y6 = ElementOp::merge(pg, y6, x7, y7);
|
2866
|
+
y0 = svadd_f32_x(pg, y0, y2);
|
2867
|
+
y4 = svadd_f32_x(pg, y4, y6);
|
2868
|
+
y0 = svadd_f32_x(pg, y0, y4);
|
2869
|
+
svst1_f32(pg, dis, y0);
|
2870
|
+
y += lanes8;
|
2871
|
+
dis += lanes;
|
2872
|
+
}
|
2873
|
+
const svbool_t pg0 = svwhilelt_b32_u64(i, ny);
|
2874
|
+
const svbool_t pga = svwhilelt_b32_u64(i * 2, ny * 2);
|
2875
|
+
const svbool_t pgb = svwhilelt_b32_u64(i * 2 + lanes, ny * 2);
|
2876
|
+
const svfloat32x4_t ya = svld4_f32(pga, y);
|
2877
|
+
const svfloat32x4_t yb = svld4_f32(pgb, y + lanes4);
|
2878
|
+
const svfloat32_t ya0 = svget4_f32(ya, 0);
|
2879
|
+
const svfloat32_t ya1 = svget4_f32(ya, 1);
|
2880
|
+
const svfloat32_t ya2 = svget4_f32(ya, 2);
|
2881
|
+
const svfloat32_t ya3 = svget4_f32(ya, 3);
|
2882
|
+
const svfloat32_t yb0 = svget4_f32(yb, 0);
|
2883
|
+
const svfloat32_t yb1 = svget4_f32(yb, 1);
|
2884
|
+
const svfloat32_t yb2 = svget4_f32(yb, 2);
|
2885
|
+
const svfloat32_t yb3 = svget4_f32(yb, 3);
|
2886
|
+
svfloat32_t y0 = svuzp1(ya0, yb0);
|
2887
|
+
const svfloat32_t y1 = svuzp1(ya1, yb1);
|
2888
|
+
svfloat32_t y2 = svuzp1(ya2, yb2);
|
2889
|
+
const svfloat32_t y3 = svuzp1(ya3, yb3);
|
2890
|
+
svfloat32_t y4 = svuzp2(ya0, yb0);
|
2891
|
+
const svfloat32_t y5 = svuzp2(ya1, yb1);
|
2892
|
+
svfloat32_t y6 = svuzp2(ya2, yb2);
|
2893
|
+
const svfloat32_t y7 = svuzp2(ya3, yb3);
|
2894
|
+
y0 = ElementOp::op(pg0, x0, y0);
|
2895
|
+
y2 = ElementOp::op(pg0, x2, y2);
|
2896
|
+
y4 = ElementOp::op(pg0, x4, y4);
|
2897
|
+
y6 = ElementOp::op(pg0, x6, y6);
|
2898
|
+
y0 = ElementOp::merge(pg0, y0, x1, y1);
|
2899
|
+
y2 = ElementOp::merge(pg0, y2, x3, y3);
|
2900
|
+
y4 = ElementOp::merge(pg0, y4, x5, y5);
|
2901
|
+
y6 = ElementOp::merge(pg0, y6, x7, y7);
|
2902
|
+
y0 = svadd_f32_x(pg0, y0, y2);
|
2903
|
+
y4 = svadd_f32_x(pg0, y4, y6);
|
2904
|
+
y0 = svadd_f32_x(pg0, y0, y4);
|
2905
|
+
svst1_f32(pg0, dis, y0);
|
2906
|
+
y += lanes8;
|
2907
|
+
dis += lanes;
|
2908
|
+
}
|
2909
|
+
|
2910
|
+
template <typename ElementOp>
|
2911
|
+
void fvec_op_ny_sve_lanes1(
|
2912
|
+
float* dis,
|
2913
|
+
const float* x,
|
2914
|
+
const float* y,
|
2915
|
+
size_t ny) {
|
2916
|
+
const size_t lanes = svcntw();
|
2917
|
+
const size_t lanes2 = lanes * 2;
|
2918
|
+
const size_t lanes3 = lanes * 3;
|
2919
|
+
const size_t lanes4 = lanes * 4;
|
2920
|
+
const svbool_t pg = svptrue_b32();
|
2921
|
+
const svfloat32_t x0 = svld1_f32(pg, x);
|
2922
|
+
size_t i = 0;
|
2923
|
+
for (; i + 3 < ny; i += 4) {
|
2924
|
+
svfloat32_t y0 = svld1_f32(pg, y);
|
2925
|
+
svfloat32_t y1 = svld1_f32(pg, y + lanes);
|
2926
|
+
svfloat32_t y2 = svld1_f32(pg, y + lanes2);
|
2927
|
+
svfloat32_t y3 = svld1_f32(pg, y + lanes3);
|
2928
|
+
y += lanes4;
|
2929
|
+
y0 = ElementOp::op(pg, x0, y0);
|
2930
|
+
y1 = ElementOp::op(pg, x0, y1);
|
2931
|
+
y2 = ElementOp::op(pg, x0, y2);
|
2932
|
+
y3 = ElementOp::op(pg, x0, y3);
|
2933
|
+
dis[i] = svaddv_f32(pg, y0);
|
2934
|
+
dis[i + 1] = svaddv_f32(pg, y1);
|
2935
|
+
dis[i + 2] = svaddv_f32(pg, y2);
|
2936
|
+
dis[i + 3] = svaddv_f32(pg, y3);
|
2937
|
+
}
|
2938
|
+
for (; i < ny; ++i) {
|
2939
|
+
svfloat32_t y0 = svld1_f32(pg, y);
|
2940
|
+
y += lanes;
|
2941
|
+
y0 = ElementOp::op(pg, x0, y0);
|
2942
|
+
dis[i] = svaddv_f32(pg, y0);
|
2943
|
+
}
|
2944
|
+
}
|
2945
|
+
|
2946
|
+
template <typename ElementOp>
|
2947
|
+
void fvec_op_ny_sve_lanes2(
|
2948
|
+
float* dis,
|
2949
|
+
const float* x,
|
2950
|
+
const float* y,
|
2951
|
+
size_t ny) {
|
2952
|
+
const size_t lanes = svcntw();
|
2953
|
+
const size_t lanes2 = lanes * 2;
|
2954
|
+
const size_t lanes3 = lanes * 3;
|
2955
|
+
const size_t lanes4 = lanes * 4;
|
2956
|
+
const svbool_t pg = svptrue_b32();
|
2957
|
+
const svfloat32_t x0 = svld1_f32(pg, x);
|
2958
|
+
const svfloat32_t x1 = svld1_f32(pg, x + lanes);
|
2959
|
+
size_t i = 0;
|
2960
|
+
for (; i + 1 < ny; i += 2) {
|
2961
|
+
svfloat32_t y00 = svld1_f32(pg, y);
|
2962
|
+
const svfloat32_t y01 = svld1_f32(pg, y + lanes);
|
2963
|
+
svfloat32_t y10 = svld1_f32(pg, y + lanes2);
|
2964
|
+
const svfloat32_t y11 = svld1_f32(pg, y + lanes3);
|
2965
|
+
y += lanes4;
|
2966
|
+
y00 = ElementOp::op(pg, x0, y00);
|
2967
|
+
y10 = ElementOp::op(pg, x0, y10);
|
2968
|
+
y00 = ElementOp::merge(pg, y00, x1, y01);
|
2969
|
+
y10 = ElementOp::merge(pg, y10, x1, y11);
|
2970
|
+
dis[i] = svaddv_f32(pg, y00);
|
2971
|
+
dis[i + 1] = svaddv_f32(pg, y10);
|
2972
|
+
}
|
2973
|
+
if (i < ny) {
|
2974
|
+
svfloat32_t y0 = svld1_f32(pg, y);
|
2975
|
+
const svfloat32_t y1 = svld1_f32(pg, y + lanes);
|
2976
|
+
y0 = ElementOp::op(pg, x0, y0);
|
2977
|
+
y0 = ElementOp::merge(pg, y0, x1, y1);
|
2978
|
+
dis[i] = svaddv_f32(pg, y0);
|
2979
|
+
}
|
2980
|
+
}
|
2981
|
+
|
2982
|
+
template <typename ElementOp>
|
2983
|
+
void fvec_op_ny_sve_lanes3(
|
2984
|
+
float* dis,
|
2985
|
+
const float* x,
|
2986
|
+
const float* y,
|
2987
|
+
size_t ny) {
|
2988
|
+
const size_t lanes = svcntw();
|
2989
|
+
const size_t lanes2 = lanes * 2;
|
2990
|
+
const size_t lanes3 = lanes * 3;
|
2991
|
+
const svbool_t pg = svptrue_b32();
|
2992
|
+
const svfloat32_t x0 = svld1_f32(pg, x);
|
2993
|
+
const svfloat32_t x1 = svld1_f32(pg, x + lanes);
|
2994
|
+
const svfloat32_t x2 = svld1_f32(pg, x + lanes2);
|
2995
|
+
for (size_t i = 0; i < ny; ++i) {
|
2996
|
+
svfloat32_t y0 = svld1_f32(pg, y);
|
2997
|
+
const svfloat32_t y1 = svld1_f32(pg, y + lanes);
|
2998
|
+
svfloat32_t y2 = svld1_f32(pg, y + lanes2);
|
2999
|
+
y += lanes3;
|
3000
|
+
y0 = ElementOp::op(pg, x0, y0);
|
3001
|
+
y0 = ElementOp::merge(pg, y0, x1, y1);
|
3002
|
+
y0 = ElementOp::merge(pg, y0, x2, y2);
|
3003
|
+
dis[i] = svaddv_f32(pg, y0);
|
3004
|
+
}
|
3005
|
+
}
|
3006
|
+
|
3007
|
+
template <typename ElementOp>
|
3008
|
+
void fvec_op_ny_sve_lanes4(
|
3009
|
+
float* dis,
|
3010
|
+
const float* x,
|
3011
|
+
const float* y,
|
3012
|
+
size_t ny) {
|
3013
|
+
const size_t lanes = svcntw();
|
3014
|
+
const size_t lanes2 = lanes * 2;
|
3015
|
+
const size_t lanes3 = lanes * 3;
|
3016
|
+
const size_t lanes4 = lanes * 4;
|
3017
|
+
const svbool_t pg = svptrue_b32();
|
3018
|
+
const svfloat32_t x0 = svld1_f32(pg, x);
|
3019
|
+
const svfloat32_t x1 = svld1_f32(pg, x + lanes);
|
3020
|
+
const svfloat32_t x2 = svld1_f32(pg, x + lanes2);
|
3021
|
+
const svfloat32_t x3 = svld1_f32(pg, x + lanes3);
|
3022
|
+
for (size_t i = 0; i < ny; ++i) {
|
3023
|
+
svfloat32_t y0 = svld1_f32(pg, y);
|
3024
|
+
const svfloat32_t y1 = svld1_f32(pg, y + lanes);
|
3025
|
+
svfloat32_t y2 = svld1_f32(pg, y + lanes2);
|
3026
|
+
const svfloat32_t y3 = svld1_f32(pg, y + lanes3);
|
3027
|
+
y += lanes4;
|
3028
|
+
y0 = ElementOp::op(pg, x0, y0);
|
3029
|
+
y2 = ElementOp::op(pg, x2, y2);
|
3030
|
+
y0 = ElementOp::merge(pg, y0, x1, y1);
|
3031
|
+
y2 = ElementOp::merge(pg, y2, x3, y3);
|
3032
|
+
y0 = svadd_f32_x(pg, y0, y2);
|
3033
|
+
dis[i] = svaddv_f32(pg, y0);
|
3034
|
+
}
|
3035
|
+
}
|
3036
|
+
|
3037
|
+
void fvec_L2sqr_ny(
|
3038
|
+
float* dis,
|
3039
|
+
const float* x,
|
3040
|
+
const float* y,
|
3041
|
+
size_t d,
|
3042
|
+
size_t ny) {
|
3043
|
+
fvec_L2sqr_ny_ref(dis, x, y, d, ny);
|
3044
|
+
}
|
3045
|
+
|
3046
|
+
void fvec_L2sqr_ny_transposed(
|
3047
|
+
float* dis,
|
3048
|
+
const float* x,
|
3049
|
+
const float* y,
|
3050
|
+
const float* y_sqlen,
|
3051
|
+
size_t d,
|
3052
|
+
size_t d_offset,
|
3053
|
+
size_t ny) {
|
3054
|
+
return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny);
|
3055
|
+
}
|
3056
|
+
|
3057
|
+
size_t fvec_L2sqr_ny_nearest(
|
3058
|
+
float* distances_tmp_buffer,
|
3059
|
+
const float* x,
|
3060
|
+
const float* y,
|
3061
|
+
size_t d,
|
3062
|
+
size_t ny) {
|
3063
|
+
return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
|
3064
|
+
}
|
3065
|
+
|
3066
|
+
size_t fvec_L2sqr_ny_nearest_y_transposed(
|
3067
|
+
float* distances_tmp_buffer,
|
3068
|
+
const float* x,
|
3069
|
+
const float* y,
|
3070
|
+
const float* y_sqlen,
|
3071
|
+
size_t d,
|
3072
|
+
size_t d_offset,
|
3073
|
+
size_t ny) {
|
3074
|
+
return fvec_L2sqr_ny_nearest_y_transposed_ref(
|
3075
|
+
distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
|
3076
|
+
}
|
3077
|
+
|
3078
|
+
float fvec_L1(const float* x, const float* y, size_t d) {
|
3079
|
+
return fvec_L1_ref(x, y, d);
|
3080
|
+
}
|
3081
|
+
|
3082
|
+
float fvec_Linf(const float* x, const float* y, size_t d) {
|
3083
|
+
return fvec_Linf_ref(x, y, d);
|
3084
|
+
}
|
3085
|
+
|
3086
|
+
void fvec_inner_products_ny(
|
3087
|
+
float* dis,
|
3088
|
+
const float* x,
|
3089
|
+
const float* y,
|
3090
|
+
size_t d,
|
3091
|
+
size_t ny) {
|
3092
|
+
const size_t lanes = svcntw();
|
3093
|
+
switch (d) {
|
3094
|
+
case 1:
|
3095
|
+
fvec_op_ny_sve_d1<ElementOpIP>(dis, x, y, ny);
|
3096
|
+
break;
|
3097
|
+
case 2:
|
3098
|
+
fvec_op_ny_sve_d2<ElementOpIP>(dis, x, y, ny);
|
3099
|
+
break;
|
3100
|
+
case 4:
|
3101
|
+
fvec_op_ny_sve_d4<ElementOpIP>(dis, x, y, ny);
|
3102
|
+
break;
|
3103
|
+
case 8:
|
3104
|
+
fvec_op_ny_sve_d8<ElementOpIP>(dis, x, y, ny);
|
3105
|
+
break;
|
3106
|
+
default:
|
3107
|
+
if (d == lanes)
|
3108
|
+
fvec_op_ny_sve_lanes1<ElementOpIP>(dis, x, y, ny);
|
3109
|
+
else if (d == lanes * 2)
|
3110
|
+
fvec_op_ny_sve_lanes2<ElementOpIP>(dis, x, y, ny);
|
3111
|
+
else if (d == lanes * 3)
|
3112
|
+
fvec_op_ny_sve_lanes3<ElementOpIP>(dis, x, y, ny);
|
3113
|
+
else if (d == lanes * 4)
|
3114
|
+
fvec_op_ny_sve_lanes4<ElementOpIP>(dis, x, y, ny);
|
3115
|
+
else
|
3116
|
+
fvec_inner_products_ny_ref(dis, x, y, d, ny);
|
3117
|
+
break;
|
3118
|
+
}
|
3119
|
+
}
|
3120
|
+
|
1723
3121
|
#elif defined(__aarch64__)
|
1724
3122
|
|
1725
3123
|
// not optimized for ARM
|
@@ -1858,7 +3256,39 @@ void fvec_inner_products_ny(
|
|
1858
3256
|
c[i] = a[i] + bf * b[i];
|
1859
3257
|
}
|
1860
3258
|
|
1861
|
-
#
|
3259
|
+
#if defined(__AVX512F__)
|
3260
|
+
|
3261
|
+
static inline void fvec_madd_avx512(
|
3262
|
+
const size_t n,
|
3263
|
+
const float* __restrict a,
|
3264
|
+
const float bf,
|
3265
|
+
const float* __restrict b,
|
3266
|
+
float* __restrict c) {
|
3267
|
+
const size_t n16 = n / 16;
|
3268
|
+
const size_t n_for_masking = n % 16;
|
3269
|
+
|
3270
|
+
const __m512 bfmm = _mm512_set1_ps(bf);
|
3271
|
+
|
3272
|
+
size_t idx = 0;
|
3273
|
+
for (idx = 0; idx < n16 * 16; idx += 16) {
|
3274
|
+
const __m512 ax = _mm512_loadu_ps(a + idx);
|
3275
|
+
const __m512 bx = _mm512_loadu_ps(b + idx);
|
3276
|
+
const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax);
|
3277
|
+
_mm512_storeu_ps(c + idx, abmul);
|
3278
|
+
}
|
3279
|
+
|
3280
|
+
if (n_for_masking > 0) {
|
3281
|
+
const __mmask16 mask = (1 << n_for_masking) - 1;
|
3282
|
+
|
3283
|
+
const __m512 ax = _mm512_maskz_loadu_ps(mask, a + idx);
|
3284
|
+
const __m512 bx = _mm512_maskz_loadu_ps(mask, b + idx);
|
3285
|
+
const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax);
|
3286
|
+
_mm512_mask_storeu_ps(c + idx, mask, abmul);
|
3287
|
+
}
|
3288
|
+
}
|
3289
|
+
|
3290
|
+
#elif defined(__AVX2__)
|
3291
|
+
|
1862
3292
|
static inline void fvec_madd_avx2(
|
1863
3293
|
const size_t n,
|
1864
3294
|
const float* __restrict a,
|
@@ -1911,6 +3341,7 @@ static inline void fvec_madd_avx2(
|
|
1911
3341
|
_mm256_maskstore_ps(c + idx, mask, abmul);
|
1912
3342
|
}
|
1913
3343
|
}
|
3344
|
+
|
1914
3345
|
#endif
|
1915
3346
|
|
1916
3347
|
#ifdef __SSE3__
|
@@ -1936,7 +3367,9 @@ static inline void fvec_madd_avx2(
|
|
1936
3367
|
}
|
1937
3368
|
|
1938
3369
|
void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
|
1939
|
-
#ifdef
|
3370
|
+
#ifdef __AVX512F__
|
3371
|
+
fvec_madd_avx512(n, a, bf, b, c);
|
3372
|
+
#elif __AVX2__
|
1940
3373
|
fvec_madd_avx2(n, a, bf, b, c);
|
1941
3374
|
#else
|
1942
3375
|
if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
|
@@ -1946,6 +3379,60 @@ void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
|
|
1946
3379
|
#endif
|
1947
3380
|
}
|
1948
3381
|
|
3382
|
+
#elif defined(__ARM_FEATURE_SVE)
|
3383
|
+
|
3384
|
+
void fvec_madd(
|
3385
|
+
const size_t n,
|
3386
|
+
const float* __restrict a,
|
3387
|
+
const float bf,
|
3388
|
+
const float* __restrict b,
|
3389
|
+
float* __restrict c) {
|
3390
|
+
const size_t lanes = static_cast<size_t>(svcntw());
|
3391
|
+
const size_t lanes2 = lanes * 2;
|
3392
|
+
const size_t lanes3 = lanes * 3;
|
3393
|
+
const size_t lanes4 = lanes * 4;
|
3394
|
+
size_t i = 0;
|
3395
|
+
for (; i + lanes4 < n; i += lanes4) {
|
3396
|
+
const auto mask = svptrue_b32();
|
3397
|
+
const auto ai0 = svld1_f32(mask, a + i);
|
3398
|
+
const auto ai1 = svld1_f32(mask, a + i + lanes);
|
3399
|
+
const auto ai2 = svld1_f32(mask, a + i + lanes2);
|
3400
|
+
const auto ai3 = svld1_f32(mask, a + i + lanes3);
|
3401
|
+
const auto bi0 = svld1_f32(mask, b + i);
|
3402
|
+
const auto bi1 = svld1_f32(mask, b + i + lanes);
|
3403
|
+
const auto bi2 = svld1_f32(mask, b + i + lanes2);
|
3404
|
+
const auto bi3 = svld1_f32(mask, b + i + lanes3);
|
3405
|
+
const auto ci0 = svmla_n_f32_x(mask, ai0, bi0, bf);
|
3406
|
+
const auto ci1 = svmla_n_f32_x(mask, ai1, bi1, bf);
|
3407
|
+
const auto ci2 = svmla_n_f32_x(mask, ai2, bi2, bf);
|
3408
|
+
const auto ci3 = svmla_n_f32_x(mask, ai3, bi3, bf);
|
3409
|
+
svst1_f32(mask, c + i, ci0);
|
3410
|
+
svst1_f32(mask, c + i + lanes, ci1);
|
3411
|
+
svst1_f32(mask, c + i + lanes2, ci2);
|
3412
|
+
svst1_f32(mask, c + i + lanes3, ci3);
|
3413
|
+
}
|
3414
|
+
const auto mask0 = svwhilelt_b32_u64(i, n);
|
3415
|
+
const auto mask1 = svwhilelt_b32_u64(i + lanes, n);
|
3416
|
+
const auto mask2 = svwhilelt_b32_u64(i + lanes2, n);
|
3417
|
+
const auto mask3 = svwhilelt_b32_u64(i + lanes3, n);
|
3418
|
+
const auto ai0 = svld1_f32(mask0, a + i);
|
3419
|
+
const auto ai1 = svld1_f32(mask1, a + i + lanes);
|
3420
|
+
const auto ai2 = svld1_f32(mask2, a + i + lanes2);
|
3421
|
+
const auto ai3 = svld1_f32(mask3, a + i + lanes3);
|
3422
|
+
const auto bi0 = svld1_f32(mask0, b + i);
|
3423
|
+
const auto bi1 = svld1_f32(mask1, b + i + lanes);
|
3424
|
+
const auto bi2 = svld1_f32(mask2, b + i + lanes2);
|
3425
|
+
const auto bi3 = svld1_f32(mask3, b + i + lanes3);
|
3426
|
+
const auto ci0 = svmla_n_f32_x(mask0, ai0, bi0, bf);
|
3427
|
+
const auto ci1 = svmla_n_f32_x(mask1, ai1, bi1, bf);
|
3428
|
+
const auto ci2 = svmla_n_f32_x(mask2, ai2, bi2, bf);
|
3429
|
+
const auto ci3 = svmla_n_f32_x(mask3, ai3, bi3, bf);
|
3430
|
+
svst1_f32(mask0, c + i, ci0);
|
3431
|
+
svst1_f32(mask1, c + i + lanes, ci1);
|
3432
|
+
svst1_f32(mask2, c + i + lanes2, ci2);
|
3433
|
+
svst1_f32(mask3, c + i + lanes3, ci3);
|
3434
|
+
}
|
3435
|
+
|
1949
3436
|
#elif defined(__aarch64__)
|
1950
3437
|
|
1951
3438
|
void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
|
@@ -2278,7 +3765,7 @@ void fvec_add(size_t d, const float* a, float b, float* c) {
|
|
2278
3765
|
size_t i;
|
2279
3766
|
simd8float32 bv(b);
|
2280
3767
|
for (i = 0; i + 7 < d; i += 8) {
|
2281
|
-
simd8float32 ci, ai
|
3768
|
+
simd8float32 ci, ai;
|
2282
3769
|
ai.loadu(a + i);
|
2283
3770
|
ci = ai + bv;
|
2284
3771
|
ci.storeu(c + i);
|