mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,750 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <metal_simdgroup>
6
+ #include <metal_simdgroup_matrix>
7
+ #include <metal_stdlib>
8
+
9
+ #include "mlx/backend/metal/kernels/steel/attn/transforms.h"
10
+ #include "mlx/backend/metal/kernels/steel/defines.h"
11
+ #include "mlx/backend/metal/kernels/steel/utils/integral_constant.h"
12
+
13
+ using namespace metal;
14
+
15
+ ///////////////////////////////////////////////////////////////////////////////
16
+ // MMA helper
17
+ ///////////////////////////////////////////////////////////////////////////////
18
+
19
+ namespace mlx {
20
+ namespace steel {
21
+
22
+ template <typename RInt, typename CInt>
23
+ struct Shape2D {
24
+ RInt r;
25
+ CInt c;
26
+
27
+ Shape2D(RInt r_, CInt c_) : r(r_), c(c_) {}
28
+ };
29
+
30
+ template <typename Shape, typename Layout>
31
+ struct Layout2D {
32
+ Shape shape;
33
+ Layout layout;
34
+ };
35
+
36
+ template <typename T, int kFragRows_, int kFragCols_>
37
+ struct BaseMMAFrag {
38
+ static_assert(
39
+ kFragRows_ == 8,
40
+ "Only 8 x 8 fragment matrices are currently supported");
41
+ static_assert(
42
+ kFragCols_ == 8,
43
+ "Only 8 x 8 fragment matrices are currently supported");
44
+ };
45
+
46
+ template <typename T>
47
+ struct BaseMMAFrag<T, 8, 8> {
48
+ STEEL_CONST int kFragRows = 8;
49
+ STEEL_CONST int kFragCols = 8;
50
+
51
+ STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32;
52
+
53
+ STEEL_CONST int kElemRows = 1;
54
+ STEEL_CONST int kElemCols = 2;
55
+
56
+ static_assert(
57
+ kElemRows * kElemCols == kElemsPerFrag,
58
+ "MMAFrag shape is not consistent with MMAFrag size");
59
+
60
+ typedef metal::simdgroup_matrix<T, kFragRows, kFragCols> mat_type;
61
+ typedef metal::vec<T, kElemsPerFrag> frag_type;
62
+ typedef metal::vec<T, kElemRows> row_frag_type;
63
+ typedef metal::vec<T, kElemCols> col_frag_type;
64
+
65
+ template <typename U>
66
+ using dtype_mat_t = typename metal::simdgroup_matrix<U, kFragRows, kFragCols>;
67
+
68
+ template <typename U>
69
+ using dtype_frag_t = typename metal::vec<U, kElemsPerFrag>;
70
+
71
+ METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id
72
+ [[thread_index_in_simdgroup]]) {
73
+ const short qid = simd_lane_id / 4;
74
+ const short fm = (qid & 4) + ((simd_lane_id / 2) % 4);
75
+ const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
76
+ return short2{fn, fm};
77
+ }
78
+
79
+ template <typename SrcPtrType, typename StrX, typename StrY>
80
+ METAL_FUNC static constexpr void
81
+ load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) {
82
+ STEEL_PRAGMA_UNROLL
83
+ for (short i = 0; i < kElemRows; i++) {
84
+ STEEL_PRAGMA_UNROLL
85
+ for (short j = 0; j < kElemCols; j++) {
86
+ dst[i * kElemCols + j] = static_cast<T>(src[i * str_x + j * str_y]);
87
+ }
88
+ }
89
+ }
90
+
91
+ template <
92
+ typename SrcPtrType,
93
+ typename StrX,
94
+ typename StrY,
95
+ typename LimX,
96
+ typename LimY,
97
+ typename OffX,
98
+ typename OffY>
99
+ METAL_FUNC static constexpr void load_safe(
100
+ thread frag_type& dst,
101
+ SrcPtrType src,
102
+ StrX str_x,
103
+ StrY str_y,
104
+ LimX lim_x,
105
+ LimY lim_y,
106
+ OffX off_x = Int<0>{},
107
+ OffY off_y = Int<0>{}) {
108
+ src += off_x * str_x + off_y * str_y;
109
+ STEEL_PRAGMA_UNROLL
110
+ for (short i = 0; i < kElemRows; i++) {
111
+ STEEL_PRAGMA_UNROLL
112
+ for (short j = 0; j < kElemCols; j++) {
113
+ if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
114
+ dst[i * kElemCols + j] = static_cast<T>(src[0]);
115
+ } else {
116
+ dst[i * kElemCols + j] = T(0);
117
+ }
118
+ src += str_y;
119
+ }
120
+ src -= kElemCols * str_y;
121
+ src += str_x;
122
+ }
123
+ }
124
+
125
+ template <typename DstPtrType, typename StrX, typename StrY>
126
+ METAL_FUNC static constexpr void
127
+ store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) {
128
+ using U = pointer_element_t<DstPtrType>;
129
+
130
+ STEEL_PRAGMA_UNROLL
131
+ for (short i = 0; i < kElemRows; i++) {
132
+ STEEL_PRAGMA_UNROLL
133
+ for (short j = 0; j < kElemCols; j++) {
134
+ dst[i * str_x + j * str_y] = static_cast<U>(src[i * kElemCols + j]);
135
+ }
136
+ }
137
+ }
138
+
139
+ template <
140
+ typename DstPtrType,
141
+ typename StrX,
142
+ typename StrY,
143
+ typename LimX,
144
+ typename LimY,
145
+ typename OffX,
146
+ typename OffY>
147
+ METAL_FUNC static constexpr void store_safe(
148
+ const thread frag_type& src,
149
+ DstPtrType dst,
150
+ StrX str_x,
151
+ StrY str_y,
152
+ LimX lim_x,
153
+ LimY lim_y,
154
+ OffX off_x = Int<0>{},
155
+ OffY off_y = Int<0>{}) {
156
+ using U = pointer_element_t<DstPtrType>;
157
+
158
+ STEEL_PRAGMA_UNROLL
159
+ for (short i = 0; i < kElemRows; i++) {
160
+ STEEL_PRAGMA_UNROLL
161
+ for (short j = 0; j < kElemCols; j++) {
162
+ if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
163
+ dst[(off_x + i) * str_x + (off_y + j) * str_y] =
164
+ static_cast<U>(src[i * kElemCols + j]);
165
+ }
166
+ }
167
+ }
168
+ }
169
+
170
+ template <typename Atype, typename Btype, typename Ctype>
171
+ METAL_FUNC static constexpr void mma(
172
+ thread frag_type& D,
173
+ thread dtype_frag_t<Atype>& A,
174
+ thread dtype_frag_t<Btype>& B,
175
+ thread dtype_frag_t<Ctype>& C) {
176
+ mat_type D_mat;
177
+ dtype_mat_t<Atype> A_mat;
178
+ dtype_mat_t<Btype> B_mat;
179
+ dtype_mat_t<Ctype> C_mat;
180
+
181
+ reinterpret_cast<thread dtype_frag_t<Atype>&>(A_mat.thread_elements()) = A;
182
+ reinterpret_cast<thread dtype_frag_t<Btype>&>(B_mat.thread_elements()) = B;
183
+ reinterpret_cast<thread dtype_frag_t<Ctype>&>(C_mat.thread_elements()) = C;
184
+
185
+ mma(D_mat, A_mat, B_mat, C_mat);
186
+
187
+ D = reinterpret_cast<thread frag_type&>(D_mat.thread_elements());
188
+ }
189
+
190
+ template <typename Atype, typename Btype, typename Ctype>
191
+ METAL_FUNC static constexpr void mma(
192
+ thread mat_type& D,
193
+ thread dtype_mat_t<Atype>& A,
194
+ thread dtype_mat_t<Btype>& B,
195
+ thread dtype_mat_t<Ctype>& C) {
196
+ simdgroup_multiply_accumulate(D, A, B, C);
197
+ }
198
+
199
+ template <typename Op>
200
+ METAL_FUNC static constexpr void row_reduce(
201
+ thread const frag_type& inp_vals,
202
+ thread T* reduced_vals) {
203
+ T thr_reduce = Op::apply(inp_vals.x, inp_vals.y);
204
+
205
+ T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1));
206
+ qgr_reduce = Op::apply(thr_reduce, qgr_reduce);
207
+
208
+ T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8));
209
+ sgr_reduce = Op::apply(qgr_reduce, sgr_reduce);
210
+
211
+ reduced_vals[0] = Op::apply(reduced_vals[0], sgr_reduce);
212
+ }
213
+
214
+ template <typename Op>
215
+ METAL_FUNC static constexpr void row_bin_op(
216
+ thread frag_type& inp_vals,
217
+ thread T* row_vals) {
218
+ STEEL_PRAGMA_UNROLL
219
+ for (short i = 0; i < kElemRows; i++) {
220
+ STEEL_PRAGMA_UNROLL
221
+ for (short j = 0; j < kElemCols; j++) {
222
+ inp_vals[i * kElemCols + j] =
223
+ Op::apply(inp_vals[i * kElemCols + j], row_vals[i]);
224
+ }
225
+ }
226
+ }
227
+ };
228
+
229
+ template <
230
+ typename T,
231
+ int kTileRows_,
232
+ int kTileCols_,
233
+ class MMAFrag_ = BaseMMAFrag<T, 8, 8>>
234
+ struct MMATile {
235
+ using MMAFrag_t = MMAFrag_;
236
+ using elem_type = T;
237
+ STEEL_CONST int kFragRows = MMAFrag_t::kFragRows;
238
+ STEEL_CONST int kFragCols = MMAFrag_t::kFragCols;
239
+ STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag;
240
+
241
+ STEEL_CONST int kTileRows = kTileRows_;
242
+ STEEL_CONST int kTileCols = kTileCols_;
243
+
244
+ STEEL_CONST int kRows = kTileRows * kFragRows;
245
+ STEEL_CONST int kCols = kTileCols * kFragCols;
246
+
247
+ STEEL_CONST int kNumFrags = kTileRows * kTileCols;
248
+ STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag;
249
+
250
+ STEEL_CONST int kRowsPerThread = kTileRows * MMAFrag_t::kElemRows;
251
+ STEEL_CONST int kColsPerThread = kTileCols * MMAFrag_t::kElemCols;
252
+
253
+ typedef typename MMAFrag_t::mat_type mat_type;
254
+ typedef typename MMAFrag_t::frag_type frag_type;
255
+
256
+ frag_type val_frags[kNumFrags]; // = {frag_type(0)};
257
+
258
+ METAL_FUNC MMATile() thread {}
259
+
260
+ METAL_FUNC constexpr void clear() {
261
+ STEEL_PRAGMA_UNROLL
262
+ for (short i = 0; i < kNumFrags; ++i) {
263
+ val_frags[i] = frag_type(0);
264
+ }
265
+ }
266
+
267
+ METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) {
268
+ return val_frags[i * kTileCols + j];
269
+ }
270
+
271
+ METAL_FUNC constexpr const thread frag_type& frag_at(
272
+ const short i,
273
+ const short j) const {
274
+ return val_frags[i * kTileCols + j];
275
+ }
276
+
277
+ METAL_FUNC mat_type mat_at(const short i, const short j) {
278
+ mat_type val_mat;
279
+ STEEL_PRAGMA_UNROLL
280
+ for (short ii = 0; ii < kElemsPerFrag; ++ii) {
281
+ val_mat.thread_elements()[ii] = frag_at(i, j)[ii];
282
+ }
283
+ return val_mat;
284
+ }
285
+
286
+ METAL_FUNC thread elem_type* elems() {
287
+ return reinterpret_cast<thread elem_type*>(val_frags);
288
+ }
289
+
290
+ METAL_FUNC const thread elem_type* elems() const {
291
+ return reinterpret_cast<const thread elem_type*>(val_frags);
292
+ }
293
+
294
+ template <typename Op>
295
+ METAL_FUNC void row_reduce(thread T vals[kRowsPerThread]) const {
296
+ STEEL_PRAGMA_UNROLL
297
+ for (short i = 0; i < kTileRows; ++i) {
298
+ STEEL_PRAGMA_UNROLL
299
+ for (short j = 0; j < kTileCols; ++j) {
300
+ MMAFrag_t::template row_reduce<Op>(
301
+ frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]);
302
+ }
303
+ }
304
+ }
305
+
306
+ template <typename Op>
307
+ METAL_FUNC void row_bin_op(thread T vals[kRowsPerThread]) {
308
+ STEEL_PRAGMA_UNROLL
309
+ for (short i = 0; i < kTileRows; ++i) {
310
+ STEEL_PRAGMA_UNROLL
311
+ for (short j = 0; j < kTileCols; ++j) {
312
+ MMAFrag_t::template row_bin_op<Op>(
313
+ frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]);
314
+ }
315
+ }
316
+ }
317
+
318
+ template <typename U, int w_x, int w_y, int str_x, int str_y>
319
+ METAL_FUNC void load(const threadgroup U* src) {
320
+ STEEL_PRAGMA_UNROLL
321
+ for (short i = 0; i < kTileRows; ++i) {
322
+ STEEL_PRAGMA_UNROLL
323
+ for (short j = 0; j < kTileCols; ++j) {
324
+ MMAFrag_t::load(
325
+ frag_at(i, j),
326
+ &(
327
+ src[(i * kFragRows) * w_x * str_x +
328
+ (j * kFragCols) * w_y * str_y]),
329
+ Int<str_x>{},
330
+ Int<str_y>{});
331
+ }
332
+ }
333
+ }
334
+
335
+ template <typename U, int w_x, int w_y, int str_x, int str_y>
336
+ METAL_FUNC void store(threadgroup U* dst) const {
337
+ STEEL_PRAGMA_UNROLL
338
+ for (short i = 0; i < kTileRows; ++i) {
339
+ STEEL_PRAGMA_UNROLL
340
+ for (short j = 0; j < kTileCols; ++j) {
341
+ MMAFrag_t::store(
342
+ frag_at(i, j),
343
+ &(
344
+ dst[(i * kFragRows) * w_x * str_x +
345
+ (j * kFragCols) * w_y * str_y]),
346
+ Int<str_x>{},
347
+ Int<str_y>{});
348
+ }
349
+ }
350
+ }
351
+
352
+ template <typename U, int w_x, int w_y>
353
+ METAL_FUNC void load(const device U* src, const int ld) {
354
+ STEEL_PRAGMA_UNROLL
355
+ for (short i = 0; i < kTileRows; ++i) {
356
+ STEEL_PRAGMA_UNROLL
357
+ for (short j = 0; j < kTileCols; ++j) {
358
+ MMAFrag_t::load(
359
+ frag_at(i, j),
360
+ &(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
361
+ ld,
362
+ Int<1>{});
363
+ }
364
+ }
365
+ }
366
+
367
+ template <typename U, int w_x, int w_y>
368
+ METAL_FUNC void store(device U* dst, const int ld) const {
369
+ STEEL_PRAGMA_UNROLL
370
+ for (short i = 0; i < kTileRows; ++i) {
371
+ STEEL_PRAGMA_UNROLL
372
+ for (short j = 0; j < kTileCols; ++j) {
373
+ MMAFrag_t::store(
374
+ frag_at(i, j),
375
+ &(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
376
+ ld,
377
+ Int<1>{});
378
+ }
379
+ }
380
+ }
381
+
382
+ template <typename U, int w_x, int w_y>
383
+ METAL_FUNC void
384
+ load_safe(const device U* src, const int ld, const short2 src_tile_dims) {
385
+ STEEL_PRAGMA_UNROLL
386
+ for (int i = 0; i < kTileRows; ++i) {
387
+ STEEL_PRAGMA_UNROLL
388
+ for (int j = 0; j < kTileCols; ++j) {
389
+ MMAFrag_t::load_safe(
390
+ frag_at(i, j),
391
+ src,
392
+ ld,
393
+ Int<1>{},
394
+ src_tile_dims.y,
395
+ src_tile_dims.x,
396
+ (i * kFragRows) * w_x,
397
+ (j * kFragCols) * w_y);
398
+ }
399
+ }
400
+ }
401
+
402
+ template <typename U, int w_x, int w_y>
403
+ METAL_FUNC void
404
+ store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const {
405
+ STEEL_PRAGMA_UNROLL
406
+ for (int i = 0; i < kTileRows; ++i) {
407
+ STEEL_PRAGMA_UNROLL
408
+ for (int j = 0; j < kTileCols; ++j) {
409
+ MMAFrag_t::store_safe(
410
+ frag_at(i, j),
411
+ dst,
412
+ ld,
413
+ Int<1>{},
414
+ dst_tile_dims.y,
415
+ dst_tile_dims.x,
416
+ (i * kFragRows) * w_x,
417
+ (j * kFragCols) * w_y);
418
+ }
419
+ }
420
+ }
421
+ };
422
+
423
+ template <
424
+ typename Dtype,
425
+ typename Atype,
426
+ typename Btype,
427
+ typename Ctype,
428
+ int M,
429
+ int N,
430
+ int K,
431
+ class MMAFragD,
432
+ class MMAFragA,
433
+ class MMAFragB,
434
+ class MMAFragC>
435
+ METAL_FUNC void tile_matmad(
436
+ thread MMATile<Dtype, M, N, MMAFragD>& D,
437
+ thread MMATile<Atype, M, K, MMAFragA>& A,
438
+ thread MMATile<Btype, K, N, MMAFragB>& B,
439
+ thread MMATile<Ctype, M, N, MMAFragC>& C) {
440
+ STEEL_PRAGMA_UNROLL
441
+ for (short m = 0; m < M; ++m) {
442
+ STEEL_PRAGMA_UNROLL
443
+ for (short n = 0; n < N; ++n) {
444
+ short m_serp = m; //(n % 2) ? (M - 1 - m) : m;
445
+ short n_serp = (m % 2) ? (N - 1 - n) : n;
446
+
447
+ STEEL_PRAGMA_UNROLL
448
+ for (short k = 0; k < K; ++k) {
449
+ MMAFragD::mma(
450
+ D.frag_at(m_serp, n_serp),
451
+ A.frag_at(m_serp, k),
452
+ B.frag_at(k, n_serp),
453
+ C.frag_at(m_serp, n_serp));
454
+ }
455
+ }
456
+ }
457
+ }
458
+
459
+ template <
460
+ typename T,
461
+ typename U,
462
+ int BM,
463
+ int BN,
464
+ int BK,
465
+ int WM,
466
+ int WN,
467
+ bool transpose_a,
468
+ bool transpose_b,
469
+ short lda_tgp,
470
+ short ldb_tgp,
471
+ typename AccumType = float,
472
+ typename Epilogue = TransformNone<U, AccumType>>
473
+ struct BlockMMA {
474
+ // MMAFrag size
475
+ STEEL_CONST short kFragSize = 8;
476
+ using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
477
+
478
+ // Warp tile simdgroup matrix strides along M
479
+ STEEL_CONST short TM_stride = kFragSize * WM;
480
+ // Warp tile simdgroup matrix strides along M
481
+ STEEL_CONST short TN_stride = kFragSize * WN;
482
+
483
+ // Warp tile size along M
484
+ STEEL_CONST short TM = BM / TM_stride;
485
+ // Warp tile size along N
486
+ STEEL_CONST short TN = BN / TN_stride;
487
+
488
+ // Threadgroup A strides
489
+ STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M
490
+ STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K
491
+
492
+ // Threadgroup B strides
493
+ STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K
494
+ STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N
495
+
496
+ // Threadgroup strides along K
497
+ STEEL_CONST short tile_stride_a = kFragSize * A_str_k;
498
+ STEEL_CONST short tile_stride_b = kFragSize * B_str_k;
499
+
500
+ // Simdgroup matrices
501
+ MMATile<AccumType, TM, 1, MMAFrag_acc_t> Atile;
502
+ MMATile<AccumType, 1, TN, MMAFrag_acc_t> Btile;
503
+ MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile;
504
+
505
+ // Offsets within threadgroup
506
+ short sm;
507
+ short sn;
508
+
509
+ short As_offset;
510
+ short Bs_offset;
511
+
512
+ /* Constructor */
513
+ METAL_FUNC BlockMMA(
514
+ ushort simd_group_id [[simdgroup_index_in_threadgroup]],
515
+ ushort simd_lane_id [[thread_index_in_simdgroup]]) {
516
+ // Determine thread position in simdgroup matrix
517
+ short tm = kFragSize * (simd_group_id / WN);
518
+ short tn = kFragSize * (simd_group_id % WN);
519
+
520
+ short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
521
+ sm = simd_coord.y;
522
+ sn = simd_coord.x;
523
+
524
+ // Determine thread and simdgroup offset
525
+ As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K
526
+ Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N
527
+
528
+ sm += tm;
529
+ sn += tn;
530
+ }
531
+
532
+ /* (BM, BK) X (BK, BN) multiply accumulate function */
533
+ METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
534
+ // Adjust for simdgroup and thread location
535
+ As += As_offset;
536
+ Bs += Bs_offset;
537
+
538
+ // Iterate over BK in blocks of kFragSize
539
+ STEEL_PRAGMA_UNROLL
540
+ for (short kk = 0; kk < BK; kk += kFragSize) {
541
+ simdgroup_barrier(mem_flags::mem_none);
542
+
543
+ Atile.template load<T, WM, 1, A_str_m, A_str_k>(As);
544
+
545
+ simdgroup_barrier(mem_flags::mem_none);
546
+
547
+ Btile.template load<T, 1, WN, B_str_k, B_str_n>(Bs);
548
+
549
+ simdgroup_barrier(mem_flags::mem_none);
550
+
551
+ tile_matmad(Ctile, Atile, Btile, Ctile);
552
+
553
+ // Progress to next simdgroup tile
554
+ As += tile_stride_a;
555
+ Bs += tile_stride_b;
556
+ }
557
+ }
558
+
559
+ /* Store results from simdgroup_matrix results into device memory */
560
+ METAL_FUNC void store_result(device U* D, const int ldd) {
561
+ // Apply epilogue
562
+ STEEL_PRAGMA_UNROLL
563
+ for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
564
+ Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
565
+ }
566
+
567
+ // Adjust for simdgroup and thread location
568
+ D += sm * ldd + sn;
569
+
570
+ Ctile.template store<U, WM, WN>(D, ldd);
571
+ }
572
+
573
+ METAL_FUNC void
574
+ store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
575
+ // Apply epilogue
576
+ STEEL_PRAGMA_UNROLL
577
+ for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
578
+ Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
579
+ }
580
+
581
+ // Adjust for simdgroup and thread location
582
+ D += sm * ldd + sn;
583
+ dst_tile_dims -= short2(sn, sm);
584
+
585
+ if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
586
+ return;
587
+
588
+ Ctile.template store_safe<U, WM, WN>(D, ldd, dst_tile_dims);
589
+ }
590
+
591
+ /* Apply epilogue */
592
+ template <typename UnaryEpilogue>
593
+ METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) {
594
+ // Loop over all simdgroup tiles
595
+ STEEL_PRAGMA_UNROLL
596
+ for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
597
+ Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]);
598
+ }
599
+ }
600
+
601
+ /* Apply epilogue */
602
+ template <typename BinaryEpilogue>
603
+ METAL_FUNC void apply_epilogue(
604
+ const device U* C,
605
+ const int ldc,
606
+ const int fdc,
607
+ thread const BinaryEpilogue& epilogue_op) {
608
+ // Adjust for simdgroup and thread location
609
+ C += (sm)*ldc + (sn)*fdc;
610
+
611
+ // Loop over all simdgroup tiles
612
+ STEEL_PRAGMA_UNROLL
613
+ for (short i = 0; i < TM; i++) {
614
+ STEEL_PRAGMA_UNROLL
615
+ for (short j = 0; j < TN; j++) {
616
+ // Get accumulated result and associated offset in C
617
+ thread auto& accum = Ctile.frag_at(i, j);
618
+ int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
619
+
620
+ // Apply epilogue
621
+ STEEL_PRAGMA_UNROLL
622
+ for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) {
623
+ accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
624
+ }
625
+ }
626
+ }
627
+ }
628
+
629
+ /* Apply epilogue */
630
+ template <typename BinaryEpilogue>
631
+ METAL_FUNC void apply_epilogue_safe(
632
+ const device U* C,
633
+ const int ldc,
634
+ const int fdc,
635
+ short2 dst_tile_dims,
636
+ thread const BinaryEpilogue& epilogue_op) {
637
+ // Adjust for simdgroup and thread location
638
+ C += (sm)*ldc + (sn)*fdc;
639
+ dst_tile_dims -= short2(sn, sm);
640
+
641
+ if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
642
+ return;
643
+
644
+ // Loop over all simdgroup tiles
645
+ STEEL_PRAGMA_UNROLL
646
+ for (short i = 0; i < TM; i++) {
647
+ STEEL_PRAGMA_UNROLL
648
+ for (short j = 0; j < TN; j++) {
649
+ // Get accumulated result and associated offset in C
650
+ thread auto& accum = Ctile.frag_at(i, j);
651
+ int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
652
+
653
+ constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
654
+
655
+ // Read C
656
+ U c_elems[kelems] = {0};
657
+
658
+ STEEL_PRAGMA_UNROLL
659
+ for (short k = 0; k < kelems; k++) {
660
+ if ((j * TN_stride + k) < dst_tile_dims.x) {
661
+ c_elems[k] = C[offset_c + k * fdc];
662
+ }
663
+ }
664
+
665
+ // Apply epilogue
666
+ STEEL_PRAGMA_UNROLL
667
+ for (short k = 0; k < kelems; k++) {
668
+ accum[k] = epilogue_op.apply(accum[k], c_elems[k]);
669
+ }
670
+ }
671
+ }
672
+ }
673
+
674
+ /* Store results from simdgroup_matrix results into device memory */
675
+ METAL_FUNC void store_result(
676
+ device U* D,
677
+ const int ldd,
678
+ const device U* C,
679
+ const int ldc,
680
+ const int fdc,
681
+ thread const Epilogue& epilogue_op) const {
682
+ // Adjust for simdgroup and thread location
683
+ C += (sm)*ldc + (sn)*fdc;
684
+ D += (sm)*ldd + sn;
685
+
686
+ constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
687
+
688
+ // Loop over all simdgroup tiles
689
+ STEEL_PRAGMA_UNROLL
690
+ for (short i = 0; i < TM; i++) {
691
+ STEEL_PRAGMA_UNROLL
692
+ for (short j = 0; j < TN; j++) {
693
+ // Get accumulated result and associated offset in C
694
+ thread const auto& accum = Ctile.frag_at(i, j);
695
+ int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
696
+ int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
697
+
698
+ // Apply epilogue
699
+ STEEL_PRAGMA_UNROLL
700
+ for (short k = 0; k < kelems; k++) {
701
+ D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
702
+ }
703
+ }
704
+ }
705
+ }
706
+
707
+ METAL_FUNC void store_result_safe(
708
+ device U* D,
709
+ const int ldd,
710
+ const device U* C,
711
+ const int ldc,
712
+ const int fdc,
713
+ short2 dst_tile_dims,
714
+ thread const Epilogue& epilogue_op) const {
715
+ // Adjust for simdgroup and thread location
716
+ C += (sm)*ldc + (sn)*fdc;
717
+ D += (sm)*ldd + sn;
718
+ dst_tile_dims -= short2(sn, sm);
719
+
720
+ if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
721
+ return;
722
+
723
+ constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
724
+
725
+ STEEL_PRAGMA_UNROLL
726
+ for (int i = 0; i < TM; i++) {
727
+ if (i * TM_stride < dst_tile_dims.y) {
728
+ STEEL_PRAGMA_UNROLL
729
+ for (int j = 0; j < TN; j++) {
730
+ // Get accumulated result and associated offset in C
731
+ thread const auto& accum = Ctile.frag_at(i, j);
732
+ int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
733
+ int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
734
+
735
+ // Apply epilogue
736
+ STEEL_PRAGMA_UNROLL
737
+ for (short k = 0; k < kelems; k++) {
738
+ if ((j * TN_stride + k) < dst_tile_dims.x) {
739
+ D[offset_d + k] =
740
+ epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
741
+ }
742
+ }
743
+ }
744
+ }
745
+ }
746
+ }
747
+ };
748
+
749
+ } // namespace steel
750
+ } // namespace mlx