mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,750 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <metal_simdgroup>
|
|
6
|
+
#include <metal_simdgroup_matrix>
|
|
7
|
+
#include <metal_stdlib>
|
|
8
|
+
|
|
9
|
+
#include "mlx/backend/metal/kernels/steel/attn/transforms.h"
|
|
10
|
+
#include "mlx/backend/metal/kernels/steel/defines.h"
|
|
11
|
+
#include "mlx/backend/metal/kernels/steel/utils/integral_constant.h"
|
|
12
|
+
|
|
13
|
+
using namespace metal;
|
|
14
|
+
|
|
15
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
16
|
+
// MMA helper
|
|
17
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
18
|
+
|
|
19
|
+
namespace mlx {
|
|
20
|
+
namespace steel {
|
|
21
|
+
|
|
22
|
+
template <typename RInt, typename CInt>
|
|
23
|
+
struct Shape2D {
|
|
24
|
+
RInt r;
|
|
25
|
+
CInt c;
|
|
26
|
+
|
|
27
|
+
Shape2D(RInt r_, CInt c_) : r(r_), c(c_) {}
|
|
28
|
+
};
|
|
29
|
+
|
|
30
|
+
template <typename Shape, typename Layout>
|
|
31
|
+
struct Layout2D {
|
|
32
|
+
Shape shape;
|
|
33
|
+
Layout layout;
|
|
34
|
+
};
|
|
35
|
+
|
|
36
|
+
template <typename T, int kFragRows_, int kFragCols_>
|
|
37
|
+
struct BaseMMAFrag {
|
|
38
|
+
static_assert(
|
|
39
|
+
kFragRows_ == 8,
|
|
40
|
+
"Only 8 x 8 fragment matrices are currently supported");
|
|
41
|
+
static_assert(
|
|
42
|
+
kFragCols_ == 8,
|
|
43
|
+
"Only 8 x 8 fragment matrices are currently supported");
|
|
44
|
+
};
|
|
45
|
+
|
|
46
|
+
template <typename T>
|
|
47
|
+
struct BaseMMAFrag<T, 8, 8> {
|
|
48
|
+
STEEL_CONST int kFragRows = 8;
|
|
49
|
+
STEEL_CONST int kFragCols = 8;
|
|
50
|
+
|
|
51
|
+
STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32;
|
|
52
|
+
|
|
53
|
+
STEEL_CONST int kElemRows = 1;
|
|
54
|
+
STEEL_CONST int kElemCols = 2;
|
|
55
|
+
|
|
56
|
+
static_assert(
|
|
57
|
+
kElemRows * kElemCols == kElemsPerFrag,
|
|
58
|
+
"MMAFrag shape is not consistent with MMAFrag size");
|
|
59
|
+
|
|
60
|
+
typedef metal::simdgroup_matrix<T, kFragRows, kFragCols> mat_type;
|
|
61
|
+
typedef metal::vec<T, kElemsPerFrag> frag_type;
|
|
62
|
+
typedef metal::vec<T, kElemRows> row_frag_type;
|
|
63
|
+
typedef metal::vec<T, kElemCols> col_frag_type;
|
|
64
|
+
|
|
65
|
+
template <typename U>
|
|
66
|
+
using dtype_mat_t = typename metal::simdgroup_matrix<U, kFragRows, kFragCols>;
|
|
67
|
+
|
|
68
|
+
template <typename U>
|
|
69
|
+
using dtype_frag_t = typename metal::vec<U, kElemsPerFrag>;
|
|
70
|
+
|
|
71
|
+
METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id
|
|
72
|
+
[[thread_index_in_simdgroup]]) {
|
|
73
|
+
const short qid = simd_lane_id / 4;
|
|
74
|
+
const short fm = (qid & 4) + ((simd_lane_id / 2) % 4);
|
|
75
|
+
const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
|
76
|
+
return short2{fn, fm};
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
template <typename SrcPtrType, typename StrX, typename StrY>
|
|
80
|
+
METAL_FUNC static constexpr void
|
|
81
|
+
load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) {
|
|
82
|
+
STEEL_PRAGMA_UNROLL
|
|
83
|
+
for (short i = 0; i < kElemRows; i++) {
|
|
84
|
+
STEEL_PRAGMA_UNROLL
|
|
85
|
+
for (short j = 0; j < kElemCols; j++) {
|
|
86
|
+
dst[i * kElemCols + j] = static_cast<T>(src[i * str_x + j * str_y]);
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
template <
|
|
92
|
+
typename SrcPtrType,
|
|
93
|
+
typename StrX,
|
|
94
|
+
typename StrY,
|
|
95
|
+
typename LimX,
|
|
96
|
+
typename LimY,
|
|
97
|
+
typename OffX,
|
|
98
|
+
typename OffY>
|
|
99
|
+
METAL_FUNC static constexpr void load_safe(
|
|
100
|
+
thread frag_type& dst,
|
|
101
|
+
SrcPtrType src,
|
|
102
|
+
StrX str_x,
|
|
103
|
+
StrY str_y,
|
|
104
|
+
LimX lim_x,
|
|
105
|
+
LimY lim_y,
|
|
106
|
+
OffX off_x = Int<0>{},
|
|
107
|
+
OffY off_y = Int<0>{}) {
|
|
108
|
+
src += off_x * str_x + off_y * str_y;
|
|
109
|
+
STEEL_PRAGMA_UNROLL
|
|
110
|
+
for (short i = 0; i < kElemRows; i++) {
|
|
111
|
+
STEEL_PRAGMA_UNROLL
|
|
112
|
+
for (short j = 0; j < kElemCols; j++) {
|
|
113
|
+
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
|
|
114
|
+
dst[i * kElemCols + j] = static_cast<T>(src[0]);
|
|
115
|
+
} else {
|
|
116
|
+
dst[i * kElemCols + j] = T(0);
|
|
117
|
+
}
|
|
118
|
+
src += str_y;
|
|
119
|
+
}
|
|
120
|
+
src -= kElemCols * str_y;
|
|
121
|
+
src += str_x;
|
|
122
|
+
}
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
template <typename DstPtrType, typename StrX, typename StrY>
|
|
126
|
+
METAL_FUNC static constexpr void
|
|
127
|
+
store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) {
|
|
128
|
+
using U = pointer_element_t<DstPtrType>;
|
|
129
|
+
|
|
130
|
+
STEEL_PRAGMA_UNROLL
|
|
131
|
+
for (short i = 0; i < kElemRows; i++) {
|
|
132
|
+
STEEL_PRAGMA_UNROLL
|
|
133
|
+
for (short j = 0; j < kElemCols; j++) {
|
|
134
|
+
dst[i * str_x + j * str_y] = static_cast<U>(src[i * kElemCols + j]);
|
|
135
|
+
}
|
|
136
|
+
}
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
template <
|
|
140
|
+
typename DstPtrType,
|
|
141
|
+
typename StrX,
|
|
142
|
+
typename StrY,
|
|
143
|
+
typename LimX,
|
|
144
|
+
typename LimY,
|
|
145
|
+
typename OffX,
|
|
146
|
+
typename OffY>
|
|
147
|
+
METAL_FUNC static constexpr void store_safe(
|
|
148
|
+
const thread frag_type& src,
|
|
149
|
+
DstPtrType dst,
|
|
150
|
+
StrX str_x,
|
|
151
|
+
StrY str_y,
|
|
152
|
+
LimX lim_x,
|
|
153
|
+
LimY lim_y,
|
|
154
|
+
OffX off_x = Int<0>{},
|
|
155
|
+
OffY off_y = Int<0>{}) {
|
|
156
|
+
using U = pointer_element_t<DstPtrType>;
|
|
157
|
+
|
|
158
|
+
STEEL_PRAGMA_UNROLL
|
|
159
|
+
for (short i = 0; i < kElemRows; i++) {
|
|
160
|
+
STEEL_PRAGMA_UNROLL
|
|
161
|
+
for (short j = 0; j < kElemCols; j++) {
|
|
162
|
+
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
|
|
163
|
+
dst[(off_x + i) * str_x + (off_y + j) * str_y] =
|
|
164
|
+
static_cast<U>(src[i * kElemCols + j]);
|
|
165
|
+
}
|
|
166
|
+
}
|
|
167
|
+
}
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
template <typename Atype, typename Btype, typename Ctype>
|
|
171
|
+
METAL_FUNC static constexpr void mma(
|
|
172
|
+
thread frag_type& D,
|
|
173
|
+
thread dtype_frag_t<Atype>& A,
|
|
174
|
+
thread dtype_frag_t<Btype>& B,
|
|
175
|
+
thread dtype_frag_t<Ctype>& C) {
|
|
176
|
+
mat_type D_mat;
|
|
177
|
+
dtype_mat_t<Atype> A_mat;
|
|
178
|
+
dtype_mat_t<Btype> B_mat;
|
|
179
|
+
dtype_mat_t<Ctype> C_mat;
|
|
180
|
+
|
|
181
|
+
reinterpret_cast<thread dtype_frag_t<Atype>&>(A_mat.thread_elements()) = A;
|
|
182
|
+
reinterpret_cast<thread dtype_frag_t<Btype>&>(B_mat.thread_elements()) = B;
|
|
183
|
+
reinterpret_cast<thread dtype_frag_t<Ctype>&>(C_mat.thread_elements()) = C;
|
|
184
|
+
|
|
185
|
+
mma(D_mat, A_mat, B_mat, C_mat);
|
|
186
|
+
|
|
187
|
+
D = reinterpret_cast<thread frag_type&>(D_mat.thread_elements());
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
template <typename Atype, typename Btype, typename Ctype>
|
|
191
|
+
METAL_FUNC static constexpr void mma(
|
|
192
|
+
thread mat_type& D,
|
|
193
|
+
thread dtype_mat_t<Atype>& A,
|
|
194
|
+
thread dtype_mat_t<Btype>& B,
|
|
195
|
+
thread dtype_mat_t<Ctype>& C) {
|
|
196
|
+
simdgroup_multiply_accumulate(D, A, B, C);
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
template <typename Op>
|
|
200
|
+
METAL_FUNC static constexpr void row_reduce(
|
|
201
|
+
thread const frag_type& inp_vals,
|
|
202
|
+
thread T* reduced_vals) {
|
|
203
|
+
T thr_reduce = Op::apply(inp_vals.x, inp_vals.y);
|
|
204
|
+
|
|
205
|
+
T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1));
|
|
206
|
+
qgr_reduce = Op::apply(thr_reduce, qgr_reduce);
|
|
207
|
+
|
|
208
|
+
T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8));
|
|
209
|
+
sgr_reduce = Op::apply(qgr_reduce, sgr_reduce);
|
|
210
|
+
|
|
211
|
+
reduced_vals[0] = Op::apply(reduced_vals[0], sgr_reduce);
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
template <typename Op>
|
|
215
|
+
METAL_FUNC static constexpr void row_bin_op(
|
|
216
|
+
thread frag_type& inp_vals,
|
|
217
|
+
thread T* row_vals) {
|
|
218
|
+
STEEL_PRAGMA_UNROLL
|
|
219
|
+
for (short i = 0; i < kElemRows; i++) {
|
|
220
|
+
STEEL_PRAGMA_UNROLL
|
|
221
|
+
for (short j = 0; j < kElemCols; j++) {
|
|
222
|
+
inp_vals[i * kElemCols + j] =
|
|
223
|
+
Op::apply(inp_vals[i * kElemCols + j], row_vals[i]);
|
|
224
|
+
}
|
|
225
|
+
}
|
|
226
|
+
}
|
|
227
|
+
};
|
|
228
|
+
|
|
229
|
+
template <
|
|
230
|
+
typename T,
|
|
231
|
+
int kTileRows_,
|
|
232
|
+
int kTileCols_,
|
|
233
|
+
class MMAFrag_ = BaseMMAFrag<T, 8, 8>>
|
|
234
|
+
struct MMATile {
|
|
235
|
+
using MMAFrag_t = MMAFrag_;
|
|
236
|
+
using elem_type = T;
|
|
237
|
+
STEEL_CONST int kFragRows = MMAFrag_t::kFragRows;
|
|
238
|
+
STEEL_CONST int kFragCols = MMAFrag_t::kFragCols;
|
|
239
|
+
STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag;
|
|
240
|
+
|
|
241
|
+
STEEL_CONST int kTileRows = kTileRows_;
|
|
242
|
+
STEEL_CONST int kTileCols = kTileCols_;
|
|
243
|
+
|
|
244
|
+
STEEL_CONST int kRows = kTileRows * kFragRows;
|
|
245
|
+
STEEL_CONST int kCols = kTileCols * kFragCols;
|
|
246
|
+
|
|
247
|
+
STEEL_CONST int kNumFrags = kTileRows * kTileCols;
|
|
248
|
+
STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag;
|
|
249
|
+
|
|
250
|
+
STEEL_CONST int kRowsPerThread = kTileRows * MMAFrag_t::kElemRows;
|
|
251
|
+
STEEL_CONST int kColsPerThread = kTileCols * MMAFrag_t::kElemCols;
|
|
252
|
+
|
|
253
|
+
typedef typename MMAFrag_t::mat_type mat_type;
|
|
254
|
+
typedef typename MMAFrag_t::frag_type frag_type;
|
|
255
|
+
|
|
256
|
+
frag_type val_frags[kNumFrags]; // = {frag_type(0)};
|
|
257
|
+
|
|
258
|
+
METAL_FUNC MMATile() thread {}
|
|
259
|
+
|
|
260
|
+
METAL_FUNC constexpr void clear() {
|
|
261
|
+
STEEL_PRAGMA_UNROLL
|
|
262
|
+
for (short i = 0; i < kNumFrags; ++i) {
|
|
263
|
+
val_frags[i] = frag_type(0);
|
|
264
|
+
}
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) {
|
|
268
|
+
return val_frags[i * kTileCols + j];
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
METAL_FUNC constexpr const thread frag_type& frag_at(
|
|
272
|
+
const short i,
|
|
273
|
+
const short j) const {
|
|
274
|
+
return val_frags[i * kTileCols + j];
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
METAL_FUNC mat_type mat_at(const short i, const short j) {
|
|
278
|
+
mat_type val_mat;
|
|
279
|
+
STEEL_PRAGMA_UNROLL
|
|
280
|
+
for (short ii = 0; ii < kElemsPerFrag; ++ii) {
|
|
281
|
+
val_mat.thread_elements()[ii] = frag_at(i, j)[ii];
|
|
282
|
+
}
|
|
283
|
+
return val_mat;
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
METAL_FUNC thread elem_type* elems() {
|
|
287
|
+
return reinterpret_cast<thread elem_type*>(val_frags);
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
METAL_FUNC const thread elem_type* elems() const {
|
|
291
|
+
return reinterpret_cast<const thread elem_type*>(val_frags);
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
template <typename Op>
|
|
295
|
+
METAL_FUNC void row_reduce(thread T vals[kRowsPerThread]) const {
|
|
296
|
+
STEEL_PRAGMA_UNROLL
|
|
297
|
+
for (short i = 0; i < kTileRows; ++i) {
|
|
298
|
+
STEEL_PRAGMA_UNROLL
|
|
299
|
+
for (short j = 0; j < kTileCols; ++j) {
|
|
300
|
+
MMAFrag_t::template row_reduce<Op>(
|
|
301
|
+
frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]);
|
|
302
|
+
}
|
|
303
|
+
}
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
template <typename Op>
|
|
307
|
+
METAL_FUNC void row_bin_op(thread T vals[kRowsPerThread]) {
|
|
308
|
+
STEEL_PRAGMA_UNROLL
|
|
309
|
+
for (short i = 0; i < kTileRows; ++i) {
|
|
310
|
+
STEEL_PRAGMA_UNROLL
|
|
311
|
+
for (short j = 0; j < kTileCols; ++j) {
|
|
312
|
+
MMAFrag_t::template row_bin_op<Op>(
|
|
313
|
+
frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]);
|
|
314
|
+
}
|
|
315
|
+
}
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
template <typename U, int w_x, int w_y, int str_x, int str_y>
|
|
319
|
+
METAL_FUNC void load(const threadgroup U* src) {
|
|
320
|
+
STEEL_PRAGMA_UNROLL
|
|
321
|
+
for (short i = 0; i < kTileRows; ++i) {
|
|
322
|
+
STEEL_PRAGMA_UNROLL
|
|
323
|
+
for (short j = 0; j < kTileCols; ++j) {
|
|
324
|
+
MMAFrag_t::load(
|
|
325
|
+
frag_at(i, j),
|
|
326
|
+
&(
|
|
327
|
+
src[(i * kFragRows) * w_x * str_x +
|
|
328
|
+
(j * kFragCols) * w_y * str_y]),
|
|
329
|
+
Int<str_x>{},
|
|
330
|
+
Int<str_y>{});
|
|
331
|
+
}
|
|
332
|
+
}
|
|
333
|
+
}
|
|
334
|
+
|
|
335
|
+
template <typename U, int w_x, int w_y, int str_x, int str_y>
|
|
336
|
+
METAL_FUNC void store(threadgroup U* dst) const {
|
|
337
|
+
STEEL_PRAGMA_UNROLL
|
|
338
|
+
for (short i = 0; i < kTileRows; ++i) {
|
|
339
|
+
STEEL_PRAGMA_UNROLL
|
|
340
|
+
for (short j = 0; j < kTileCols; ++j) {
|
|
341
|
+
MMAFrag_t::store(
|
|
342
|
+
frag_at(i, j),
|
|
343
|
+
&(
|
|
344
|
+
dst[(i * kFragRows) * w_x * str_x +
|
|
345
|
+
(j * kFragCols) * w_y * str_y]),
|
|
346
|
+
Int<str_x>{},
|
|
347
|
+
Int<str_y>{});
|
|
348
|
+
}
|
|
349
|
+
}
|
|
350
|
+
}
|
|
351
|
+
|
|
352
|
+
template <typename U, int w_x, int w_y>
|
|
353
|
+
METAL_FUNC void load(const device U* src, const int ld) {
|
|
354
|
+
STEEL_PRAGMA_UNROLL
|
|
355
|
+
for (short i = 0; i < kTileRows; ++i) {
|
|
356
|
+
STEEL_PRAGMA_UNROLL
|
|
357
|
+
for (short j = 0; j < kTileCols; ++j) {
|
|
358
|
+
MMAFrag_t::load(
|
|
359
|
+
frag_at(i, j),
|
|
360
|
+
&(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
|
|
361
|
+
ld,
|
|
362
|
+
Int<1>{});
|
|
363
|
+
}
|
|
364
|
+
}
|
|
365
|
+
}
|
|
366
|
+
|
|
367
|
+
template <typename U, int w_x, int w_y>
|
|
368
|
+
METAL_FUNC void store(device U* dst, const int ld) const {
|
|
369
|
+
STEEL_PRAGMA_UNROLL
|
|
370
|
+
for (short i = 0; i < kTileRows; ++i) {
|
|
371
|
+
STEEL_PRAGMA_UNROLL
|
|
372
|
+
for (short j = 0; j < kTileCols; ++j) {
|
|
373
|
+
MMAFrag_t::store(
|
|
374
|
+
frag_at(i, j),
|
|
375
|
+
&(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
|
|
376
|
+
ld,
|
|
377
|
+
Int<1>{});
|
|
378
|
+
}
|
|
379
|
+
}
|
|
380
|
+
}
|
|
381
|
+
|
|
382
|
+
template <typename U, int w_x, int w_y>
|
|
383
|
+
METAL_FUNC void
|
|
384
|
+
load_safe(const device U* src, const int ld, const short2 src_tile_dims) {
|
|
385
|
+
STEEL_PRAGMA_UNROLL
|
|
386
|
+
for (int i = 0; i < kTileRows; ++i) {
|
|
387
|
+
STEEL_PRAGMA_UNROLL
|
|
388
|
+
for (int j = 0; j < kTileCols; ++j) {
|
|
389
|
+
MMAFrag_t::load_safe(
|
|
390
|
+
frag_at(i, j),
|
|
391
|
+
src,
|
|
392
|
+
ld,
|
|
393
|
+
Int<1>{},
|
|
394
|
+
src_tile_dims.y,
|
|
395
|
+
src_tile_dims.x,
|
|
396
|
+
(i * kFragRows) * w_x,
|
|
397
|
+
(j * kFragCols) * w_y);
|
|
398
|
+
}
|
|
399
|
+
}
|
|
400
|
+
}
|
|
401
|
+
|
|
402
|
+
template <typename U, int w_x, int w_y>
|
|
403
|
+
METAL_FUNC void
|
|
404
|
+
store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const {
|
|
405
|
+
STEEL_PRAGMA_UNROLL
|
|
406
|
+
for (int i = 0; i < kTileRows; ++i) {
|
|
407
|
+
STEEL_PRAGMA_UNROLL
|
|
408
|
+
for (int j = 0; j < kTileCols; ++j) {
|
|
409
|
+
MMAFrag_t::store_safe(
|
|
410
|
+
frag_at(i, j),
|
|
411
|
+
dst,
|
|
412
|
+
ld,
|
|
413
|
+
Int<1>{},
|
|
414
|
+
dst_tile_dims.y,
|
|
415
|
+
dst_tile_dims.x,
|
|
416
|
+
(i * kFragRows) * w_x,
|
|
417
|
+
(j * kFragCols) * w_y);
|
|
418
|
+
}
|
|
419
|
+
}
|
|
420
|
+
}
|
|
421
|
+
};
|
|
422
|
+
|
|
423
|
+
template <
|
|
424
|
+
typename Dtype,
|
|
425
|
+
typename Atype,
|
|
426
|
+
typename Btype,
|
|
427
|
+
typename Ctype,
|
|
428
|
+
int M,
|
|
429
|
+
int N,
|
|
430
|
+
int K,
|
|
431
|
+
class MMAFragD,
|
|
432
|
+
class MMAFragA,
|
|
433
|
+
class MMAFragB,
|
|
434
|
+
class MMAFragC>
|
|
435
|
+
METAL_FUNC void tile_matmad(
|
|
436
|
+
thread MMATile<Dtype, M, N, MMAFragD>& D,
|
|
437
|
+
thread MMATile<Atype, M, K, MMAFragA>& A,
|
|
438
|
+
thread MMATile<Btype, K, N, MMAFragB>& B,
|
|
439
|
+
thread MMATile<Ctype, M, N, MMAFragC>& C) {
|
|
440
|
+
STEEL_PRAGMA_UNROLL
|
|
441
|
+
for (short m = 0; m < M; ++m) {
|
|
442
|
+
STEEL_PRAGMA_UNROLL
|
|
443
|
+
for (short n = 0; n < N; ++n) {
|
|
444
|
+
short m_serp = m; //(n % 2) ? (M - 1 - m) : m;
|
|
445
|
+
short n_serp = (m % 2) ? (N - 1 - n) : n;
|
|
446
|
+
|
|
447
|
+
STEEL_PRAGMA_UNROLL
|
|
448
|
+
for (short k = 0; k < K; ++k) {
|
|
449
|
+
MMAFragD::mma(
|
|
450
|
+
D.frag_at(m_serp, n_serp),
|
|
451
|
+
A.frag_at(m_serp, k),
|
|
452
|
+
B.frag_at(k, n_serp),
|
|
453
|
+
C.frag_at(m_serp, n_serp));
|
|
454
|
+
}
|
|
455
|
+
}
|
|
456
|
+
}
|
|
457
|
+
}
|
|
458
|
+
|
|
459
|
+
template <
|
|
460
|
+
typename T,
|
|
461
|
+
typename U,
|
|
462
|
+
int BM,
|
|
463
|
+
int BN,
|
|
464
|
+
int BK,
|
|
465
|
+
int WM,
|
|
466
|
+
int WN,
|
|
467
|
+
bool transpose_a,
|
|
468
|
+
bool transpose_b,
|
|
469
|
+
short lda_tgp,
|
|
470
|
+
short ldb_tgp,
|
|
471
|
+
typename AccumType = float,
|
|
472
|
+
typename Epilogue = TransformNone<U, AccumType>>
|
|
473
|
+
struct BlockMMA {
|
|
474
|
+
// MMAFrag size
|
|
475
|
+
STEEL_CONST short kFragSize = 8;
|
|
476
|
+
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
|
|
477
|
+
|
|
478
|
+
// Warp tile simdgroup matrix strides along M
|
|
479
|
+
STEEL_CONST short TM_stride = kFragSize * WM;
|
|
480
|
+
// Warp tile simdgroup matrix strides along M
|
|
481
|
+
STEEL_CONST short TN_stride = kFragSize * WN;
|
|
482
|
+
|
|
483
|
+
// Warp tile size along M
|
|
484
|
+
STEEL_CONST short TM = BM / TM_stride;
|
|
485
|
+
// Warp tile size along N
|
|
486
|
+
STEEL_CONST short TN = BN / TN_stride;
|
|
487
|
+
|
|
488
|
+
// Threadgroup A strides
|
|
489
|
+
STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M
|
|
490
|
+
STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K
|
|
491
|
+
|
|
492
|
+
// Threadgroup B strides
|
|
493
|
+
STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K
|
|
494
|
+
STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N
|
|
495
|
+
|
|
496
|
+
// Threadgroup strides along K
|
|
497
|
+
STEEL_CONST short tile_stride_a = kFragSize * A_str_k;
|
|
498
|
+
STEEL_CONST short tile_stride_b = kFragSize * B_str_k;
|
|
499
|
+
|
|
500
|
+
// Simdgroup matrices
|
|
501
|
+
MMATile<AccumType, TM, 1, MMAFrag_acc_t> Atile;
|
|
502
|
+
MMATile<AccumType, 1, TN, MMAFrag_acc_t> Btile;
|
|
503
|
+
MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile;
|
|
504
|
+
|
|
505
|
+
// Offsets within threadgroup
|
|
506
|
+
short sm;
|
|
507
|
+
short sn;
|
|
508
|
+
|
|
509
|
+
short As_offset;
|
|
510
|
+
short Bs_offset;
|
|
511
|
+
|
|
512
|
+
/* Constructor */
|
|
513
|
+
METAL_FUNC BlockMMA(
|
|
514
|
+
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
515
|
+
ushort simd_lane_id [[thread_index_in_simdgroup]]) {
|
|
516
|
+
// Determine thread position in simdgroup matrix
|
|
517
|
+
short tm = kFragSize * (simd_group_id / WN);
|
|
518
|
+
short tn = kFragSize * (simd_group_id % WN);
|
|
519
|
+
|
|
520
|
+
short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
|
|
521
|
+
sm = simd_coord.y;
|
|
522
|
+
sn = simd_coord.x;
|
|
523
|
+
|
|
524
|
+
// Determine thread and simdgroup offset
|
|
525
|
+
As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K
|
|
526
|
+
Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N
|
|
527
|
+
|
|
528
|
+
sm += tm;
|
|
529
|
+
sn += tn;
|
|
530
|
+
}
|
|
531
|
+
|
|
532
|
+
/* (BM, BK) X (BK, BN) multiply accumulate function */
|
|
533
|
+
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
|
|
534
|
+
// Adjust for simdgroup and thread location
|
|
535
|
+
As += As_offset;
|
|
536
|
+
Bs += Bs_offset;
|
|
537
|
+
|
|
538
|
+
// Iterate over BK in blocks of kFragSize
|
|
539
|
+
STEEL_PRAGMA_UNROLL
|
|
540
|
+
for (short kk = 0; kk < BK; kk += kFragSize) {
|
|
541
|
+
simdgroup_barrier(mem_flags::mem_none);
|
|
542
|
+
|
|
543
|
+
Atile.template load<T, WM, 1, A_str_m, A_str_k>(As);
|
|
544
|
+
|
|
545
|
+
simdgroup_barrier(mem_flags::mem_none);
|
|
546
|
+
|
|
547
|
+
Btile.template load<T, 1, WN, B_str_k, B_str_n>(Bs);
|
|
548
|
+
|
|
549
|
+
simdgroup_barrier(mem_flags::mem_none);
|
|
550
|
+
|
|
551
|
+
tile_matmad(Ctile, Atile, Btile, Ctile);
|
|
552
|
+
|
|
553
|
+
// Progress to next simdgroup tile
|
|
554
|
+
As += tile_stride_a;
|
|
555
|
+
Bs += tile_stride_b;
|
|
556
|
+
}
|
|
557
|
+
}
|
|
558
|
+
|
|
559
|
+
/* Store results from simdgroup_matrix results into device memory */
|
|
560
|
+
METAL_FUNC void store_result(device U* D, const int ldd) {
|
|
561
|
+
// Apply epilogue
|
|
562
|
+
STEEL_PRAGMA_UNROLL
|
|
563
|
+
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
|
|
564
|
+
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
|
|
565
|
+
}
|
|
566
|
+
|
|
567
|
+
// Adjust for simdgroup and thread location
|
|
568
|
+
D += sm * ldd + sn;
|
|
569
|
+
|
|
570
|
+
Ctile.template store<U, WM, WN>(D, ldd);
|
|
571
|
+
}
|
|
572
|
+
|
|
573
|
+
METAL_FUNC void
|
|
574
|
+
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
|
|
575
|
+
// Apply epilogue
|
|
576
|
+
STEEL_PRAGMA_UNROLL
|
|
577
|
+
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
|
|
578
|
+
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
|
|
579
|
+
}
|
|
580
|
+
|
|
581
|
+
// Adjust for simdgroup and thread location
|
|
582
|
+
D += sm * ldd + sn;
|
|
583
|
+
dst_tile_dims -= short2(sn, sm);
|
|
584
|
+
|
|
585
|
+
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
|
586
|
+
return;
|
|
587
|
+
|
|
588
|
+
Ctile.template store_safe<U, WM, WN>(D, ldd, dst_tile_dims);
|
|
589
|
+
}
|
|
590
|
+
|
|
591
|
+
/* Apply epilogue */
|
|
592
|
+
template <typename UnaryEpilogue>
|
|
593
|
+
METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) {
|
|
594
|
+
// Loop over all simdgroup tiles
|
|
595
|
+
STEEL_PRAGMA_UNROLL
|
|
596
|
+
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
|
|
597
|
+
Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]);
|
|
598
|
+
}
|
|
599
|
+
}
|
|
600
|
+
|
|
601
|
+
/* Apply epilogue */
|
|
602
|
+
template <typename BinaryEpilogue>
|
|
603
|
+
METAL_FUNC void apply_epilogue(
|
|
604
|
+
const device U* C,
|
|
605
|
+
const int ldc,
|
|
606
|
+
const int fdc,
|
|
607
|
+
thread const BinaryEpilogue& epilogue_op) {
|
|
608
|
+
// Adjust for simdgroup and thread location
|
|
609
|
+
C += (sm)*ldc + (sn)*fdc;
|
|
610
|
+
|
|
611
|
+
// Loop over all simdgroup tiles
|
|
612
|
+
STEEL_PRAGMA_UNROLL
|
|
613
|
+
for (short i = 0; i < TM; i++) {
|
|
614
|
+
STEEL_PRAGMA_UNROLL
|
|
615
|
+
for (short j = 0; j < TN; j++) {
|
|
616
|
+
// Get accumulated result and associated offset in C
|
|
617
|
+
thread auto& accum = Ctile.frag_at(i, j);
|
|
618
|
+
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
|
619
|
+
|
|
620
|
+
// Apply epilogue
|
|
621
|
+
STEEL_PRAGMA_UNROLL
|
|
622
|
+
for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) {
|
|
623
|
+
accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
|
|
624
|
+
}
|
|
625
|
+
}
|
|
626
|
+
}
|
|
627
|
+
}
|
|
628
|
+
|
|
629
|
+
/* Apply epilogue */
|
|
630
|
+
template <typename BinaryEpilogue>
|
|
631
|
+
METAL_FUNC void apply_epilogue_safe(
|
|
632
|
+
const device U* C,
|
|
633
|
+
const int ldc,
|
|
634
|
+
const int fdc,
|
|
635
|
+
short2 dst_tile_dims,
|
|
636
|
+
thread const BinaryEpilogue& epilogue_op) {
|
|
637
|
+
// Adjust for simdgroup and thread location
|
|
638
|
+
C += (sm)*ldc + (sn)*fdc;
|
|
639
|
+
dst_tile_dims -= short2(sn, sm);
|
|
640
|
+
|
|
641
|
+
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
|
642
|
+
return;
|
|
643
|
+
|
|
644
|
+
// Loop over all simdgroup tiles
|
|
645
|
+
STEEL_PRAGMA_UNROLL
|
|
646
|
+
for (short i = 0; i < TM; i++) {
|
|
647
|
+
STEEL_PRAGMA_UNROLL
|
|
648
|
+
for (short j = 0; j < TN; j++) {
|
|
649
|
+
// Get accumulated result and associated offset in C
|
|
650
|
+
thread auto& accum = Ctile.frag_at(i, j);
|
|
651
|
+
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
|
652
|
+
|
|
653
|
+
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
|
|
654
|
+
|
|
655
|
+
// Read C
|
|
656
|
+
U c_elems[kelems] = {0};
|
|
657
|
+
|
|
658
|
+
STEEL_PRAGMA_UNROLL
|
|
659
|
+
for (short k = 0; k < kelems; k++) {
|
|
660
|
+
if ((j * TN_stride + k) < dst_tile_dims.x) {
|
|
661
|
+
c_elems[k] = C[offset_c + k * fdc];
|
|
662
|
+
}
|
|
663
|
+
}
|
|
664
|
+
|
|
665
|
+
// Apply epilogue
|
|
666
|
+
STEEL_PRAGMA_UNROLL
|
|
667
|
+
for (short k = 0; k < kelems; k++) {
|
|
668
|
+
accum[k] = epilogue_op.apply(accum[k], c_elems[k]);
|
|
669
|
+
}
|
|
670
|
+
}
|
|
671
|
+
}
|
|
672
|
+
}
|
|
673
|
+
|
|
674
|
+
/* Store results from simdgroup_matrix results into device memory */
|
|
675
|
+
METAL_FUNC void store_result(
|
|
676
|
+
device U* D,
|
|
677
|
+
const int ldd,
|
|
678
|
+
const device U* C,
|
|
679
|
+
const int ldc,
|
|
680
|
+
const int fdc,
|
|
681
|
+
thread const Epilogue& epilogue_op) const {
|
|
682
|
+
// Adjust for simdgroup and thread location
|
|
683
|
+
C += (sm)*ldc + (sn)*fdc;
|
|
684
|
+
D += (sm)*ldd + sn;
|
|
685
|
+
|
|
686
|
+
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
|
|
687
|
+
|
|
688
|
+
// Loop over all simdgroup tiles
|
|
689
|
+
STEEL_PRAGMA_UNROLL
|
|
690
|
+
for (short i = 0; i < TM; i++) {
|
|
691
|
+
STEEL_PRAGMA_UNROLL
|
|
692
|
+
for (short j = 0; j < TN; j++) {
|
|
693
|
+
// Get accumulated result and associated offset in C
|
|
694
|
+
thread const auto& accum = Ctile.frag_at(i, j);
|
|
695
|
+
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
|
696
|
+
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
|
|
697
|
+
|
|
698
|
+
// Apply epilogue
|
|
699
|
+
STEEL_PRAGMA_UNROLL
|
|
700
|
+
for (short k = 0; k < kelems; k++) {
|
|
701
|
+
D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
|
|
702
|
+
}
|
|
703
|
+
}
|
|
704
|
+
}
|
|
705
|
+
}
|
|
706
|
+
|
|
707
|
+
METAL_FUNC void store_result_safe(
|
|
708
|
+
device U* D,
|
|
709
|
+
const int ldd,
|
|
710
|
+
const device U* C,
|
|
711
|
+
const int ldc,
|
|
712
|
+
const int fdc,
|
|
713
|
+
short2 dst_tile_dims,
|
|
714
|
+
thread const Epilogue& epilogue_op) const {
|
|
715
|
+
// Adjust for simdgroup and thread location
|
|
716
|
+
C += (sm)*ldc + (sn)*fdc;
|
|
717
|
+
D += (sm)*ldd + sn;
|
|
718
|
+
dst_tile_dims -= short2(sn, sm);
|
|
719
|
+
|
|
720
|
+
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
|
721
|
+
return;
|
|
722
|
+
|
|
723
|
+
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
|
|
724
|
+
|
|
725
|
+
STEEL_PRAGMA_UNROLL
|
|
726
|
+
for (int i = 0; i < TM; i++) {
|
|
727
|
+
if (i * TM_stride < dst_tile_dims.y) {
|
|
728
|
+
STEEL_PRAGMA_UNROLL
|
|
729
|
+
for (int j = 0; j < TN; j++) {
|
|
730
|
+
// Get accumulated result and associated offset in C
|
|
731
|
+
thread const auto& accum = Ctile.frag_at(i, j);
|
|
732
|
+
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
|
733
|
+
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
|
|
734
|
+
|
|
735
|
+
// Apply epilogue
|
|
736
|
+
STEEL_PRAGMA_UNROLL
|
|
737
|
+
for (short k = 0; k < kelems; k++) {
|
|
738
|
+
if ((j * TN_stride + k) < dst_tile_dims.x) {
|
|
739
|
+
D[offset_d + k] =
|
|
740
|
+
epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
|
|
741
|
+
}
|
|
742
|
+
}
|
|
743
|
+
}
|
|
744
|
+
}
|
|
745
|
+
}
|
|
746
|
+
}
|
|
747
|
+
};
|
|
748
|
+
|
|
749
|
+
} // namespace steel
|
|
750
|
+
} // namespace mlx
|