mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,1146 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <metal_simdgroup>
|
|
6
|
+
#include <metal_simdgroup_matrix>
|
|
7
|
+
#include <metal_stdlib>
|
|
8
|
+
|
|
9
|
+
#include "mlx/backend/metal/kernels/steel/defines.h"
|
|
10
|
+
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
|
|
11
|
+
#include "mlx/backend/metal/kernels/steel/utils/integral_constant.h"
|
|
12
|
+
|
|
13
|
+
using namespace metal;
|
|
14
|
+
|
|
15
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
16
|
+
// MMA helper
|
|
17
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
18
|
+
|
|
19
|
+
namespace mlx {
|
|
20
|
+
namespace steel {
|
|
21
|
+
|
|
22
|
+
template <typename T, int kFragRows_, int kFragCols_>
|
|
23
|
+
struct BaseMMAFrag {
|
|
24
|
+
static_assert(
|
|
25
|
+
kFragRows_ == 8,
|
|
26
|
+
"Only 8 x 8 fragment matrices are currently supported");
|
|
27
|
+
static_assert(
|
|
28
|
+
kFragCols_ == 8,
|
|
29
|
+
"Only 8 x 8 fragment matrices are currently supported");
|
|
30
|
+
};
|
|
31
|
+
|
|
32
|
+
template <typename T>
|
|
33
|
+
struct BaseMMAFrag<T, 8, 8> {
|
|
34
|
+
STEEL_CONST int kFragRows = 8;
|
|
35
|
+
STEEL_CONST int kFragCols = 8;
|
|
36
|
+
|
|
37
|
+
STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32;
|
|
38
|
+
|
|
39
|
+
STEEL_CONST int kElemRows = 1;
|
|
40
|
+
STEEL_CONST int kElemCols = 2;
|
|
41
|
+
|
|
42
|
+
static_assert(
|
|
43
|
+
kElemRows * kElemCols == kElemsPerFrag,
|
|
44
|
+
"MMAFrag shape is not consistent with MMAFrag size");
|
|
45
|
+
|
|
46
|
+
typedef metal::simdgroup_matrix<T, kFragRows, kFragCols> mat_type;
|
|
47
|
+
typedef metal::vec<T, kElemsPerFrag> frag_type;
|
|
48
|
+
|
|
49
|
+
METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id
|
|
50
|
+
[[thread_index_in_simdgroup]]) {
|
|
51
|
+
const short qid = simd_lane_id / 4;
|
|
52
|
+
const short fm = (qid & 4) + ((simd_lane_id / 2) % 4);
|
|
53
|
+
const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
|
54
|
+
return short2{fn, fm};
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
template <typename SrcPtrType, typename StrX, typename StrY>
|
|
58
|
+
METAL_FUNC static constexpr void
|
|
59
|
+
load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) {
|
|
60
|
+
STEEL_PRAGMA_UNROLL
|
|
61
|
+
for (short i = 0; i < kElemRows; i++) {
|
|
62
|
+
STEEL_PRAGMA_UNROLL
|
|
63
|
+
for (short j = 0; j < kElemCols; j++) {
|
|
64
|
+
dst[i * kElemCols + j] = static_cast<T>(src[i * str_x + j * str_y]);
|
|
65
|
+
}
|
|
66
|
+
}
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
template <
|
|
70
|
+
typename SrcPtrType,
|
|
71
|
+
typename StrX,
|
|
72
|
+
typename StrY,
|
|
73
|
+
typename LimX,
|
|
74
|
+
typename LimY,
|
|
75
|
+
typename OffX,
|
|
76
|
+
typename OffY>
|
|
77
|
+
METAL_FUNC static constexpr void load_safe(
|
|
78
|
+
thread frag_type& dst,
|
|
79
|
+
SrcPtrType src,
|
|
80
|
+
StrX str_x,
|
|
81
|
+
StrY str_y,
|
|
82
|
+
LimX lim_x,
|
|
83
|
+
LimY lim_y,
|
|
84
|
+
OffX off_x = Int<0>{},
|
|
85
|
+
OffY off_y = Int<0>{}) {
|
|
86
|
+
STEEL_PRAGMA_UNROLL
|
|
87
|
+
for (short i = 0; i < kElemRows; i++) {
|
|
88
|
+
STEEL_PRAGMA_UNROLL
|
|
89
|
+
for (short j = 0; j < kElemCols; j++) {
|
|
90
|
+
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
|
|
91
|
+
dst[i * kElemCols + j] =
|
|
92
|
+
static_cast<T>(src[(off_x + i) * str_x + (off_x + j) * str_y]);
|
|
93
|
+
} else {
|
|
94
|
+
dst[i * kElemCols + j] = T(0);
|
|
95
|
+
}
|
|
96
|
+
}
|
|
97
|
+
}
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
template <typename DstPtrType, typename StrX, typename StrY>
|
|
101
|
+
METAL_FUNC static constexpr void
|
|
102
|
+
store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) {
|
|
103
|
+
using U = pointer_element_t<DstPtrType>;
|
|
104
|
+
|
|
105
|
+
STEEL_PRAGMA_UNROLL
|
|
106
|
+
for (short i = 0; i < kElemRows; i++) {
|
|
107
|
+
STEEL_PRAGMA_UNROLL
|
|
108
|
+
for (short j = 0; j < kElemCols; j++) {
|
|
109
|
+
dst[i * str_x + j * str_y] = static_cast<U>(src[i * kElemCols + j]);
|
|
110
|
+
}
|
|
111
|
+
}
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
template <
|
|
115
|
+
typename DstPtrType,
|
|
116
|
+
typename StrX,
|
|
117
|
+
typename StrY,
|
|
118
|
+
typename LimX,
|
|
119
|
+
typename LimY,
|
|
120
|
+
typename OffX,
|
|
121
|
+
typename OffY>
|
|
122
|
+
METAL_FUNC static constexpr void store_safe(
|
|
123
|
+
const thread frag_type& src,
|
|
124
|
+
DstPtrType dst,
|
|
125
|
+
StrX str_x,
|
|
126
|
+
StrY str_y,
|
|
127
|
+
LimX lim_x,
|
|
128
|
+
LimY lim_y,
|
|
129
|
+
OffX off_x = Int<0>{},
|
|
130
|
+
OffY off_y = Int<0>{}) {
|
|
131
|
+
using U = pointer_element_t<DstPtrType>;
|
|
132
|
+
|
|
133
|
+
STEEL_PRAGMA_UNROLL
|
|
134
|
+
for (short i = 0; i < kElemRows; i++) {
|
|
135
|
+
STEEL_PRAGMA_UNROLL
|
|
136
|
+
for (short j = 0; j < kElemCols; j++) {
|
|
137
|
+
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
|
|
138
|
+
dst[(off_x + i) * str_x + (off_y + j) * str_y] =
|
|
139
|
+
static_cast<U>(src[i * kElemCols + j]);
|
|
140
|
+
}
|
|
141
|
+
}
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
template <
|
|
146
|
+
typename DstPtrType,
|
|
147
|
+
typename StrX,
|
|
148
|
+
typename StrY,
|
|
149
|
+
typename StartX,
|
|
150
|
+
typename StopX,
|
|
151
|
+
typename StartY,
|
|
152
|
+
typename StopY,
|
|
153
|
+
typename OffX,
|
|
154
|
+
typename OffY>
|
|
155
|
+
METAL_FUNC static constexpr void store_slice(
|
|
156
|
+
const thread frag_type& src,
|
|
157
|
+
DstPtrType dst,
|
|
158
|
+
StrX str_x,
|
|
159
|
+
StrY str_y,
|
|
160
|
+
StartX start_x,
|
|
161
|
+
StopX stop_x,
|
|
162
|
+
StartY start_y,
|
|
163
|
+
StopY stop_y,
|
|
164
|
+
OffX off_x = Int<0>{},
|
|
165
|
+
OffY off_y = Int<0>{}) {
|
|
166
|
+
using U = pointer_element_t<DstPtrType>;
|
|
167
|
+
|
|
168
|
+
STEEL_PRAGMA_UNROLL
|
|
169
|
+
for (short i = 0; i < kElemRows; i++) {
|
|
170
|
+
STEEL_PRAGMA_UNROLL
|
|
171
|
+
for (short j = 0; j < kElemCols; j++) {
|
|
172
|
+
if ((off_x + i) < stop_x && (off_x + i) >= start_x &&
|
|
173
|
+
(off_y + j) < stop_y && (off_y + j) >= start_y) {
|
|
174
|
+
dst[(off_x + i) * str_x + (off_y + j) * str_y] =
|
|
175
|
+
static_cast<U>(src[i * kElemCols + j]);
|
|
176
|
+
}
|
|
177
|
+
}
|
|
178
|
+
}
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
METAL_FUNC static constexpr void mma(
|
|
182
|
+
thread frag_type& D,
|
|
183
|
+
thread frag_type& A,
|
|
184
|
+
thread frag_type& B,
|
|
185
|
+
thread frag_type& C) {
|
|
186
|
+
mat_type D_mat;
|
|
187
|
+
mat_type A_mat;
|
|
188
|
+
mat_type B_mat;
|
|
189
|
+
mat_type C_mat;
|
|
190
|
+
|
|
191
|
+
reinterpret_cast<thread frag_type&>(A_mat.thread_elements()) = A;
|
|
192
|
+
reinterpret_cast<thread frag_type&>(B_mat.thread_elements()) = B;
|
|
193
|
+
reinterpret_cast<thread frag_type&>(C_mat.thread_elements()) = C;
|
|
194
|
+
|
|
195
|
+
mma(D_mat, A_mat, B_mat, C_mat);
|
|
196
|
+
|
|
197
|
+
D = reinterpret_cast<thread frag_type&>(D_mat.thread_elements());
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
METAL_FUNC static constexpr void mma(
|
|
201
|
+
thread mat_type& D,
|
|
202
|
+
thread mat_type& A,
|
|
203
|
+
thread mat_type& B,
|
|
204
|
+
thread mat_type& C) {
|
|
205
|
+
simdgroup_multiply_accumulate(D, A, B, C);
|
|
206
|
+
}
|
|
207
|
+
};
|
|
208
|
+
|
|
209
|
+
template <
|
|
210
|
+
typename T,
|
|
211
|
+
int kTileRows_,
|
|
212
|
+
int kTileCols_,
|
|
213
|
+
class MMAFrag_ = BaseMMAFrag<T, 8, 8>>
|
|
214
|
+
struct MMATile {
|
|
215
|
+
using MMAFrag_t = MMAFrag_;
|
|
216
|
+
using elem_type = T;
|
|
217
|
+
STEEL_CONST int kFragRows = MMAFrag_t::kFragRows;
|
|
218
|
+
STEEL_CONST int kFragCols = MMAFrag_t::kFragCols;
|
|
219
|
+
STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag;
|
|
220
|
+
|
|
221
|
+
STEEL_CONST int kTileRows = kTileRows_;
|
|
222
|
+
STEEL_CONST int kTileCols = kTileCols_;
|
|
223
|
+
|
|
224
|
+
STEEL_CONST int kRows = kTileRows * kFragRows;
|
|
225
|
+
STEEL_CONST int kCols = kTileCols * kFragCols;
|
|
226
|
+
|
|
227
|
+
STEEL_CONST int kNumFrags = kTileRows * kTileCols;
|
|
228
|
+
STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag;
|
|
229
|
+
|
|
230
|
+
typedef typename MMAFrag_t::mat_type mat_type;
|
|
231
|
+
typedef typename MMAFrag_t::frag_type frag_type;
|
|
232
|
+
|
|
233
|
+
frag_type val_frags[kNumFrags] = {frag_type(0)};
|
|
234
|
+
|
|
235
|
+
METAL_FUNC MMATile() thread {}
|
|
236
|
+
|
|
237
|
+
METAL_FUNC constexpr void clear() {
|
|
238
|
+
STEEL_PRAGMA_UNROLL
|
|
239
|
+
for (short i = 0; i < kNumFrags; ++i) {
|
|
240
|
+
val_frags[i] = frag_type(0);
|
|
241
|
+
}
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) {
|
|
245
|
+
return val_frags[i * kTileCols + j];
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
METAL_FUNC constexpr const thread frag_type& frag_at(
|
|
249
|
+
const short i,
|
|
250
|
+
const short j) const {
|
|
251
|
+
return val_frags[i * kTileCols + j];
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
METAL_FUNC mat_type mat_at(const short i, const short j) {
|
|
255
|
+
mat_type val_mat;
|
|
256
|
+
STEEL_PRAGMA_UNROLL
|
|
257
|
+
for (short ii = 0; ii < kElemsPerFrag; ++ii) {
|
|
258
|
+
val_mat.thread_elements()[ii] = frag_at(i, j)[ii];
|
|
259
|
+
}
|
|
260
|
+
return val_mat;
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
METAL_FUNC thread elem_type* elems() {
|
|
264
|
+
return reinterpret_cast<thread elem_type*>(val_frags);
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
METAL_FUNC const thread elem_type* elems() const {
|
|
268
|
+
return reinterpret_cast<const thread elem_type*>(val_frags);
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
template <typename U, int w_x, int w_y, int str_x, int str_y>
|
|
272
|
+
METAL_FUNC void load(const threadgroup U* src) {
|
|
273
|
+
STEEL_PRAGMA_UNROLL
|
|
274
|
+
for (short i = 0; i < kTileRows; ++i) {
|
|
275
|
+
STEEL_PRAGMA_UNROLL
|
|
276
|
+
for (short j = 0; j < kTileCols; ++j) {
|
|
277
|
+
MMAFrag_t::load(
|
|
278
|
+
frag_at(i, j),
|
|
279
|
+
&(
|
|
280
|
+
src[(i * kFragRows) * w_x * str_x +
|
|
281
|
+
(j * kFragCols) * w_y * str_y]),
|
|
282
|
+
Int<str_x>{},
|
|
283
|
+
Int<str_y>{});
|
|
284
|
+
}
|
|
285
|
+
}
|
|
286
|
+
}
|
|
287
|
+
|
|
288
|
+
template <typename U, int w_x, int w_y, int str_x, int str_y>
|
|
289
|
+
METAL_FUNC void store(threadgroup U* dst) const {
|
|
290
|
+
STEEL_PRAGMA_UNROLL
|
|
291
|
+
for (short i = 0; i < kTileRows; ++i) {
|
|
292
|
+
STEEL_PRAGMA_UNROLL
|
|
293
|
+
for (short j = 0; j < kTileCols; ++j) {
|
|
294
|
+
MMAFrag_t::store(
|
|
295
|
+
frag_at(i, j),
|
|
296
|
+
&(
|
|
297
|
+
dst[(i * kFragRows) * w_x * str_x +
|
|
298
|
+
(j * kFragCols) * w_y * str_y]),
|
|
299
|
+
Int<str_x>{},
|
|
300
|
+
Int<str_y>{});
|
|
301
|
+
}
|
|
302
|
+
}
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
template <typename U, int w_x, int w_y>
|
|
306
|
+
METAL_FUNC void load(const device U* src, const int ld) {
|
|
307
|
+
STEEL_PRAGMA_UNROLL
|
|
308
|
+
for (short i = 0; i < kTileRows; ++i) {
|
|
309
|
+
STEEL_PRAGMA_UNROLL
|
|
310
|
+
for (short j = 0; j < kTileCols; ++j) {
|
|
311
|
+
MMAFrag_t::load(
|
|
312
|
+
frag_at(i, j),
|
|
313
|
+
&(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
|
|
314
|
+
ld,
|
|
315
|
+
Int<1>{});
|
|
316
|
+
}
|
|
317
|
+
}
|
|
318
|
+
}
|
|
319
|
+
|
|
320
|
+
template <typename U, int w_x, int w_y>
|
|
321
|
+
METAL_FUNC void store(device U* dst, const int ld) const {
|
|
322
|
+
STEEL_PRAGMA_UNROLL
|
|
323
|
+
for (short i = 0; i < kTileRows; ++i) {
|
|
324
|
+
STEEL_PRAGMA_UNROLL
|
|
325
|
+
for (short j = 0; j < kTileCols; ++j) {
|
|
326
|
+
MMAFrag_t::store(
|
|
327
|
+
frag_at(i, j),
|
|
328
|
+
&(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
|
|
329
|
+
ld,
|
|
330
|
+
Int<1>{});
|
|
331
|
+
}
|
|
332
|
+
}
|
|
333
|
+
}
|
|
334
|
+
|
|
335
|
+
template <typename U, int w_x, int w_y>
|
|
336
|
+
METAL_FUNC void
|
|
337
|
+
load_safe(const device U* src, const int ld, const short2 src_tile_dims) {
|
|
338
|
+
STEEL_PRAGMA_UNROLL
|
|
339
|
+
for (int i = 0; i < kTileRows; ++i) {
|
|
340
|
+
STEEL_PRAGMA_UNROLL
|
|
341
|
+
for (int j = 0; j < kTileCols; ++j) {
|
|
342
|
+
MMAFrag_t::load_safe(
|
|
343
|
+
frag_at(i, j),
|
|
344
|
+
src,
|
|
345
|
+
ld,
|
|
346
|
+
Int<1>{},
|
|
347
|
+
src_tile_dims.y,
|
|
348
|
+
src_tile_dims.x,
|
|
349
|
+
(i * kFragRows) * w_x,
|
|
350
|
+
(j * kFragCols) * w_y);
|
|
351
|
+
}
|
|
352
|
+
}
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
template <typename U, int w_x, int w_y>
|
|
356
|
+
METAL_FUNC void
|
|
357
|
+
store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const {
|
|
358
|
+
STEEL_PRAGMA_UNROLL
|
|
359
|
+
for (int i = 0; i < kTileRows; ++i) {
|
|
360
|
+
STEEL_PRAGMA_UNROLL
|
|
361
|
+
for (int j = 0; j < kTileCols; ++j) {
|
|
362
|
+
MMAFrag_t::store_safe(
|
|
363
|
+
frag_at(i, j),
|
|
364
|
+
dst,
|
|
365
|
+
ld,
|
|
366
|
+
Int<1>{},
|
|
367
|
+
dst_tile_dims.y,
|
|
368
|
+
dst_tile_dims.x,
|
|
369
|
+
(i * kFragRows) * w_x,
|
|
370
|
+
(j * kFragCols) * w_y);
|
|
371
|
+
}
|
|
372
|
+
}
|
|
373
|
+
}
|
|
374
|
+
|
|
375
|
+
template <typename U, int w_x, int w_y>
|
|
376
|
+
METAL_FUNC void store_slice(
|
|
377
|
+
device U* dst,
|
|
378
|
+
const int ld,
|
|
379
|
+
const short2 start,
|
|
380
|
+
const short2 stop) const {
|
|
381
|
+
STEEL_PRAGMA_UNROLL
|
|
382
|
+
for (int i = 0; i < kTileRows; ++i) {
|
|
383
|
+
STEEL_PRAGMA_UNROLL
|
|
384
|
+
for (int j = 0; j < kTileCols; ++j) {
|
|
385
|
+
MMAFrag_t::store_slice(
|
|
386
|
+
frag_at(i, j),
|
|
387
|
+
dst,
|
|
388
|
+
ld,
|
|
389
|
+
Int<1>{},
|
|
390
|
+
start.y,
|
|
391
|
+
stop.y,
|
|
392
|
+
start.x,
|
|
393
|
+
stop.x,
|
|
394
|
+
(i * kFragRows) * w_x,
|
|
395
|
+
(j * kFragCols) * w_y);
|
|
396
|
+
}
|
|
397
|
+
}
|
|
398
|
+
}
|
|
399
|
+
};
|
|
400
|
+
|
|
401
|
+
template <typename T, typename U, int M, int N, int K>
|
|
402
|
+
METAL_FUNC void tile_matmad(
|
|
403
|
+
thread MMATile<T, M, N>& D,
|
|
404
|
+
thread MMATile<U, M, K>& A,
|
|
405
|
+
thread MMATile<U, K, N>& B,
|
|
406
|
+
thread MMATile<T, M, N>& C) {
|
|
407
|
+
STEEL_PRAGMA_UNROLL
|
|
408
|
+
for (short m = 0; m < M; ++m) {
|
|
409
|
+
STEEL_PRAGMA_UNROLL
|
|
410
|
+
for (short n = 0; n < N; ++n) {
|
|
411
|
+
short n_serp = (m % 2) ? (N - 1 - n) : n;
|
|
412
|
+
STEEL_PRAGMA_UNROLL
|
|
413
|
+
for (short k = 0; k < K; ++k) {
|
|
414
|
+
MMATile<T, M, N>::MMAFrag_t::mma(
|
|
415
|
+
D.frag_at(m, n_serp),
|
|
416
|
+
A.frag_at(m, k),
|
|
417
|
+
B.frag_at(k, n_serp),
|
|
418
|
+
C.frag_at(m, n_serp));
|
|
419
|
+
}
|
|
420
|
+
}
|
|
421
|
+
}
|
|
422
|
+
}
|
|
423
|
+
|
|
424
|
+
template <typename InT>
|
|
425
|
+
struct TransformNone<complex64_t, InT> {
|
|
426
|
+
static METAL_FUNC complex64_t apply(complex64_t x) {
|
|
427
|
+
return x;
|
|
428
|
+
}
|
|
429
|
+
static METAL_FUNC complex64_t apply(complex64_t x, complex64_t) {
|
|
430
|
+
return x;
|
|
431
|
+
}
|
|
432
|
+
};
|
|
433
|
+
|
|
434
|
+
template <
|
|
435
|
+
typename T,
|
|
436
|
+
typename U,
|
|
437
|
+
int BM,
|
|
438
|
+
int BN,
|
|
439
|
+
int BK,
|
|
440
|
+
int WM,
|
|
441
|
+
int WN,
|
|
442
|
+
bool transpose_a,
|
|
443
|
+
bool transpose_b,
|
|
444
|
+
short lda_tgp,
|
|
445
|
+
short ldb_tgp,
|
|
446
|
+
typename AccumType = float,
|
|
447
|
+
typename Epilogue = TransformNone<U, AccumType>>
|
|
448
|
+
struct BlockMMA {
|
|
449
|
+
// MMAFrag size
|
|
450
|
+
STEEL_CONST short kFragSize = 8;
|
|
451
|
+
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
|
|
452
|
+
|
|
453
|
+
// Warp tile simdgroup matrix strides along M
|
|
454
|
+
STEEL_CONST short TM_stride = kFragSize * WM;
|
|
455
|
+
// Warp tile simdgroup matrix strides along M
|
|
456
|
+
STEEL_CONST short TN_stride = kFragSize * WN;
|
|
457
|
+
|
|
458
|
+
// Warp tile size along M
|
|
459
|
+
STEEL_CONST short TM = BM / (kFragSize * WM);
|
|
460
|
+
// Warp tile size along N
|
|
461
|
+
STEEL_CONST short TN = BN / (kFragSize * WN);
|
|
462
|
+
|
|
463
|
+
// Threadgroup A strides
|
|
464
|
+
STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M
|
|
465
|
+
STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K
|
|
466
|
+
|
|
467
|
+
// Threadgroup B strides
|
|
468
|
+
STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K
|
|
469
|
+
STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N
|
|
470
|
+
|
|
471
|
+
// Threadgroup strides along K
|
|
472
|
+
STEEL_CONST short tile_stride_a = kFragSize * A_str_k;
|
|
473
|
+
STEEL_CONST short tile_stride_b = kFragSize * B_str_k;
|
|
474
|
+
|
|
475
|
+
// Simdgroup matrices
|
|
476
|
+
MMATile<AccumType, TM, 1, MMAFrag_acc_t> Atile;
|
|
477
|
+
MMATile<AccumType, 1, TN, MMAFrag_acc_t> Btile;
|
|
478
|
+
MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile;
|
|
479
|
+
|
|
480
|
+
// Offsets within threadgroup
|
|
481
|
+
short sm;
|
|
482
|
+
short sn;
|
|
483
|
+
|
|
484
|
+
short As_offset;
|
|
485
|
+
short Bs_offset;
|
|
486
|
+
|
|
487
|
+
/* Constructor */
|
|
488
|
+
METAL_FUNC BlockMMA(
|
|
489
|
+
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
490
|
+
ushort simd_lane_id [[thread_index_in_simdgroup]]) {
|
|
491
|
+
// Determine thread position in simdgroup matrix
|
|
492
|
+
short tm = kFragSize * (simd_group_id / WN);
|
|
493
|
+
short tn = kFragSize * (simd_group_id % WN);
|
|
494
|
+
|
|
495
|
+
short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
|
|
496
|
+
sm = simd_coord.y;
|
|
497
|
+
sn = simd_coord.x;
|
|
498
|
+
|
|
499
|
+
// Determine thread and simdgroup offset
|
|
500
|
+
As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K
|
|
501
|
+
Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N
|
|
502
|
+
|
|
503
|
+
sm += tm;
|
|
504
|
+
sn += tn;
|
|
505
|
+
}
|
|
506
|
+
|
|
507
|
+
/* (BM, BK) X (BK, BN) multiply accumulate function */
|
|
508
|
+
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
|
|
509
|
+
// Adjust for simdgroup and thread location
|
|
510
|
+
As += As_offset;
|
|
511
|
+
Bs += Bs_offset;
|
|
512
|
+
|
|
513
|
+
// Iterate over BK in blocks of kFragSize
|
|
514
|
+
STEEL_PRAGMA_UNROLL
|
|
515
|
+
for (short kk = 0; kk < BK; kk += kFragSize) {
|
|
516
|
+
simdgroup_barrier(mem_flags::mem_none);
|
|
517
|
+
|
|
518
|
+
Atile.template load<T, WM, 1, A_str_m, A_str_k>(As);
|
|
519
|
+
|
|
520
|
+
simdgroup_barrier(mem_flags::mem_none);
|
|
521
|
+
|
|
522
|
+
Btile.template load<T, 1, WN, B_str_k, B_str_n>(Bs);
|
|
523
|
+
|
|
524
|
+
simdgroup_barrier(mem_flags::mem_none);
|
|
525
|
+
|
|
526
|
+
tile_matmad(Ctile, Atile, Btile, Ctile);
|
|
527
|
+
|
|
528
|
+
// Progress to next simdgroup tile
|
|
529
|
+
As += tile_stride_a;
|
|
530
|
+
Bs += tile_stride_b;
|
|
531
|
+
}
|
|
532
|
+
}
|
|
533
|
+
|
|
534
|
+
/* Store results from simdgroup_matrix results into device memory */
|
|
535
|
+
METAL_FUNC void store_result(device U* D, const int ldd) {
|
|
536
|
+
// Apply epilogue
|
|
537
|
+
STEEL_PRAGMA_UNROLL
|
|
538
|
+
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
|
|
539
|
+
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
|
|
540
|
+
}
|
|
541
|
+
|
|
542
|
+
// Adjust for simdgroup and thread location
|
|
543
|
+
D += sm * ldd + sn;
|
|
544
|
+
|
|
545
|
+
Ctile.template store<U, WM, WN>(D, ldd);
|
|
546
|
+
}
|
|
547
|
+
|
|
548
|
+
METAL_FUNC void
|
|
549
|
+
store_result_slice(device U* D, const int ldd, short2 start, short2 stop) {
|
|
550
|
+
// Apply epilogue
|
|
551
|
+
STEEL_PRAGMA_UNROLL
|
|
552
|
+
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
|
|
553
|
+
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
|
|
554
|
+
}
|
|
555
|
+
|
|
556
|
+
D += sm * ldd + sn;
|
|
557
|
+
start -= short2(sn, sm);
|
|
558
|
+
stop -= short2(sn, sm);
|
|
559
|
+
|
|
560
|
+
// TODO: Check the start as well
|
|
561
|
+
if (stop.y <= 0 || stop.x <= 0) {
|
|
562
|
+
return;
|
|
563
|
+
}
|
|
564
|
+
|
|
565
|
+
Ctile.template store_slice<U, WM, WN>(D, ldd, start, stop);
|
|
566
|
+
}
|
|
567
|
+
|
|
568
|
+
METAL_FUNC void
|
|
569
|
+
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
|
|
570
|
+
// Apply epilogue
|
|
571
|
+
STEEL_PRAGMA_UNROLL
|
|
572
|
+
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
|
|
573
|
+
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
|
|
574
|
+
}
|
|
575
|
+
|
|
576
|
+
// Adjust for simdgroup and thread location
|
|
577
|
+
D += sm * ldd + sn;
|
|
578
|
+
dst_tile_dims -= short2(sn, sm);
|
|
579
|
+
|
|
580
|
+
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
|
581
|
+
return;
|
|
582
|
+
|
|
583
|
+
Ctile.template store_safe<U, WM, WN>(D, ldd, dst_tile_dims);
|
|
584
|
+
}
|
|
585
|
+
|
|
586
|
+
/* Apply epilogue */
|
|
587
|
+
template <typename UnaryEpilogue>
|
|
588
|
+
METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) {
|
|
589
|
+
// Loop over all simdgroup tiles
|
|
590
|
+
STEEL_PRAGMA_UNROLL
|
|
591
|
+
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
|
|
592
|
+
Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]);
|
|
593
|
+
}
|
|
594
|
+
}
|
|
595
|
+
|
|
596
|
+
/* Apply epilogue */
|
|
597
|
+
template <typename BinaryEpilogue>
|
|
598
|
+
METAL_FUNC void apply_epilogue(
|
|
599
|
+
const device U* C,
|
|
600
|
+
const int ldc,
|
|
601
|
+
const int fdc,
|
|
602
|
+
thread const BinaryEpilogue& epilogue_op) {
|
|
603
|
+
// Adjust for simdgroup and thread location
|
|
604
|
+
C += (sm)*ldc + (sn)*fdc;
|
|
605
|
+
|
|
606
|
+
// Loop over all simdgroup tiles
|
|
607
|
+
STEEL_PRAGMA_UNROLL
|
|
608
|
+
for (short i = 0; i < TM; i++) {
|
|
609
|
+
STEEL_PRAGMA_UNROLL
|
|
610
|
+
for (short j = 0; j < TN; j++) {
|
|
611
|
+
// Get accumulated result and associated offset in C
|
|
612
|
+
thread auto& accum = Ctile.frag_at(i, j);
|
|
613
|
+
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
|
614
|
+
|
|
615
|
+
// Apply epilogue
|
|
616
|
+
STEEL_PRAGMA_UNROLL
|
|
617
|
+
for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) {
|
|
618
|
+
accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
|
|
619
|
+
}
|
|
620
|
+
}
|
|
621
|
+
}
|
|
622
|
+
}
|
|
623
|
+
|
|
624
|
+
/* Apply epilogue */
|
|
625
|
+
template <typename BinaryEpilogue>
|
|
626
|
+
METAL_FUNC void apply_epilogue_safe(
|
|
627
|
+
const device U* C,
|
|
628
|
+
const int ldc,
|
|
629
|
+
const int fdc,
|
|
630
|
+
short2 dst_tile_dims,
|
|
631
|
+
thread const BinaryEpilogue& epilogue_op) {
|
|
632
|
+
// Adjust for simdgroup and thread location
|
|
633
|
+
C += (sm)*ldc + (sn)*fdc;
|
|
634
|
+
dst_tile_dims -= short2(sn, sm);
|
|
635
|
+
|
|
636
|
+
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
|
637
|
+
return;
|
|
638
|
+
|
|
639
|
+
// Loop over all simdgroup tiles
|
|
640
|
+
STEEL_PRAGMA_UNROLL
|
|
641
|
+
for (short i = 0; i < TM; i++) {
|
|
642
|
+
STEEL_PRAGMA_UNROLL
|
|
643
|
+
for (short j = 0; j < TN; j++) {
|
|
644
|
+
// Get accumulated result and associated offset in C
|
|
645
|
+
thread auto& accum = Ctile.frag_at(i, j);
|
|
646
|
+
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
|
647
|
+
|
|
648
|
+
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
|
|
649
|
+
|
|
650
|
+
// Read C
|
|
651
|
+
U c_elems[kelems] = {0};
|
|
652
|
+
|
|
653
|
+
STEEL_PRAGMA_UNROLL
|
|
654
|
+
for (short k = 0; k < kelems; k++) {
|
|
655
|
+
if ((j * TN_stride + k) < dst_tile_dims.x) {
|
|
656
|
+
c_elems[k] = C[offset_c + k * fdc];
|
|
657
|
+
}
|
|
658
|
+
}
|
|
659
|
+
|
|
660
|
+
// Apply epilogue
|
|
661
|
+
STEEL_PRAGMA_UNROLL
|
|
662
|
+
for (short k = 0; k < kelems; k++) {
|
|
663
|
+
accum[k] = epilogue_op.apply(accum[k], c_elems[k]);
|
|
664
|
+
}
|
|
665
|
+
}
|
|
666
|
+
}
|
|
667
|
+
}
|
|
668
|
+
|
|
669
|
+
/* Store results from simdgroup_matrix results into device memory */
|
|
670
|
+
METAL_FUNC void store_result(
|
|
671
|
+
device U* D,
|
|
672
|
+
const int ldd,
|
|
673
|
+
const device U* C,
|
|
674
|
+
const int ldc,
|
|
675
|
+
const int fdc,
|
|
676
|
+
thread const Epilogue& epilogue_op) const {
|
|
677
|
+
// Adjust for simdgroup and thread location
|
|
678
|
+
C += (sm)*ldc + (sn)*fdc;
|
|
679
|
+
D += (sm)*ldd + sn;
|
|
680
|
+
|
|
681
|
+
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
|
|
682
|
+
|
|
683
|
+
// Loop over all simdgroup tiles
|
|
684
|
+
STEEL_PRAGMA_UNROLL
|
|
685
|
+
for (short i = 0; i < TM; i++) {
|
|
686
|
+
STEEL_PRAGMA_UNROLL
|
|
687
|
+
for (short j = 0; j < TN; j++) {
|
|
688
|
+
// Get accumulated result and associated offset in C
|
|
689
|
+
thread const auto& accum = Ctile.frag_at(i, j);
|
|
690
|
+
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
|
691
|
+
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
|
|
692
|
+
|
|
693
|
+
// Apply epilogue
|
|
694
|
+
STEEL_PRAGMA_UNROLL
|
|
695
|
+
for (short k = 0; k < kelems; k++) {
|
|
696
|
+
D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
|
|
697
|
+
}
|
|
698
|
+
}
|
|
699
|
+
}
|
|
700
|
+
}
|
|
701
|
+
|
|
702
|
+
METAL_FUNC void store_result_safe(
|
|
703
|
+
device U* D,
|
|
704
|
+
const int ldd,
|
|
705
|
+
const device U* C,
|
|
706
|
+
const int ldc,
|
|
707
|
+
const int fdc,
|
|
708
|
+
short2 dst_tile_dims,
|
|
709
|
+
thread const Epilogue& epilogue_op) const {
|
|
710
|
+
// Adjust for simdgroup and thread location
|
|
711
|
+
C += (sm)*ldc + (sn)*fdc;
|
|
712
|
+
D += (sm)*ldd + sn;
|
|
713
|
+
dst_tile_dims -= short2(sn, sm);
|
|
714
|
+
|
|
715
|
+
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
|
716
|
+
return;
|
|
717
|
+
|
|
718
|
+
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
|
|
719
|
+
|
|
720
|
+
STEEL_PRAGMA_UNROLL
|
|
721
|
+
for (int i = 0; i < TM; i++) {
|
|
722
|
+
if (i * TM_stride < dst_tile_dims.y) {
|
|
723
|
+
STEEL_PRAGMA_UNROLL
|
|
724
|
+
for (int j = 0; j < TN; j++) {
|
|
725
|
+
// Get accumulated result and associated offset in C
|
|
726
|
+
thread const auto& accum = Ctile.frag_at(i, j);
|
|
727
|
+
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
|
728
|
+
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
|
|
729
|
+
|
|
730
|
+
// Apply epilogue
|
|
731
|
+
STEEL_PRAGMA_UNROLL
|
|
732
|
+
for (short k = 0; k < kelems; k++) {
|
|
733
|
+
if ((j * TN_stride + k) < dst_tile_dims.x) {
|
|
734
|
+
D[offset_d + k] =
|
|
735
|
+
epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
|
|
736
|
+
}
|
|
737
|
+
}
|
|
738
|
+
}
|
|
739
|
+
}
|
|
740
|
+
}
|
|
741
|
+
}
|
|
742
|
+
};
|
|
743
|
+
|
|
744
|
+
template <
|
|
745
|
+
typename U,
|
|
746
|
+
int BM,
|
|
747
|
+
int BN,
|
|
748
|
+
int BK,
|
|
749
|
+
int WM,
|
|
750
|
+
int WN,
|
|
751
|
+
bool transpose_a,
|
|
752
|
+
bool transpose_b,
|
|
753
|
+
short lda_tgp,
|
|
754
|
+
short ldb_tgp,
|
|
755
|
+
typename AccumType,
|
|
756
|
+
typename Epilogue>
|
|
757
|
+
struct BlockMMA<
|
|
758
|
+
complex64_t,
|
|
759
|
+
U,
|
|
760
|
+
BM,
|
|
761
|
+
BN,
|
|
762
|
+
BK,
|
|
763
|
+
WM,
|
|
764
|
+
WN,
|
|
765
|
+
transpose_a,
|
|
766
|
+
transpose_b,
|
|
767
|
+
lda_tgp,
|
|
768
|
+
ldb_tgp,
|
|
769
|
+
AccumType,
|
|
770
|
+
Epilogue> {
|
|
771
|
+
static_assert(
|
|
772
|
+
metal::is_same_v<AccumType, float>,
|
|
773
|
+
"BlockMMA<complex64_t,...> expects float accumulators");
|
|
774
|
+
static_assert(
|
|
775
|
+
metal::is_same_v<U, complex64_t>,
|
|
776
|
+
"For complex BlockMMA, U must be complex64_t; use a different epilogue for projections");
|
|
777
|
+
// MMAFrag size
|
|
778
|
+
STEEL_CONST short kFragSize = 8;
|
|
779
|
+
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
|
|
780
|
+
|
|
781
|
+
// Warp tile simdgroup matrix strides along M
|
|
782
|
+
STEEL_CONST short TM_stride = kFragSize * WM;
|
|
783
|
+
// Warp tile simdgroup matrix strides along M
|
|
784
|
+
STEEL_CONST short TN_stride = kFragSize * WN;
|
|
785
|
+
|
|
786
|
+
// Warp tile size along M
|
|
787
|
+
STEEL_CONST short TM = BM / (kFragSize * WM);
|
|
788
|
+
// Warp tile size along N
|
|
789
|
+
STEEL_CONST short TN = BN / (kFragSize * WN);
|
|
790
|
+
|
|
791
|
+
// Threadgroup A strides
|
|
792
|
+
STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M
|
|
793
|
+
STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K
|
|
794
|
+
|
|
795
|
+
// Threadgroup B strides
|
|
796
|
+
STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K
|
|
797
|
+
STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N
|
|
798
|
+
|
|
799
|
+
// Threadgroup strides along K
|
|
800
|
+
STEEL_CONST short tile_stride_a = kFragSize * A_str_k;
|
|
801
|
+
STEEL_CONST short tile_stride_b = kFragSize * B_str_k;
|
|
802
|
+
|
|
803
|
+
// When indexing complex as float[2]
|
|
804
|
+
STEEL_CONST short A_str_m_f = A_str_m * 2;
|
|
805
|
+
STEEL_CONST short A_str_k_f = A_str_k * 2;
|
|
806
|
+
STEEL_CONST short B_str_k_f = B_str_k * 2;
|
|
807
|
+
STEEL_CONST short B_str_n_f = B_str_n * 2;
|
|
808
|
+
STEEL_CONST short tile_stride_a_f = tile_stride_a * 2;
|
|
809
|
+
STEEL_CONST short tile_stride_b_f = tile_stride_b * 2;
|
|
810
|
+
|
|
811
|
+
// Accumulators (real/imag)
|
|
812
|
+
MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile_r;
|
|
813
|
+
MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile_i;
|
|
814
|
+
|
|
815
|
+
// Offsets within threadgroup
|
|
816
|
+
short sm, sn;
|
|
817
|
+
short As_offset, Bs_offset;
|
|
818
|
+
|
|
819
|
+
/* Constructor */
|
|
820
|
+
METAL_FUNC BlockMMA(
|
|
821
|
+
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
822
|
+
ushort simd_lane_id [[thread_index_in_simdgroup]]) {
|
|
823
|
+
// Determine thread position in simdgroup matrix
|
|
824
|
+
short tm = kFragSize * (simd_group_id / WN);
|
|
825
|
+
short tn = kFragSize * (simd_group_id % WN);
|
|
826
|
+
|
|
827
|
+
short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
|
|
828
|
+
sm = simd_coord.y;
|
|
829
|
+
sn = simd_coord.x;
|
|
830
|
+
|
|
831
|
+
// Determine thread and simdgroup offset
|
|
832
|
+
As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // (M,K)
|
|
833
|
+
Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // (K,N)
|
|
834
|
+
|
|
835
|
+
sm += tm;
|
|
836
|
+
sn += tn;
|
|
837
|
+
}
|
|
838
|
+
|
|
839
|
+
/* Karatsuba MMA: 3 real MMAs per K-chunk */
|
|
840
|
+
METAL_FUNC void mma(
|
|
841
|
+
const threadgroup complex64_t* As,
|
|
842
|
+
const threadgroup complex64_t* Bs) {
|
|
843
|
+
// Adjust for simdgroup and thread location
|
|
844
|
+
As += As_offset;
|
|
845
|
+
Bs += Bs_offset;
|
|
846
|
+
threadgroup const float* As_f =
|
|
847
|
+
reinterpret_cast<threadgroup const float*>(As);
|
|
848
|
+
threadgroup const float* Bs_f =
|
|
849
|
+
reinterpret_cast<threadgroup const float*>(Bs);
|
|
850
|
+
|
|
851
|
+
// Iterate over BK in blocks of kFragSize
|
|
852
|
+
STEEL_PRAGMA_UNROLL
|
|
853
|
+
for (short kk = 0; kk < BK; kk += kFragSize) {
|
|
854
|
+
simdgroup_barrier(mem_flags::mem_none);
|
|
855
|
+
|
|
856
|
+
MMATile<AccumType, TM, 1, MMAFrag_acc_t> Ar, Ai;
|
|
857
|
+
Ar.template load<float, WM, 1, A_str_m_f, A_str_k_f>(As_f + 0);
|
|
858
|
+
Ai.template load<float, WM, 1, A_str_m_f, A_str_k_f>(As_f + 1);
|
|
859
|
+
|
|
860
|
+
simdgroup_barrier(mem_flags::mem_none);
|
|
861
|
+
|
|
862
|
+
MMATile<AccumType, 1, TN, MMAFrag_acc_t> Br, Bi;
|
|
863
|
+
Br.template load<float, 1, WN, B_str_k_f, B_str_n_f>(Bs_f + 0);
|
|
864
|
+
Bi.template load<float, 1, WN, B_str_k_f, B_str_n_f>(Bs_f + 1);
|
|
865
|
+
|
|
866
|
+
simdgroup_barrier(mem_flags::mem_none);
|
|
867
|
+
|
|
868
|
+
// P = Ar*Br ; Q = Ai*Bi ; R = (Ar+Ai)*(Br+Bi)
|
|
869
|
+
MMATile<AccumType, TM, TN, MMAFrag_acc_t> P, Q, R;
|
|
870
|
+
|
|
871
|
+
tile_matmad(P, Ar, Br, P);
|
|
872
|
+
tile_matmad(Q, Ai, Bi, Q);
|
|
873
|
+
|
|
874
|
+
STEEL_PRAGMA_UNROLL
|
|
875
|
+
for (short i = 0; i < decltype(Ar)::kElemsPerTile; ++i)
|
|
876
|
+
Ar.elems()[i] += Ai.elems()[i];
|
|
877
|
+
STEEL_PRAGMA_UNROLL
|
|
878
|
+
for (short i = 0; i < decltype(Br)::kElemsPerTile; ++i)
|
|
879
|
+
Br.elems()[i] += Bi.elems()[i];
|
|
880
|
+
|
|
881
|
+
tile_matmad(R, Ar, Br, R);
|
|
882
|
+
|
|
883
|
+
// C_r += P - Q ; C_i -= Q
|
|
884
|
+
STEEL_PRAGMA_UNROLL
|
|
885
|
+
for (short i = 0; i < decltype(Ctile_r)::kElemsPerTile; ++i) {
|
|
886
|
+
const auto p = P.elems()[i];
|
|
887
|
+
const auto q = Q.elems()[i];
|
|
888
|
+
const auto r = R.elems()[i];
|
|
889
|
+
Ctile_r.elems()[i] += (p - q);
|
|
890
|
+
Ctile_i.elems()[i] += (r - p - q);
|
|
891
|
+
}
|
|
892
|
+
|
|
893
|
+
// Progress to next simdgroup tile
|
|
894
|
+
As_f += tile_stride_a_f;
|
|
895
|
+
Bs_f += tile_stride_b_f;
|
|
896
|
+
}
|
|
897
|
+
}
|
|
898
|
+
|
|
899
|
+
/* Store results from simdgroup_matrix results into device memory */
|
|
900
|
+
METAL_FUNC void store_result(device U* D, const int ldd) {
|
|
901
|
+
// Adjust for simdgroup and thread location
|
|
902
|
+
D += sm * ldd + sn;
|
|
903
|
+
|
|
904
|
+
STEEL_PRAGMA_UNROLL
|
|
905
|
+
for (short i = 0; i < TM; i++) {
|
|
906
|
+
STEEL_PRAGMA_UNROLL
|
|
907
|
+
for (short j = 0; j < TN; j++) {
|
|
908
|
+
thread const auto& r = Ctile_r.frag_at(i, j);
|
|
909
|
+
thread const auto& im = Ctile_i.frag_at(i, j);
|
|
910
|
+
int off = (i * TM_stride) * ldd + (j * TN_stride);
|
|
911
|
+
STEEL_PRAGMA_UNROLL
|
|
912
|
+
for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) {
|
|
913
|
+
D[off + k] = Epilogue::apply(complex64_t(r[k], im[k]));
|
|
914
|
+
}
|
|
915
|
+
}
|
|
916
|
+
}
|
|
917
|
+
}
|
|
918
|
+
|
|
919
|
+
METAL_FUNC void
|
|
920
|
+
store_result_slice(device U* D, const int ldd, short2 start, short2 stop) {
|
|
921
|
+
D += sm * ldd + sn;
|
|
922
|
+
start -= short2(sn, sm);
|
|
923
|
+
stop -= short2(sn, sm);
|
|
924
|
+
|
|
925
|
+
if (stop.y <= 0 || stop.x <= 0)
|
|
926
|
+
return;
|
|
927
|
+
|
|
928
|
+
STEEL_PRAGMA_UNROLL
|
|
929
|
+
for (short i = 0; i < TM; ++i) {
|
|
930
|
+
const int row = i * TM_stride;
|
|
931
|
+
if (row >= start.y && row < stop.y) {
|
|
932
|
+
STEEL_PRAGMA_UNROLL
|
|
933
|
+
for (short j = 0; j < TN; ++j) {
|
|
934
|
+
const int off = row * ldd + (j * TN_stride);
|
|
935
|
+
thread const auto& r = Ctile_r.frag_at(i, j);
|
|
936
|
+
thread const auto& im = Ctile_i.frag_at(i, j);
|
|
937
|
+
|
|
938
|
+
STEEL_PRAGMA_UNROLL
|
|
939
|
+
for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; ++k) {
|
|
940
|
+
const int col = j * TN_stride + k;
|
|
941
|
+
if (col >= start.x && col < stop.x) {
|
|
942
|
+
D[off + k] = Epilogue::apply(complex64_t(r[k], im[k]));
|
|
943
|
+
}
|
|
944
|
+
}
|
|
945
|
+
}
|
|
946
|
+
}
|
|
947
|
+
}
|
|
948
|
+
}
|
|
949
|
+
|
|
950
|
+
METAL_FUNC void
|
|
951
|
+
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
|
|
952
|
+
D += sm * ldd + sn;
|
|
953
|
+
dst_tile_dims -= short2(sn, sm);
|
|
954
|
+
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
|
955
|
+
return;
|
|
956
|
+
STEEL_PRAGMA_UNROLL
|
|
957
|
+
for (short i = 0; i < TM; i++) {
|
|
958
|
+
if (i * TM_stride < dst_tile_dims.y) {
|
|
959
|
+
STEEL_PRAGMA_UNROLL
|
|
960
|
+
for (short j = 0; j < TN; j++) {
|
|
961
|
+
int off = (i * TM_stride) * ldd + (j * TN_stride);
|
|
962
|
+
thread const auto& r = Ctile_r.frag_at(i, j);
|
|
963
|
+
thread const auto& im = Ctile_i.frag_at(i, j);
|
|
964
|
+
STEEL_PRAGMA_UNROLL
|
|
965
|
+
for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) {
|
|
966
|
+
if ((j * TN_stride + k) < dst_tile_dims.x) {
|
|
967
|
+
D[off + k] = Epilogue::apply(complex64_t(r[k], im[k]));
|
|
968
|
+
}
|
|
969
|
+
}
|
|
970
|
+
}
|
|
971
|
+
}
|
|
972
|
+
}
|
|
973
|
+
}
|
|
974
|
+
|
|
975
|
+
/* Apply epilogue */
|
|
976
|
+
template <typename UnaryEpilogue>
|
|
977
|
+
METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) {
|
|
978
|
+
STEEL_PRAGMA_UNROLL
|
|
979
|
+
for (short i = 0; i < decltype(Ctile_r)::kElemsPerTile; i++) {
|
|
980
|
+
complex64_t out = epilogue_op.apply(
|
|
981
|
+
complex64_t(Ctile_r.elems()[i], Ctile_i.elems()[i]));
|
|
982
|
+
Ctile_r.elems()[i] = out.real;
|
|
983
|
+
Ctile_i.elems()[i] = out.imag;
|
|
984
|
+
}
|
|
985
|
+
}
|
|
986
|
+
|
|
987
|
+
/* Apply epilogue */
|
|
988
|
+
template <typename BinaryEpilogue>
|
|
989
|
+
METAL_FUNC void apply_epilogue(
|
|
990
|
+
const device U* C,
|
|
991
|
+
const int ldc,
|
|
992
|
+
const int fdc,
|
|
993
|
+
thread const BinaryEpilogue& epilogue_op) {
|
|
994
|
+
// Adjust for simdgroup and thread location
|
|
995
|
+
C += (sm)*ldc + (sn)*fdc;
|
|
996
|
+
|
|
997
|
+
// Loop over all simdgroup tiles
|
|
998
|
+
STEEL_PRAGMA_UNROLL
|
|
999
|
+
for (short i = 0; i < TM; i++) {
|
|
1000
|
+
STEEL_PRAGMA_UNROLL
|
|
1001
|
+
for (short j = 0; j < TN; j++) {
|
|
1002
|
+
// Get accumulated result and associated offset in Cr, Ci
|
|
1003
|
+
thread auto& r = Ctile_r.frag_at(i, j);
|
|
1004
|
+
thread auto& im = Ctile_i.frag_at(i, j);
|
|
1005
|
+
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
|
1006
|
+
|
|
1007
|
+
STEEL_PRAGMA_UNROLL
|
|
1008
|
+
for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) {
|
|
1009
|
+
complex64_t out = epilogue_op.apply(
|
|
1010
|
+
complex64_t(r[k], im[k]), C[offset_c + k * fdc]);
|
|
1011
|
+
r[k] = out.real;
|
|
1012
|
+
im[k] = out.imag;
|
|
1013
|
+
}
|
|
1014
|
+
}
|
|
1015
|
+
}
|
|
1016
|
+
}
|
|
1017
|
+
|
|
1018
|
+
/* Apply epilogue */
|
|
1019
|
+
template <typename BinaryEpilogue>
|
|
1020
|
+
METAL_FUNC void apply_epilogue_safe(
|
|
1021
|
+
const device U* C,
|
|
1022
|
+
const int ldc,
|
|
1023
|
+
const int fdc,
|
|
1024
|
+
short2 dst_tile_dims,
|
|
1025
|
+
thread const BinaryEpilogue& epilogue_op) {
|
|
1026
|
+
// Adjust for simdgroup and thread location
|
|
1027
|
+
C += (sm)*ldc + (sn)*fdc;
|
|
1028
|
+
dst_tile_dims -= short2(sn, sm);
|
|
1029
|
+
|
|
1030
|
+
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
|
1031
|
+
return;
|
|
1032
|
+
|
|
1033
|
+
// Loop over all simdgroup tiles
|
|
1034
|
+
STEEL_PRAGMA_UNROLL
|
|
1035
|
+
for (short i = 0; i < TM; i++) {
|
|
1036
|
+
STEEL_PRAGMA_UNROLL
|
|
1037
|
+
for (short j = 0; j < TN; j++) {
|
|
1038
|
+
// Get accumulated result and associated offset in Cr, Ci
|
|
1039
|
+
thread auto& r = Ctile_r.frag_at(i, j);
|
|
1040
|
+
thread auto& im = Ctile_i.frag_at(i, j);
|
|
1041
|
+
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
|
1042
|
+
|
|
1043
|
+
constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag;
|
|
1044
|
+
complex64_t tmp[kelems];
|
|
1045
|
+
|
|
1046
|
+
STEEL_PRAGMA_UNROLL
|
|
1047
|
+
for (short k = 0; k < kelems; k++) {
|
|
1048
|
+
if ((j * TN_stride + k) < dst_tile_dims.x &&
|
|
1049
|
+
(i * TM_stride) < dst_tile_dims.y) {
|
|
1050
|
+
tmp[k] = C[offset_c + k * fdc];
|
|
1051
|
+
} else {
|
|
1052
|
+
tmp[k] = complex64_t(0.0f, 0.0f);
|
|
1053
|
+
}
|
|
1054
|
+
}
|
|
1055
|
+
|
|
1056
|
+
// Apply epilogue
|
|
1057
|
+
STEEL_PRAGMA_UNROLL
|
|
1058
|
+
for (short k = 0; k < kelems; k++) {
|
|
1059
|
+
complex64_t out = epilogue_op.apply(complex64_t(r[k], im[k]), tmp[k]);
|
|
1060
|
+
r[k] = out.real;
|
|
1061
|
+
im[k] = out.imag;
|
|
1062
|
+
}
|
|
1063
|
+
}
|
|
1064
|
+
}
|
|
1065
|
+
}
|
|
1066
|
+
|
|
1067
|
+
/* Store results from simdgroup_matrix results into device memory */
|
|
1068
|
+
METAL_FUNC void store_result(
|
|
1069
|
+
device U* D,
|
|
1070
|
+
const int ldd,
|
|
1071
|
+
const device U* C,
|
|
1072
|
+
const int ldc,
|
|
1073
|
+
const int fdc,
|
|
1074
|
+
thread const Epilogue& epilogue_op) const {
|
|
1075
|
+
// Adjust for simdgroup and thread location
|
|
1076
|
+
C += (sm)*ldc + (sn)*fdc;
|
|
1077
|
+
D += (sm)*ldd + sn;
|
|
1078
|
+
|
|
1079
|
+
constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag;
|
|
1080
|
+
|
|
1081
|
+
// Loop over all simdgroup tiles
|
|
1082
|
+
STEEL_PRAGMA_UNROLL
|
|
1083
|
+
for (short i = 0; i < TM; i++) {
|
|
1084
|
+
STEEL_PRAGMA_UNROLL
|
|
1085
|
+
for (short j = 0; j < TN; j++) {
|
|
1086
|
+
// Get accumulated result and associated offset in Cr, Ci
|
|
1087
|
+
thread const auto& r = Ctile_r.frag_at(i, j);
|
|
1088
|
+
thread const auto& im = Ctile_i.frag_at(i, j);
|
|
1089
|
+
int off_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
|
1090
|
+
int off_d = (i * TM_stride) * ldd + (j * TN_stride);
|
|
1091
|
+
|
|
1092
|
+
// Apply epilogue
|
|
1093
|
+
STEEL_PRAGMA_UNROLL
|
|
1094
|
+
for (short k = 0; k < kelems; k++) {
|
|
1095
|
+
D[off_d + k] =
|
|
1096
|
+
epilogue_op.apply(complex64_t(r[k], im[k]), C[off_c + k * fdc]);
|
|
1097
|
+
}
|
|
1098
|
+
}
|
|
1099
|
+
}
|
|
1100
|
+
}
|
|
1101
|
+
|
|
1102
|
+
METAL_FUNC void store_result_safe(
|
|
1103
|
+
device U* D,
|
|
1104
|
+
const int ldd,
|
|
1105
|
+
const device U* C,
|
|
1106
|
+
const int ldc,
|
|
1107
|
+
const int fdc,
|
|
1108
|
+
short2 dst_tile_dims,
|
|
1109
|
+
thread const Epilogue& epilogue_op) const {
|
|
1110
|
+
// Adjust for simdgroup and thread location
|
|
1111
|
+
C += (sm)*ldc + (sn)*fdc;
|
|
1112
|
+
D += (sm)*ldd + sn;
|
|
1113
|
+
dst_tile_dims -= short2(sn, sm);
|
|
1114
|
+
|
|
1115
|
+
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
|
1116
|
+
return;
|
|
1117
|
+
|
|
1118
|
+
constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag;
|
|
1119
|
+
|
|
1120
|
+
STEEL_PRAGMA_UNROLL
|
|
1121
|
+
for (int i = 0; i < TM; i++) {
|
|
1122
|
+
if (i * TM_stride < dst_tile_dims.y) {
|
|
1123
|
+
STEEL_PRAGMA_UNROLL
|
|
1124
|
+
for (int j = 0; j < TN; j++) {
|
|
1125
|
+
// Get accumulated result and associated offset in Cr, Ci
|
|
1126
|
+
thread const auto& r = Ctile_r.frag_at(i, j);
|
|
1127
|
+
thread const auto& im = Ctile_i.frag_at(i, j);
|
|
1128
|
+
int off_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
|
1129
|
+
int off_d = (i * TM_stride) * ldd + (j * TN_stride);
|
|
1130
|
+
|
|
1131
|
+
// Apply epilogue
|
|
1132
|
+
STEEL_PRAGMA_UNROLL
|
|
1133
|
+
for (short k = 0; k < kelems; k++) {
|
|
1134
|
+
if ((j * TN_stride + k) < dst_tile_dims.x) {
|
|
1135
|
+
D[off_d + k] = epilogue_op.apply(
|
|
1136
|
+
complex64_t(r[k], im[k]), C[off_c + k * fdc]);
|
|
1137
|
+
}
|
|
1138
|
+
}
|
|
1139
|
+
}
|
|
1140
|
+
}
|
|
1141
|
+
}
|
|
1142
|
+
}
|
|
1143
|
+
};
|
|
1144
|
+
|
|
1145
|
+
} // namespace steel
|
|
1146
|
+
} // namespace mlx
|