mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,1084 @@
|
|
|
1
|
+
// Copyright © 2025 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <metal_simdgroup>
|
|
6
|
+
#include <metal_simdgroup_matrix>
|
|
7
|
+
#include <metal_stdlib>
|
|
8
|
+
|
|
9
|
+
#include "mlx/backend/metal/kernels/steel/defines.h"
|
|
10
|
+
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
|
|
11
|
+
#include "mlx/backend/metal/kernels/steel/utils/integral_constant.h"
|
|
12
|
+
|
|
13
|
+
#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h>
|
|
14
|
+
|
|
15
|
+
using namespace metal;
|
|
16
|
+
|
|
17
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
18
|
+
// MMA helper
|
|
19
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
20
|
+
|
|
21
|
+
namespace mlx {
|
|
22
|
+
namespace steel {
|
|
23
|
+
|
|
24
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
25
|
+
// NAX Steel with new tiles
|
|
26
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
27
|
+
|
|
28
|
+
struct BaseNAXFrag {
|
|
29
|
+
STEEL_CONST short kFragRows = 16;
|
|
30
|
+
STEEL_CONST short kFragCols = 16;
|
|
31
|
+
|
|
32
|
+
STEEL_CONST short kElemsPerFrag = (kFragRows * kFragCols) / 32;
|
|
33
|
+
|
|
34
|
+
STEEL_CONST short kElemRows = 2;
|
|
35
|
+
STEEL_CONST short kElemCols = 4;
|
|
36
|
+
|
|
37
|
+
STEEL_CONST short kElemRowsJump = 8;
|
|
38
|
+
|
|
39
|
+
static_assert(
|
|
40
|
+
kElemRows * kElemCols == kElemsPerFrag,
|
|
41
|
+
"MMAFrag shape is not consistent with MMAFrag size");
|
|
42
|
+
|
|
43
|
+
template <typename U>
|
|
44
|
+
using dtype_frag_t = typename metal::vec<U, kElemsPerFrag>;
|
|
45
|
+
|
|
46
|
+
METAL_FUNC static short2 get_coord() {
|
|
47
|
+
const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort());
|
|
48
|
+
const short qid = simd_lane_id >> 2;
|
|
49
|
+
const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3));
|
|
50
|
+
const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4;
|
|
51
|
+
return short2{fn, fm};
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
METAL_FUNC static short2 get_coord(short idx) {
|
|
55
|
+
const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort());
|
|
56
|
+
const short qid = simd_lane_id >> 2;
|
|
57
|
+
const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3)) + (idx >> 2) * 8;
|
|
58
|
+
const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4 + idx % 4;
|
|
59
|
+
return short2{fn, fm};
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
template <
|
|
63
|
+
typename T,
|
|
64
|
+
typename SrcPtrType,
|
|
65
|
+
typename StrX,
|
|
66
|
+
typename StrY,
|
|
67
|
+
typename OffX = Int<0>,
|
|
68
|
+
typename OffY = Int<0>>
|
|
69
|
+
METAL_FUNC static constexpr void load(
|
|
70
|
+
thread dtype_frag_t<T>& dst,
|
|
71
|
+
SrcPtrType src,
|
|
72
|
+
StrX str_x,
|
|
73
|
+
StrY str_y,
|
|
74
|
+
OffX off_x = {},
|
|
75
|
+
OffY off_y = {}) {
|
|
76
|
+
const short2 sc = get_coord();
|
|
77
|
+
STEEL_PRAGMA_UNROLL
|
|
78
|
+
for (short i = 0; i < kElemRows; i++) {
|
|
79
|
+
const auto r = off_x + i * kElemRowsJump + sc.y;
|
|
80
|
+
const auto c = off_y + sc.x;
|
|
81
|
+
|
|
82
|
+
if constexpr (metal::is_same_v<StrY, Int<1>>) {
|
|
83
|
+
STEEL_PRAGMA_UNROLL
|
|
84
|
+
for (short j = 0; j < kElemCols; j++) {
|
|
85
|
+
dst[i * kElemCols + j] = static_cast<T>(src[r * str_x + c + j]);
|
|
86
|
+
}
|
|
87
|
+
} else {
|
|
88
|
+
STEEL_PRAGMA_UNROLL
|
|
89
|
+
for (short j = 0; j < kElemCols; j++) {
|
|
90
|
+
dst[i * kElemCols + j] =
|
|
91
|
+
static_cast<T>(src[r * str_x + (c + j) * str_y]);
|
|
92
|
+
}
|
|
93
|
+
}
|
|
94
|
+
}
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
template <
|
|
98
|
+
typename T,
|
|
99
|
+
typename SrcPtrType,
|
|
100
|
+
typename StrX,
|
|
101
|
+
typename StrY,
|
|
102
|
+
typename LimX,
|
|
103
|
+
typename OffX = Int<0>,
|
|
104
|
+
typename OffY = Int<0>>
|
|
105
|
+
METAL_FUNC static constexpr void load_rows(
|
|
106
|
+
thread dtype_frag_t<T>& dst,
|
|
107
|
+
SrcPtrType src,
|
|
108
|
+
StrX str_x,
|
|
109
|
+
StrY str_y,
|
|
110
|
+
LimX lim_x,
|
|
111
|
+
OffX off_x = {},
|
|
112
|
+
OffY off_y = {}) {
|
|
113
|
+
const short2 sc = get_coord();
|
|
114
|
+
STEEL_PRAGMA_UNROLL
|
|
115
|
+
for (short i = 0; i < kElemRows; i++) {
|
|
116
|
+
const auto r = off_x + i * kElemRowsJump + sc.y;
|
|
117
|
+
const auto c = off_y + sc.x;
|
|
118
|
+
|
|
119
|
+
if (r < lim_x) {
|
|
120
|
+
if constexpr (metal::is_same_v<StrY, Int<1>>) {
|
|
121
|
+
STEEL_PRAGMA_UNROLL
|
|
122
|
+
for (short j = 0; j < kElemCols; j++) {
|
|
123
|
+
dst[i * kElemCols + j] = static_cast<T>(src[r * str_x + (c + j)]);
|
|
124
|
+
}
|
|
125
|
+
} else {
|
|
126
|
+
STEEL_PRAGMA_UNROLL
|
|
127
|
+
for (short j = 0; j < kElemCols; j++) {
|
|
128
|
+
dst[i * kElemCols + j] =
|
|
129
|
+
static_cast<T>(src[r * str_x + (c + j) * str_y]);
|
|
130
|
+
}
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
} else {
|
|
134
|
+
dst = dtype_frag_t<T>(0);
|
|
135
|
+
}
|
|
136
|
+
}
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
template <
|
|
140
|
+
typename T,
|
|
141
|
+
typename SrcPtrType,
|
|
142
|
+
typename StrX,
|
|
143
|
+
typename StrY,
|
|
144
|
+
typename LimX,
|
|
145
|
+
typename LimY,
|
|
146
|
+
typename OffX = Int<0>,
|
|
147
|
+
typename OffY = Int<0>>
|
|
148
|
+
METAL_FUNC static constexpr void load_safe(
|
|
149
|
+
thread dtype_frag_t<T>& dst,
|
|
150
|
+
SrcPtrType src,
|
|
151
|
+
StrX str_x,
|
|
152
|
+
StrY str_y,
|
|
153
|
+
LimX lim_x,
|
|
154
|
+
LimY lim_y,
|
|
155
|
+
OffX off_x = {},
|
|
156
|
+
OffY off_y = {}) {
|
|
157
|
+
const short2 sc = get_coord();
|
|
158
|
+
STEEL_PRAGMA_UNROLL
|
|
159
|
+
for (short i = 0; i < kElemRows; i++) {
|
|
160
|
+
const auto r = off_x + i * kElemRowsJump + sc.y;
|
|
161
|
+
const auto c = off_y + sc.x;
|
|
162
|
+
STEEL_PRAGMA_UNROLL
|
|
163
|
+
for (short j = 0; j < kElemCols; j++) {
|
|
164
|
+
if (r < lim_x && (c + j) < lim_y) {
|
|
165
|
+
dst[i * kElemCols + j] =
|
|
166
|
+
static_cast<T>(src[r * str_x + (c + j) * str_y]);
|
|
167
|
+
} else {
|
|
168
|
+
dst[i * kElemCols + j] = T(0);
|
|
169
|
+
}
|
|
170
|
+
}
|
|
171
|
+
}
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
template <
|
|
175
|
+
typename T,
|
|
176
|
+
typename DstPtrType,
|
|
177
|
+
typename StrX,
|
|
178
|
+
typename StrY,
|
|
179
|
+
typename OffX = Int<0>,
|
|
180
|
+
typename OffY = Int<0>>
|
|
181
|
+
METAL_FUNC static constexpr void store(
|
|
182
|
+
const thread dtype_frag_t<T>& src,
|
|
183
|
+
DstPtrType dst,
|
|
184
|
+
StrX str_x,
|
|
185
|
+
StrY str_y,
|
|
186
|
+
OffX off_x = {},
|
|
187
|
+
OffY off_y = {}) {
|
|
188
|
+
using U = pointer_element_t<DstPtrType>;
|
|
189
|
+
|
|
190
|
+
const short2 sc = get_coord();
|
|
191
|
+
STEEL_PRAGMA_UNROLL
|
|
192
|
+
for (short i = 0; i < kElemRows; i++) {
|
|
193
|
+
const auto r = off_x + i * kElemRowsJump + sc.y;
|
|
194
|
+
const auto c = off_y + sc.x;
|
|
195
|
+
|
|
196
|
+
if constexpr (metal::is_same_v<StrY, Int<1>>) {
|
|
197
|
+
STEEL_PRAGMA_UNROLL
|
|
198
|
+
for (short j = 0; j < kElemCols; j++) {
|
|
199
|
+
dst[r * str_x + c + j] = static_cast<U>(src[i * kElemCols + j]);
|
|
200
|
+
}
|
|
201
|
+
} else {
|
|
202
|
+
STEEL_PRAGMA_UNROLL
|
|
203
|
+
for (short j = 0; j < kElemCols; j++) {
|
|
204
|
+
dst[r * str_x + (c + j) * str_y] =
|
|
205
|
+
static_cast<U>(src[i * kElemCols + j]);
|
|
206
|
+
}
|
|
207
|
+
}
|
|
208
|
+
}
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
template <
|
|
212
|
+
typename T,
|
|
213
|
+
typename DstPtrType,
|
|
214
|
+
typename StrX,
|
|
215
|
+
typename StrY,
|
|
216
|
+
typename LimX,
|
|
217
|
+
typename OffX = Int<0>,
|
|
218
|
+
typename OffY = Int<0>>
|
|
219
|
+
METAL_FUNC static constexpr void store_rows(
|
|
220
|
+
const thread dtype_frag_t<T>& src,
|
|
221
|
+
DstPtrType dst,
|
|
222
|
+
StrX str_x,
|
|
223
|
+
StrY str_y,
|
|
224
|
+
LimX lim_x,
|
|
225
|
+
OffX off_x = {},
|
|
226
|
+
OffY off_y = {}) {
|
|
227
|
+
using U = pointer_element_t<DstPtrType>;
|
|
228
|
+
|
|
229
|
+
const short2 sc = get_coord();
|
|
230
|
+
STEEL_PRAGMA_UNROLL
|
|
231
|
+
for (short i = 0; i < kElemRows; i++) {
|
|
232
|
+
const auto r = off_x + i * kElemRowsJump + sc.y;
|
|
233
|
+
const auto c = off_y + sc.x;
|
|
234
|
+
|
|
235
|
+
if (r < lim_x) {
|
|
236
|
+
if constexpr (metal::is_same_v<StrY, Int<1>>) {
|
|
237
|
+
STEEL_PRAGMA_UNROLL
|
|
238
|
+
for (short j = 0; j < kElemCols; j++) {
|
|
239
|
+
dst[r * str_x + c + j] = static_cast<U>(src[i * kElemCols + j]);
|
|
240
|
+
}
|
|
241
|
+
} else {
|
|
242
|
+
STEEL_PRAGMA_UNROLL
|
|
243
|
+
for (short j = 0; j < kElemCols; j++) {
|
|
244
|
+
dst[r * str_x + (c + j) * str_y] =
|
|
245
|
+
static_cast<U>(src[i * kElemCols + j]);
|
|
246
|
+
}
|
|
247
|
+
}
|
|
248
|
+
}
|
|
249
|
+
}
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
template <
|
|
253
|
+
typename T,
|
|
254
|
+
typename DstPtrType,
|
|
255
|
+
typename StrX,
|
|
256
|
+
typename StrY,
|
|
257
|
+
typename LimX,
|
|
258
|
+
typename LimY,
|
|
259
|
+
typename OffX = Int<0>,
|
|
260
|
+
typename OffY = Int<0>>
|
|
261
|
+
METAL_FUNC static constexpr void store_safe(
|
|
262
|
+
const thread dtype_frag_t<T>& src,
|
|
263
|
+
DstPtrType dst,
|
|
264
|
+
StrX str_x,
|
|
265
|
+
StrY str_y,
|
|
266
|
+
LimX lim_x,
|
|
267
|
+
LimY lim_y,
|
|
268
|
+
OffX off_x = {},
|
|
269
|
+
OffY off_y = {}) {
|
|
270
|
+
using U = pointer_element_t<DstPtrType>;
|
|
271
|
+
|
|
272
|
+
const short2 sc = get_coord();
|
|
273
|
+
STEEL_PRAGMA_UNROLL
|
|
274
|
+
for (short i = 0; i < kElemRows; i++) {
|
|
275
|
+
const auto r = off_x + i * kElemRowsJump + sc.y;
|
|
276
|
+
const auto c = off_y + sc.x;
|
|
277
|
+
|
|
278
|
+
STEEL_PRAGMA_UNROLL
|
|
279
|
+
for (short j = 0; j < kElemCols; j++) {
|
|
280
|
+
if (r < lim_x && (c + j) < lim_y) {
|
|
281
|
+
dst[r * str_x + (c + j) * str_y] =
|
|
282
|
+
static_cast<U>(src[i * kElemCols + j]);
|
|
283
|
+
}
|
|
284
|
+
}
|
|
285
|
+
}
|
|
286
|
+
}
|
|
287
|
+
|
|
288
|
+
template <
|
|
289
|
+
typename T,
|
|
290
|
+
typename DstPtrType,
|
|
291
|
+
typename StrX,
|
|
292
|
+
typename StrY,
|
|
293
|
+
typename StartX,
|
|
294
|
+
typename StopX,
|
|
295
|
+
typename StartY,
|
|
296
|
+
typename StopY,
|
|
297
|
+
typename OffX = Int<0>,
|
|
298
|
+
typename OffY = Int<0>>
|
|
299
|
+
METAL_FUNC static constexpr void store_slice(
|
|
300
|
+
const thread dtype_frag_t<T>& src,
|
|
301
|
+
DstPtrType dst,
|
|
302
|
+
StrX str_x,
|
|
303
|
+
StrY str_y,
|
|
304
|
+
StartX start_x,
|
|
305
|
+
StopX stop_x,
|
|
306
|
+
StartY start_y,
|
|
307
|
+
StopY stop_y,
|
|
308
|
+
OffX off_x = Int<0>{},
|
|
309
|
+
OffY off_y = Int<0>{}) {
|
|
310
|
+
using U = pointer_element_t<DstPtrType>;
|
|
311
|
+
|
|
312
|
+
const short2 sc = get_coord();
|
|
313
|
+
|
|
314
|
+
const_for_loop<0, kElemRows, 1>([&](auto idx_row) {
|
|
315
|
+
const auto r = off_x + idx_row * Int<kElemRowsJump>{};
|
|
316
|
+
if (r >= stop_x - sc.y || r < start_x - sc.y) {
|
|
317
|
+
return;
|
|
318
|
+
}
|
|
319
|
+
|
|
320
|
+
const_for_loop<0, kElemCols, 1>([&](auto idx_col) {
|
|
321
|
+
const auto c = off_y + idx_col;
|
|
322
|
+
if (c >= stop_y - sc.x || c < start_y - sc.x) {
|
|
323
|
+
return;
|
|
324
|
+
}
|
|
325
|
+
|
|
326
|
+
const auto src_idx = idx_row * Int<kElemCols>{} + idx_col;
|
|
327
|
+
dst[(r + sc.y) * str_x + (c + sc.x) * str_y] =
|
|
328
|
+
static_cast<U>(src[src_idx]);
|
|
329
|
+
});
|
|
330
|
+
});
|
|
331
|
+
}
|
|
332
|
+
|
|
333
|
+
template <typename Op, typename T>
|
|
334
|
+
METAL_FUNC static constexpr void row_reduce(
|
|
335
|
+
thread const dtype_frag_t<T>& inp_vals,
|
|
336
|
+
thread T* reduced_vals) {
|
|
337
|
+
STEEL_PRAGMA_UNROLL
|
|
338
|
+
for (short i = 0; i < kElemRows; i++) {
|
|
339
|
+
T thr_reduce = Op::apply(
|
|
340
|
+
Op::apply(inp_vals[i * kElemCols + 0], inp_vals[i * kElemCols + 1]),
|
|
341
|
+
Op::apply(inp_vals[i * kElemCols + 2], inp_vals[i * kElemCols + 3]));
|
|
342
|
+
|
|
343
|
+
T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1));
|
|
344
|
+
qgr_reduce = Op::apply(thr_reduce, qgr_reduce);
|
|
345
|
+
|
|
346
|
+
T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8));
|
|
347
|
+
sgr_reduce = Op::apply(qgr_reduce, sgr_reduce);
|
|
348
|
+
|
|
349
|
+
reduced_vals[i] = Op::apply(reduced_vals[i], sgr_reduce);
|
|
350
|
+
}
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
template <typename Op, typename T>
|
|
354
|
+
METAL_FUNC static constexpr void row_bin_op(
|
|
355
|
+
thread dtype_frag_t<T>& inp_vals,
|
|
356
|
+
thread T* row_vals) {
|
|
357
|
+
STEEL_PRAGMA_UNROLL
|
|
358
|
+
for (short i = 0; i < kElemRows; i++) {
|
|
359
|
+
STEEL_PRAGMA_UNROLL
|
|
360
|
+
for (short j = 0; j < kElemCols; j++) {
|
|
361
|
+
inp_vals[i * kElemCols + j] =
|
|
362
|
+
Op::apply(inp_vals[i * kElemCols + j], row_vals[i]);
|
|
363
|
+
}
|
|
364
|
+
}
|
|
365
|
+
}
|
|
366
|
+
};
|
|
367
|
+
|
|
368
|
+
template <
|
|
369
|
+
typename T,
|
|
370
|
+
short kRows_,
|
|
371
|
+
short kCols_,
|
|
372
|
+
typename NAXFrag_t = BaseNAXFrag>
|
|
373
|
+
struct NAXSubTile {
|
|
374
|
+
STEEL_CONST short kRows = kRows_;
|
|
375
|
+
STEEL_CONST short kCols = kCols_;
|
|
376
|
+
|
|
377
|
+
STEEL_CONST short kFragRows = NAXFrag_t::kFragRows;
|
|
378
|
+
STEEL_CONST short kFragCols = NAXFrag_t::kFragCols;
|
|
379
|
+
STEEL_CONST short kElemsPerFrag = NAXFrag_t::kElemsPerFrag;
|
|
380
|
+
|
|
381
|
+
STEEL_CONST short kSubTileRows = kRows / kFragRows;
|
|
382
|
+
STEEL_CONST short kSubTileCols = kCols / kFragCols;
|
|
383
|
+
|
|
384
|
+
STEEL_CONST short kNumFrags = kSubTileRows * kSubTileCols;
|
|
385
|
+
STEEL_CONST short kElemsPerSubTile = kNumFrags * kElemsPerFrag;
|
|
386
|
+
|
|
387
|
+
STEEL_CONST int kRowsPerThread = kSubTileRows * NAXFrag_t::kElemRows;
|
|
388
|
+
STEEL_CONST int kColsPerThread = kSubTileCols * NAXFrag_t::kElemCols;
|
|
389
|
+
|
|
390
|
+
STEEL_CONST short kFragThrRows = NAXFrag_t::kElemRows;
|
|
391
|
+
STEEL_CONST short kFragThrCols = NAXFrag_t::kElemCols;
|
|
392
|
+
STEEL_CONST short kFragRowsJump = NAXFrag_t::kElemRowsJump;
|
|
393
|
+
|
|
394
|
+
using frag_type = typename NAXFrag_t::template dtype_frag_t<T>;
|
|
395
|
+
|
|
396
|
+
frag_type val_frags[kNumFrags];
|
|
397
|
+
|
|
398
|
+
METAL_FUNC constexpr void clear() {
|
|
399
|
+
STEEL_PRAGMA_UNROLL
|
|
400
|
+
for (short i = 0; i < kNumFrags; ++i) {
|
|
401
|
+
val_frags[i] = frag_type(0);
|
|
402
|
+
}
|
|
403
|
+
}
|
|
404
|
+
|
|
405
|
+
METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) {
|
|
406
|
+
return val_frags[i * kSubTileCols + j];
|
|
407
|
+
}
|
|
408
|
+
|
|
409
|
+
METAL_FUNC constexpr const thread frag_type& frag_at(
|
|
410
|
+
const short i,
|
|
411
|
+
const short j) const {
|
|
412
|
+
return val_frags[i * kSubTileCols + j];
|
|
413
|
+
}
|
|
414
|
+
|
|
415
|
+
template <int i, int j>
|
|
416
|
+
METAL_FUNC constexpr thread frag_type& frag_at() {
|
|
417
|
+
return val_frags[i * kSubTileCols + j];
|
|
418
|
+
}
|
|
419
|
+
|
|
420
|
+
template <int i, int j>
|
|
421
|
+
METAL_FUNC constexpr const thread frag_type& frag_at() const {
|
|
422
|
+
return val_frags[i * kSubTileCols + j];
|
|
423
|
+
}
|
|
424
|
+
|
|
425
|
+
METAL_FUNC thread T* elems() {
|
|
426
|
+
return reinterpret_cast<thread T*>(val_frags);
|
|
427
|
+
}
|
|
428
|
+
|
|
429
|
+
METAL_FUNC const thread T* elems() const {
|
|
430
|
+
return reinterpret_cast<const thread T*>(val_frags);
|
|
431
|
+
}
|
|
432
|
+
|
|
433
|
+
template <typename Op>
|
|
434
|
+
METAL_FUNC void row_reduce(thread metal::vec<T, kRowsPerThread>& vals) const {
|
|
435
|
+
STEEL_PRAGMA_UNROLL
|
|
436
|
+
for (short i = 0; i < kSubTileRows; ++i) {
|
|
437
|
+
STEEL_PRAGMA_UNROLL
|
|
438
|
+
for (short j = 0; j < kSubTileCols; ++j) {
|
|
439
|
+
NAXFrag_t::template row_reduce<Op>(
|
|
440
|
+
frag_at(i, j), &vals[i * kFragThrRows]);
|
|
441
|
+
}
|
|
442
|
+
}
|
|
443
|
+
}
|
|
444
|
+
|
|
445
|
+
template <typename Op>
|
|
446
|
+
METAL_FUNC void row_bin_op(thread metal::vec<T, kRowsPerThread>& vals) {
|
|
447
|
+
STEEL_PRAGMA_UNROLL
|
|
448
|
+
for (short i = 0; i < kSubTileRows; ++i) {
|
|
449
|
+
STEEL_PRAGMA_UNROLL
|
|
450
|
+
for (short j = 0; j < kSubTileCols; ++j) {
|
|
451
|
+
NAXFrag_t::template row_bin_op<Op>(
|
|
452
|
+
frag_at(i, j), &vals[i * kFragThrRows]);
|
|
453
|
+
}
|
|
454
|
+
}
|
|
455
|
+
}
|
|
456
|
+
|
|
457
|
+
template <
|
|
458
|
+
typename SrcPtrType,
|
|
459
|
+
typename StrX,
|
|
460
|
+
typename StrY,
|
|
461
|
+
typename OffX = Int<0>,
|
|
462
|
+
typename OffY = Int<0>>
|
|
463
|
+
METAL_FUNC constexpr void load(
|
|
464
|
+
SrcPtrType src,
|
|
465
|
+
StrX str_x,
|
|
466
|
+
StrY str_y,
|
|
467
|
+
OffX off_x = {},
|
|
468
|
+
OffY off_y = {}) {
|
|
469
|
+
STEEL_PRAGMA_UNROLL
|
|
470
|
+
for (short i = 0; i < kSubTileRows; ++i) {
|
|
471
|
+
STEEL_PRAGMA_UNROLL
|
|
472
|
+
for (short j = 0; j < kSubTileCols; ++j) {
|
|
473
|
+
NAXFrag_t::load(
|
|
474
|
+
frag_at(i, j),
|
|
475
|
+
src,
|
|
476
|
+
str_x,
|
|
477
|
+
str_y,
|
|
478
|
+
off_x + i * kFragRows,
|
|
479
|
+
off_y + j * kFragCols);
|
|
480
|
+
}
|
|
481
|
+
}
|
|
482
|
+
}
|
|
483
|
+
|
|
484
|
+
template <
|
|
485
|
+
typename DstPtrType,
|
|
486
|
+
typename StrX,
|
|
487
|
+
typename StrY,
|
|
488
|
+
typename OffX = Int<0>,
|
|
489
|
+
typename OffY = Int<0>>
|
|
490
|
+
METAL_FUNC constexpr void store(
|
|
491
|
+
DstPtrType dst,
|
|
492
|
+
StrX str_x,
|
|
493
|
+
StrY str_y,
|
|
494
|
+
OffX off_x = {},
|
|
495
|
+
OffY off_y = {}) const {
|
|
496
|
+
STEEL_PRAGMA_UNROLL
|
|
497
|
+
for (short i = 0; i < kSubTileRows; ++i) {
|
|
498
|
+
STEEL_PRAGMA_UNROLL
|
|
499
|
+
for (short j = 0; j < kSubTileCols; ++j) {
|
|
500
|
+
NAXFrag_t::store(
|
|
501
|
+
frag_at(i, j),
|
|
502
|
+
dst,
|
|
503
|
+
str_x,
|
|
504
|
+
str_y,
|
|
505
|
+
off_x + i * kFragRows,
|
|
506
|
+
off_y + j * kFragCols);
|
|
507
|
+
}
|
|
508
|
+
}
|
|
509
|
+
}
|
|
510
|
+
|
|
511
|
+
template <
|
|
512
|
+
typename SrcPtrType,
|
|
513
|
+
typename StrX,
|
|
514
|
+
typename StrY,
|
|
515
|
+
typename LimX,
|
|
516
|
+
typename OffX = Int<0>,
|
|
517
|
+
typename OffY = Int<0>>
|
|
518
|
+
METAL_FUNC constexpr void load_rows(
|
|
519
|
+
SrcPtrType src,
|
|
520
|
+
StrX str_x,
|
|
521
|
+
StrY str_y,
|
|
522
|
+
LimX lim_x,
|
|
523
|
+
OffX off_x = {},
|
|
524
|
+
OffY off_y = {}) {
|
|
525
|
+
STEEL_PRAGMA_UNROLL
|
|
526
|
+
for (int i = 0; i < kSubTileRows; ++i) {
|
|
527
|
+
STEEL_PRAGMA_UNROLL
|
|
528
|
+
for (int j = 0; j < kSubTileCols; ++j) {
|
|
529
|
+
NAXFrag_t::load_rows(
|
|
530
|
+
frag_at(i, j),
|
|
531
|
+
src,
|
|
532
|
+
str_x,
|
|
533
|
+
str_y,
|
|
534
|
+
lim_x,
|
|
535
|
+
off_x + (i * kFragRows),
|
|
536
|
+
off_y + (j * kFragCols));
|
|
537
|
+
}
|
|
538
|
+
}
|
|
539
|
+
}
|
|
540
|
+
|
|
541
|
+
template <
|
|
542
|
+
typename SrcPtrType,
|
|
543
|
+
typename StrX,
|
|
544
|
+
typename StrY,
|
|
545
|
+
typename LimX,
|
|
546
|
+
typename LimY,
|
|
547
|
+
typename OffX = Int<0>,
|
|
548
|
+
typename OffY = Int<0>>
|
|
549
|
+
METAL_FUNC constexpr void load_safe(
|
|
550
|
+
SrcPtrType src,
|
|
551
|
+
StrX str_x,
|
|
552
|
+
StrY str_y,
|
|
553
|
+
LimX lim_x,
|
|
554
|
+
LimY lim_y,
|
|
555
|
+
OffX off_x = {},
|
|
556
|
+
OffY off_y = {}) {
|
|
557
|
+
STEEL_PRAGMA_UNROLL
|
|
558
|
+
for (int i = 0; i < kSubTileRows; ++i) {
|
|
559
|
+
STEEL_PRAGMA_UNROLL
|
|
560
|
+
for (int j = 0; j < kSubTileCols; ++j) {
|
|
561
|
+
NAXFrag_t::load_safe(
|
|
562
|
+
frag_at(i, j),
|
|
563
|
+
src,
|
|
564
|
+
str_x,
|
|
565
|
+
str_y,
|
|
566
|
+
lim_x,
|
|
567
|
+
lim_y,
|
|
568
|
+
off_x + (i * kFragRows),
|
|
569
|
+
off_y + (j * kFragCols));
|
|
570
|
+
}
|
|
571
|
+
}
|
|
572
|
+
}
|
|
573
|
+
|
|
574
|
+
template <
|
|
575
|
+
typename DstPtrType,
|
|
576
|
+
typename StrX,
|
|
577
|
+
typename StrY,
|
|
578
|
+
typename LimX,
|
|
579
|
+
typename LimY,
|
|
580
|
+
typename OffX = Int<0>,
|
|
581
|
+
typename OffY = Int<0>>
|
|
582
|
+
METAL_FUNC constexpr void store_safe(
|
|
583
|
+
DstPtrType dst,
|
|
584
|
+
StrX str_x,
|
|
585
|
+
StrY str_y,
|
|
586
|
+
LimX lim_x,
|
|
587
|
+
LimY lim_y,
|
|
588
|
+
OffX off_x = {},
|
|
589
|
+
OffY off_y = {}) const {
|
|
590
|
+
STEEL_PRAGMA_UNROLL
|
|
591
|
+
for (int i = 0; i < kSubTileRows; ++i) {
|
|
592
|
+
STEEL_PRAGMA_UNROLL
|
|
593
|
+
for (int j = 0; j < kSubTileCols; ++j) {
|
|
594
|
+
NAXFrag_t::store_safe(
|
|
595
|
+
frag_at(i, j),
|
|
596
|
+
dst,
|
|
597
|
+
str_x,
|
|
598
|
+
str_y,
|
|
599
|
+
lim_x,
|
|
600
|
+
lim_y,
|
|
601
|
+
off_x + (i * kFragRows),
|
|
602
|
+
off_y + (j * kFragCols));
|
|
603
|
+
}
|
|
604
|
+
}
|
|
605
|
+
}
|
|
606
|
+
|
|
607
|
+
template <
|
|
608
|
+
typename DstPtrType,
|
|
609
|
+
typename StrX,
|
|
610
|
+
typename StrY,
|
|
611
|
+
typename LimX,
|
|
612
|
+
typename OffX = Int<0>,
|
|
613
|
+
typename OffY = Int<0>>
|
|
614
|
+
METAL_FUNC constexpr void store_rows(
|
|
615
|
+
DstPtrType dst,
|
|
616
|
+
StrX str_x,
|
|
617
|
+
StrY str_y,
|
|
618
|
+
LimX lim_x,
|
|
619
|
+
OffX off_x = {},
|
|
620
|
+
OffY off_y = {}) const {
|
|
621
|
+
STEEL_PRAGMA_UNROLL
|
|
622
|
+
for (int i = 0; i < kSubTileRows; ++i) {
|
|
623
|
+
STEEL_PRAGMA_UNROLL
|
|
624
|
+
for (int j = 0; j < kSubTileCols; ++j) {
|
|
625
|
+
NAXFrag_t::store_safe(
|
|
626
|
+
frag_at(i, j),
|
|
627
|
+
dst,
|
|
628
|
+
str_x,
|
|
629
|
+
str_y,
|
|
630
|
+
lim_x,
|
|
631
|
+
off_x + (i * kFragRows),
|
|
632
|
+
off_y + (j * kFragCols));
|
|
633
|
+
}
|
|
634
|
+
}
|
|
635
|
+
}
|
|
636
|
+
|
|
637
|
+
template <
|
|
638
|
+
typename DstPtrType,
|
|
639
|
+
typename StrX,
|
|
640
|
+
typename StrY,
|
|
641
|
+
typename StartX,
|
|
642
|
+
typename StopX,
|
|
643
|
+
typename StartY,
|
|
644
|
+
typename StopY,
|
|
645
|
+
typename OffX = Int<0>,
|
|
646
|
+
typename OffY = Int<0>>
|
|
647
|
+
METAL_FUNC constexpr void store_slice(
|
|
648
|
+
DstPtrType dst,
|
|
649
|
+
StrX str_x,
|
|
650
|
+
StrY str_y,
|
|
651
|
+
StartX start_x,
|
|
652
|
+
StopX stop_x,
|
|
653
|
+
StartY start_y,
|
|
654
|
+
StopY stop_y,
|
|
655
|
+
OffX off_x = Int<0>{},
|
|
656
|
+
OffY off_y = Int<0>{}) const {
|
|
657
|
+
const_for_loop<0, kSubTileRows, 1>([&](auto idx_row) {
|
|
658
|
+
const_for_loop<0, kSubTileCols, 1>([&](auto idx_col) {
|
|
659
|
+
NAXFrag_t::store_slice(
|
|
660
|
+
frag_at<idx_row.value, idx_col.value>(),
|
|
661
|
+
dst,
|
|
662
|
+
str_x,
|
|
663
|
+
str_y,
|
|
664
|
+
start_x,
|
|
665
|
+
stop_x,
|
|
666
|
+
start_y,
|
|
667
|
+
stop_y,
|
|
668
|
+
off_x + idx_row * Int<kFragRows>{},
|
|
669
|
+
off_y + idx_col * Int<kFragCols>{});
|
|
670
|
+
});
|
|
671
|
+
});
|
|
672
|
+
}
|
|
673
|
+
};
|
|
674
|
+
|
|
675
|
+
template <
|
|
676
|
+
short RC,
|
|
677
|
+
short CC,
|
|
678
|
+
short RA,
|
|
679
|
+
short CA,
|
|
680
|
+
short RB,
|
|
681
|
+
short CB,
|
|
682
|
+
typename CType,
|
|
683
|
+
typename AType,
|
|
684
|
+
typename BType,
|
|
685
|
+
bool transpose_a,
|
|
686
|
+
bool transpose_b,
|
|
687
|
+
typename NAXFrag_t = BaseNAXFrag>
|
|
688
|
+
METAL_FUNC void subtile_matmad_nax(
|
|
689
|
+
thread NAXSubTile<CType, RC, CC, NAXFrag_t>& C,
|
|
690
|
+
thread NAXSubTile<AType, RA, CA, NAXFrag_t>& A,
|
|
691
|
+
metal::bool_constant<transpose_a>,
|
|
692
|
+
thread NAXSubTile<BType, RB, CB, NAXFrag_t>& B,
|
|
693
|
+
metal::bool_constant<transpose_b>) {
|
|
694
|
+
// Static checks
|
|
695
|
+
constexpr short FMa = transpose_a ? CA : RA;
|
|
696
|
+
constexpr short FMc = RC;
|
|
697
|
+
static_assert(FMa == FMc, "NAX matmul: M dimensions do not match");
|
|
698
|
+
|
|
699
|
+
constexpr short FNb = transpose_b ? RB : CB;
|
|
700
|
+
constexpr short FNc = CC;
|
|
701
|
+
static_assert(FNb == FNc, "NAX matmul: N dimensions do not match");
|
|
702
|
+
|
|
703
|
+
constexpr short FKa = transpose_a ? RA : CA;
|
|
704
|
+
constexpr short FKb = transpose_b ? CB : RB;
|
|
705
|
+
static_assert(FKa == FKb, "NAX matmul: N dimensions do not match");
|
|
706
|
+
|
|
707
|
+
constexpr short FM = FMc;
|
|
708
|
+
constexpr short FN = FNc;
|
|
709
|
+
constexpr short FK = FKa;
|
|
710
|
+
|
|
711
|
+
constexpr int TM = FM / 16;
|
|
712
|
+
constexpr int TN = FN / 16;
|
|
713
|
+
constexpr int TK = FK / 16;
|
|
714
|
+
|
|
715
|
+
// Create Matmul descriptor
|
|
716
|
+
constexpr auto desc = mpp::tensor_ops::matmul2d_descriptor(
|
|
717
|
+
FM,
|
|
718
|
+
FN,
|
|
719
|
+
FK,
|
|
720
|
+
transpose_a,
|
|
721
|
+
transpose_b,
|
|
722
|
+
true,
|
|
723
|
+
mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate);
|
|
724
|
+
|
|
725
|
+
// Create matmul op
|
|
726
|
+
mpp::tensor_ops::matmul2d<desc, metal::execution_simdgroup> gemm_op;
|
|
727
|
+
|
|
728
|
+
// Create matmul operands in registers
|
|
729
|
+
auto ct_a =
|
|
730
|
+
gemm_op.template get_left_input_cooperative_tensor<AType, BType, CType>();
|
|
731
|
+
auto ct_b =
|
|
732
|
+
gemm_op
|
|
733
|
+
.template get_right_input_cooperative_tensor<AType, BType, CType>();
|
|
734
|
+
|
|
735
|
+
// Create matmul output in register
|
|
736
|
+
auto ct_c = gemm_op.template get_destination_cooperative_tensor<
|
|
737
|
+
decltype(ct_a),
|
|
738
|
+
decltype(ct_b),
|
|
739
|
+
CType>();
|
|
740
|
+
|
|
741
|
+
// Load A in to left operand registers
|
|
742
|
+
STEEL_PRAGMA_UNROLL
|
|
743
|
+
for (short mm = 0; mm < TM; mm++) {
|
|
744
|
+
STEEL_PRAGMA_UNROLL
|
|
745
|
+
for (short kk = 0; kk < TK; kk++) {
|
|
746
|
+
const short fi = transpose_a ? kk : mm;
|
|
747
|
+
const short fj = transpose_a ? mm : kk;
|
|
748
|
+
|
|
749
|
+
STEEL_PRAGMA_UNROLL
|
|
750
|
+
for (short i = 0; i < 8; i++) {
|
|
751
|
+
ct_a[(TK * mm + kk) * 8 + i] = A.frag_at(fi, fj)[i];
|
|
752
|
+
}
|
|
753
|
+
}
|
|
754
|
+
}
|
|
755
|
+
|
|
756
|
+
// Load B into right operand registers
|
|
757
|
+
STEEL_PRAGMA_UNROLL
|
|
758
|
+
for (short nn = 0; nn < TN; nn++) {
|
|
759
|
+
STEEL_PRAGMA_UNROLL
|
|
760
|
+
for (short kk = 0; kk < TK; kk++) {
|
|
761
|
+
const short fi = transpose_b ? nn : kk;
|
|
762
|
+
const short fj = transpose_b ? kk : nn;
|
|
763
|
+
|
|
764
|
+
STEEL_PRAGMA_UNROLL
|
|
765
|
+
for (short i = 0; i < 8; i++) {
|
|
766
|
+
ct_b[(TN * kk + nn) * 8 + i] = B.frag_at(fi, fj)[i];
|
|
767
|
+
}
|
|
768
|
+
}
|
|
769
|
+
}
|
|
770
|
+
|
|
771
|
+
// Load C into output registers (op handles accumulation)
|
|
772
|
+
STEEL_PRAGMA_UNROLL
|
|
773
|
+
for (short i = 0; i < ct_c.get_capacity(); i++) {
|
|
774
|
+
ct_c[i] = C.elems()[i];
|
|
775
|
+
}
|
|
776
|
+
|
|
777
|
+
// Do matmul
|
|
778
|
+
gemm_op.run(ct_a, ct_b, ct_c);
|
|
779
|
+
|
|
780
|
+
// Copy out results
|
|
781
|
+
STEEL_PRAGMA_UNROLL
|
|
782
|
+
for (short i = 0; i < ct_c.get_capacity(); i++) {
|
|
783
|
+
C.elems()[i] = ct_c[i];
|
|
784
|
+
}
|
|
785
|
+
}
|
|
786
|
+
|
|
787
|
+
template <typename T, short kTileRows_, short kTileCols_, class NAXSubTile_>
|
|
788
|
+
struct NAXTile {
|
|
789
|
+
using NAXSubTile_t = NAXSubTile_;
|
|
790
|
+
using elem_type = T;
|
|
791
|
+
STEEL_CONST short kSubTileRows = NAXSubTile_t::kRows;
|
|
792
|
+
STEEL_CONST short kSubTileCols = NAXSubTile_t::kCols;
|
|
793
|
+
STEEL_CONST short kElemsPerSubTile = NAXSubTile_t::kElemsPerSubTile;
|
|
794
|
+
|
|
795
|
+
STEEL_CONST short kTileRows = kTileRows_;
|
|
796
|
+
STEEL_CONST short kTileCols = kTileCols_;
|
|
797
|
+
|
|
798
|
+
STEEL_CONST short kRows = kTileRows * kSubTileRows;
|
|
799
|
+
STEEL_CONST short kCols = kTileCols * kSubTileCols;
|
|
800
|
+
|
|
801
|
+
STEEL_CONST short kSubTiles = kTileRows * kTileCols;
|
|
802
|
+
STEEL_CONST short kElemsPerTile = kSubTiles * kElemsPerSubTile;
|
|
803
|
+
|
|
804
|
+
STEEL_CONST short kRowsPerThread = kTileRows * NAXSubTile_t::kRowsPerThread;
|
|
805
|
+
STEEL_CONST short kColsPerThread = kTileCols * NAXSubTile_t::kColsPerThread;
|
|
806
|
+
|
|
807
|
+
STEEL_CONST short kSubTileThrRows = NAXSubTile_t::kRowsPerThread;
|
|
808
|
+
STEEL_CONST short kSubTileThrCols = NAXSubTile_t::kColsPerThread;
|
|
809
|
+
|
|
810
|
+
NAXSubTile_t val_subtiles[kSubTiles];
|
|
811
|
+
|
|
812
|
+
METAL_FUNC NAXTile() thread {}
|
|
813
|
+
|
|
814
|
+
METAL_FUNC constexpr void clear() {
|
|
815
|
+
STEEL_PRAGMA_UNROLL
|
|
816
|
+
for (short i = 0; i < kSubTiles; ++i) {
|
|
817
|
+
val_subtiles[i].clear();
|
|
818
|
+
}
|
|
819
|
+
}
|
|
820
|
+
|
|
821
|
+
METAL_FUNC constexpr thread NAXSubTile_t& subtile_at(
|
|
822
|
+
const short i,
|
|
823
|
+
const short j) {
|
|
824
|
+
return val_subtiles[i * kTileCols + j];
|
|
825
|
+
}
|
|
826
|
+
|
|
827
|
+
METAL_FUNC constexpr const thread NAXSubTile_t& subtile_at(
|
|
828
|
+
const short i,
|
|
829
|
+
const short j) const {
|
|
830
|
+
return val_subtiles[i * kTileCols + j];
|
|
831
|
+
}
|
|
832
|
+
|
|
833
|
+
template <int i, int j>
|
|
834
|
+
METAL_FUNC constexpr const thread NAXSubTile_t& subtile_at() const {
|
|
835
|
+
return val_subtiles[i * kTileCols + j];
|
|
836
|
+
}
|
|
837
|
+
|
|
838
|
+
METAL_FUNC thread elem_type* elems() {
|
|
839
|
+
return reinterpret_cast<thread elem_type*>(val_subtiles[0].elems());
|
|
840
|
+
}
|
|
841
|
+
|
|
842
|
+
METAL_FUNC const thread elem_type* elems() const {
|
|
843
|
+
return reinterpret_cast<const thread elem_type*>(val_subtiles[0].elems());
|
|
844
|
+
}
|
|
845
|
+
|
|
846
|
+
template <typename Op>
|
|
847
|
+
METAL_FUNC void row_reduce(thread metal::vec<T, kRowsPerThread>& vals) const {
|
|
848
|
+
auto sub_rows = (thread metal::vec<T, kSubTileThrRows>*)(&vals);
|
|
849
|
+
STEEL_PRAGMA_UNROLL
|
|
850
|
+
for (short i = 0; i < kTileRows; ++i) {
|
|
851
|
+
STEEL_PRAGMA_UNROLL
|
|
852
|
+
for (short j = 0; j < kTileCols; ++j) {
|
|
853
|
+
subtile_at(i, j).template row_reduce<Op>(sub_rows[i]);
|
|
854
|
+
}
|
|
855
|
+
}
|
|
856
|
+
}
|
|
857
|
+
|
|
858
|
+
template <typename Op>
|
|
859
|
+
METAL_FUNC void row_bin_op(thread metal::vec<T, kRowsPerThread>& vals) {
|
|
860
|
+
auto sub_rows = (thread metal::vec<T, kSubTileThrRows>*)(&vals);
|
|
861
|
+
STEEL_PRAGMA_UNROLL
|
|
862
|
+
for (short i = 0; i < kTileRows; ++i) {
|
|
863
|
+
STEEL_PRAGMA_UNROLL
|
|
864
|
+
for (short j = 0; j < kTileCols; ++j) {
|
|
865
|
+
subtile_at(i, j).template row_bin_op<Op>(sub_rows[i]);
|
|
866
|
+
}
|
|
867
|
+
}
|
|
868
|
+
}
|
|
869
|
+
|
|
870
|
+
template <typename U, int str_x, int str_y>
|
|
871
|
+
METAL_FUNC void load(const threadgroup U* src) {
|
|
872
|
+
STEEL_PRAGMA_UNROLL
|
|
873
|
+
for (short i = 0; i < kTileRows; ++i) {
|
|
874
|
+
STEEL_PRAGMA_UNROLL
|
|
875
|
+
for (short j = 0; j < kTileCols; ++j) {
|
|
876
|
+
subtile_at(i, j).load(
|
|
877
|
+
src,
|
|
878
|
+
Int<str_x>{},
|
|
879
|
+
Int<str_y>{},
|
|
880
|
+
i * kSubTileRows,
|
|
881
|
+
j * kSubTileCols);
|
|
882
|
+
}
|
|
883
|
+
}
|
|
884
|
+
}
|
|
885
|
+
|
|
886
|
+
template <typename U, int str_x, int str_y>
|
|
887
|
+
METAL_FUNC void store(threadgroup U* dst) const {
|
|
888
|
+
STEEL_PRAGMA_UNROLL
|
|
889
|
+
for (short i = 0; i < kTileRows; ++i) {
|
|
890
|
+
STEEL_PRAGMA_UNROLL
|
|
891
|
+
for (short j = 0; j < kTileCols; ++j) {
|
|
892
|
+
subtile_at(i, j).store(
|
|
893
|
+
dst,
|
|
894
|
+
Int<str_x>{},
|
|
895
|
+
Int<str_y>{},
|
|
896
|
+
i * kSubTileRows,
|
|
897
|
+
j * kSubTileCols);
|
|
898
|
+
}
|
|
899
|
+
}
|
|
900
|
+
}
|
|
901
|
+
|
|
902
|
+
template <typename U>
|
|
903
|
+
METAL_FUNC void load(const device U* src, const int ld) {
|
|
904
|
+
STEEL_PRAGMA_UNROLL
|
|
905
|
+
for (short i = 0; i < kTileRows; ++i) {
|
|
906
|
+
STEEL_PRAGMA_UNROLL
|
|
907
|
+
for (short j = 0; j < kTileCols; ++j) {
|
|
908
|
+
subtile_at(i, j).load(
|
|
909
|
+
&src[(i * kSubTileRows * ld + j * kSubTileCols)], ld, Int<1>{});
|
|
910
|
+
}
|
|
911
|
+
}
|
|
912
|
+
}
|
|
913
|
+
|
|
914
|
+
template <typename U>
|
|
915
|
+
METAL_FUNC void store(device U* dst, const int ld) const {
|
|
916
|
+
STEEL_PRAGMA_UNROLL
|
|
917
|
+
for (short i = 0; i < kTileRows; ++i) {
|
|
918
|
+
STEEL_PRAGMA_UNROLL
|
|
919
|
+
for (short j = 0; j < kTileCols; ++j) {
|
|
920
|
+
subtile_at(i, j).store(
|
|
921
|
+
&dst[(i * kSubTileRows * ld + j * kSubTileCols)], ld, Int<1>{});
|
|
922
|
+
}
|
|
923
|
+
}
|
|
924
|
+
}
|
|
925
|
+
|
|
926
|
+
template <typename U>
|
|
927
|
+
METAL_FUNC void
|
|
928
|
+
load_rows(const device U* src, const int ld, const short n_rows) {
|
|
929
|
+
STEEL_PRAGMA_UNROLL
|
|
930
|
+
for (int i = 0; i < kTileRows; ++i) {
|
|
931
|
+
STEEL_PRAGMA_UNROLL
|
|
932
|
+
for (int j = 0; j < kTileCols; ++j) {
|
|
933
|
+
subtile_at(i, j).load_rows(
|
|
934
|
+
&src[(i * kSubTileRows) * ld + (j * kSubTileCols)],
|
|
935
|
+
ld,
|
|
936
|
+
Int<1>{},
|
|
937
|
+
n_rows - i * kSubTileRows);
|
|
938
|
+
}
|
|
939
|
+
}
|
|
940
|
+
}
|
|
941
|
+
|
|
942
|
+
template <typename U>
|
|
943
|
+
METAL_FUNC void
|
|
944
|
+
load_safe(const device U* src, const int ld, const short2 src_tile_dims) {
|
|
945
|
+
STEEL_PRAGMA_UNROLL
|
|
946
|
+
for (int i = 0; i < kTileRows; ++i) {
|
|
947
|
+
STEEL_PRAGMA_UNROLL
|
|
948
|
+
for (int j = 0; j < kTileCols; ++j) {
|
|
949
|
+
subtile_at(i, j).load_safe(
|
|
950
|
+
src,
|
|
951
|
+
ld,
|
|
952
|
+
Int<1>{},
|
|
953
|
+
src_tile_dims.y,
|
|
954
|
+
src_tile_dims.x,
|
|
955
|
+
i * kSubTileRows,
|
|
956
|
+
j * kSubTileCols);
|
|
957
|
+
}
|
|
958
|
+
}
|
|
959
|
+
}
|
|
960
|
+
|
|
961
|
+
template <typename U>
|
|
962
|
+
METAL_FUNC void store_rows(device U* dst, const int ld, const short n_rows)
|
|
963
|
+
const {
|
|
964
|
+
STEEL_PRAGMA_UNROLL
|
|
965
|
+
for (int i = 0; i < kTileRows; ++i) {
|
|
966
|
+
STEEL_PRAGMA_UNROLL
|
|
967
|
+
for (int j = 0; j < kTileCols; ++j) {
|
|
968
|
+
subtile_at(i, j).store_rows(
|
|
969
|
+
&dst[(i * kSubTileRows) * ld + (j * kSubTileCols)],
|
|
970
|
+
ld,
|
|
971
|
+
Int<1>{},
|
|
972
|
+
n_rows - i * kSubTileRows);
|
|
973
|
+
}
|
|
974
|
+
}
|
|
975
|
+
}
|
|
976
|
+
|
|
977
|
+
template <typename U>
|
|
978
|
+
METAL_FUNC void
|
|
979
|
+
store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const {
|
|
980
|
+
STEEL_PRAGMA_UNROLL
|
|
981
|
+
for (int i = 0; i < kTileRows; ++i) {
|
|
982
|
+
STEEL_PRAGMA_UNROLL
|
|
983
|
+
for (int j = 0; j < kTileCols; ++j) {
|
|
984
|
+
subtile_at(i, j).store_safe(
|
|
985
|
+
dst,
|
|
986
|
+
ld,
|
|
987
|
+
Int<1>{},
|
|
988
|
+
dst_tile_dims.y,
|
|
989
|
+
dst_tile_dims.x,
|
|
990
|
+
i * kSubTileRows,
|
|
991
|
+
j * kSubTileCols);
|
|
992
|
+
}
|
|
993
|
+
}
|
|
994
|
+
}
|
|
995
|
+
|
|
996
|
+
template <typename U>
|
|
997
|
+
METAL_FUNC void store_slice(
|
|
998
|
+
device U* dst,
|
|
999
|
+
const int ld,
|
|
1000
|
+
const short2 start,
|
|
1001
|
+
const short2 stop) const {
|
|
1002
|
+
const_for_loop<0, kTileRows, 1>([&](auto idx_row) {
|
|
1003
|
+
const_for_loop<0, kTileCols, 1>([&](auto idx_col) {
|
|
1004
|
+
subtile_at<idx_row.value, idx_col.value>().store_slice(
|
|
1005
|
+
dst,
|
|
1006
|
+
ld,
|
|
1007
|
+
Int<1>{},
|
|
1008
|
+
start.y,
|
|
1009
|
+
stop.y,
|
|
1010
|
+
start.x,
|
|
1011
|
+
stop.x,
|
|
1012
|
+
idx_row * Int<kSubTileRows>{},
|
|
1013
|
+
idx_col * Int<kSubTileCols>{});
|
|
1014
|
+
});
|
|
1015
|
+
});
|
|
1016
|
+
}
|
|
1017
|
+
};
|
|
1018
|
+
|
|
1019
|
+
template <
|
|
1020
|
+
class CTile,
|
|
1021
|
+
class ATile,
|
|
1022
|
+
class BTile,
|
|
1023
|
+
bool transpose_a,
|
|
1024
|
+
bool transpose_b>
|
|
1025
|
+
METAL_FUNC void tile_matmad_nax(
|
|
1026
|
+
thread CTile& C,
|
|
1027
|
+
thread ATile& A,
|
|
1028
|
+
metal::bool_constant<transpose_a>,
|
|
1029
|
+
thread BTile& B,
|
|
1030
|
+
metal::bool_constant<transpose_b>) {
|
|
1031
|
+
// Static checks
|
|
1032
|
+
constexpr short TMa = transpose_a ? ATile::kTileCols : ATile::kTileRows;
|
|
1033
|
+
constexpr short TMc = CTile::kTileRows;
|
|
1034
|
+
static_assert(TMa == TMc, "NAX tile matmul: M dimensions do not match");
|
|
1035
|
+
|
|
1036
|
+
constexpr short FMa = transpose_a ? ATile::kSubTileCols : ATile::kSubTileRows;
|
|
1037
|
+
constexpr short FMc = CTile::kSubTileRows;
|
|
1038
|
+
static_assert(FMa == FMc, "NAX subtile matmul: M dimensions do not match");
|
|
1039
|
+
|
|
1040
|
+
constexpr short TNb = transpose_b ? BTile::kTileRows : BTile::kTileCols;
|
|
1041
|
+
constexpr short TNc = CTile::kTileCols;
|
|
1042
|
+
static_assert(TNb == TNc, "NAX tile matmul: N dimensions do not match");
|
|
1043
|
+
|
|
1044
|
+
constexpr short FNb = transpose_b ? BTile::kSubTileRows : BTile::kSubTileCols;
|
|
1045
|
+
constexpr short FNc = CTile::kSubTileCols;
|
|
1046
|
+
static_assert(FNb == FNc, "NAX subtile matmul: N dimensions do not match");
|
|
1047
|
+
|
|
1048
|
+
constexpr short TKa = transpose_a ? ATile::kTileRows : ATile::kTileCols;
|
|
1049
|
+
constexpr short TKb = transpose_b ? BTile::kTileCols : BTile::kTileRows;
|
|
1050
|
+
static_assert(TKa == TKb, "NAX tile matmul: K dimensions do not match");
|
|
1051
|
+
|
|
1052
|
+
constexpr short FKa = transpose_a ? ATile::kSubTileRows : ATile::kSubTileCols;
|
|
1053
|
+
constexpr short FKb = transpose_b ? BTile::kSubTileCols : BTile::kSubTileRows;
|
|
1054
|
+
static_assert(FKa == FKb, "NAX subtile matmul: K dimensions do not match");
|
|
1055
|
+
|
|
1056
|
+
constexpr short TM = TMc;
|
|
1057
|
+
constexpr short TN = TNc;
|
|
1058
|
+
constexpr short TK = TKa;
|
|
1059
|
+
|
|
1060
|
+
// Do matmul here
|
|
1061
|
+
STEEL_PRAGMA_UNROLL
|
|
1062
|
+
for (short i = 0; i < TM; ++i) {
|
|
1063
|
+
STEEL_PRAGMA_UNROLL
|
|
1064
|
+
for (short j = 0; j < TN; ++j) {
|
|
1065
|
+
STEEL_PRAGMA_UNROLL
|
|
1066
|
+
for (short k = 0; k < TK; ++k) {
|
|
1067
|
+
const short ra = transpose_a ? k : i;
|
|
1068
|
+
const short ca = transpose_a ? i : k;
|
|
1069
|
+
const short rb = transpose_b ? j : k;
|
|
1070
|
+
const short cb = transpose_b ? k : j;
|
|
1071
|
+
|
|
1072
|
+
subtile_matmad_nax(
|
|
1073
|
+
C.subtile_at(i, j),
|
|
1074
|
+
A.subtile_at(ra, ca),
|
|
1075
|
+
metal::bool_constant<transpose_a>{},
|
|
1076
|
+
B.subtile_at(rb, cb),
|
|
1077
|
+
metal::bool_constant<transpose_b>{});
|
|
1078
|
+
}
|
|
1079
|
+
}
|
|
1080
|
+
}
|
|
1081
|
+
}
|
|
1082
|
+
|
|
1083
|
+
} // namespace steel
|
|
1084
|
+
} // namespace mlx
|