faiss 0.3.1 → 0.3.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +8 -0
- data/LICENSE.txt +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +2 -2
- data/vendor/faiss/faiss/AutoTune.h +3 -3
- data/vendor/faiss/faiss/Clustering.cpp +37 -6
- data/vendor/faiss/faiss/Clustering.h +12 -3
- data/vendor/faiss/faiss/IVFlib.cpp +6 -3
- data/vendor/faiss/faiss/IVFlib.h +2 -2
- data/vendor/faiss/faiss/Index.cpp +6 -2
- data/vendor/faiss/faiss/Index.h +30 -8
- data/vendor/faiss/faiss/Index2Layer.cpp +2 -2
- data/vendor/faiss/faiss/Index2Layer.h +2 -2
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +7 -7
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +2 -2
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +14 -16
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +2 -2
- data/vendor/faiss/faiss/IndexBinary.cpp +13 -2
- data/vendor/faiss/faiss/IndexBinary.h +8 -2
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -3
- data/vendor/faiss/faiss/IndexBinaryFlat.h +2 -2
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -2
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -2
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +2 -7
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +3 -3
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -3
- data/vendor/faiss/faiss/IndexBinaryHash.h +2 -2
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +3 -3
- data/vendor/faiss/faiss/IndexBinaryIVF.h +2 -2
- data/vendor/faiss/faiss/IndexFastScan.cpp +32 -18
- data/vendor/faiss/faiss/IndexFastScan.h +11 -2
- data/vendor/faiss/faiss/IndexFlat.cpp +13 -10
- data/vendor/faiss/faiss/IndexFlat.h +2 -2
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -7
- data/vendor/faiss/faiss/IndexFlatCodes.h +25 -5
- data/vendor/faiss/faiss/IndexHNSW.cpp +156 -96
- data/vendor/faiss/faiss/IndexHNSW.h +54 -5
- data/vendor/faiss/faiss/IndexIDMap.cpp +19 -3
- data/vendor/faiss/faiss/IndexIDMap.h +5 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +5 -6
- data/vendor/faiss/faiss/IndexIVF.h +13 -4
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +21 -7
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +5 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +3 -14
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -4
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +201 -91
- data/vendor/faiss/faiss/IndexIVFFastScan.h +33 -9
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +2 -2
- data/vendor/faiss/faiss/IndexIVFFlat.h +2 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +2 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +2 -2
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -6
- data/vendor/faiss/faiss/IndexIVFPQ.h +2 -2
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +7 -14
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +2 -4
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -2
- data/vendor/faiss/faiss/IndexIVFPQR.h +2 -2
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +2 -3
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -2
- data/vendor/faiss/faiss/IndexLSH.cpp +2 -3
- data/vendor/faiss/faiss/IndexLSH.h +2 -2
- data/vendor/faiss/faiss/IndexLattice.cpp +3 -21
- data/vendor/faiss/faiss/IndexLattice.h +5 -24
- data/vendor/faiss/faiss/IndexNNDescent.cpp +2 -31
- data/vendor/faiss/faiss/IndexNNDescent.h +3 -3
- data/vendor/faiss/faiss/IndexNSG.cpp +2 -5
- data/vendor/faiss/faiss/IndexNSG.h +3 -3
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
- data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +26 -26
- data/vendor/faiss/faiss/IndexPQ.h +2 -2
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +2 -5
- data/vendor/faiss/faiss/IndexPQFastScan.h +2 -11
- data/vendor/faiss/faiss/IndexPreTransform.cpp +2 -2
- data/vendor/faiss/faiss/IndexPreTransform.h +3 -3
- data/vendor/faiss/faiss/IndexRefine.cpp +46 -9
- data/vendor/faiss/faiss/IndexRefine.h +9 -2
- data/vendor/faiss/faiss/IndexReplicas.cpp +2 -2
- data/vendor/faiss/faiss/IndexReplicas.h +2 -2
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +2 -2
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +2 -2
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +5 -4
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -2
- data/vendor/faiss/faiss/IndexShards.cpp +2 -2
- data/vendor/faiss/faiss/IndexShards.h +2 -2
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +2 -2
- data/vendor/faiss/faiss/IndexShardsIVF.h +2 -2
- data/vendor/faiss/faiss/MatrixStats.cpp +2 -2
- data/vendor/faiss/faiss/MatrixStats.h +2 -2
- data/vendor/faiss/faiss/MetaIndexes.cpp +2 -3
- data/vendor/faiss/faiss/MetaIndexes.h +2 -2
- data/vendor/faiss/faiss/MetricType.h +9 -4
- data/vendor/faiss/faiss/VectorTransform.cpp +2 -2
- data/vendor/faiss/faiss/VectorTransform.h +2 -2
- data/vendor/faiss/faiss/clone_index.cpp +2 -2
- data/vendor/faiss/faiss/clone_index.h +2 -2
- data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +2 -2
- data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +2 -2
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +97 -19
- data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +192 -0
- data/vendor/faiss/faiss/cppcontrib/factory_tools.h +29 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2 -2
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +85 -32
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2 -2
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +2 -2
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +2 -2
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +2 -2
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +2 -2
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +2 -5
- data/vendor/faiss/faiss/gpu/GpuAutoTune.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +45 -13
- data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +12 -6
- data/vendor/faiss/faiss/gpu/GpuDistance.h +11 -7
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +3 -3
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuIndex.h +10 -15
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +285 -0
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -2
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +4 -2
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +3 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +7 -2
- data/vendor/faiss/faiss/gpu/GpuResources.h +11 -4
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +66 -11
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +15 -5
- data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -2
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +28 -23
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +2 -2
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +2 -2
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +2 -2
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +2 -2
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +8 -2
- data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +2 -3
- data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +2 -2
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +10 -7
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +2 -2
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +54 -54
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +144 -77
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +51 -51
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +2 -2
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +3 -3
- data/vendor/faiss/faiss/gpu/test/TestGpuResidualQuantizer.cpp +70 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +74 -4
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +2 -2
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +3 -3
- data/vendor/faiss/faiss/gpu/utils/{RaftUtils.h → CuvsUtils.h} +12 -11
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +8 -2
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +2 -2
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +2 -2
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +2 -2
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +6 -3
- data/vendor/faiss/faiss/gpu/utils/Timer.h +3 -3
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +79 -11
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +17 -5
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +27 -2
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +11 -3
- data/vendor/faiss/faiss/impl/CodePacker.cpp +2 -2
- data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
- data/vendor/faiss/faiss/impl/DistanceComputer.h +48 -2
- data/vendor/faiss/faiss/impl/FaissAssert.h +6 -4
- data/vendor/faiss/faiss/impl/FaissException.cpp +2 -2
- data/vendor/faiss/faiss/impl/FaissException.h +2 -3
- data/vendor/faiss/faiss/impl/HNSW.cpp +378 -205
- data/vendor/faiss/faiss/impl/HNSW.h +55 -24
- data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
- data/vendor/faiss/faiss/impl/IDSelector.h +2 -2
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +10 -10
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +2 -2
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +36 -2
- data/vendor/faiss/faiss/impl/NNDescent.cpp +15 -10
- data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
- data/vendor/faiss/faiss/impl/NSG.cpp +26 -49
- data/vendor/faiss/faiss/impl/NSG.h +20 -8
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +2 -2
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +2 -2
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +2 -4
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +2 -2
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -2
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +3 -2
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +7 -3
- data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +2 -36
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +3 -13
- data/vendor/faiss/faiss/impl/ResultHandler.h +153 -34
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +721 -104
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +5 -2
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +2 -2
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +2 -2
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +7 -2
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +2 -2
- data/vendor/faiss/faiss/impl/code_distance/code_distance-sve.h +440 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +55 -2
- data/vendor/faiss/faiss/impl/index_read.cpp +31 -20
- data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
- data/vendor/faiss/faiss/impl/index_write.cpp +30 -16
- data/vendor/faiss/faiss/impl/io.cpp +15 -7
- data/vendor/faiss/faiss/impl/io.h +6 -6
- data/vendor/faiss/faiss/impl/io_macros.h +8 -9
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +2 -3
- data/vendor/faiss/faiss/impl/kmeans1d.h +2 -2
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +2 -3
- data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
- data/vendor/faiss/faiss/impl/platform_macros.h +34 -2
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +13 -2
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +20 -2
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +3 -3
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +450 -3
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +8 -8
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +3 -3
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +151 -67
- data/vendor/faiss/faiss/index_factory.cpp +51 -34
- data/vendor/faiss/faiss/index_factory.h +2 -2
- data/vendor/faiss/faiss/index_io.h +14 -7
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +30 -10
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +5 -2
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +11 -3
- data/vendor/faiss/faiss/invlists/DirectMap.h +2 -2
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +57 -19
- data/vendor/faiss/faiss/invlists/InvertedLists.h +20 -11
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +2 -2
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +2 -2
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +23 -9
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +4 -3
- data/vendor/faiss/faiss/python/python_callbacks.cpp +5 -5
- data/vendor/faiss/faiss/python/python_callbacks.h +2 -2
- data/vendor/faiss/faiss/utils/AlignedTable.h +5 -3
- data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
- data/vendor/faiss/faiss/utils/Heap.h +107 -2
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +346 -0
- data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
- data/vendor/faiss/faiss/utils/WorkerThread.cpp +2 -2
- data/vendor/faiss/faiss/utils/WorkerThread.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk/generic.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk/mode.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +2 -2
- data/vendor/faiss/faiss/utils/bf16.h +36 -0
- data/vendor/faiss/faiss/utils/distances.cpp +249 -90
- data/vendor/faiss/faiss/utils/distances.h +8 -8
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +2 -2
- data/vendor/faiss/faiss/utils/distances_simd.cpp +1543 -56
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +72 -2
- data/vendor/faiss/faiss/utils/extra_distances.cpp +87 -140
- data/vendor/faiss/faiss/utils/extra_distances.h +5 -4
- data/vendor/faiss/faiss/utils/fp16-arm.h +2 -2
- data/vendor/faiss/faiss/utils/fp16-fp16c.h +2 -2
- data/vendor/faiss/faiss/utils/fp16-inl.h +2 -2
- data/vendor/faiss/faiss/utils/fp16.h +2 -2
- data/vendor/faiss/faiss/utils/hamming-inl.h +2 -2
- data/vendor/faiss/faiss/utils/hamming.cpp +3 -4
- data/vendor/faiss/faiss/utils/hamming.h +2 -2
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +2 -2
- data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +490 -0
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +2 -2
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +6 -3
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +7 -3
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +5 -5
- data/vendor/faiss/faiss/utils/ordered_key_value.h +2 -2
- data/vendor/faiss/faiss/utils/partitioning.cpp +2 -2
- data/vendor/faiss/faiss/utils/partitioning.h +2 -2
- data/vendor/faiss/faiss/utils/prefetch.h +2 -2
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +2 -2
- data/vendor/faiss/faiss/utils/quantize_lut.h +2 -2
- data/vendor/faiss/faiss/utils/random.cpp +45 -2
- data/vendor/faiss/faiss/utils/random.h +27 -2
- data/vendor/faiss/faiss/utils/simdlib.h +12 -3
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +2 -2
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +2 -2
- data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -4
- data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
- data/vendor/faiss/faiss/utils/sorting.cpp +2 -2
- data/vendor/faiss/faiss/utils/sorting.h +2 -2
- data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +2 -2
- data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
- data/vendor/faiss/faiss/utils/utils.cpp +17 -10
- data/vendor/faiss/faiss/utils/utils.h +7 -3
- metadata +22 -11
- data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -1,5 +1,5 @@
|
|
1
|
-
|
2
|
-
* Copyright (c)
|
1
|
+
/*
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
3
3
|
*
|
4
4
|
* This source code is licensed under the MIT license found in the
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
@@ -23,10 +23,16 @@
|
|
23
23
|
#include <immintrin.h>
|
24
24
|
#endif
|
25
25
|
|
26
|
-
#
|
26
|
+
#if defined(__AVX512F__)
|
27
|
+
#include <faiss/utils/transpose/transpose-avx512-inl.h>
|
28
|
+
#elif defined(__AVX2__)
|
27
29
|
#include <faiss/utils/transpose/transpose-avx2-inl.h>
|
28
30
|
#endif
|
29
31
|
|
32
|
+
#ifdef __ARM_FEATURE_SVE
|
33
|
+
#include <arm_sve.h>
|
34
|
+
#endif
|
35
|
+
|
30
36
|
#ifdef __aarch64__
|
31
37
|
#include <arm_neon.h>
|
32
38
|
#endif
|
@@ -346,6 +352,14 @@ inline float horizontal_sum(const __m256 v) {
|
|
346
352
|
}
|
347
353
|
#endif
|
348
354
|
|
355
|
+
#ifdef __AVX512F__
|
356
|
+
/// helper function for AVX512
|
357
|
+
inline float horizontal_sum(const __m512 v) {
|
358
|
+
// performs better than adding the high and low parts
|
359
|
+
return _mm512_reduce_add_ps(v);
|
360
|
+
}
|
361
|
+
#endif
|
362
|
+
|
349
363
|
/// Function that does a component-wise operation between x and y
|
350
364
|
/// to compute L2 distances. ElementOp can then be used in the fvec_op_ny
|
351
365
|
/// functions below
|
@@ -366,6 +380,13 @@ struct ElementOpL2 {
|
|
366
380
|
return _mm256_mul_ps(tmp, tmp);
|
367
381
|
}
|
368
382
|
#endif
|
383
|
+
|
384
|
+
#ifdef __AVX512F__
|
385
|
+
static __m512 op(__m512 x, __m512 y) {
|
386
|
+
__m512 tmp = _mm512_sub_ps(x, y);
|
387
|
+
return _mm512_mul_ps(tmp, tmp);
|
388
|
+
}
|
389
|
+
#endif
|
369
390
|
};
|
370
391
|
|
371
392
|
/// Function that does a component-wise operation between x and y
|
@@ -384,6 +405,12 @@ struct ElementOpIP {
|
|
384
405
|
return _mm256_mul_ps(x, y);
|
385
406
|
}
|
386
407
|
#endif
|
408
|
+
|
409
|
+
#ifdef __AVX512F__
|
410
|
+
static __m512 op(__m512 x, __m512 y) {
|
411
|
+
return _mm512_mul_ps(x, y);
|
412
|
+
}
|
413
|
+
#endif
|
387
414
|
};
|
388
415
|
|
389
416
|
template <class ElementOp>
|
@@ -426,7 +453,130 @@ void fvec_op_ny_D2(float* dis, const float* x, const float* y, size_t ny) {
|
|
426
453
|
}
|
427
454
|
}
|
428
455
|
|
429
|
-
#
|
456
|
+
#if defined(__AVX512F__)
|
457
|
+
|
458
|
+
template <>
|
459
|
+
void fvec_op_ny_D2<ElementOpIP>(
|
460
|
+
float* dis,
|
461
|
+
const float* x,
|
462
|
+
const float* y,
|
463
|
+
size_t ny) {
|
464
|
+
const size_t ny16 = ny / 16;
|
465
|
+
size_t i = 0;
|
466
|
+
|
467
|
+
if (ny16 > 0) {
|
468
|
+
// process 16 D2-vectors per loop.
|
469
|
+
_mm_prefetch((const char*)y, _MM_HINT_T0);
|
470
|
+
_mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
|
471
|
+
|
472
|
+
const __m512 m0 = _mm512_set1_ps(x[0]);
|
473
|
+
const __m512 m1 = _mm512_set1_ps(x[1]);
|
474
|
+
|
475
|
+
for (i = 0; i < ny16 * 16; i += 16) {
|
476
|
+
_mm_prefetch((const char*)(y + 64), _MM_HINT_T0);
|
477
|
+
|
478
|
+
// load 16x2 matrix and transpose it in registers.
|
479
|
+
// the typical bottleneck is memory access, so
|
480
|
+
// let's trade instructions for the bandwidth.
|
481
|
+
|
482
|
+
__m512 v0;
|
483
|
+
__m512 v1;
|
484
|
+
|
485
|
+
transpose_16x2(
|
486
|
+
_mm512_loadu_ps(y + 0 * 16),
|
487
|
+
_mm512_loadu_ps(y + 1 * 16),
|
488
|
+
v0,
|
489
|
+
v1);
|
490
|
+
|
491
|
+
// compute distances (dot product)
|
492
|
+
__m512 distances = _mm512_mul_ps(m0, v0);
|
493
|
+
distances = _mm512_fmadd_ps(m1, v1, distances);
|
494
|
+
|
495
|
+
// store
|
496
|
+
_mm512_storeu_ps(dis + i, distances);
|
497
|
+
|
498
|
+
y += 32; // move to the next set of 16x2 elements
|
499
|
+
}
|
500
|
+
}
|
501
|
+
|
502
|
+
if (i < ny) {
|
503
|
+
// process leftovers
|
504
|
+
float x0 = x[0];
|
505
|
+
float x1 = x[1];
|
506
|
+
|
507
|
+
for (; i < ny; i++) {
|
508
|
+
float distance = x0 * y[0] + x1 * y[1];
|
509
|
+
y += 2;
|
510
|
+
dis[i] = distance;
|
511
|
+
}
|
512
|
+
}
|
513
|
+
}
|
514
|
+
|
515
|
+
template <>
|
516
|
+
void fvec_op_ny_D2<ElementOpL2>(
|
517
|
+
float* dis,
|
518
|
+
const float* x,
|
519
|
+
const float* y,
|
520
|
+
size_t ny) {
|
521
|
+
const size_t ny16 = ny / 16;
|
522
|
+
size_t i = 0;
|
523
|
+
|
524
|
+
if (ny16 > 0) {
|
525
|
+
// process 16 D2-vectors per loop.
|
526
|
+
_mm_prefetch((const char*)y, _MM_HINT_T0);
|
527
|
+
_mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
|
528
|
+
|
529
|
+
const __m512 m0 = _mm512_set1_ps(x[0]);
|
530
|
+
const __m512 m1 = _mm512_set1_ps(x[1]);
|
531
|
+
|
532
|
+
for (i = 0; i < ny16 * 16; i += 16) {
|
533
|
+
_mm_prefetch((const char*)(y + 64), _MM_HINT_T0);
|
534
|
+
|
535
|
+
// load 16x2 matrix and transpose it in registers.
|
536
|
+
// the typical bottleneck is memory access, so
|
537
|
+
// let's trade instructions for the bandwidth.
|
538
|
+
|
539
|
+
__m512 v0;
|
540
|
+
__m512 v1;
|
541
|
+
|
542
|
+
transpose_16x2(
|
543
|
+
_mm512_loadu_ps(y + 0 * 16),
|
544
|
+
_mm512_loadu_ps(y + 1 * 16),
|
545
|
+
v0,
|
546
|
+
v1);
|
547
|
+
|
548
|
+
// compute differences
|
549
|
+
const __m512 d0 = _mm512_sub_ps(m0, v0);
|
550
|
+
const __m512 d1 = _mm512_sub_ps(m1, v1);
|
551
|
+
|
552
|
+
// compute squares of differences
|
553
|
+
__m512 distances = _mm512_mul_ps(d0, d0);
|
554
|
+
distances = _mm512_fmadd_ps(d1, d1, distances);
|
555
|
+
|
556
|
+
// store
|
557
|
+
_mm512_storeu_ps(dis + i, distances);
|
558
|
+
|
559
|
+
y += 32; // move to the next set of 16x2 elements
|
560
|
+
}
|
561
|
+
}
|
562
|
+
|
563
|
+
if (i < ny) {
|
564
|
+
// process leftovers
|
565
|
+
float x0 = x[0];
|
566
|
+
float x1 = x[1];
|
567
|
+
|
568
|
+
for (; i < ny; i++) {
|
569
|
+
float sub0 = x0 - y[0];
|
570
|
+
float sub1 = x1 - y[1];
|
571
|
+
float distance = sub0 * sub0 + sub1 * sub1;
|
572
|
+
|
573
|
+
y += 2;
|
574
|
+
dis[i] = distance;
|
575
|
+
}
|
576
|
+
}
|
577
|
+
}
|
578
|
+
|
579
|
+
#elif defined(__AVX2__)
|
430
580
|
|
431
581
|
template <>
|
432
582
|
void fvec_op_ny_D2<ElementOpIP>(
|
@@ -562,7 +712,137 @@ void fvec_op_ny_D4(float* dis, const float* x, const float* y, size_t ny) {
|
|
562
712
|
}
|
563
713
|
}
|
564
714
|
|
565
|
-
#
|
715
|
+
#if defined(__AVX512F__)
|
716
|
+
|
717
|
+
template <>
|
718
|
+
void fvec_op_ny_D4<ElementOpIP>(
|
719
|
+
float* dis,
|
720
|
+
const float* x,
|
721
|
+
const float* y,
|
722
|
+
size_t ny) {
|
723
|
+
const size_t ny16 = ny / 16;
|
724
|
+
size_t i = 0;
|
725
|
+
|
726
|
+
if (ny16 > 0) {
|
727
|
+
// process 16 D4-vectors per loop.
|
728
|
+
const __m512 m0 = _mm512_set1_ps(x[0]);
|
729
|
+
const __m512 m1 = _mm512_set1_ps(x[1]);
|
730
|
+
const __m512 m2 = _mm512_set1_ps(x[2]);
|
731
|
+
const __m512 m3 = _mm512_set1_ps(x[3]);
|
732
|
+
|
733
|
+
for (i = 0; i < ny16 * 16; i += 16) {
|
734
|
+
// load 16x4 matrix and transpose it in registers.
|
735
|
+
// the typical bottleneck is memory access, so
|
736
|
+
// let's trade instructions for the bandwidth.
|
737
|
+
|
738
|
+
__m512 v0;
|
739
|
+
__m512 v1;
|
740
|
+
__m512 v2;
|
741
|
+
__m512 v3;
|
742
|
+
|
743
|
+
transpose_16x4(
|
744
|
+
_mm512_loadu_ps(y + 0 * 16),
|
745
|
+
_mm512_loadu_ps(y + 1 * 16),
|
746
|
+
_mm512_loadu_ps(y + 2 * 16),
|
747
|
+
_mm512_loadu_ps(y + 3 * 16),
|
748
|
+
v0,
|
749
|
+
v1,
|
750
|
+
v2,
|
751
|
+
v3);
|
752
|
+
|
753
|
+
// compute distances
|
754
|
+
__m512 distances = _mm512_mul_ps(m0, v0);
|
755
|
+
distances = _mm512_fmadd_ps(m1, v1, distances);
|
756
|
+
distances = _mm512_fmadd_ps(m2, v2, distances);
|
757
|
+
distances = _mm512_fmadd_ps(m3, v3, distances);
|
758
|
+
|
759
|
+
// store
|
760
|
+
_mm512_storeu_ps(dis + i, distances);
|
761
|
+
|
762
|
+
y += 64; // move to the next set of 16x4 elements
|
763
|
+
}
|
764
|
+
}
|
765
|
+
|
766
|
+
if (i < ny) {
|
767
|
+
// process leftovers
|
768
|
+
__m128 x0 = _mm_loadu_ps(x);
|
769
|
+
|
770
|
+
for (; i < ny; i++) {
|
771
|
+
__m128 accu = ElementOpIP::op(x0, _mm_loadu_ps(y));
|
772
|
+
y += 4;
|
773
|
+
dis[i] = horizontal_sum(accu);
|
774
|
+
}
|
775
|
+
}
|
776
|
+
}
|
777
|
+
|
778
|
+
template <>
|
779
|
+
void fvec_op_ny_D4<ElementOpL2>(
|
780
|
+
float* dis,
|
781
|
+
const float* x,
|
782
|
+
const float* y,
|
783
|
+
size_t ny) {
|
784
|
+
const size_t ny16 = ny / 16;
|
785
|
+
size_t i = 0;
|
786
|
+
|
787
|
+
if (ny16 > 0) {
|
788
|
+
// process 16 D4-vectors per loop.
|
789
|
+
const __m512 m0 = _mm512_set1_ps(x[0]);
|
790
|
+
const __m512 m1 = _mm512_set1_ps(x[1]);
|
791
|
+
const __m512 m2 = _mm512_set1_ps(x[2]);
|
792
|
+
const __m512 m3 = _mm512_set1_ps(x[3]);
|
793
|
+
|
794
|
+
for (i = 0; i < ny16 * 16; i += 16) {
|
795
|
+
// load 16x4 matrix and transpose it in registers.
|
796
|
+
// the typical bottleneck is memory access, so
|
797
|
+
// let's trade instructions for the bandwidth.
|
798
|
+
|
799
|
+
__m512 v0;
|
800
|
+
__m512 v1;
|
801
|
+
__m512 v2;
|
802
|
+
__m512 v3;
|
803
|
+
|
804
|
+
transpose_16x4(
|
805
|
+
_mm512_loadu_ps(y + 0 * 16),
|
806
|
+
_mm512_loadu_ps(y + 1 * 16),
|
807
|
+
_mm512_loadu_ps(y + 2 * 16),
|
808
|
+
_mm512_loadu_ps(y + 3 * 16),
|
809
|
+
v0,
|
810
|
+
v1,
|
811
|
+
v2,
|
812
|
+
v3);
|
813
|
+
|
814
|
+
// compute differences
|
815
|
+
const __m512 d0 = _mm512_sub_ps(m0, v0);
|
816
|
+
const __m512 d1 = _mm512_sub_ps(m1, v1);
|
817
|
+
const __m512 d2 = _mm512_sub_ps(m2, v2);
|
818
|
+
const __m512 d3 = _mm512_sub_ps(m3, v3);
|
819
|
+
|
820
|
+
// compute squares of differences
|
821
|
+
__m512 distances = _mm512_mul_ps(d0, d0);
|
822
|
+
distances = _mm512_fmadd_ps(d1, d1, distances);
|
823
|
+
distances = _mm512_fmadd_ps(d2, d2, distances);
|
824
|
+
distances = _mm512_fmadd_ps(d3, d3, distances);
|
825
|
+
|
826
|
+
// store
|
827
|
+
_mm512_storeu_ps(dis + i, distances);
|
828
|
+
|
829
|
+
y += 64; // move to the next set of 16x4 elements
|
830
|
+
}
|
831
|
+
}
|
832
|
+
|
833
|
+
if (i < ny) {
|
834
|
+
// process leftovers
|
835
|
+
__m128 x0 = _mm_loadu_ps(x);
|
836
|
+
|
837
|
+
for (; i < ny; i++) {
|
838
|
+
__m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
|
839
|
+
y += 4;
|
840
|
+
dis[i] = horizontal_sum(accu);
|
841
|
+
}
|
842
|
+
}
|
843
|
+
}
|
844
|
+
|
845
|
+
#elif defined(__AVX2__)
|
566
846
|
|
567
847
|
template <>
|
568
848
|
void fvec_op_ny_D4<ElementOpIP>(
|
@@ -710,7 +990,181 @@ void fvec_op_ny_D8(float* dis, const float* x, const float* y, size_t ny) {
|
|
710
990
|
}
|
711
991
|
}
|
712
992
|
|
713
|
-
#
|
993
|
+
#if defined(__AVX512F__)
|
994
|
+
|
995
|
+
template <>
|
996
|
+
void fvec_op_ny_D8<ElementOpIP>(
|
997
|
+
float* dis,
|
998
|
+
const float* x,
|
999
|
+
const float* y,
|
1000
|
+
size_t ny) {
|
1001
|
+
const size_t ny16 = ny / 16;
|
1002
|
+
size_t i = 0;
|
1003
|
+
|
1004
|
+
if (ny16 > 0) {
|
1005
|
+
// process 16 D16-vectors per loop.
|
1006
|
+
const __m512 m0 = _mm512_set1_ps(x[0]);
|
1007
|
+
const __m512 m1 = _mm512_set1_ps(x[1]);
|
1008
|
+
const __m512 m2 = _mm512_set1_ps(x[2]);
|
1009
|
+
const __m512 m3 = _mm512_set1_ps(x[3]);
|
1010
|
+
const __m512 m4 = _mm512_set1_ps(x[4]);
|
1011
|
+
const __m512 m5 = _mm512_set1_ps(x[5]);
|
1012
|
+
const __m512 m6 = _mm512_set1_ps(x[6]);
|
1013
|
+
const __m512 m7 = _mm512_set1_ps(x[7]);
|
1014
|
+
|
1015
|
+
for (i = 0; i < ny16 * 16; i += 16) {
|
1016
|
+
// load 16x8 matrix and transpose it in registers.
|
1017
|
+
// the typical bottleneck is memory access, so
|
1018
|
+
// let's trade instructions for the bandwidth.
|
1019
|
+
|
1020
|
+
__m512 v0;
|
1021
|
+
__m512 v1;
|
1022
|
+
__m512 v2;
|
1023
|
+
__m512 v3;
|
1024
|
+
__m512 v4;
|
1025
|
+
__m512 v5;
|
1026
|
+
__m512 v6;
|
1027
|
+
__m512 v7;
|
1028
|
+
|
1029
|
+
transpose_16x8(
|
1030
|
+
_mm512_loadu_ps(y + 0 * 16),
|
1031
|
+
_mm512_loadu_ps(y + 1 * 16),
|
1032
|
+
_mm512_loadu_ps(y + 2 * 16),
|
1033
|
+
_mm512_loadu_ps(y + 3 * 16),
|
1034
|
+
_mm512_loadu_ps(y + 4 * 16),
|
1035
|
+
_mm512_loadu_ps(y + 5 * 16),
|
1036
|
+
_mm512_loadu_ps(y + 6 * 16),
|
1037
|
+
_mm512_loadu_ps(y + 7 * 16),
|
1038
|
+
v0,
|
1039
|
+
v1,
|
1040
|
+
v2,
|
1041
|
+
v3,
|
1042
|
+
v4,
|
1043
|
+
v5,
|
1044
|
+
v6,
|
1045
|
+
v7);
|
1046
|
+
|
1047
|
+
// compute distances
|
1048
|
+
__m512 distances = _mm512_mul_ps(m0, v0);
|
1049
|
+
distances = _mm512_fmadd_ps(m1, v1, distances);
|
1050
|
+
distances = _mm512_fmadd_ps(m2, v2, distances);
|
1051
|
+
distances = _mm512_fmadd_ps(m3, v3, distances);
|
1052
|
+
distances = _mm512_fmadd_ps(m4, v4, distances);
|
1053
|
+
distances = _mm512_fmadd_ps(m5, v5, distances);
|
1054
|
+
distances = _mm512_fmadd_ps(m6, v6, distances);
|
1055
|
+
distances = _mm512_fmadd_ps(m7, v7, distances);
|
1056
|
+
|
1057
|
+
// store
|
1058
|
+
_mm512_storeu_ps(dis + i, distances);
|
1059
|
+
|
1060
|
+
y += 128; // 16 floats * 8 rows
|
1061
|
+
}
|
1062
|
+
}
|
1063
|
+
|
1064
|
+
if (i < ny) {
|
1065
|
+
// process leftovers
|
1066
|
+
__m256 x0 = _mm256_loadu_ps(x);
|
1067
|
+
|
1068
|
+
for (; i < ny; i++) {
|
1069
|
+
__m256 accu = ElementOpIP::op(x0, _mm256_loadu_ps(y));
|
1070
|
+
y += 8;
|
1071
|
+
dis[i] = horizontal_sum(accu);
|
1072
|
+
}
|
1073
|
+
}
|
1074
|
+
}
|
1075
|
+
|
1076
|
+
template <>
|
1077
|
+
void fvec_op_ny_D8<ElementOpL2>(
|
1078
|
+
float* dis,
|
1079
|
+
const float* x,
|
1080
|
+
const float* y,
|
1081
|
+
size_t ny) {
|
1082
|
+
const size_t ny16 = ny / 16;
|
1083
|
+
size_t i = 0;
|
1084
|
+
|
1085
|
+
if (ny16 > 0) {
|
1086
|
+
// process 16 D16-vectors per loop.
|
1087
|
+
const __m512 m0 = _mm512_set1_ps(x[0]);
|
1088
|
+
const __m512 m1 = _mm512_set1_ps(x[1]);
|
1089
|
+
const __m512 m2 = _mm512_set1_ps(x[2]);
|
1090
|
+
const __m512 m3 = _mm512_set1_ps(x[3]);
|
1091
|
+
const __m512 m4 = _mm512_set1_ps(x[4]);
|
1092
|
+
const __m512 m5 = _mm512_set1_ps(x[5]);
|
1093
|
+
const __m512 m6 = _mm512_set1_ps(x[6]);
|
1094
|
+
const __m512 m7 = _mm512_set1_ps(x[7]);
|
1095
|
+
|
1096
|
+
for (i = 0; i < ny16 * 16; i += 16) {
|
1097
|
+
// load 16x8 matrix and transpose it in registers.
|
1098
|
+
// the typical bottleneck is memory access, so
|
1099
|
+
// let's trade instructions for the bandwidth.
|
1100
|
+
|
1101
|
+
__m512 v0;
|
1102
|
+
__m512 v1;
|
1103
|
+
__m512 v2;
|
1104
|
+
__m512 v3;
|
1105
|
+
__m512 v4;
|
1106
|
+
__m512 v5;
|
1107
|
+
__m512 v6;
|
1108
|
+
__m512 v7;
|
1109
|
+
|
1110
|
+
transpose_16x8(
|
1111
|
+
_mm512_loadu_ps(y + 0 * 16),
|
1112
|
+
_mm512_loadu_ps(y + 1 * 16),
|
1113
|
+
_mm512_loadu_ps(y + 2 * 16),
|
1114
|
+
_mm512_loadu_ps(y + 3 * 16),
|
1115
|
+
_mm512_loadu_ps(y + 4 * 16),
|
1116
|
+
_mm512_loadu_ps(y + 5 * 16),
|
1117
|
+
_mm512_loadu_ps(y + 6 * 16),
|
1118
|
+
_mm512_loadu_ps(y + 7 * 16),
|
1119
|
+
v0,
|
1120
|
+
v1,
|
1121
|
+
v2,
|
1122
|
+
v3,
|
1123
|
+
v4,
|
1124
|
+
v5,
|
1125
|
+
v6,
|
1126
|
+
v7);
|
1127
|
+
|
1128
|
+
// compute differences
|
1129
|
+
const __m512 d0 = _mm512_sub_ps(m0, v0);
|
1130
|
+
const __m512 d1 = _mm512_sub_ps(m1, v1);
|
1131
|
+
const __m512 d2 = _mm512_sub_ps(m2, v2);
|
1132
|
+
const __m512 d3 = _mm512_sub_ps(m3, v3);
|
1133
|
+
const __m512 d4 = _mm512_sub_ps(m4, v4);
|
1134
|
+
const __m512 d5 = _mm512_sub_ps(m5, v5);
|
1135
|
+
const __m512 d6 = _mm512_sub_ps(m6, v6);
|
1136
|
+
const __m512 d7 = _mm512_sub_ps(m7, v7);
|
1137
|
+
|
1138
|
+
// compute squares of differences
|
1139
|
+
__m512 distances = _mm512_mul_ps(d0, d0);
|
1140
|
+
distances = _mm512_fmadd_ps(d1, d1, distances);
|
1141
|
+
distances = _mm512_fmadd_ps(d2, d2, distances);
|
1142
|
+
distances = _mm512_fmadd_ps(d3, d3, distances);
|
1143
|
+
distances = _mm512_fmadd_ps(d4, d4, distances);
|
1144
|
+
distances = _mm512_fmadd_ps(d5, d5, distances);
|
1145
|
+
distances = _mm512_fmadd_ps(d6, d6, distances);
|
1146
|
+
distances = _mm512_fmadd_ps(d7, d7, distances);
|
1147
|
+
|
1148
|
+
// store
|
1149
|
+
_mm512_storeu_ps(dis + i, distances);
|
1150
|
+
|
1151
|
+
y += 128; // 16 floats * 8 rows
|
1152
|
+
}
|
1153
|
+
}
|
1154
|
+
|
1155
|
+
if (i < ny) {
|
1156
|
+
// process leftovers
|
1157
|
+
__m256 x0 = _mm256_loadu_ps(x);
|
1158
|
+
|
1159
|
+
for (; i < ny; i++) {
|
1160
|
+
__m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y));
|
1161
|
+
y += 8;
|
1162
|
+
dis[i] = horizontal_sum(accu);
|
1163
|
+
}
|
1164
|
+
}
|
1165
|
+
}
|
1166
|
+
|
1167
|
+
#elif defined(__AVX2__)
|
714
1168
|
|
715
1169
|
template <>
|
716
1170
|
void fvec_op_ny_D8<ElementOpIP>(
|
@@ -955,7 +1409,83 @@ void fvec_inner_products_ny(
|
|
955
1409
|
#undef DISPATCH
|
956
1410
|
}
|
957
1411
|
|
958
|
-
#
|
1412
|
+
#if defined(__AVX512F__)
|
1413
|
+
|
1414
|
+
template <size_t DIM>
|
1415
|
+
void fvec_L2sqr_ny_y_transposed_D(
|
1416
|
+
float* distances,
|
1417
|
+
const float* x,
|
1418
|
+
const float* y,
|
1419
|
+
const float* y_sqlen,
|
1420
|
+
const size_t d_offset,
|
1421
|
+
size_t ny) {
|
1422
|
+
// current index being processed
|
1423
|
+
size_t i = 0;
|
1424
|
+
|
1425
|
+
// squared length of x
|
1426
|
+
float x_sqlen = 0;
|
1427
|
+
for (size_t j = 0; j < DIM; j++) {
|
1428
|
+
x_sqlen += x[j] * x[j];
|
1429
|
+
}
|
1430
|
+
|
1431
|
+
// process 16 vectors per loop
|
1432
|
+
const size_t ny16 = ny / 16;
|
1433
|
+
|
1434
|
+
if (ny16 > 0) {
|
1435
|
+
// m[i] = (2 * x[i], ... 2 * x[i])
|
1436
|
+
__m512 m[DIM];
|
1437
|
+
for (size_t j = 0; j < DIM; j++) {
|
1438
|
+
m[j] = _mm512_set1_ps(x[j]);
|
1439
|
+
m[j] = _mm512_add_ps(m[j], m[j]); // m[j] = 2 * x[j]
|
1440
|
+
}
|
1441
|
+
|
1442
|
+
__m512 x_sqlen_ymm = _mm512_set1_ps(x_sqlen);
|
1443
|
+
|
1444
|
+
for (; i < ny16 * 16; i += 16) {
|
1445
|
+
// Load vectors for 16 dimensions
|
1446
|
+
__m512 v[DIM];
|
1447
|
+
for (size_t j = 0; j < DIM; j++) {
|
1448
|
+
v[j] = _mm512_loadu_ps(y + j * d_offset);
|
1449
|
+
}
|
1450
|
+
|
1451
|
+
// Compute dot products
|
1452
|
+
__m512 dp = _mm512_fnmadd_ps(m[0], v[0], x_sqlen_ymm);
|
1453
|
+
for (size_t j = 1; j < DIM; j++) {
|
1454
|
+
dp = _mm512_fnmadd_ps(m[j], v[j], dp);
|
1455
|
+
}
|
1456
|
+
|
1457
|
+
// Compute y^2 - (2 * x, y) + x^2
|
1458
|
+
__m512 distances_v = _mm512_add_ps(_mm512_loadu_ps(y_sqlen), dp);
|
1459
|
+
|
1460
|
+
_mm512_storeu_ps(distances + i, distances_v);
|
1461
|
+
|
1462
|
+
// Scroll y and y_sqlen forward
|
1463
|
+
y += 16;
|
1464
|
+
y_sqlen += 16;
|
1465
|
+
}
|
1466
|
+
}
|
1467
|
+
|
1468
|
+
if (i < ny) {
|
1469
|
+
// Process leftovers
|
1470
|
+
for (; i < ny; i++) {
|
1471
|
+
float dp = 0;
|
1472
|
+
for (size_t j = 0; j < DIM; j++) {
|
1473
|
+
dp += x[j] * y[j * d_offset];
|
1474
|
+
}
|
1475
|
+
|
1476
|
+
// Compute y^2 - 2 * (x, y), which is sufficient for looking for the
|
1477
|
+
// lowest distance.
|
1478
|
+
const float distance = y_sqlen[0] - 2 * dp + x_sqlen;
|
1479
|
+
distances[i] = distance;
|
1480
|
+
|
1481
|
+
y += 1;
|
1482
|
+
y_sqlen += 1;
|
1483
|
+
}
|
1484
|
+
}
|
1485
|
+
}
|
1486
|
+
|
1487
|
+
#elif defined(__AVX2__)
|
1488
|
+
|
959
1489
|
template <size_t DIM>
|
960
1490
|
void fvec_L2sqr_ny_y_transposed_D(
|
961
1491
|
float* distances,
|
@@ -1014,58 +1544,368 @@ void fvec_L2sqr_ny_y_transposed_D(
|
|
1014
1544
|
}
|
1015
1545
|
|
1016
1546
|
if (i < ny) {
|
1017
|
-
// process leftovers
|
1018
|
-
for (; i < ny; i++) {
|
1019
|
-
float dp = 0;
|
1020
|
-
for (size_t j = 0; j < DIM; j++) {
|
1021
|
-
dp += x[j] * y[j * d_offset];
|
1022
|
-
}
|
1547
|
+
// process leftovers
|
1548
|
+
for (; i < ny; i++) {
|
1549
|
+
float dp = 0;
|
1550
|
+
for (size_t j = 0; j < DIM; j++) {
|
1551
|
+
dp += x[j] * y[j * d_offset];
|
1552
|
+
}
|
1553
|
+
|
1554
|
+
// compute y^2 - 2 * (x, y), which is sufficient for looking for the
|
1555
|
+
// lowest distance.
|
1556
|
+
const float distance = y_sqlen[0] - 2 * dp + x_sqlen;
|
1557
|
+
distances[i] = distance;
|
1558
|
+
|
1559
|
+
y += 1;
|
1560
|
+
y_sqlen += 1;
|
1561
|
+
}
|
1562
|
+
}
|
1563
|
+
}
|
1564
|
+
|
1565
|
+
#endif
|
1566
|
+
|
1567
|
+
void fvec_L2sqr_ny_transposed(
|
1568
|
+
float* dis,
|
1569
|
+
const float* x,
|
1570
|
+
const float* y,
|
1571
|
+
const float* y_sqlen,
|
1572
|
+
size_t d,
|
1573
|
+
size_t d_offset,
|
1574
|
+
size_t ny) {
|
1575
|
+
// optimized for a few special cases
|
1576
|
+
|
1577
|
+
#ifdef __AVX2__
|
1578
|
+
#define DISPATCH(dval) \
|
1579
|
+
case dval: \
|
1580
|
+
return fvec_L2sqr_ny_y_transposed_D<dval>( \
|
1581
|
+
dis, x, y, y_sqlen, d_offset, ny);
|
1582
|
+
|
1583
|
+
switch (d) {
|
1584
|
+
DISPATCH(1)
|
1585
|
+
DISPATCH(2)
|
1586
|
+
DISPATCH(4)
|
1587
|
+
DISPATCH(8)
|
1588
|
+
default:
|
1589
|
+
return fvec_L2sqr_ny_y_transposed_ref(
|
1590
|
+
dis, x, y, y_sqlen, d, d_offset, ny);
|
1591
|
+
}
|
1592
|
+
#undef DISPATCH
|
1593
|
+
#else
|
1594
|
+
// non-AVX2 case
|
1595
|
+
return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny);
|
1596
|
+
#endif
|
1597
|
+
}
|
1598
|
+
|
1599
|
+
#if defined(__AVX512F__)
|
1600
|
+
|
1601
|
+
size_t fvec_L2sqr_ny_nearest_D2(
|
1602
|
+
float* distances_tmp_buffer,
|
1603
|
+
const float* x,
|
1604
|
+
const float* y,
|
1605
|
+
size_t ny) {
|
1606
|
+
// this implementation does not use distances_tmp_buffer.
|
1607
|
+
|
1608
|
+
size_t i = 0;
|
1609
|
+
float current_min_distance = HUGE_VALF;
|
1610
|
+
size_t current_min_index = 0;
|
1611
|
+
|
1612
|
+
const size_t ny16 = ny / 16;
|
1613
|
+
if (ny16 > 0) {
|
1614
|
+
_mm_prefetch((const char*)y, _MM_HINT_T0);
|
1615
|
+
_mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
|
1616
|
+
|
1617
|
+
__m512 min_distances = _mm512_set1_ps(HUGE_VALF);
|
1618
|
+
__m512i min_indices = _mm512_set1_epi32(0);
|
1619
|
+
|
1620
|
+
__m512i current_indices = _mm512_setr_epi32(
|
1621
|
+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
|
1622
|
+
const __m512i indices_increment = _mm512_set1_epi32(16);
|
1623
|
+
|
1624
|
+
const __m512 m0 = _mm512_set1_ps(x[0]);
|
1625
|
+
const __m512 m1 = _mm512_set1_ps(x[1]);
|
1626
|
+
|
1627
|
+
for (; i < ny16 * 16; i += 16) {
|
1628
|
+
_mm_prefetch((const char*)(y + 64), _MM_HINT_T0);
|
1629
|
+
|
1630
|
+
__m512 v0;
|
1631
|
+
__m512 v1;
|
1632
|
+
|
1633
|
+
transpose_16x2(
|
1634
|
+
_mm512_loadu_ps(y + 0 * 16),
|
1635
|
+
_mm512_loadu_ps(y + 1 * 16),
|
1636
|
+
v0,
|
1637
|
+
v1);
|
1638
|
+
|
1639
|
+
const __m512 d0 = _mm512_sub_ps(m0, v0);
|
1640
|
+
const __m512 d1 = _mm512_sub_ps(m1, v1);
|
1641
|
+
|
1642
|
+
__m512 distances = _mm512_mul_ps(d0, d0);
|
1643
|
+
distances = _mm512_fmadd_ps(d1, d1, distances);
|
1644
|
+
|
1645
|
+
__mmask16 comparison =
|
1646
|
+
_mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS);
|
1647
|
+
|
1648
|
+
min_distances = _mm512_min_ps(distances, min_distances);
|
1649
|
+
min_indices = _mm512_mask_blend_epi32(
|
1650
|
+
comparison, min_indices, current_indices);
|
1651
|
+
|
1652
|
+
current_indices =
|
1653
|
+
_mm512_add_epi32(current_indices, indices_increment);
|
1654
|
+
|
1655
|
+
y += 32;
|
1656
|
+
}
|
1657
|
+
|
1658
|
+
alignas(64) float min_distances_scalar[16];
|
1659
|
+
alignas(64) uint32_t min_indices_scalar[16];
|
1660
|
+
_mm512_store_ps(min_distances_scalar, min_distances);
|
1661
|
+
_mm512_store_epi32(min_indices_scalar, min_indices);
|
1662
|
+
|
1663
|
+
for (size_t j = 0; j < 16; j++) {
|
1664
|
+
if (current_min_distance > min_distances_scalar[j]) {
|
1665
|
+
current_min_distance = min_distances_scalar[j];
|
1666
|
+
current_min_index = min_indices_scalar[j];
|
1667
|
+
}
|
1668
|
+
}
|
1669
|
+
}
|
1670
|
+
|
1671
|
+
if (i < ny) {
|
1672
|
+
float x0 = x[0];
|
1673
|
+
float x1 = x[1];
|
1674
|
+
|
1675
|
+
for (; i < ny; i++) {
|
1676
|
+
float sub0 = x0 - y[0];
|
1677
|
+
float sub1 = x1 - y[1];
|
1678
|
+
float distance = sub0 * sub0 + sub1 * sub1;
|
1679
|
+
|
1680
|
+
y += 2;
|
1681
|
+
|
1682
|
+
if (current_min_distance > distance) {
|
1683
|
+
current_min_distance = distance;
|
1684
|
+
current_min_index = i;
|
1685
|
+
}
|
1686
|
+
}
|
1687
|
+
}
|
1688
|
+
|
1689
|
+
return current_min_index;
|
1690
|
+
}
|
1691
|
+
|
1692
|
+
size_t fvec_L2sqr_ny_nearest_D4(
|
1693
|
+
float* distances_tmp_buffer,
|
1694
|
+
const float* x,
|
1695
|
+
const float* y,
|
1696
|
+
size_t ny) {
|
1697
|
+
// this implementation does not use distances_tmp_buffer.
|
1698
|
+
|
1699
|
+
size_t i = 0;
|
1700
|
+
float current_min_distance = HUGE_VALF;
|
1701
|
+
size_t current_min_index = 0;
|
1702
|
+
|
1703
|
+
const size_t ny16 = ny / 16;
|
1704
|
+
|
1705
|
+
if (ny16 > 0) {
|
1706
|
+
__m512 min_distances = _mm512_set1_ps(HUGE_VALF);
|
1707
|
+
__m512i min_indices = _mm512_set1_epi32(0);
|
1708
|
+
|
1709
|
+
__m512i current_indices = _mm512_setr_epi32(
|
1710
|
+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
|
1711
|
+
const __m512i indices_increment = _mm512_set1_epi32(16);
|
1712
|
+
|
1713
|
+
const __m512 m0 = _mm512_set1_ps(x[0]);
|
1714
|
+
const __m512 m1 = _mm512_set1_ps(x[1]);
|
1715
|
+
const __m512 m2 = _mm512_set1_ps(x[2]);
|
1716
|
+
const __m512 m3 = _mm512_set1_ps(x[3]);
|
1717
|
+
|
1718
|
+
for (; i < ny16 * 16; i += 16) {
|
1719
|
+
__m512 v0;
|
1720
|
+
__m512 v1;
|
1721
|
+
__m512 v2;
|
1722
|
+
__m512 v3;
|
1723
|
+
|
1724
|
+
transpose_16x4(
|
1725
|
+
_mm512_loadu_ps(y + 0 * 16),
|
1726
|
+
_mm512_loadu_ps(y + 1 * 16),
|
1727
|
+
_mm512_loadu_ps(y + 2 * 16),
|
1728
|
+
_mm512_loadu_ps(y + 3 * 16),
|
1729
|
+
v0,
|
1730
|
+
v1,
|
1731
|
+
v2,
|
1732
|
+
v3);
|
1733
|
+
|
1734
|
+
const __m512 d0 = _mm512_sub_ps(m0, v0);
|
1735
|
+
const __m512 d1 = _mm512_sub_ps(m1, v1);
|
1736
|
+
const __m512 d2 = _mm512_sub_ps(m2, v2);
|
1737
|
+
const __m512 d3 = _mm512_sub_ps(m3, v3);
|
1738
|
+
|
1739
|
+
__m512 distances = _mm512_mul_ps(d0, d0);
|
1740
|
+
distances = _mm512_fmadd_ps(d1, d1, distances);
|
1741
|
+
distances = _mm512_fmadd_ps(d2, d2, distances);
|
1742
|
+
distances = _mm512_fmadd_ps(d3, d3, distances);
|
1743
|
+
|
1744
|
+
__mmask16 comparison =
|
1745
|
+
_mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS);
|
1746
|
+
|
1747
|
+
min_distances = _mm512_min_ps(distances, min_distances);
|
1748
|
+
min_indices = _mm512_mask_blend_epi32(
|
1749
|
+
comparison, min_indices, current_indices);
|
1750
|
+
|
1751
|
+
current_indices =
|
1752
|
+
_mm512_add_epi32(current_indices, indices_increment);
|
1753
|
+
|
1754
|
+
y += 64;
|
1755
|
+
}
|
1756
|
+
|
1757
|
+
alignas(64) float min_distances_scalar[16];
|
1758
|
+
alignas(64) uint32_t min_indices_scalar[16];
|
1759
|
+
_mm512_store_ps(min_distances_scalar, min_distances);
|
1760
|
+
_mm512_store_epi32(min_indices_scalar, min_indices);
|
1761
|
+
|
1762
|
+
for (size_t j = 0; j < 16; j++) {
|
1763
|
+
if (current_min_distance > min_distances_scalar[j]) {
|
1764
|
+
current_min_distance = min_distances_scalar[j];
|
1765
|
+
current_min_index = min_indices_scalar[j];
|
1766
|
+
}
|
1767
|
+
}
|
1768
|
+
}
|
1769
|
+
|
1770
|
+
if (i < ny) {
|
1771
|
+
__m128 x0 = _mm_loadu_ps(x);
|
1772
|
+
|
1773
|
+
for (; i < ny; i++) {
|
1774
|
+
__m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
|
1775
|
+
y += 4;
|
1776
|
+
const float distance = horizontal_sum(accu);
|
1777
|
+
|
1778
|
+
if (current_min_distance > distance) {
|
1779
|
+
current_min_distance = distance;
|
1780
|
+
current_min_index = i;
|
1781
|
+
}
|
1782
|
+
}
|
1783
|
+
}
|
1784
|
+
|
1785
|
+
return current_min_index;
|
1786
|
+
}
|
1787
|
+
|
1788
|
+
size_t fvec_L2sqr_ny_nearest_D8(
|
1789
|
+
float* distances_tmp_buffer,
|
1790
|
+
const float* x,
|
1791
|
+
const float* y,
|
1792
|
+
size_t ny) {
|
1793
|
+
// this implementation does not use distances_tmp_buffer.
|
1794
|
+
|
1795
|
+
size_t i = 0;
|
1796
|
+
float current_min_distance = HUGE_VALF;
|
1797
|
+
size_t current_min_index = 0;
|
1798
|
+
|
1799
|
+
const size_t ny16 = ny / 16;
|
1800
|
+
if (ny16 > 0) {
|
1801
|
+
__m512 min_distances = _mm512_set1_ps(HUGE_VALF);
|
1802
|
+
__m512i min_indices = _mm512_set1_epi32(0);
|
1803
|
+
|
1804
|
+
__m512i current_indices = _mm512_setr_epi32(
|
1805
|
+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
|
1806
|
+
const __m512i indices_increment = _mm512_set1_epi32(16);
|
1807
|
+
|
1808
|
+
const __m512 m0 = _mm512_set1_ps(x[0]);
|
1809
|
+
const __m512 m1 = _mm512_set1_ps(x[1]);
|
1810
|
+
const __m512 m2 = _mm512_set1_ps(x[2]);
|
1811
|
+
const __m512 m3 = _mm512_set1_ps(x[3]);
|
1812
|
+
|
1813
|
+
const __m512 m4 = _mm512_set1_ps(x[4]);
|
1814
|
+
const __m512 m5 = _mm512_set1_ps(x[5]);
|
1815
|
+
const __m512 m6 = _mm512_set1_ps(x[6]);
|
1816
|
+
const __m512 m7 = _mm512_set1_ps(x[7]);
|
1817
|
+
|
1818
|
+
for (; i < ny16 * 16; i += 16) {
|
1819
|
+
__m512 v0;
|
1820
|
+
__m512 v1;
|
1821
|
+
__m512 v2;
|
1822
|
+
__m512 v3;
|
1823
|
+
__m512 v4;
|
1824
|
+
__m512 v5;
|
1825
|
+
__m512 v6;
|
1826
|
+
__m512 v7;
|
1827
|
+
|
1828
|
+
transpose_16x8(
|
1829
|
+
_mm512_loadu_ps(y + 0 * 16),
|
1830
|
+
_mm512_loadu_ps(y + 1 * 16),
|
1831
|
+
_mm512_loadu_ps(y + 2 * 16),
|
1832
|
+
_mm512_loadu_ps(y + 3 * 16),
|
1833
|
+
_mm512_loadu_ps(y + 4 * 16),
|
1834
|
+
_mm512_loadu_ps(y + 5 * 16),
|
1835
|
+
_mm512_loadu_ps(y + 6 * 16),
|
1836
|
+
_mm512_loadu_ps(y + 7 * 16),
|
1837
|
+
v0,
|
1838
|
+
v1,
|
1839
|
+
v2,
|
1840
|
+
v3,
|
1841
|
+
v4,
|
1842
|
+
v5,
|
1843
|
+
v6,
|
1844
|
+
v7);
|
1845
|
+
|
1846
|
+
const __m512 d0 = _mm512_sub_ps(m0, v0);
|
1847
|
+
const __m512 d1 = _mm512_sub_ps(m1, v1);
|
1848
|
+
const __m512 d2 = _mm512_sub_ps(m2, v2);
|
1849
|
+
const __m512 d3 = _mm512_sub_ps(m3, v3);
|
1850
|
+
const __m512 d4 = _mm512_sub_ps(m4, v4);
|
1851
|
+
const __m512 d5 = _mm512_sub_ps(m5, v5);
|
1852
|
+
const __m512 d6 = _mm512_sub_ps(m6, v6);
|
1853
|
+
const __m512 d7 = _mm512_sub_ps(m7, v7);
|
1854
|
+
|
1855
|
+
__m512 distances = _mm512_mul_ps(d0, d0);
|
1856
|
+
distances = _mm512_fmadd_ps(d1, d1, distances);
|
1857
|
+
distances = _mm512_fmadd_ps(d2, d2, distances);
|
1858
|
+
distances = _mm512_fmadd_ps(d3, d3, distances);
|
1859
|
+
distances = _mm512_fmadd_ps(d4, d4, distances);
|
1860
|
+
distances = _mm512_fmadd_ps(d5, d5, distances);
|
1861
|
+
distances = _mm512_fmadd_ps(d6, d6, distances);
|
1862
|
+
distances = _mm512_fmadd_ps(d7, d7, distances);
|
1863
|
+
|
1864
|
+
__mmask16 comparison =
|
1865
|
+
_mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS);
|
1866
|
+
|
1867
|
+
min_distances = _mm512_min_ps(distances, min_distances);
|
1868
|
+
min_indices = _mm512_mask_blend_epi32(
|
1869
|
+
comparison, min_indices, current_indices);
|
1870
|
+
|
1871
|
+
current_indices =
|
1872
|
+
_mm512_add_epi32(current_indices, indices_increment);
|
1873
|
+
|
1874
|
+
y += 128;
|
1875
|
+
}
|
1876
|
+
|
1877
|
+
alignas(64) float min_distances_scalar[16];
|
1878
|
+
alignas(64) uint32_t min_indices_scalar[16];
|
1879
|
+
_mm512_store_ps(min_distances_scalar, min_distances);
|
1880
|
+
_mm512_store_epi32(min_indices_scalar, min_indices);
|
1881
|
+
|
1882
|
+
for (size_t j = 0; j < 16; j++) {
|
1883
|
+
if (current_min_distance > min_distances_scalar[j]) {
|
1884
|
+
current_min_distance = min_distances_scalar[j];
|
1885
|
+
current_min_index = min_indices_scalar[j];
|
1886
|
+
}
|
1887
|
+
}
|
1888
|
+
}
|
1889
|
+
|
1890
|
+
if (i < ny) {
|
1891
|
+
__m256 x0 = _mm256_loadu_ps(x);
|
1023
1892
|
|
1024
|
-
|
1025
|
-
|
1026
|
-
|
1027
|
-
|
1893
|
+
for (; i < ny; i++) {
|
1894
|
+
__m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y));
|
1895
|
+
y += 8;
|
1896
|
+
const float distance = horizontal_sum(accu);
|
1028
1897
|
|
1029
|
-
|
1030
|
-
|
1898
|
+
if (current_min_distance > distance) {
|
1899
|
+
current_min_distance = distance;
|
1900
|
+
current_min_index = i;
|
1901
|
+
}
|
1031
1902
|
}
|
1032
1903
|
}
|
1033
|
-
}
|
1034
|
-
#endif
|
1035
|
-
|
1036
|
-
void fvec_L2sqr_ny_transposed(
|
1037
|
-
float* dis,
|
1038
|
-
const float* x,
|
1039
|
-
const float* y,
|
1040
|
-
const float* y_sqlen,
|
1041
|
-
size_t d,
|
1042
|
-
size_t d_offset,
|
1043
|
-
size_t ny) {
|
1044
|
-
// optimized for a few special cases
|
1045
|
-
|
1046
|
-
#ifdef __AVX2__
|
1047
|
-
#define DISPATCH(dval) \
|
1048
|
-
case dval: \
|
1049
|
-
return fvec_L2sqr_ny_y_transposed_D<dval>( \
|
1050
|
-
dis, x, y, y_sqlen, d_offset, ny);
|
1051
1904
|
|
1052
|
-
|
1053
|
-
DISPATCH(1)
|
1054
|
-
DISPATCH(2)
|
1055
|
-
DISPATCH(4)
|
1056
|
-
DISPATCH(8)
|
1057
|
-
default:
|
1058
|
-
return fvec_L2sqr_ny_y_transposed_ref(
|
1059
|
-
dis, x, y, y_sqlen, d, d_offset, ny);
|
1060
|
-
}
|
1061
|
-
#undef DISPATCH
|
1062
|
-
#else
|
1063
|
-
// non-AVX2 case
|
1064
|
-
return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny);
|
1065
|
-
#endif
|
1905
|
+
return current_min_index;
|
1066
1906
|
}
|
1067
1907
|
|
1068
|
-
#
|
1908
|
+
#elif defined(__AVX2__)
|
1069
1909
|
|
1070
1910
|
size_t fvec_L2sqr_ny_nearest_D2(
|
1071
1911
|
float* distances_tmp_buffer,
|
@@ -1476,7 +2316,123 @@ size_t fvec_L2sqr_ny_nearest(
|
|
1476
2316
|
#undef DISPATCH
|
1477
2317
|
}
|
1478
2318
|
|
1479
|
-
#
|
2319
|
+
#if defined(__AVX512F__)
|
2320
|
+
|
2321
|
+
template <size_t DIM>
|
2322
|
+
size_t fvec_L2sqr_ny_nearest_y_transposed_D(
|
2323
|
+
float* distances_tmp_buffer,
|
2324
|
+
const float* x,
|
2325
|
+
const float* y,
|
2326
|
+
const float* y_sqlen,
|
2327
|
+
const size_t d_offset,
|
2328
|
+
size_t ny) {
|
2329
|
+
// This implementation does not use distances_tmp_buffer.
|
2330
|
+
|
2331
|
+
// Current index being processed
|
2332
|
+
size_t i = 0;
|
2333
|
+
|
2334
|
+
// Min distance and the index of the closest vector so far
|
2335
|
+
float current_min_distance = HUGE_VALF;
|
2336
|
+
size_t current_min_index = 0;
|
2337
|
+
|
2338
|
+
// Process 16 vectors per loop
|
2339
|
+
const size_t ny16 = ny / 16;
|
2340
|
+
|
2341
|
+
if (ny16 > 0) {
|
2342
|
+
// Track min distance and the closest vector independently
|
2343
|
+
// for each of 16 AVX-512 components.
|
2344
|
+
__m512 min_distances = _mm512_set1_ps(HUGE_VALF);
|
2345
|
+
__m512i min_indices = _mm512_set1_epi32(0);
|
2346
|
+
|
2347
|
+
__m512i current_indices = _mm512_setr_epi32(
|
2348
|
+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
|
2349
|
+
const __m512i indices_increment = _mm512_set1_epi32(16);
|
2350
|
+
|
2351
|
+
// m[i] = (2 * x[i], ... 2 * x[i])
|
2352
|
+
__m512 m[DIM];
|
2353
|
+
for (size_t j = 0; j < DIM; j++) {
|
2354
|
+
m[j] = _mm512_set1_ps(x[j]);
|
2355
|
+
m[j] = _mm512_add_ps(m[j], m[j]);
|
2356
|
+
}
|
2357
|
+
|
2358
|
+
for (; i < ny16 * 16; i += 16) {
|
2359
|
+
// Compute dot products
|
2360
|
+
const __m512 v0 = _mm512_loadu_ps(y + 0 * d_offset);
|
2361
|
+
__m512 dp = _mm512_mul_ps(m[0], v0);
|
2362
|
+
for (size_t j = 1; j < DIM; j++) {
|
2363
|
+
const __m512 vj = _mm512_loadu_ps(y + j * d_offset);
|
2364
|
+
dp = _mm512_fmadd_ps(m[j], vj, dp);
|
2365
|
+
}
|
2366
|
+
|
2367
|
+
// Compute y^2 - (2 * x, y), which is sufficient for looking for the
|
2368
|
+
// lowest distance.
|
2369
|
+
// x^2 is the constant that can be avoided.
|
2370
|
+
const __m512 distances =
|
2371
|
+
_mm512_sub_ps(_mm512_loadu_ps(y_sqlen), dp);
|
2372
|
+
|
2373
|
+
// Compare the new distances to the min distances
|
2374
|
+
__mmask16 comparison =
|
2375
|
+
_mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS);
|
2376
|
+
|
2377
|
+
// Update min distances and indices with closest vectors if needed
|
2378
|
+
min_distances =
|
2379
|
+
_mm512_mask_blend_ps(comparison, distances, min_distances);
|
2380
|
+
min_indices = _mm512_castps_si512(_mm512_mask_blend_ps(
|
2381
|
+
comparison,
|
2382
|
+
_mm512_castsi512_ps(current_indices),
|
2383
|
+
_mm512_castsi512_ps(min_indices)));
|
2384
|
+
|
2385
|
+
// Update current indices values. Basically, +16 to each of the 16
|
2386
|
+
// AVX-512 components.
|
2387
|
+
current_indices =
|
2388
|
+
_mm512_add_epi32(current_indices, indices_increment);
|
2389
|
+
|
2390
|
+
// Scroll y and y_sqlen forward.
|
2391
|
+
y += 16;
|
2392
|
+
y_sqlen += 16;
|
2393
|
+
}
|
2394
|
+
|
2395
|
+
// Dump values and find the minimum distance / minimum index
|
2396
|
+
float min_distances_scalar[16];
|
2397
|
+
uint32_t min_indices_scalar[16];
|
2398
|
+
_mm512_storeu_ps(min_distances_scalar, min_distances);
|
2399
|
+
_mm512_storeu_si512((__m512i*)(min_indices_scalar), min_indices);
|
2400
|
+
|
2401
|
+
for (size_t j = 0; j < 16; j++) {
|
2402
|
+
if (current_min_distance > min_distances_scalar[j]) {
|
2403
|
+
current_min_distance = min_distances_scalar[j];
|
2404
|
+
current_min_index = min_indices_scalar[j];
|
2405
|
+
}
|
2406
|
+
}
|
2407
|
+
}
|
2408
|
+
|
2409
|
+
if (i < ny) {
|
2410
|
+
// Process leftovers
|
2411
|
+
for (; i < ny; i++) {
|
2412
|
+
float dp = 0;
|
2413
|
+
for (size_t j = 0; j < DIM; j++) {
|
2414
|
+
dp += x[j] * y[j * d_offset];
|
2415
|
+
}
|
2416
|
+
|
2417
|
+
// Compute y^2 - 2 * (x, y), which is sufficient for looking for the
|
2418
|
+
// lowest distance.
|
2419
|
+
const float distance = y_sqlen[0] - 2 * dp;
|
2420
|
+
|
2421
|
+
if (current_min_distance > distance) {
|
2422
|
+
current_min_distance = distance;
|
2423
|
+
current_min_index = i;
|
2424
|
+
}
|
2425
|
+
|
2426
|
+
y += 1;
|
2427
|
+
y_sqlen += 1;
|
2428
|
+
}
|
2429
|
+
}
|
2430
|
+
|
2431
|
+
return current_min_index;
|
2432
|
+
}
|
2433
|
+
|
2434
|
+
#elif defined(__AVX2__)
|
2435
|
+
|
1480
2436
|
template <size_t DIM>
|
1481
2437
|
size_t fvec_L2sqr_ny_nearest_y_transposed_D(
|
1482
2438
|
float* distances_tmp_buffer,
|
@@ -1592,6 +2548,7 @@ size_t fvec_L2sqr_ny_nearest_y_transposed_D(
|
|
1592
2548
|
|
1593
2549
|
return current_min_index;
|
1594
2550
|
}
|
2551
|
+
|
1595
2552
|
#endif
|
1596
2553
|
|
1597
2554
|
size_t fvec_L2sqr_ny_nearest_y_transposed(
|
@@ -1632,6 +2589,7 @@ size_t fvec_L2sqr_ny_nearest_y_transposed(
|
|
1632
2589
|
|
1633
2590
|
float fvec_L1(const float* x, const float* y, size_t d) {
|
1634
2591
|
__m256 msum1 = _mm256_setzero_ps();
|
2592
|
+
// signmask used for absolute value
|
1635
2593
|
__m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL));
|
1636
2594
|
|
1637
2595
|
while (d >= 8) {
|
@@ -1639,7 +2597,9 @@ float fvec_L1(const float* x, const float* y, size_t d) {
|
|
1639
2597
|
x += 8;
|
1640
2598
|
__m256 my = _mm256_loadu_ps(y);
|
1641
2599
|
y += 8;
|
2600
|
+
// subtract
|
1642
2601
|
const __m256 a_m_b = _mm256_sub_ps(mx, my);
|
2602
|
+
// find sum of absolute value of distances (manhattan distance)
|
1643
2603
|
msum1 = _mm256_add_ps(msum1, _mm256_and_ps(signmask, a_m_b));
|
1644
2604
|
d -= 8;
|
1645
2605
|
}
|
@@ -1672,6 +2632,7 @@ float fvec_L1(const float* x, const float* y, size_t d) {
|
|
1672
2632
|
|
1673
2633
|
float fvec_Linf(const float* x, const float* y, size_t d) {
|
1674
2634
|
__m256 msum1 = _mm256_setzero_ps();
|
2635
|
+
// signmask used for absolute value
|
1675
2636
|
__m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL));
|
1676
2637
|
|
1677
2638
|
while (d >= 8) {
|
@@ -1679,7 +2640,9 @@ float fvec_Linf(const float* x, const float* y, size_t d) {
|
|
1679
2640
|
x += 8;
|
1680
2641
|
__m256 my = _mm256_loadu_ps(y);
|
1681
2642
|
y += 8;
|
2643
|
+
// subtract
|
1682
2644
|
const __m256 a_m_b = _mm256_sub_ps(mx, my);
|
2645
|
+
// find max of absolute value of distances (chebyshev distance)
|
1683
2646
|
msum1 = _mm256_max_ps(msum1, _mm256_and_ps(signmask, a_m_b));
|
1684
2647
|
d -= 8;
|
1685
2648
|
}
|
@@ -1720,6 +2683,441 @@ float fvec_Linf(const float* x, const float* y, size_t d) {
|
|
1720
2683
|
return fvec_Linf_ref(x, y, d);
|
1721
2684
|
}
|
1722
2685
|
|
2686
|
+
#elif defined(__ARM_FEATURE_SVE)
|
2687
|
+
|
2688
|
+
struct ElementOpIP {
|
2689
|
+
static svfloat32_t op(svbool_t pg, svfloat32_t x, svfloat32_t y) {
|
2690
|
+
return svmul_f32_x(pg, x, y);
|
2691
|
+
}
|
2692
|
+
static svfloat32_t merge(
|
2693
|
+
svbool_t pg,
|
2694
|
+
svfloat32_t z,
|
2695
|
+
svfloat32_t x,
|
2696
|
+
svfloat32_t y) {
|
2697
|
+
return svmla_f32_x(pg, z, x, y);
|
2698
|
+
}
|
2699
|
+
};
|
2700
|
+
|
2701
|
+
template <typename ElementOp>
|
2702
|
+
void fvec_op_ny_sve_d1(float* dis, const float* x, const float* y, size_t ny) {
|
2703
|
+
const size_t lanes = svcntw();
|
2704
|
+
const size_t lanes2 = lanes * 2;
|
2705
|
+
const size_t lanes3 = lanes * 3;
|
2706
|
+
const size_t lanes4 = lanes * 4;
|
2707
|
+
const svbool_t pg = svptrue_b32();
|
2708
|
+
const svfloat32_t x0 = svdup_n_f32(x[0]);
|
2709
|
+
size_t i = 0;
|
2710
|
+
for (; i + lanes4 < ny; i += lanes4) {
|
2711
|
+
svfloat32_t y0 = svld1_f32(pg, y);
|
2712
|
+
svfloat32_t y1 = svld1_f32(pg, y + lanes);
|
2713
|
+
svfloat32_t y2 = svld1_f32(pg, y + lanes2);
|
2714
|
+
svfloat32_t y3 = svld1_f32(pg, y + lanes3);
|
2715
|
+
y0 = ElementOp::op(pg, x0, y0);
|
2716
|
+
y1 = ElementOp::op(pg, x0, y1);
|
2717
|
+
y2 = ElementOp::op(pg, x0, y2);
|
2718
|
+
y3 = ElementOp::op(pg, x0, y3);
|
2719
|
+
svst1_f32(pg, dis, y0);
|
2720
|
+
svst1_f32(pg, dis + lanes, y1);
|
2721
|
+
svst1_f32(pg, dis + lanes2, y2);
|
2722
|
+
svst1_f32(pg, dis + lanes3, y3);
|
2723
|
+
y += lanes4;
|
2724
|
+
dis += lanes4;
|
2725
|
+
}
|
2726
|
+
const svbool_t pg0 = svwhilelt_b32_u64(i, ny);
|
2727
|
+
const svbool_t pg1 = svwhilelt_b32_u64(i + lanes, ny);
|
2728
|
+
const svbool_t pg2 = svwhilelt_b32_u64(i + lanes2, ny);
|
2729
|
+
const svbool_t pg3 = svwhilelt_b32_u64(i + lanes3, ny);
|
2730
|
+
svfloat32_t y0 = svld1_f32(pg0, y);
|
2731
|
+
svfloat32_t y1 = svld1_f32(pg1, y + lanes);
|
2732
|
+
svfloat32_t y2 = svld1_f32(pg2, y + lanes2);
|
2733
|
+
svfloat32_t y3 = svld1_f32(pg3, y + lanes3);
|
2734
|
+
y0 = ElementOp::op(pg0, x0, y0);
|
2735
|
+
y1 = ElementOp::op(pg1, x0, y1);
|
2736
|
+
y2 = ElementOp::op(pg2, x0, y2);
|
2737
|
+
y3 = ElementOp::op(pg3, x0, y3);
|
2738
|
+
svst1_f32(pg0, dis, y0);
|
2739
|
+
svst1_f32(pg1, dis + lanes, y1);
|
2740
|
+
svst1_f32(pg2, dis + lanes2, y2);
|
2741
|
+
svst1_f32(pg3, dis + lanes3, y3);
|
2742
|
+
}
|
2743
|
+
|
2744
|
+
template <typename ElementOp>
|
2745
|
+
void fvec_op_ny_sve_d2(float* dis, const float* x, const float* y, size_t ny) {
|
2746
|
+
const size_t lanes = svcntw();
|
2747
|
+
const size_t lanes2 = lanes * 2;
|
2748
|
+
const size_t lanes4 = lanes * 4;
|
2749
|
+
const svbool_t pg = svptrue_b32();
|
2750
|
+
const svfloat32_t x0 = svdup_n_f32(x[0]);
|
2751
|
+
const svfloat32_t x1 = svdup_n_f32(x[1]);
|
2752
|
+
size_t i = 0;
|
2753
|
+
for (; i + lanes2 < ny; i += lanes2) {
|
2754
|
+
const svfloat32x2_t y0 = svld2_f32(pg, y);
|
2755
|
+
const svfloat32x2_t y1 = svld2_f32(pg, y + lanes2);
|
2756
|
+
svfloat32_t y00 = svget2_f32(y0, 0);
|
2757
|
+
const svfloat32_t y01 = svget2_f32(y0, 1);
|
2758
|
+
svfloat32_t y10 = svget2_f32(y1, 0);
|
2759
|
+
const svfloat32_t y11 = svget2_f32(y1, 1);
|
2760
|
+
y00 = ElementOp::op(pg, x0, y00);
|
2761
|
+
y10 = ElementOp::op(pg, x0, y10);
|
2762
|
+
y00 = ElementOp::merge(pg, y00, x1, y01);
|
2763
|
+
y10 = ElementOp::merge(pg, y10, x1, y11);
|
2764
|
+
svst1_f32(pg, dis, y00);
|
2765
|
+
svst1_f32(pg, dis + lanes, y10);
|
2766
|
+
y += lanes4;
|
2767
|
+
dis += lanes2;
|
2768
|
+
}
|
2769
|
+
const svbool_t pg0 = svwhilelt_b32_u64(i, ny);
|
2770
|
+
const svbool_t pg1 = svwhilelt_b32_u64(i + lanes, ny);
|
2771
|
+
const svfloat32x2_t y0 = svld2_f32(pg0, y);
|
2772
|
+
const svfloat32x2_t y1 = svld2_f32(pg1, y + lanes2);
|
2773
|
+
svfloat32_t y00 = svget2_f32(y0, 0);
|
2774
|
+
const svfloat32_t y01 = svget2_f32(y0, 1);
|
2775
|
+
svfloat32_t y10 = svget2_f32(y1, 0);
|
2776
|
+
const svfloat32_t y11 = svget2_f32(y1, 1);
|
2777
|
+
y00 = ElementOp::op(pg0, x0, y00);
|
2778
|
+
y10 = ElementOp::op(pg1, x0, y10);
|
2779
|
+
y00 = ElementOp::merge(pg0, y00, x1, y01);
|
2780
|
+
y10 = ElementOp::merge(pg1, y10, x1, y11);
|
2781
|
+
svst1_f32(pg0, dis, y00);
|
2782
|
+
svst1_f32(pg1, dis + lanes, y10);
|
2783
|
+
}
|
2784
|
+
|
2785
|
+
template <typename ElementOp>
|
2786
|
+
void fvec_op_ny_sve_d4(float* dis, const float* x, const float* y, size_t ny) {
|
2787
|
+
const size_t lanes = svcntw();
|
2788
|
+
const size_t lanes4 = lanes * 4;
|
2789
|
+
const svbool_t pg = svptrue_b32();
|
2790
|
+
const svfloat32_t x0 = svdup_n_f32(x[0]);
|
2791
|
+
const svfloat32_t x1 = svdup_n_f32(x[1]);
|
2792
|
+
const svfloat32_t x2 = svdup_n_f32(x[2]);
|
2793
|
+
const svfloat32_t x3 = svdup_n_f32(x[3]);
|
2794
|
+
size_t i = 0;
|
2795
|
+
for (; i + lanes < ny; i += lanes) {
|
2796
|
+
const svfloat32x4_t y0 = svld4_f32(pg, y);
|
2797
|
+
svfloat32_t y00 = svget4_f32(y0, 0);
|
2798
|
+
const svfloat32_t y01 = svget4_f32(y0, 1);
|
2799
|
+
svfloat32_t y02 = svget4_f32(y0, 2);
|
2800
|
+
const svfloat32_t y03 = svget4_f32(y0, 3);
|
2801
|
+
y00 = ElementOp::op(pg, x0, y00);
|
2802
|
+
y02 = ElementOp::op(pg, x2, y02);
|
2803
|
+
y00 = ElementOp::merge(pg, y00, x1, y01);
|
2804
|
+
y02 = ElementOp::merge(pg, y02, x3, y03);
|
2805
|
+
y00 = svadd_f32_x(pg, y00, y02);
|
2806
|
+
svst1_f32(pg, dis, y00);
|
2807
|
+
y += lanes4;
|
2808
|
+
dis += lanes;
|
2809
|
+
}
|
2810
|
+
const svbool_t pg0 = svwhilelt_b32_u64(i, ny);
|
2811
|
+
const svfloat32x4_t y0 = svld4_f32(pg0, y);
|
2812
|
+
svfloat32_t y00 = svget4_f32(y0, 0);
|
2813
|
+
const svfloat32_t y01 = svget4_f32(y0, 1);
|
2814
|
+
svfloat32_t y02 = svget4_f32(y0, 2);
|
2815
|
+
const svfloat32_t y03 = svget4_f32(y0, 3);
|
2816
|
+
y00 = ElementOp::op(pg0, x0, y00);
|
2817
|
+
y02 = ElementOp::op(pg0, x2, y02);
|
2818
|
+
y00 = ElementOp::merge(pg0, y00, x1, y01);
|
2819
|
+
y02 = ElementOp::merge(pg0, y02, x3, y03);
|
2820
|
+
y00 = svadd_f32_x(pg0, y00, y02);
|
2821
|
+
svst1_f32(pg0, dis, y00);
|
2822
|
+
}
|
2823
|
+
|
2824
|
+
template <typename ElementOp>
|
2825
|
+
void fvec_op_ny_sve_d8(float* dis, const float* x, const float* y, size_t ny) {
|
2826
|
+
const size_t lanes = svcntw();
|
2827
|
+
const size_t lanes4 = lanes * 4;
|
2828
|
+
const size_t lanes8 = lanes * 8;
|
2829
|
+
const svbool_t pg = svptrue_b32();
|
2830
|
+
const svfloat32_t x0 = svdup_n_f32(x[0]);
|
2831
|
+
const svfloat32_t x1 = svdup_n_f32(x[1]);
|
2832
|
+
const svfloat32_t x2 = svdup_n_f32(x[2]);
|
2833
|
+
const svfloat32_t x3 = svdup_n_f32(x[3]);
|
2834
|
+
const svfloat32_t x4 = svdup_n_f32(x[4]);
|
2835
|
+
const svfloat32_t x5 = svdup_n_f32(x[5]);
|
2836
|
+
const svfloat32_t x6 = svdup_n_f32(x[6]);
|
2837
|
+
const svfloat32_t x7 = svdup_n_f32(x[7]);
|
2838
|
+
size_t i = 0;
|
2839
|
+
for (; i + lanes < ny; i += lanes) {
|
2840
|
+
const svfloat32x4_t ya = svld4_f32(pg, y);
|
2841
|
+
const svfloat32x4_t yb = svld4_f32(pg, y + lanes4);
|
2842
|
+
const svfloat32_t ya0 = svget4_f32(ya, 0);
|
2843
|
+
const svfloat32_t ya1 = svget4_f32(ya, 1);
|
2844
|
+
const svfloat32_t ya2 = svget4_f32(ya, 2);
|
2845
|
+
const svfloat32_t ya3 = svget4_f32(ya, 3);
|
2846
|
+
const svfloat32_t yb0 = svget4_f32(yb, 0);
|
2847
|
+
const svfloat32_t yb1 = svget4_f32(yb, 1);
|
2848
|
+
const svfloat32_t yb2 = svget4_f32(yb, 2);
|
2849
|
+
const svfloat32_t yb3 = svget4_f32(yb, 3);
|
2850
|
+
svfloat32_t y0 = svuzp1(ya0, yb0);
|
2851
|
+
const svfloat32_t y1 = svuzp1(ya1, yb1);
|
2852
|
+
svfloat32_t y2 = svuzp1(ya2, yb2);
|
2853
|
+
const svfloat32_t y3 = svuzp1(ya3, yb3);
|
2854
|
+
svfloat32_t y4 = svuzp2(ya0, yb0);
|
2855
|
+
const svfloat32_t y5 = svuzp2(ya1, yb1);
|
2856
|
+
svfloat32_t y6 = svuzp2(ya2, yb2);
|
2857
|
+
const svfloat32_t y7 = svuzp2(ya3, yb3);
|
2858
|
+
y0 = ElementOp::op(pg, x0, y0);
|
2859
|
+
y2 = ElementOp::op(pg, x2, y2);
|
2860
|
+
y4 = ElementOp::op(pg, x4, y4);
|
2861
|
+
y6 = ElementOp::op(pg, x6, y6);
|
2862
|
+
y0 = ElementOp::merge(pg, y0, x1, y1);
|
2863
|
+
y2 = ElementOp::merge(pg, y2, x3, y3);
|
2864
|
+
y4 = ElementOp::merge(pg, y4, x5, y5);
|
2865
|
+
y6 = ElementOp::merge(pg, y6, x7, y7);
|
2866
|
+
y0 = svadd_f32_x(pg, y0, y2);
|
2867
|
+
y4 = svadd_f32_x(pg, y4, y6);
|
2868
|
+
y0 = svadd_f32_x(pg, y0, y4);
|
2869
|
+
svst1_f32(pg, dis, y0);
|
2870
|
+
y += lanes8;
|
2871
|
+
dis += lanes;
|
2872
|
+
}
|
2873
|
+
const svbool_t pg0 = svwhilelt_b32_u64(i, ny);
|
2874
|
+
const svbool_t pga = svwhilelt_b32_u64(i * 2, ny * 2);
|
2875
|
+
const svbool_t pgb = svwhilelt_b32_u64(i * 2 + lanes, ny * 2);
|
2876
|
+
const svfloat32x4_t ya = svld4_f32(pga, y);
|
2877
|
+
const svfloat32x4_t yb = svld4_f32(pgb, y + lanes4);
|
2878
|
+
const svfloat32_t ya0 = svget4_f32(ya, 0);
|
2879
|
+
const svfloat32_t ya1 = svget4_f32(ya, 1);
|
2880
|
+
const svfloat32_t ya2 = svget4_f32(ya, 2);
|
2881
|
+
const svfloat32_t ya3 = svget4_f32(ya, 3);
|
2882
|
+
const svfloat32_t yb0 = svget4_f32(yb, 0);
|
2883
|
+
const svfloat32_t yb1 = svget4_f32(yb, 1);
|
2884
|
+
const svfloat32_t yb2 = svget4_f32(yb, 2);
|
2885
|
+
const svfloat32_t yb3 = svget4_f32(yb, 3);
|
2886
|
+
svfloat32_t y0 = svuzp1(ya0, yb0);
|
2887
|
+
const svfloat32_t y1 = svuzp1(ya1, yb1);
|
2888
|
+
svfloat32_t y2 = svuzp1(ya2, yb2);
|
2889
|
+
const svfloat32_t y3 = svuzp1(ya3, yb3);
|
2890
|
+
svfloat32_t y4 = svuzp2(ya0, yb0);
|
2891
|
+
const svfloat32_t y5 = svuzp2(ya1, yb1);
|
2892
|
+
svfloat32_t y6 = svuzp2(ya2, yb2);
|
2893
|
+
const svfloat32_t y7 = svuzp2(ya3, yb3);
|
2894
|
+
y0 = ElementOp::op(pg0, x0, y0);
|
2895
|
+
y2 = ElementOp::op(pg0, x2, y2);
|
2896
|
+
y4 = ElementOp::op(pg0, x4, y4);
|
2897
|
+
y6 = ElementOp::op(pg0, x6, y6);
|
2898
|
+
y0 = ElementOp::merge(pg0, y0, x1, y1);
|
2899
|
+
y2 = ElementOp::merge(pg0, y2, x3, y3);
|
2900
|
+
y4 = ElementOp::merge(pg0, y4, x5, y5);
|
2901
|
+
y6 = ElementOp::merge(pg0, y6, x7, y7);
|
2902
|
+
y0 = svadd_f32_x(pg0, y0, y2);
|
2903
|
+
y4 = svadd_f32_x(pg0, y4, y6);
|
2904
|
+
y0 = svadd_f32_x(pg0, y0, y4);
|
2905
|
+
svst1_f32(pg0, dis, y0);
|
2906
|
+
y += lanes8;
|
2907
|
+
dis += lanes;
|
2908
|
+
}
|
2909
|
+
|
2910
|
+
template <typename ElementOp>
|
2911
|
+
void fvec_op_ny_sve_lanes1(
|
2912
|
+
float* dis,
|
2913
|
+
const float* x,
|
2914
|
+
const float* y,
|
2915
|
+
size_t ny) {
|
2916
|
+
const size_t lanes = svcntw();
|
2917
|
+
const size_t lanes2 = lanes * 2;
|
2918
|
+
const size_t lanes3 = lanes * 3;
|
2919
|
+
const size_t lanes4 = lanes * 4;
|
2920
|
+
const svbool_t pg = svptrue_b32();
|
2921
|
+
const svfloat32_t x0 = svld1_f32(pg, x);
|
2922
|
+
size_t i = 0;
|
2923
|
+
for (; i + 3 < ny; i += 4) {
|
2924
|
+
svfloat32_t y0 = svld1_f32(pg, y);
|
2925
|
+
svfloat32_t y1 = svld1_f32(pg, y + lanes);
|
2926
|
+
svfloat32_t y2 = svld1_f32(pg, y + lanes2);
|
2927
|
+
svfloat32_t y3 = svld1_f32(pg, y + lanes3);
|
2928
|
+
y += lanes4;
|
2929
|
+
y0 = ElementOp::op(pg, x0, y0);
|
2930
|
+
y1 = ElementOp::op(pg, x0, y1);
|
2931
|
+
y2 = ElementOp::op(pg, x0, y2);
|
2932
|
+
y3 = ElementOp::op(pg, x0, y3);
|
2933
|
+
dis[i] = svaddv_f32(pg, y0);
|
2934
|
+
dis[i + 1] = svaddv_f32(pg, y1);
|
2935
|
+
dis[i + 2] = svaddv_f32(pg, y2);
|
2936
|
+
dis[i + 3] = svaddv_f32(pg, y3);
|
2937
|
+
}
|
2938
|
+
for (; i < ny; ++i) {
|
2939
|
+
svfloat32_t y0 = svld1_f32(pg, y);
|
2940
|
+
y += lanes;
|
2941
|
+
y0 = ElementOp::op(pg, x0, y0);
|
2942
|
+
dis[i] = svaddv_f32(pg, y0);
|
2943
|
+
}
|
2944
|
+
}
|
2945
|
+
|
2946
|
+
template <typename ElementOp>
|
2947
|
+
void fvec_op_ny_sve_lanes2(
|
2948
|
+
float* dis,
|
2949
|
+
const float* x,
|
2950
|
+
const float* y,
|
2951
|
+
size_t ny) {
|
2952
|
+
const size_t lanes = svcntw();
|
2953
|
+
const size_t lanes2 = lanes * 2;
|
2954
|
+
const size_t lanes3 = lanes * 3;
|
2955
|
+
const size_t lanes4 = lanes * 4;
|
2956
|
+
const svbool_t pg = svptrue_b32();
|
2957
|
+
const svfloat32_t x0 = svld1_f32(pg, x);
|
2958
|
+
const svfloat32_t x1 = svld1_f32(pg, x + lanes);
|
2959
|
+
size_t i = 0;
|
2960
|
+
for (; i + 1 < ny; i += 2) {
|
2961
|
+
svfloat32_t y00 = svld1_f32(pg, y);
|
2962
|
+
const svfloat32_t y01 = svld1_f32(pg, y + lanes);
|
2963
|
+
svfloat32_t y10 = svld1_f32(pg, y + lanes2);
|
2964
|
+
const svfloat32_t y11 = svld1_f32(pg, y + lanes3);
|
2965
|
+
y += lanes4;
|
2966
|
+
y00 = ElementOp::op(pg, x0, y00);
|
2967
|
+
y10 = ElementOp::op(pg, x0, y10);
|
2968
|
+
y00 = ElementOp::merge(pg, y00, x1, y01);
|
2969
|
+
y10 = ElementOp::merge(pg, y10, x1, y11);
|
2970
|
+
dis[i] = svaddv_f32(pg, y00);
|
2971
|
+
dis[i + 1] = svaddv_f32(pg, y10);
|
2972
|
+
}
|
2973
|
+
if (i < ny) {
|
2974
|
+
svfloat32_t y0 = svld1_f32(pg, y);
|
2975
|
+
const svfloat32_t y1 = svld1_f32(pg, y + lanes);
|
2976
|
+
y0 = ElementOp::op(pg, x0, y0);
|
2977
|
+
y0 = ElementOp::merge(pg, y0, x1, y1);
|
2978
|
+
dis[i] = svaddv_f32(pg, y0);
|
2979
|
+
}
|
2980
|
+
}
|
2981
|
+
|
2982
|
+
template <typename ElementOp>
|
2983
|
+
void fvec_op_ny_sve_lanes3(
|
2984
|
+
float* dis,
|
2985
|
+
const float* x,
|
2986
|
+
const float* y,
|
2987
|
+
size_t ny) {
|
2988
|
+
const size_t lanes = svcntw();
|
2989
|
+
const size_t lanes2 = lanes * 2;
|
2990
|
+
const size_t lanes3 = lanes * 3;
|
2991
|
+
const svbool_t pg = svptrue_b32();
|
2992
|
+
const svfloat32_t x0 = svld1_f32(pg, x);
|
2993
|
+
const svfloat32_t x1 = svld1_f32(pg, x + lanes);
|
2994
|
+
const svfloat32_t x2 = svld1_f32(pg, x + lanes2);
|
2995
|
+
for (size_t i = 0; i < ny; ++i) {
|
2996
|
+
svfloat32_t y0 = svld1_f32(pg, y);
|
2997
|
+
const svfloat32_t y1 = svld1_f32(pg, y + lanes);
|
2998
|
+
svfloat32_t y2 = svld1_f32(pg, y + lanes2);
|
2999
|
+
y += lanes3;
|
3000
|
+
y0 = ElementOp::op(pg, x0, y0);
|
3001
|
+
y0 = ElementOp::merge(pg, y0, x1, y1);
|
3002
|
+
y0 = ElementOp::merge(pg, y0, x2, y2);
|
3003
|
+
dis[i] = svaddv_f32(pg, y0);
|
3004
|
+
}
|
3005
|
+
}
|
3006
|
+
|
3007
|
+
template <typename ElementOp>
|
3008
|
+
void fvec_op_ny_sve_lanes4(
|
3009
|
+
float* dis,
|
3010
|
+
const float* x,
|
3011
|
+
const float* y,
|
3012
|
+
size_t ny) {
|
3013
|
+
const size_t lanes = svcntw();
|
3014
|
+
const size_t lanes2 = lanes * 2;
|
3015
|
+
const size_t lanes3 = lanes * 3;
|
3016
|
+
const size_t lanes4 = lanes * 4;
|
3017
|
+
const svbool_t pg = svptrue_b32();
|
3018
|
+
const svfloat32_t x0 = svld1_f32(pg, x);
|
3019
|
+
const svfloat32_t x1 = svld1_f32(pg, x + lanes);
|
3020
|
+
const svfloat32_t x2 = svld1_f32(pg, x + lanes2);
|
3021
|
+
const svfloat32_t x3 = svld1_f32(pg, x + lanes3);
|
3022
|
+
for (size_t i = 0; i < ny; ++i) {
|
3023
|
+
svfloat32_t y0 = svld1_f32(pg, y);
|
3024
|
+
const svfloat32_t y1 = svld1_f32(pg, y + lanes);
|
3025
|
+
svfloat32_t y2 = svld1_f32(pg, y + lanes2);
|
3026
|
+
const svfloat32_t y3 = svld1_f32(pg, y + lanes3);
|
3027
|
+
y += lanes4;
|
3028
|
+
y0 = ElementOp::op(pg, x0, y0);
|
3029
|
+
y2 = ElementOp::op(pg, x2, y2);
|
3030
|
+
y0 = ElementOp::merge(pg, y0, x1, y1);
|
3031
|
+
y2 = ElementOp::merge(pg, y2, x3, y3);
|
3032
|
+
y0 = svadd_f32_x(pg, y0, y2);
|
3033
|
+
dis[i] = svaddv_f32(pg, y0);
|
3034
|
+
}
|
3035
|
+
}
|
3036
|
+
|
3037
|
+
void fvec_L2sqr_ny(
|
3038
|
+
float* dis,
|
3039
|
+
const float* x,
|
3040
|
+
const float* y,
|
3041
|
+
size_t d,
|
3042
|
+
size_t ny) {
|
3043
|
+
fvec_L2sqr_ny_ref(dis, x, y, d, ny);
|
3044
|
+
}
|
3045
|
+
|
3046
|
+
void fvec_L2sqr_ny_transposed(
|
3047
|
+
float* dis,
|
3048
|
+
const float* x,
|
3049
|
+
const float* y,
|
3050
|
+
const float* y_sqlen,
|
3051
|
+
size_t d,
|
3052
|
+
size_t d_offset,
|
3053
|
+
size_t ny) {
|
3054
|
+
return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny);
|
3055
|
+
}
|
3056
|
+
|
3057
|
+
size_t fvec_L2sqr_ny_nearest(
|
3058
|
+
float* distances_tmp_buffer,
|
3059
|
+
const float* x,
|
3060
|
+
const float* y,
|
3061
|
+
size_t d,
|
3062
|
+
size_t ny) {
|
3063
|
+
return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
|
3064
|
+
}
|
3065
|
+
|
3066
|
+
size_t fvec_L2sqr_ny_nearest_y_transposed(
|
3067
|
+
float* distances_tmp_buffer,
|
3068
|
+
const float* x,
|
3069
|
+
const float* y,
|
3070
|
+
const float* y_sqlen,
|
3071
|
+
size_t d,
|
3072
|
+
size_t d_offset,
|
3073
|
+
size_t ny) {
|
3074
|
+
return fvec_L2sqr_ny_nearest_y_transposed_ref(
|
3075
|
+
distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
|
3076
|
+
}
|
3077
|
+
|
3078
|
+
float fvec_L1(const float* x, const float* y, size_t d) {
|
3079
|
+
return fvec_L1_ref(x, y, d);
|
3080
|
+
}
|
3081
|
+
|
3082
|
+
float fvec_Linf(const float* x, const float* y, size_t d) {
|
3083
|
+
return fvec_Linf_ref(x, y, d);
|
3084
|
+
}
|
3085
|
+
|
3086
|
+
void fvec_inner_products_ny(
|
3087
|
+
float* dis,
|
3088
|
+
const float* x,
|
3089
|
+
const float* y,
|
3090
|
+
size_t d,
|
3091
|
+
size_t ny) {
|
3092
|
+
const size_t lanes = svcntw();
|
3093
|
+
switch (d) {
|
3094
|
+
case 1:
|
3095
|
+
fvec_op_ny_sve_d1<ElementOpIP>(dis, x, y, ny);
|
3096
|
+
break;
|
3097
|
+
case 2:
|
3098
|
+
fvec_op_ny_sve_d2<ElementOpIP>(dis, x, y, ny);
|
3099
|
+
break;
|
3100
|
+
case 4:
|
3101
|
+
fvec_op_ny_sve_d4<ElementOpIP>(dis, x, y, ny);
|
3102
|
+
break;
|
3103
|
+
case 8:
|
3104
|
+
fvec_op_ny_sve_d8<ElementOpIP>(dis, x, y, ny);
|
3105
|
+
break;
|
3106
|
+
default:
|
3107
|
+
if (d == lanes)
|
3108
|
+
fvec_op_ny_sve_lanes1<ElementOpIP>(dis, x, y, ny);
|
3109
|
+
else if (d == lanes * 2)
|
3110
|
+
fvec_op_ny_sve_lanes2<ElementOpIP>(dis, x, y, ny);
|
3111
|
+
else if (d == lanes * 3)
|
3112
|
+
fvec_op_ny_sve_lanes3<ElementOpIP>(dis, x, y, ny);
|
3113
|
+
else if (d == lanes * 4)
|
3114
|
+
fvec_op_ny_sve_lanes4<ElementOpIP>(dis, x, y, ny);
|
3115
|
+
else
|
3116
|
+
fvec_inner_products_ny_ref(dis, x, y, d, ny);
|
3117
|
+
break;
|
3118
|
+
}
|
3119
|
+
}
|
3120
|
+
|
1723
3121
|
#elif defined(__aarch64__)
|
1724
3122
|
|
1725
3123
|
// not optimized for ARM
|
@@ -1858,7 +3256,39 @@ void fvec_inner_products_ny(
|
|
1858
3256
|
c[i] = a[i] + bf * b[i];
|
1859
3257
|
}
|
1860
3258
|
|
1861
|
-
#
|
3259
|
+
#if defined(__AVX512F__)
|
3260
|
+
|
3261
|
+
static inline void fvec_madd_avx512(
|
3262
|
+
const size_t n,
|
3263
|
+
const float* __restrict a,
|
3264
|
+
const float bf,
|
3265
|
+
const float* __restrict b,
|
3266
|
+
float* __restrict c) {
|
3267
|
+
const size_t n16 = n / 16;
|
3268
|
+
const size_t n_for_masking = n % 16;
|
3269
|
+
|
3270
|
+
const __m512 bfmm = _mm512_set1_ps(bf);
|
3271
|
+
|
3272
|
+
size_t idx = 0;
|
3273
|
+
for (idx = 0; idx < n16 * 16; idx += 16) {
|
3274
|
+
const __m512 ax = _mm512_loadu_ps(a + idx);
|
3275
|
+
const __m512 bx = _mm512_loadu_ps(b + idx);
|
3276
|
+
const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax);
|
3277
|
+
_mm512_storeu_ps(c + idx, abmul);
|
3278
|
+
}
|
3279
|
+
|
3280
|
+
if (n_for_masking > 0) {
|
3281
|
+
const __mmask16 mask = (1 << n_for_masking) - 1;
|
3282
|
+
|
3283
|
+
const __m512 ax = _mm512_maskz_loadu_ps(mask, a + idx);
|
3284
|
+
const __m512 bx = _mm512_maskz_loadu_ps(mask, b + idx);
|
3285
|
+
const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax);
|
3286
|
+
_mm512_mask_storeu_ps(c + idx, mask, abmul);
|
3287
|
+
}
|
3288
|
+
}
|
3289
|
+
|
3290
|
+
#elif defined(__AVX2__)
|
3291
|
+
|
1862
3292
|
static inline void fvec_madd_avx2(
|
1863
3293
|
const size_t n,
|
1864
3294
|
const float* __restrict a,
|
@@ -1911,6 +3341,7 @@ static inline void fvec_madd_avx2(
|
|
1911
3341
|
_mm256_maskstore_ps(c + idx, mask, abmul);
|
1912
3342
|
}
|
1913
3343
|
}
|
3344
|
+
|
1914
3345
|
#endif
|
1915
3346
|
|
1916
3347
|
#ifdef __SSE3__
|
@@ -1936,7 +3367,9 @@ static inline void fvec_madd_avx2(
|
|
1936
3367
|
}
|
1937
3368
|
|
1938
3369
|
void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
|
1939
|
-
#ifdef
|
3370
|
+
#ifdef __AVX512F__
|
3371
|
+
fvec_madd_avx512(n, a, bf, b, c);
|
3372
|
+
#elif __AVX2__
|
1940
3373
|
fvec_madd_avx2(n, a, bf, b, c);
|
1941
3374
|
#else
|
1942
3375
|
if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
|
@@ -1946,6 +3379,60 @@ void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
|
|
1946
3379
|
#endif
|
1947
3380
|
}
|
1948
3381
|
|
3382
|
+
#elif defined(__ARM_FEATURE_SVE)
|
3383
|
+
|
3384
|
+
void fvec_madd(
|
3385
|
+
const size_t n,
|
3386
|
+
const float* __restrict a,
|
3387
|
+
const float bf,
|
3388
|
+
const float* __restrict b,
|
3389
|
+
float* __restrict c) {
|
3390
|
+
const size_t lanes = static_cast<size_t>(svcntw());
|
3391
|
+
const size_t lanes2 = lanes * 2;
|
3392
|
+
const size_t lanes3 = lanes * 3;
|
3393
|
+
const size_t lanes4 = lanes * 4;
|
3394
|
+
size_t i = 0;
|
3395
|
+
for (; i + lanes4 < n; i += lanes4) {
|
3396
|
+
const auto mask = svptrue_b32();
|
3397
|
+
const auto ai0 = svld1_f32(mask, a + i);
|
3398
|
+
const auto ai1 = svld1_f32(mask, a + i + lanes);
|
3399
|
+
const auto ai2 = svld1_f32(mask, a + i + lanes2);
|
3400
|
+
const auto ai3 = svld1_f32(mask, a + i + lanes3);
|
3401
|
+
const auto bi0 = svld1_f32(mask, b + i);
|
3402
|
+
const auto bi1 = svld1_f32(mask, b + i + lanes);
|
3403
|
+
const auto bi2 = svld1_f32(mask, b + i + lanes2);
|
3404
|
+
const auto bi3 = svld1_f32(mask, b + i + lanes3);
|
3405
|
+
const auto ci0 = svmla_n_f32_x(mask, ai0, bi0, bf);
|
3406
|
+
const auto ci1 = svmla_n_f32_x(mask, ai1, bi1, bf);
|
3407
|
+
const auto ci2 = svmla_n_f32_x(mask, ai2, bi2, bf);
|
3408
|
+
const auto ci3 = svmla_n_f32_x(mask, ai3, bi3, bf);
|
3409
|
+
svst1_f32(mask, c + i, ci0);
|
3410
|
+
svst1_f32(mask, c + i + lanes, ci1);
|
3411
|
+
svst1_f32(mask, c + i + lanes2, ci2);
|
3412
|
+
svst1_f32(mask, c + i + lanes3, ci3);
|
3413
|
+
}
|
3414
|
+
const auto mask0 = svwhilelt_b32_u64(i, n);
|
3415
|
+
const auto mask1 = svwhilelt_b32_u64(i + lanes, n);
|
3416
|
+
const auto mask2 = svwhilelt_b32_u64(i + lanes2, n);
|
3417
|
+
const auto mask3 = svwhilelt_b32_u64(i + lanes3, n);
|
3418
|
+
const auto ai0 = svld1_f32(mask0, a + i);
|
3419
|
+
const auto ai1 = svld1_f32(mask1, a + i + lanes);
|
3420
|
+
const auto ai2 = svld1_f32(mask2, a + i + lanes2);
|
3421
|
+
const auto ai3 = svld1_f32(mask3, a + i + lanes3);
|
3422
|
+
const auto bi0 = svld1_f32(mask0, b + i);
|
3423
|
+
const auto bi1 = svld1_f32(mask1, b + i + lanes);
|
3424
|
+
const auto bi2 = svld1_f32(mask2, b + i + lanes2);
|
3425
|
+
const auto bi3 = svld1_f32(mask3, b + i + lanes3);
|
3426
|
+
const auto ci0 = svmla_n_f32_x(mask0, ai0, bi0, bf);
|
3427
|
+
const auto ci1 = svmla_n_f32_x(mask1, ai1, bi1, bf);
|
3428
|
+
const auto ci2 = svmla_n_f32_x(mask2, ai2, bi2, bf);
|
3429
|
+
const auto ci3 = svmla_n_f32_x(mask3, ai3, bi3, bf);
|
3430
|
+
svst1_f32(mask0, c + i, ci0);
|
3431
|
+
svst1_f32(mask1, c + i + lanes, ci1);
|
3432
|
+
svst1_f32(mask2, c + i + lanes2, ci2);
|
3433
|
+
svst1_f32(mask3, c + i + lanes3, ci3);
|
3434
|
+
}
|
3435
|
+
|
1949
3436
|
#elif defined(__aarch64__)
|
1950
3437
|
|
1951
3438
|
void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
|
@@ -2278,7 +3765,7 @@ void fvec_add(size_t d, const float* a, float b, float* c) {
|
|
2278
3765
|
size_t i;
|
2279
3766
|
simd8float32 bv(b);
|
2280
3767
|
for (i = 0; i + 7 < d; i += 8) {
|
2281
|
-
simd8float32 ci, ai
|
3768
|
+
simd8float32 ci, ai;
|
2282
3769
|
ai.loadu(a + i);
|
2283
3770
|
ci = ai + bv;
|
2284
3771
|
ci.storeu(c + i);
|