mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,1084 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <metal_simdgroup>
6
+ #include <metal_simdgroup_matrix>
7
+ #include <metal_stdlib>
8
+
9
+ #include "mlx/backend/metal/kernels/steel/defines.h"
10
+ #include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
11
+ #include "mlx/backend/metal/kernels/steel/utils/integral_constant.h"
12
+
13
+ #include <MetalPerformancePrimitives/MetalPerformancePrimitives.h>
14
+
15
+ using namespace metal;
16
+
17
+ ///////////////////////////////////////////////////////////////////////////////
18
+ // MMA helper
19
+ ///////////////////////////////////////////////////////////////////////////////
20
+
21
+ namespace mlx {
22
+ namespace steel {
23
+
24
+ ///////////////////////////////////////////////////////////////////////////////
25
+ // NAX Steel with new tiles
26
+ ///////////////////////////////////////////////////////////////////////////////
27
+
28
+ struct BaseNAXFrag {
29
+ STEEL_CONST short kFragRows = 16;
30
+ STEEL_CONST short kFragCols = 16;
31
+
32
+ STEEL_CONST short kElemsPerFrag = (kFragRows * kFragCols) / 32;
33
+
34
+ STEEL_CONST short kElemRows = 2;
35
+ STEEL_CONST short kElemCols = 4;
36
+
37
+ STEEL_CONST short kElemRowsJump = 8;
38
+
39
+ static_assert(
40
+ kElemRows * kElemCols == kElemsPerFrag,
41
+ "MMAFrag shape is not consistent with MMAFrag size");
42
+
43
+ template <typename U>
44
+ using dtype_frag_t = typename metal::vec<U, kElemsPerFrag>;
45
+
46
+ METAL_FUNC static short2 get_coord() {
47
+ const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort());
48
+ const short qid = simd_lane_id >> 2;
49
+ const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3));
50
+ const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4;
51
+ return short2{fn, fm};
52
+ }
53
+
54
+ METAL_FUNC static short2 get_coord(short idx) {
55
+ const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort());
56
+ const short qid = simd_lane_id >> 2;
57
+ const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3)) + (idx >> 2) * 8;
58
+ const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4 + idx % 4;
59
+ return short2{fn, fm};
60
+ }
61
+
62
+ template <
63
+ typename T,
64
+ typename SrcPtrType,
65
+ typename StrX,
66
+ typename StrY,
67
+ typename OffX = Int<0>,
68
+ typename OffY = Int<0>>
69
+ METAL_FUNC static constexpr void load(
70
+ thread dtype_frag_t<T>& dst,
71
+ SrcPtrType src,
72
+ StrX str_x,
73
+ StrY str_y,
74
+ OffX off_x = {},
75
+ OffY off_y = {}) {
76
+ const short2 sc = get_coord();
77
+ STEEL_PRAGMA_UNROLL
78
+ for (short i = 0; i < kElemRows; i++) {
79
+ const auto r = off_x + i * kElemRowsJump + sc.y;
80
+ const auto c = off_y + sc.x;
81
+
82
+ if constexpr (metal::is_same_v<StrY, Int<1>>) {
83
+ STEEL_PRAGMA_UNROLL
84
+ for (short j = 0; j < kElemCols; j++) {
85
+ dst[i * kElemCols + j] = static_cast<T>(src[r * str_x + c + j]);
86
+ }
87
+ } else {
88
+ STEEL_PRAGMA_UNROLL
89
+ for (short j = 0; j < kElemCols; j++) {
90
+ dst[i * kElemCols + j] =
91
+ static_cast<T>(src[r * str_x + (c + j) * str_y]);
92
+ }
93
+ }
94
+ }
95
+ }
96
+
97
+ template <
98
+ typename T,
99
+ typename SrcPtrType,
100
+ typename StrX,
101
+ typename StrY,
102
+ typename LimX,
103
+ typename OffX = Int<0>,
104
+ typename OffY = Int<0>>
105
+ METAL_FUNC static constexpr void load_rows(
106
+ thread dtype_frag_t<T>& dst,
107
+ SrcPtrType src,
108
+ StrX str_x,
109
+ StrY str_y,
110
+ LimX lim_x,
111
+ OffX off_x = {},
112
+ OffY off_y = {}) {
113
+ const short2 sc = get_coord();
114
+ STEEL_PRAGMA_UNROLL
115
+ for (short i = 0; i < kElemRows; i++) {
116
+ const auto r = off_x + i * kElemRowsJump + sc.y;
117
+ const auto c = off_y + sc.x;
118
+
119
+ if (r < lim_x) {
120
+ if constexpr (metal::is_same_v<StrY, Int<1>>) {
121
+ STEEL_PRAGMA_UNROLL
122
+ for (short j = 0; j < kElemCols; j++) {
123
+ dst[i * kElemCols + j] = static_cast<T>(src[r * str_x + (c + j)]);
124
+ }
125
+ } else {
126
+ STEEL_PRAGMA_UNROLL
127
+ for (short j = 0; j < kElemCols; j++) {
128
+ dst[i * kElemCols + j] =
129
+ static_cast<T>(src[r * str_x + (c + j) * str_y]);
130
+ }
131
+ }
132
+
133
+ } else {
134
+ dst = dtype_frag_t<T>(0);
135
+ }
136
+ }
137
+ }
138
+
139
+ template <
140
+ typename T,
141
+ typename SrcPtrType,
142
+ typename StrX,
143
+ typename StrY,
144
+ typename LimX,
145
+ typename LimY,
146
+ typename OffX = Int<0>,
147
+ typename OffY = Int<0>>
148
+ METAL_FUNC static constexpr void load_safe(
149
+ thread dtype_frag_t<T>& dst,
150
+ SrcPtrType src,
151
+ StrX str_x,
152
+ StrY str_y,
153
+ LimX lim_x,
154
+ LimY lim_y,
155
+ OffX off_x = {},
156
+ OffY off_y = {}) {
157
+ const short2 sc = get_coord();
158
+ STEEL_PRAGMA_UNROLL
159
+ for (short i = 0; i < kElemRows; i++) {
160
+ const auto r = off_x + i * kElemRowsJump + sc.y;
161
+ const auto c = off_y + sc.x;
162
+ STEEL_PRAGMA_UNROLL
163
+ for (short j = 0; j < kElemCols; j++) {
164
+ if (r < lim_x && (c + j) < lim_y) {
165
+ dst[i * kElemCols + j] =
166
+ static_cast<T>(src[r * str_x + (c + j) * str_y]);
167
+ } else {
168
+ dst[i * kElemCols + j] = T(0);
169
+ }
170
+ }
171
+ }
172
+ }
173
+
174
+ template <
175
+ typename T,
176
+ typename DstPtrType,
177
+ typename StrX,
178
+ typename StrY,
179
+ typename OffX = Int<0>,
180
+ typename OffY = Int<0>>
181
+ METAL_FUNC static constexpr void store(
182
+ const thread dtype_frag_t<T>& src,
183
+ DstPtrType dst,
184
+ StrX str_x,
185
+ StrY str_y,
186
+ OffX off_x = {},
187
+ OffY off_y = {}) {
188
+ using U = pointer_element_t<DstPtrType>;
189
+
190
+ const short2 sc = get_coord();
191
+ STEEL_PRAGMA_UNROLL
192
+ for (short i = 0; i < kElemRows; i++) {
193
+ const auto r = off_x + i * kElemRowsJump + sc.y;
194
+ const auto c = off_y + sc.x;
195
+
196
+ if constexpr (metal::is_same_v<StrY, Int<1>>) {
197
+ STEEL_PRAGMA_UNROLL
198
+ for (short j = 0; j < kElemCols; j++) {
199
+ dst[r * str_x + c + j] = static_cast<U>(src[i * kElemCols + j]);
200
+ }
201
+ } else {
202
+ STEEL_PRAGMA_UNROLL
203
+ for (short j = 0; j < kElemCols; j++) {
204
+ dst[r * str_x + (c + j) * str_y] =
205
+ static_cast<U>(src[i * kElemCols + j]);
206
+ }
207
+ }
208
+ }
209
+ }
210
+
211
+ template <
212
+ typename T,
213
+ typename DstPtrType,
214
+ typename StrX,
215
+ typename StrY,
216
+ typename LimX,
217
+ typename OffX = Int<0>,
218
+ typename OffY = Int<0>>
219
+ METAL_FUNC static constexpr void store_rows(
220
+ const thread dtype_frag_t<T>& src,
221
+ DstPtrType dst,
222
+ StrX str_x,
223
+ StrY str_y,
224
+ LimX lim_x,
225
+ OffX off_x = {},
226
+ OffY off_y = {}) {
227
+ using U = pointer_element_t<DstPtrType>;
228
+
229
+ const short2 sc = get_coord();
230
+ STEEL_PRAGMA_UNROLL
231
+ for (short i = 0; i < kElemRows; i++) {
232
+ const auto r = off_x + i * kElemRowsJump + sc.y;
233
+ const auto c = off_y + sc.x;
234
+
235
+ if (r < lim_x) {
236
+ if constexpr (metal::is_same_v<StrY, Int<1>>) {
237
+ STEEL_PRAGMA_UNROLL
238
+ for (short j = 0; j < kElemCols; j++) {
239
+ dst[r * str_x + c + j] = static_cast<U>(src[i * kElemCols + j]);
240
+ }
241
+ } else {
242
+ STEEL_PRAGMA_UNROLL
243
+ for (short j = 0; j < kElemCols; j++) {
244
+ dst[r * str_x + (c + j) * str_y] =
245
+ static_cast<U>(src[i * kElemCols + j]);
246
+ }
247
+ }
248
+ }
249
+ }
250
+ }
251
+
252
+ template <
253
+ typename T,
254
+ typename DstPtrType,
255
+ typename StrX,
256
+ typename StrY,
257
+ typename LimX,
258
+ typename LimY,
259
+ typename OffX = Int<0>,
260
+ typename OffY = Int<0>>
261
+ METAL_FUNC static constexpr void store_safe(
262
+ const thread dtype_frag_t<T>& src,
263
+ DstPtrType dst,
264
+ StrX str_x,
265
+ StrY str_y,
266
+ LimX lim_x,
267
+ LimY lim_y,
268
+ OffX off_x = {},
269
+ OffY off_y = {}) {
270
+ using U = pointer_element_t<DstPtrType>;
271
+
272
+ const short2 sc = get_coord();
273
+ STEEL_PRAGMA_UNROLL
274
+ for (short i = 0; i < kElemRows; i++) {
275
+ const auto r = off_x + i * kElemRowsJump + sc.y;
276
+ const auto c = off_y + sc.x;
277
+
278
+ STEEL_PRAGMA_UNROLL
279
+ for (short j = 0; j < kElemCols; j++) {
280
+ if (r < lim_x && (c + j) < lim_y) {
281
+ dst[r * str_x + (c + j) * str_y] =
282
+ static_cast<U>(src[i * kElemCols + j]);
283
+ }
284
+ }
285
+ }
286
+ }
287
+
288
+ template <
289
+ typename T,
290
+ typename DstPtrType,
291
+ typename StrX,
292
+ typename StrY,
293
+ typename StartX,
294
+ typename StopX,
295
+ typename StartY,
296
+ typename StopY,
297
+ typename OffX = Int<0>,
298
+ typename OffY = Int<0>>
299
+ METAL_FUNC static constexpr void store_slice(
300
+ const thread dtype_frag_t<T>& src,
301
+ DstPtrType dst,
302
+ StrX str_x,
303
+ StrY str_y,
304
+ StartX start_x,
305
+ StopX stop_x,
306
+ StartY start_y,
307
+ StopY stop_y,
308
+ OffX off_x = Int<0>{},
309
+ OffY off_y = Int<0>{}) {
310
+ using U = pointer_element_t<DstPtrType>;
311
+
312
+ const short2 sc = get_coord();
313
+
314
+ const_for_loop<0, kElemRows, 1>([&](auto idx_row) {
315
+ const auto r = off_x + idx_row * Int<kElemRowsJump>{};
316
+ if (r >= stop_x - sc.y || r < start_x - sc.y) {
317
+ return;
318
+ }
319
+
320
+ const_for_loop<0, kElemCols, 1>([&](auto idx_col) {
321
+ const auto c = off_y + idx_col;
322
+ if (c >= stop_y - sc.x || c < start_y - sc.x) {
323
+ return;
324
+ }
325
+
326
+ const auto src_idx = idx_row * Int<kElemCols>{} + idx_col;
327
+ dst[(r + sc.y) * str_x + (c + sc.x) * str_y] =
328
+ static_cast<U>(src[src_idx]);
329
+ });
330
+ });
331
+ }
332
+
333
+ template <typename Op, typename T>
334
+ METAL_FUNC static constexpr void row_reduce(
335
+ thread const dtype_frag_t<T>& inp_vals,
336
+ thread T* reduced_vals) {
337
+ STEEL_PRAGMA_UNROLL
338
+ for (short i = 0; i < kElemRows; i++) {
339
+ T thr_reduce = Op::apply(
340
+ Op::apply(inp_vals[i * kElemCols + 0], inp_vals[i * kElemCols + 1]),
341
+ Op::apply(inp_vals[i * kElemCols + 2], inp_vals[i * kElemCols + 3]));
342
+
343
+ T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1));
344
+ qgr_reduce = Op::apply(thr_reduce, qgr_reduce);
345
+
346
+ T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8));
347
+ sgr_reduce = Op::apply(qgr_reduce, sgr_reduce);
348
+
349
+ reduced_vals[i] = Op::apply(reduced_vals[i], sgr_reduce);
350
+ }
351
+ }
352
+
353
+ template <typename Op, typename T>
354
+ METAL_FUNC static constexpr void row_bin_op(
355
+ thread dtype_frag_t<T>& inp_vals,
356
+ thread T* row_vals) {
357
+ STEEL_PRAGMA_UNROLL
358
+ for (short i = 0; i < kElemRows; i++) {
359
+ STEEL_PRAGMA_UNROLL
360
+ for (short j = 0; j < kElemCols; j++) {
361
+ inp_vals[i * kElemCols + j] =
362
+ Op::apply(inp_vals[i * kElemCols + j], row_vals[i]);
363
+ }
364
+ }
365
+ }
366
+ };
367
+
368
+ template <
369
+ typename T,
370
+ short kRows_,
371
+ short kCols_,
372
+ typename NAXFrag_t = BaseNAXFrag>
373
+ struct NAXSubTile {
374
+ STEEL_CONST short kRows = kRows_;
375
+ STEEL_CONST short kCols = kCols_;
376
+
377
+ STEEL_CONST short kFragRows = NAXFrag_t::kFragRows;
378
+ STEEL_CONST short kFragCols = NAXFrag_t::kFragCols;
379
+ STEEL_CONST short kElemsPerFrag = NAXFrag_t::kElemsPerFrag;
380
+
381
+ STEEL_CONST short kSubTileRows = kRows / kFragRows;
382
+ STEEL_CONST short kSubTileCols = kCols / kFragCols;
383
+
384
+ STEEL_CONST short kNumFrags = kSubTileRows * kSubTileCols;
385
+ STEEL_CONST short kElemsPerSubTile = kNumFrags * kElemsPerFrag;
386
+
387
+ STEEL_CONST int kRowsPerThread = kSubTileRows * NAXFrag_t::kElemRows;
388
+ STEEL_CONST int kColsPerThread = kSubTileCols * NAXFrag_t::kElemCols;
389
+
390
+ STEEL_CONST short kFragThrRows = NAXFrag_t::kElemRows;
391
+ STEEL_CONST short kFragThrCols = NAXFrag_t::kElemCols;
392
+ STEEL_CONST short kFragRowsJump = NAXFrag_t::kElemRowsJump;
393
+
394
+ using frag_type = typename NAXFrag_t::template dtype_frag_t<T>;
395
+
396
+ frag_type val_frags[kNumFrags];
397
+
398
+ METAL_FUNC constexpr void clear() {
399
+ STEEL_PRAGMA_UNROLL
400
+ for (short i = 0; i < kNumFrags; ++i) {
401
+ val_frags[i] = frag_type(0);
402
+ }
403
+ }
404
+
405
+ METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) {
406
+ return val_frags[i * kSubTileCols + j];
407
+ }
408
+
409
+ METAL_FUNC constexpr const thread frag_type& frag_at(
410
+ const short i,
411
+ const short j) const {
412
+ return val_frags[i * kSubTileCols + j];
413
+ }
414
+
415
+ template <int i, int j>
416
+ METAL_FUNC constexpr thread frag_type& frag_at() {
417
+ return val_frags[i * kSubTileCols + j];
418
+ }
419
+
420
+ template <int i, int j>
421
+ METAL_FUNC constexpr const thread frag_type& frag_at() const {
422
+ return val_frags[i * kSubTileCols + j];
423
+ }
424
+
425
+ METAL_FUNC thread T* elems() {
426
+ return reinterpret_cast<thread T*>(val_frags);
427
+ }
428
+
429
+ METAL_FUNC const thread T* elems() const {
430
+ return reinterpret_cast<const thread T*>(val_frags);
431
+ }
432
+
433
+ template <typename Op>
434
+ METAL_FUNC void row_reduce(thread metal::vec<T, kRowsPerThread>& vals) const {
435
+ STEEL_PRAGMA_UNROLL
436
+ for (short i = 0; i < kSubTileRows; ++i) {
437
+ STEEL_PRAGMA_UNROLL
438
+ for (short j = 0; j < kSubTileCols; ++j) {
439
+ NAXFrag_t::template row_reduce<Op>(
440
+ frag_at(i, j), &vals[i * kFragThrRows]);
441
+ }
442
+ }
443
+ }
444
+
445
+ template <typename Op>
446
+ METAL_FUNC void row_bin_op(thread metal::vec<T, kRowsPerThread>& vals) {
447
+ STEEL_PRAGMA_UNROLL
448
+ for (short i = 0; i < kSubTileRows; ++i) {
449
+ STEEL_PRAGMA_UNROLL
450
+ for (short j = 0; j < kSubTileCols; ++j) {
451
+ NAXFrag_t::template row_bin_op<Op>(
452
+ frag_at(i, j), &vals[i * kFragThrRows]);
453
+ }
454
+ }
455
+ }
456
+
457
+ template <
458
+ typename SrcPtrType,
459
+ typename StrX,
460
+ typename StrY,
461
+ typename OffX = Int<0>,
462
+ typename OffY = Int<0>>
463
+ METAL_FUNC constexpr void load(
464
+ SrcPtrType src,
465
+ StrX str_x,
466
+ StrY str_y,
467
+ OffX off_x = {},
468
+ OffY off_y = {}) {
469
+ STEEL_PRAGMA_UNROLL
470
+ for (short i = 0; i < kSubTileRows; ++i) {
471
+ STEEL_PRAGMA_UNROLL
472
+ for (short j = 0; j < kSubTileCols; ++j) {
473
+ NAXFrag_t::load(
474
+ frag_at(i, j),
475
+ src,
476
+ str_x,
477
+ str_y,
478
+ off_x + i * kFragRows,
479
+ off_y + j * kFragCols);
480
+ }
481
+ }
482
+ }
483
+
484
+ template <
485
+ typename DstPtrType,
486
+ typename StrX,
487
+ typename StrY,
488
+ typename OffX = Int<0>,
489
+ typename OffY = Int<0>>
490
+ METAL_FUNC constexpr void store(
491
+ DstPtrType dst,
492
+ StrX str_x,
493
+ StrY str_y,
494
+ OffX off_x = {},
495
+ OffY off_y = {}) const {
496
+ STEEL_PRAGMA_UNROLL
497
+ for (short i = 0; i < kSubTileRows; ++i) {
498
+ STEEL_PRAGMA_UNROLL
499
+ for (short j = 0; j < kSubTileCols; ++j) {
500
+ NAXFrag_t::store(
501
+ frag_at(i, j),
502
+ dst,
503
+ str_x,
504
+ str_y,
505
+ off_x + i * kFragRows,
506
+ off_y + j * kFragCols);
507
+ }
508
+ }
509
+ }
510
+
511
+ template <
512
+ typename SrcPtrType,
513
+ typename StrX,
514
+ typename StrY,
515
+ typename LimX,
516
+ typename OffX = Int<0>,
517
+ typename OffY = Int<0>>
518
+ METAL_FUNC constexpr void load_rows(
519
+ SrcPtrType src,
520
+ StrX str_x,
521
+ StrY str_y,
522
+ LimX lim_x,
523
+ OffX off_x = {},
524
+ OffY off_y = {}) {
525
+ STEEL_PRAGMA_UNROLL
526
+ for (int i = 0; i < kSubTileRows; ++i) {
527
+ STEEL_PRAGMA_UNROLL
528
+ for (int j = 0; j < kSubTileCols; ++j) {
529
+ NAXFrag_t::load_rows(
530
+ frag_at(i, j),
531
+ src,
532
+ str_x,
533
+ str_y,
534
+ lim_x,
535
+ off_x + (i * kFragRows),
536
+ off_y + (j * kFragCols));
537
+ }
538
+ }
539
+ }
540
+
541
+ template <
542
+ typename SrcPtrType,
543
+ typename StrX,
544
+ typename StrY,
545
+ typename LimX,
546
+ typename LimY,
547
+ typename OffX = Int<0>,
548
+ typename OffY = Int<0>>
549
+ METAL_FUNC constexpr void load_safe(
550
+ SrcPtrType src,
551
+ StrX str_x,
552
+ StrY str_y,
553
+ LimX lim_x,
554
+ LimY lim_y,
555
+ OffX off_x = {},
556
+ OffY off_y = {}) {
557
+ STEEL_PRAGMA_UNROLL
558
+ for (int i = 0; i < kSubTileRows; ++i) {
559
+ STEEL_PRAGMA_UNROLL
560
+ for (int j = 0; j < kSubTileCols; ++j) {
561
+ NAXFrag_t::load_safe(
562
+ frag_at(i, j),
563
+ src,
564
+ str_x,
565
+ str_y,
566
+ lim_x,
567
+ lim_y,
568
+ off_x + (i * kFragRows),
569
+ off_y + (j * kFragCols));
570
+ }
571
+ }
572
+ }
573
+
574
+ template <
575
+ typename DstPtrType,
576
+ typename StrX,
577
+ typename StrY,
578
+ typename LimX,
579
+ typename LimY,
580
+ typename OffX = Int<0>,
581
+ typename OffY = Int<0>>
582
+ METAL_FUNC constexpr void store_safe(
583
+ DstPtrType dst,
584
+ StrX str_x,
585
+ StrY str_y,
586
+ LimX lim_x,
587
+ LimY lim_y,
588
+ OffX off_x = {},
589
+ OffY off_y = {}) const {
590
+ STEEL_PRAGMA_UNROLL
591
+ for (int i = 0; i < kSubTileRows; ++i) {
592
+ STEEL_PRAGMA_UNROLL
593
+ for (int j = 0; j < kSubTileCols; ++j) {
594
+ NAXFrag_t::store_safe(
595
+ frag_at(i, j),
596
+ dst,
597
+ str_x,
598
+ str_y,
599
+ lim_x,
600
+ lim_y,
601
+ off_x + (i * kFragRows),
602
+ off_y + (j * kFragCols));
603
+ }
604
+ }
605
+ }
606
+
607
+ template <
608
+ typename DstPtrType,
609
+ typename StrX,
610
+ typename StrY,
611
+ typename LimX,
612
+ typename OffX = Int<0>,
613
+ typename OffY = Int<0>>
614
+ METAL_FUNC constexpr void store_rows(
615
+ DstPtrType dst,
616
+ StrX str_x,
617
+ StrY str_y,
618
+ LimX lim_x,
619
+ OffX off_x = {},
620
+ OffY off_y = {}) const {
621
+ STEEL_PRAGMA_UNROLL
622
+ for (int i = 0; i < kSubTileRows; ++i) {
623
+ STEEL_PRAGMA_UNROLL
624
+ for (int j = 0; j < kSubTileCols; ++j) {
625
+ NAXFrag_t::store_safe(
626
+ frag_at(i, j),
627
+ dst,
628
+ str_x,
629
+ str_y,
630
+ lim_x,
631
+ off_x + (i * kFragRows),
632
+ off_y + (j * kFragCols));
633
+ }
634
+ }
635
+ }
636
+
637
+ template <
638
+ typename DstPtrType,
639
+ typename StrX,
640
+ typename StrY,
641
+ typename StartX,
642
+ typename StopX,
643
+ typename StartY,
644
+ typename StopY,
645
+ typename OffX = Int<0>,
646
+ typename OffY = Int<0>>
647
+ METAL_FUNC constexpr void store_slice(
648
+ DstPtrType dst,
649
+ StrX str_x,
650
+ StrY str_y,
651
+ StartX start_x,
652
+ StopX stop_x,
653
+ StartY start_y,
654
+ StopY stop_y,
655
+ OffX off_x = Int<0>{},
656
+ OffY off_y = Int<0>{}) const {
657
+ const_for_loop<0, kSubTileRows, 1>([&](auto idx_row) {
658
+ const_for_loop<0, kSubTileCols, 1>([&](auto idx_col) {
659
+ NAXFrag_t::store_slice(
660
+ frag_at<idx_row.value, idx_col.value>(),
661
+ dst,
662
+ str_x,
663
+ str_y,
664
+ start_x,
665
+ stop_x,
666
+ start_y,
667
+ stop_y,
668
+ off_x + idx_row * Int<kFragRows>{},
669
+ off_y + idx_col * Int<kFragCols>{});
670
+ });
671
+ });
672
+ }
673
+ };
674
+
675
+ template <
676
+ short RC,
677
+ short CC,
678
+ short RA,
679
+ short CA,
680
+ short RB,
681
+ short CB,
682
+ typename CType,
683
+ typename AType,
684
+ typename BType,
685
+ bool transpose_a,
686
+ bool transpose_b,
687
+ typename NAXFrag_t = BaseNAXFrag>
688
+ METAL_FUNC void subtile_matmad_nax(
689
+ thread NAXSubTile<CType, RC, CC, NAXFrag_t>& C,
690
+ thread NAXSubTile<AType, RA, CA, NAXFrag_t>& A,
691
+ metal::bool_constant<transpose_a>,
692
+ thread NAXSubTile<BType, RB, CB, NAXFrag_t>& B,
693
+ metal::bool_constant<transpose_b>) {
694
+ // Static checks
695
+ constexpr short FMa = transpose_a ? CA : RA;
696
+ constexpr short FMc = RC;
697
+ static_assert(FMa == FMc, "NAX matmul: M dimensions do not match");
698
+
699
+ constexpr short FNb = transpose_b ? RB : CB;
700
+ constexpr short FNc = CC;
701
+ static_assert(FNb == FNc, "NAX matmul: N dimensions do not match");
702
+
703
+ constexpr short FKa = transpose_a ? RA : CA;
704
+ constexpr short FKb = transpose_b ? CB : RB;
705
+ static_assert(FKa == FKb, "NAX matmul: N dimensions do not match");
706
+
707
+ constexpr short FM = FMc;
708
+ constexpr short FN = FNc;
709
+ constexpr short FK = FKa;
710
+
711
+ constexpr int TM = FM / 16;
712
+ constexpr int TN = FN / 16;
713
+ constexpr int TK = FK / 16;
714
+
715
+ // Create Matmul descriptor
716
+ constexpr auto desc = mpp::tensor_ops::matmul2d_descriptor(
717
+ FM,
718
+ FN,
719
+ FK,
720
+ transpose_a,
721
+ transpose_b,
722
+ true,
723
+ mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate);
724
+
725
+ // Create matmul op
726
+ mpp::tensor_ops::matmul2d<desc, metal::execution_simdgroup> gemm_op;
727
+
728
+ // Create matmul operands in registers
729
+ auto ct_a =
730
+ gemm_op.template get_left_input_cooperative_tensor<AType, BType, CType>();
731
+ auto ct_b =
732
+ gemm_op
733
+ .template get_right_input_cooperative_tensor<AType, BType, CType>();
734
+
735
+ // Create matmul output in register
736
+ auto ct_c = gemm_op.template get_destination_cooperative_tensor<
737
+ decltype(ct_a),
738
+ decltype(ct_b),
739
+ CType>();
740
+
741
+ // Load A in to left operand registers
742
+ STEEL_PRAGMA_UNROLL
743
+ for (short mm = 0; mm < TM; mm++) {
744
+ STEEL_PRAGMA_UNROLL
745
+ for (short kk = 0; kk < TK; kk++) {
746
+ const short fi = transpose_a ? kk : mm;
747
+ const short fj = transpose_a ? mm : kk;
748
+
749
+ STEEL_PRAGMA_UNROLL
750
+ for (short i = 0; i < 8; i++) {
751
+ ct_a[(TK * mm + kk) * 8 + i] = A.frag_at(fi, fj)[i];
752
+ }
753
+ }
754
+ }
755
+
756
+ // Load B into right operand registers
757
+ STEEL_PRAGMA_UNROLL
758
+ for (short nn = 0; nn < TN; nn++) {
759
+ STEEL_PRAGMA_UNROLL
760
+ for (short kk = 0; kk < TK; kk++) {
761
+ const short fi = transpose_b ? nn : kk;
762
+ const short fj = transpose_b ? kk : nn;
763
+
764
+ STEEL_PRAGMA_UNROLL
765
+ for (short i = 0; i < 8; i++) {
766
+ ct_b[(TN * kk + nn) * 8 + i] = B.frag_at(fi, fj)[i];
767
+ }
768
+ }
769
+ }
770
+
771
+ // Load C into output registers (op handles accumulation)
772
+ STEEL_PRAGMA_UNROLL
773
+ for (short i = 0; i < ct_c.get_capacity(); i++) {
774
+ ct_c[i] = C.elems()[i];
775
+ }
776
+
777
+ // Do matmul
778
+ gemm_op.run(ct_a, ct_b, ct_c);
779
+
780
+ // Copy out results
781
+ STEEL_PRAGMA_UNROLL
782
+ for (short i = 0; i < ct_c.get_capacity(); i++) {
783
+ C.elems()[i] = ct_c[i];
784
+ }
785
+ }
786
+
787
+ template <typename T, short kTileRows_, short kTileCols_, class NAXSubTile_>
788
+ struct NAXTile {
789
+ using NAXSubTile_t = NAXSubTile_;
790
+ using elem_type = T;
791
+ STEEL_CONST short kSubTileRows = NAXSubTile_t::kRows;
792
+ STEEL_CONST short kSubTileCols = NAXSubTile_t::kCols;
793
+ STEEL_CONST short kElemsPerSubTile = NAXSubTile_t::kElemsPerSubTile;
794
+
795
+ STEEL_CONST short kTileRows = kTileRows_;
796
+ STEEL_CONST short kTileCols = kTileCols_;
797
+
798
+ STEEL_CONST short kRows = kTileRows * kSubTileRows;
799
+ STEEL_CONST short kCols = kTileCols * kSubTileCols;
800
+
801
+ STEEL_CONST short kSubTiles = kTileRows * kTileCols;
802
+ STEEL_CONST short kElemsPerTile = kSubTiles * kElemsPerSubTile;
803
+
804
+ STEEL_CONST short kRowsPerThread = kTileRows * NAXSubTile_t::kRowsPerThread;
805
+ STEEL_CONST short kColsPerThread = kTileCols * NAXSubTile_t::kColsPerThread;
806
+
807
+ STEEL_CONST short kSubTileThrRows = NAXSubTile_t::kRowsPerThread;
808
+ STEEL_CONST short kSubTileThrCols = NAXSubTile_t::kColsPerThread;
809
+
810
+ NAXSubTile_t val_subtiles[kSubTiles];
811
+
812
+ METAL_FUNC NAXTile() thread {}
813
+
814
+ METAL_FUNC constexpr void clear() {
815
+ STEEL_PRAGMA_UNROLL
816
+ for (short i = 0; i < kSubTiles; ++i) {
817
+ val_subtiles[i].clear();
818
+ }
819
+ }
820
+
821
+ METAL_FUNC constexpr thread NAXSubTile_t& subtile_at(
822
+ const short i,
823
+ const short j) {
824
+ return val_subtiles[i * kTileCols + j];
825
+ }
826
+
827
+ METAL_FUNC constexpr const thread NAXSubTile_t& subtile_at(
828
+ const short i,
829
+ const short j) const {
830
+ return val_subtiles[i * kTileCols + j];
831
+ }
832
+
833
+ template <int i, int j>
834
+ METAL_FUNC constexpr const thread NAXSubTile_t& subtile_at() const {
835
+ return val_subtiles[i * kTileCols + j];
836
+ }
837
+
838
+ METAL_FUNC thread elem_type* elems() {
839
+ return reinterpret_cast<thread elem_type*>(val_subtiles[0].elems());
840
+ }
841
+
842
+ METAL_FUNC const thread elem_type* elems() const {
843
+ return reinterpret_cast<const thread elem_type*>(val_subtiles[0].elems());
844
+ }
845
+
846
+ template <typename Op>
847
+ METAL_FUNC void row_reduce(thread metal::vec<T, kRowsPerThread>& vals) const {
848
+ auto sub_rows = (thread metal::vec<T, kSubTileThrRows>*)(&vals);
849
+ STEEL_PRAGMA_UNROLL
850
+ for (short i = 0; i < kTileRows; ++i) {
851
+ STEEL_PRAGMA_UNROLL
852
+ for (short j = 0; j < kTileCols; ++j) {
853
+ subtile_at(i, j).template row_reduce<Op>(sub_rows[i]);
854
+ }
855
+ }
856
+ }
857
+
858
+ template <typename Op>
859
+ METAL_FUNC void row_bin_op(thread metal::vec<T, kRowsPerThread>& vals) {
860
+ auto sub_rows = (thread metal::vec<T, kSubTileThrRows>*)(&vals);
861
+ STEEL_PRAGMA_UNROLL
862
+ for (short i = 0; i < kTileRows; ++i) {
863
+ STEEL_PRAGMA_UNROLL
864
+ for (short j = 0; j < kTileCols; ++j) {
865
+ subtile_at(i, j).template row_bin_op<Op>(sub_rows[i]);
866
+ }
867
+ }
868
+ }
869
+
870
+ template <typename U, int str_x, int str_y>
871
+ METAL_FUNC void load(const threadgroup U* src) {
872
+ STEEL_PRAGMA_UNROLL
873
+ for (short i = 0; i < kTileRows; ++i) {
874
+ STEEL_PRAGMA_UNROLL
875
+ for (short j = 0; j < kTileCols; ++j) {
876
+ subtile_at(i, j).load(
877
+ src,
878
+ Int<str_x>{},
879
+ Int<str_y>{},
880
+ i * kSubTileRows,
881
+ j * kSubTileCols);
882
+ }
883
+ }
884
+ }
885
+
886
+ template <typename U, int str_x, int str_y>
887
+ METAL_FUNC void store(threadgroup U* dst) const {
888
+ STEEL_PRAGMA_UNROLL
889
+ for (short i = 0; i < kTileRows; ++i) {
890
+ STEEL_PRAGMA_UNROLL
891
+ for (short j = 0; j < kTileCols; ++j) {
892
+ subtile_at(i, j).store(
893
+ dst,
894
+ Int<str_x>{},
895
+ Int<str_y>{},
896
+ i * kSubTileRows,
897
+ j * kSubTileCols);
898
+ }
899
+ }
900
+ }
901
+
902
+ template <typename U>
903
+ METAL_FUNC void load(const device U* src, const int ld) {
904
+ STEEL_PRAGMA_UNROLL
905
+ for (short i = 0; i < kTileRows; ++i) {
906
+ STEEL_PRAGMA_UNROLL
907
+ for (short j = 0; j < kTileCols; ++j) {
908
+ subtile_at(i, j).load(
909
+ &src[(i * kSubTileRows * ld + j * kSubTileCols)], ld, Int<1>{});
910
+ }
911
+ }
912
+ }
913
+
914
+ template <typename U>
915
+ METAL_FUNC void store(device U* dst, const int ld) const {
916
+ STEEL_PRAGMA_UNROLL
917
+ for (short i = 0; i < kTileRows; ++i) {
918
+ STEEL_PRAGMA_UNROLL
919
+ for (short j = 0; j < kTileCols; ++j) {
920
+ subtile_at(i, j).store(
921
+ &dst[(i * kSubTileRows * ld + j * kSubTileCols)], ld, Int<1>{});
922
+ }
923
+ }
924
+ }
925
+
926
+ template <typename U>
927
+ METAL_FUNC void
928
+ load_rows(const device U* src, const int ld, const short n_rows) {
929
+ STEEL_PRAGMA_UNROLL
930
+ for (int i = 0; i < kTileRows; ++i) {
931
+ STEEL_PRAGMA_UNROLL
932
+ for (int j = 0; j < kTileCols; ++j) {
933
+ subtile_at(i, j).load_rows(
934
+ &src[(i * kSubTileRows) * ld + (j * kSubTileCols)],
935
+ ld,
936
+ Int<1>{},
937
+ n_rows - i * kSubTileRows);
938
+ }
939
+ }
940
+ }
941
+
942
+ template <typename U>
943
+ METAL_FUNC void
944
+ load_safe(const device U* src, const int ld, const short2 src_tile_dims) {
945
+ STEEL_PRAGMA_UNROLL
946
+ for (int i = 0; i < kTileRows; ++i) {
947
+ STEEL_PRAGMA_UNROLL
948
+ for (int j = 0; j < kTileCols; ++j) {
949
+ subtile_at(i, j).load_safe(
950
+ src,
951
+ ld,
952
+ Int<1>{},
953
+ src_tile_dims.y,
954
+ src_tile_dims.x,
955
+ i * kSubTileRows,
956
+ j * kSubTileCols);
957
+ }
958
+ }
959
+ }
960
+
961
+ template <typename U>
962
+ METAL_FUNC void store_rows(device U* dst, const int ld, const short n_rows)
963
+ const {
964
+ STEEL_PRAGMA_UNROLL
965
+ for (int i = 0; i < kTileRows; ++i) {
966
+ STEEL_PRAGMA_UNROLL
967
+ for (int j = 0; j < kTileCols; ++j) {
968
+ subtile_at(i, j).store_rows(
969
+ &dst[(i * kSubTileRows) * ld + (j * kSubTileCols)],
970
+ ld,
971
+ Int<1>{},
972
+ n_rows - i * kSubTileRows);
973
+ }
974
+ }
975
+ }
976
+
977
+ template <typename U>
978
+ METAL_FUNC void
979
+ store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const {
980
+ STEEL_PRAGMA_UNROLL
981
+ for (int i = 0; i < kTileRows; ++i) {
982
+ STEEL_PRAGMA_UNROLL
983
+ for (int j = 0; j < kTileCols; ++j) {
984
+ subtile_at(i, j).store_safe(
985
+ dst,
986
+ ld,
987
+ Int<1>{},
988
+ dst_tile_dims.y,
989
+ dst_tile_dims.x,
990
+ i * kSubTileRows,
991
+ j * kSubTileCols);
992
+ }
993
+ }
994
+ }
995
+
996
+ template <typename U>
997
+ METAL_FUNC void store_slice(
998
+ device U* dst,
999
+ const int ld,
1000
+ const short2 start,
1001
+ const short2 stop) const {
1002
+ const_for_loop<0, kTileRows, 1>([&](auto idx_row) {
1003
+ const_for_loop<0, kTileCols, 1>([&](auto idx_col) {
1004
+ subtile_at<idx_row.value, idx_col.value>().store_slice(
1005
+ dst,
1006
+ ld,
1007
+ Int<1>{},
1008
+ start.y,
1009
+ stop.y,
1010
+ start.x,
1011
+ stop.x,
1012
+ idx_row * Int<kSubTileRows>{},
1013
+ idx_col * Int<kSubTileCols>{});
1014
+ });
1015
+ });
1016
+ }
1017
+ };
1018
+
1019
+ template <
1020
+ class CTile,
1021
+ class ATile,
1022
+ class BTile,
1023
+ bool transpose_a,
1024
+ bool transpose_b>
1025
+ METAL_FUNC void tile_matmad_nax(
1026
+ thread CTile& C,
1027
+ thread ATile& A,
1028
+ metal::bool_constant<transpose_a>,
1029
+ thread BTile& B,
1030
+ metal::bool_constant<transpose_b>) {
1031
+ // Static checks
1032
+ constexpr short TMa = transpose_a ? ATile::kTileCols : ATile::kTileRows;
1033
+ constexpr short TMc = CTile::kTileRows;
1034
+ static_assert(TMa == TMc, "NAX tile matmul: M dimensions do not match");
1035
+
1036
+ constexpr short FMa = transpose_a ? ATile::kSubTileCols : ATile::kSubTileRows;
1037
+ constexpr short FMc = CTile::kSubTileRows;
1038
+ static_assert(FMa == FMc, "NAX subtile matmul: M dimensions do not match");
1039
+
1040
+ constexpr short TNb = transpose_b ? BTile::kTileRows : BTile::kTileCols;
1041
+ constexpr short TNc = CTile::kTileCols;
1042
+ static_assert(TNb == TNc, "NAX tile matmul: N dimensions do not match");
1043
+
1044
+ constexpr short FNb = transpose_b ? BTile::kSubTileRows : BTile::kSubTileCols;
1045
+ constexpr short FNc = CTile::kSubTileCols;
1046
+ static_assert(FNb == FNc, "NAX subtile matmul: N dimensions do not match");
1047
+
1048
+ constexpr short TKa = transpose_a ? ATile::kTileRows : ATile::kTileCols;
1049
+ constexpr short TKb = transpose_b ? BTile::kTileCols : BTile::kTileRows;
1050
+ static_assert(TKa == TKb, "NAX tile matmul: K dimensions do not match");
1051
+
1052
+ constexpr short FKa = transpose_a ? ATile::kSubTileRows : ATile::kSubTileCols;
1053
+ constexpr short FKb = transpose_b ? BTile::kSubTileCols : BTile::kSubTileRows;
1054
+ static_assert(FKa == FKb, "NAX subtile matmul: K dimensions do not match");
1055
+
1056
+ constexpr short TM = TMc;
1057
+ constexpr short TN = TNc;
1058
+ constexpr short TK = TKa;
1059
+
1060
+ // Do matmul here
1061
+ STEEL_PRAGMA_UNROLL
1062
+ for (short i = 0; i < TM; ++i) {
1063
+ STEEL_PRAGMA_UNROLL
1064
+ for (short j = 0; j < TN; ++j) {
1065
+ STEEL_PRAGMA_UNROLL
1066
+ for (short k = 0; k < TK; ++k) {
1067
+ const short ra = transpose_a ? k : i;
1068
+ const short ca = transpose_a ? i : k;
1069
+ const short rb = transpose_b ? j : k;
1070
+ const short cb = transpose_b ? k : j;
1071
+
1072
+ subtile_matmad_nax(
1073
+ C.subtile_at(i, j),
1074
+ A.subtile_at(ra, ca),
1075
+ metal::bool_constant<transpose_a>{},
1076
+ B.subtile_at(rb, cb),
1077
+ metal::bool_constant<transpose_b>{});
1078
+ }
1079
+ }
1080
+ }
1081
+ }
1082
+
1083
+ } // namespace steel
1084
+ } // namespace mlx