mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,1076 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <metal_simdgroup>
6
+ #include <metal_simdgroup_matrix>
7
+ #include <metal_stdlib>
8
+
9
+ #include "mlx/backend/metal/kernels/steel/defines.h"
10
+ #include "mlx/backend/metal/kernels/steel/utils/integral_constant.h"
11
+
12
+ #include <MetalPerformancePrimitives/MetalPerformancePrimitives.h>
13
+
14
+ using namespace metal;
15
+
16
+ ///////////////////////////////////////////////////////////////////////////////
17
+ // MMA helper
18
+ ///////////////////////////////////////////////////////////////////////////////
19
+
20
+ namespace mlx {
21
+ namespace steel {
22
+
23
+ ///////////////////////////////////////////////////////////////////////////////
24
+ // NAX Steel with new tiles
25
+ ///////////////////////////////////////////////////////////////////////////////
26
+
27
+ struct BaseNAXFrag {
28
+ STEEL_CONST short kFragRows = 16;
29
+ STEEL_CONST short kFragCols = 16;
30
+
31
+ STEEL_CONST short kElemsPerFrag = (kFragRows * kFragCols) / 32;
32
+
33
+ STEEL_CONST short kElemRows = 2;
34
+ STEEL_CONST short kElemCols = 4;
35
+
36
+ STEEL_CONST short kElemRowsJump = 8;
37
+
38
+ static_assert(
39
+ kElemRows * kElemCols == kElemsPerFrag,
40
+ "MMAFrag shape is not consistent with MMAFrag size");
41
+
42
+ template <typename U>
43
+ using dtype_frag_t = typename metal::vec<U, kElemsPerFrag>;
44
+
45
+ METAL_FUNC static short2 get_coord() {
46
+ const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort());
47
+ const short qid = simd_lane_id >> 2;
48
+ const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3));
49
+ const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4;
50
+ return short2{fn, fm};
51
+ }
52
+
53
+ METAL_FUNC static short2 get_coord(short idx) {
54
+ const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort());
55
+ const short qid = simd_lane_id >> 2;
56
+ const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3)) + (idx >> 2) * 8;
57
+ const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4 + idx % 4;
58
+ return short2{fn, fm};
59
+ }
60
+
61
+ template <
62
+ typename T,
63
+ typename SrcPtrType,
64
+ typename StrX,
65
+ typename StrY,
66
+ typename OffX = Int<0>,
67
+ typename OffY = Int<0>>
68
+ METAL_FUNC static constexpr void load(
69
+ thread dtype_frag_t<T>& dst,
70
+ SrcPtrType src,
71
+ StrX str_x,
72
+ StrY str_y,
73
+ OffX off_x = {},
74
+ OffY off_y = {}) {
75
+ const short2 sc = short2{0, 0}; // get_coord();
76
+ STEEL_PRAGMA_UNROLL
77
+ for (short i = 0; i < kElemRows; i++) {
78
+ const auto r = off_x + i * kElemRowsJump + sc.y;
79
+ const auto c = off_y + sc.x;
80
+
81
+ if constexpr (metal::is_same_v<StrY, Int<1>>) {
82
+ STEEL_PRAGMA_UNROLL
83
+ for (short j = 0; j < kElemCols; j++) {
84
+ dst[i * kElemCols + j] = static_cast<T>(src[r * str_x + c + j]);
85
+ }
86
+ } else {
87
+ STEEL_PRAGMA_UNROLL
88
+ for (short j = 0; j < kElemCols; j++) {
89
+ dst[i * kElemCols + j] =
90
+ static_cast<T>(src[r * str_x + (c + j) * str_y]);
91
+ }
92
+ }
93
+ }
94
+ }
95
+
96
+ template <
97
+ typename T,
98
+ typename SrcPtrType,
99
+ typename StrX,
100
+ typename StrY,
101
+ typename LimX,
102
+ typename OffX = Int<0>,
103
+ typename OffY = Int<0>>
104
+ METAL_FUNC static constexpr void load_rows(
105
+ thread dtype_frag_t<T>& dst,
106
+ SrcPtrType src,
107
+ StrX str_x,
108
+ StrY str_y,
109
+ LimX lim_x,
110
+ OffX off_x = {},
111
+ OffY off_y = {}) {
112
+ const short2 sc = short2{0, 0}; // get_coord();
113
+ STEEL_PRAGMA_UNROLL
114
+ for (short i = 0; i < kElemRows; i++) {
115
+ const auto r = off_x + i * kElemRowsJump + sc.y;
116
+ const auto c = off_y + sc.x;
117
+
118
+ if (r < lim_x) {
119
+ if constexpr (metal::is_same_v<StrY, Int<1>>) {
120
+ STEEL_PRAGMA_UNROLL
121
+ for (short j = 0; j < kElemCols; j++) {
122
+ dst[i * kElemCols + j] = static_cast<T>(src[r * str_x + (c + j)]);
123
+ }
124
+ } else {
125
+ STEEL_PRAGMA_UNROLL
126
+ for (short j = 0; j < kElemCols; j++) {
127
+ dst[i * kElemCols + j] =
128
+ static_cast<T>(src[r * str_x + (c + j) * str_y]);
129
+ }
130
+ }
131
+
132
+ } else {
133
+ dst = dtype_frag_t<T>(0);
134
+ }
135
+ }
136
+ }
137
+
138
+ template <
139
+ typename T,
140
+ typename SrcPtrType,
141
+ typename StrX,
142
+ typename StrY,
143
+ typename LimX,
144
+ typename LimY,
145
+ typename OffX = Int<0>,
146
+ typename OffY = Int<0>>
147
+ METAL_FUNC static constexpr void load_safe(
148
+ thread dtype_frag_t<T>& dst,
149
+ SrcPtrType src,
150
+ StrX str_x,
151
+ StrY str_y,
152
+ LimX lim_x,
153
+ LimY lim_y,
154
+ OffX off_x = {},
155
+ OffY off_y = {}) {
156
+ const short2 sc = short2{0, 0}; // get_coord();
157
+ STEEL_PRAGMA_UNROLL
158
+ for (short i = 0; i < kElemRows; i++) {
159
+ const auto r = off_x + i * kElemRowsJump + sc.y;
160
+ const auto c = off_y + sc.x;
161
+ STEEL_PRAGMA_UNROLL
162
+ for (short j = 0; j < kElemCols; j++) {
163
+ if (r < lim_x && (c + j) < lim_y) {
164
+ dst[i * kElemCols + j] =
165
+ static_cast<T>(src[r * str_x + (c + j) * str_y]);
166
+ } else {
167
+ dst[i * kElemCols + j] = T(0);
168
+ }
169
+ }
170
+ }
171
+ }
172
+
173
+ template <
174
+ typename T,
175
+ typename DstPtrType,
176
+ typename StrX,
177
+ typename StrY,
178
+ typename OffX = Int<0>,
179
+ typename OffY = Int<0>>
180
+ METAL_FUNC static constexpr void store(
181
+ const thread dtype_frag_t<T>& src,
182
+ DstPtrType dst,
183
+ StrX str_x,
184
+ StrY str_y,
185
+ OffX off_x = {},
186
+ OffY off_y = {}) {
187
+ using U = pointer_element_t<DstPtrType>;
188
+
189
+ const short2 sc = short2{0, 0}; // get_coord();
190
+ STEEL_PRAGMA_UNROLL
191
+ for (short i = 0; i < kElemRows; i++) {
192
+ const auto r = off_x + i * kElemRowsJump + sc.y;
193
+ const auto c = off_y + sc.x;
194
+
195
+ if constexpr (metal::is_same_v<StrY, Int<1>>) {
196
+ STEEL_PRAGMA_UNROLL
197
+ for (short j = 0; j < kElemCols; j++) {
198
+ dst[r * str_x + c + j] = static_cast<U>(src[i * kElemCols + j]);
199
+ }
200
+ } else {
201
+ STEEL_PRAGMA_UNROLL
202
+ for (short j = 0; j < kElemCols; j++) {
203
+ dst[r * str_x + (c + j) * str_y] =
204
+ static_cast<U>(src[i * kElemCols + j]);
205
+ }
206
+ }
207
+ }
208
+ }
209
+
210
+ template <
211
+ typename T,
212
+ typename DstPtrType,
213
+ typename StrX,
214
+ typename StrY,
215
+ typename LimX,
216
+ typename OffX = Int<0>,
217
+ typename OffY = Int<0>>
218
+ METAL_FUNC static constexpr void store_rows(
219
+ const thread dtype_frag_t<T>& src,
220
+ DstPtrType dst,
221
+ StrX str_x,
222
+ StrY str_y,
223
+ LimX lim_x,
224
+ OffX off_x = {},
225
+ OffY off_y = {}) {
226
+ using U = pointer_element_t<DstPtrType>;
227
+
228
+ const short2 sc = short2{0, 0}; // get_coord();
229
+ STEEL_PRAGMA_UNROLL
230
+ for (short i = 0; i < kElemRows; i++) {
231
+ const auto r = off_x + i * kElemRowsJump + sc.y;
232
+ const auto c = off_y + sc.x;
233
+
234
+ if (r < lim_x) {
235
+ if constexpr (metal::is_same_v<StrY, Int<1>>) {
236
+ STEEL_PRAGMA_UNROLL
237
+ for (short j = 0; j < kElemCols; j++) {
238
+ dst[r * str_x + c + j] = static_cast<U>(src[i * kElemCols + j]);
239
+ }
240
+ } else {
241
+ STEEL_PRAGMA_UNROLL
242
+ for (short j = 0; j < kElemCols; j++) {
243
+ dst[r * str_x + (c + j) * str_y] =
244
+ static_cast<U>(src[i * kElemCols + j]);
245
+ }
246
+ }
247
+ }
248
+ }
249
+ }
250
+
251
+ template <
252
+ typename T,
253
+ typename DstPtrType,
254
+ typename StrX,
255
+ typename StrY,
256
+ typename LimX,
257
+ typename LimY,
258
+ typename OffX = Int<0>,
259
+ typename OffY = Int<0>>
260
+ METAL_FUNC static constexpr void store_safe(
261
+ const thread dtype_frag_t<T>& src,
262
+ DstPtrType dst,
263
+ StrX str_x,
264
+ StrY str_y,
265
+ LimX lim_x,
266
+ LimY lim_y,
267
+ OffX off_x = {},
268
+ OffY off_y = {}) {
269
+ using U = pointer_element_t<DstPtrType>;
270
+
271
+ const short2 sc = short2{0, 0}; // get_coord();
272
+ STEEL_PRAGMA_UNROLL
273
+ for (short i = 0; i < kElemRows; i++) {
274
+ const auto r = off_x + i * kElemRowsJump + sc.y;
275
+ const auto c = off_y + sc.x;
276
+
277
+ STEEL_PRAGMA_UNROLL
278
+ for (short j = 0; j < kElemCols; j++) {
279
+ if (r < lim_x && (c + j) < lim_y) {
280
+ dst[r * str_x + (c + j) * str_y] =
281
+ static_cast<U>(src[i * kElemCols + j]);
282
+ }
283
+ }
284
+ }
285
+ }
286
+
287
+ template <
288
+ typename T,
289
+ typename DstPtrType,
290
+ typename StrX,
291
+ typename StrY,
292
+ typename StartX,
293
+ typename StopX,
294
+ typename StartY,
295
+ typename StopY,
296
+ typename OffX = Int<0>,
297
+ typename OffY = Int<0>>
298
+ METAL_FUNC static constexpr void store_slice(
299
+ const thread dtype_frag_t<T>& src,
300
+ DstPtrType dst,
301
+ StrX str_x,
302
+ StrY str_y,
303
+ StartX start_x,
304
+ StopX stop_x,
305
+ StartY start_y,
306
+ StopY stop_y,
307
+ OffX off_x = Int<0>{},
308
+ OffY off_y = Int<0>{}) {
309
+ using U = pointer_element_t<DstPtrType>;
310
+
311
+ const short2 sc = short2{0, 0}; // get_coord();
312
+
313
+ const_for_loop<0, kElemRows, 1>([&](auto idx_row) {
314
+ const auto r = off_x + idx_row * Int<kElemRowsJump>{};
315
+ if (r >= stop_x - sc.y || r < start_x - sc.y) {
316
+ return;
317
+ }
318
+
319
+ const_for_loop<0, kElemCols, 1>([&](auto idx_col) {
320
+ const auto c = off_y + idx_col;
321
+ if (c >= stop_y - sc.x || c < start_y - sc.x) {
322
+ return;
323
+ }
324
+
325
+ const auto src_idx = idx_row * Int<kElemCols>{} + idx_col;
326
+ dst[(r + sc.y) * str_x + (c + sc.x) * str_y] =
327
+ static_cast<U>(src[src_idx]);
328
+ });
329
+ });
330
+ }
331
+
332
+ template <typename Op, typename T>
333
+ METAL_FUNC static constexpr void row_reduce(
334
+ thread const dtype_frag_t<T>& inp_vals,
335
+ thread T* reduced_vals) {
336
+ STEEL_PRAGMA_UNROLL
337
+ for (short i = 0; i < kElemRows; i++) {
338
+ T thr_reduce = Op::apply(
339
+ Op::apply(inp_vals[i * kElemCols + 0], inp_vals[i * kElemCols + 1]),
340
+ Op::apply(inp_vals[i * kElemCols + 2], inp_vals[i * kElemCols + 3]));
341
+
342
+ T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1));
343
+ qgr_reduce = Op::apply(thr_reduce, qgr_reduce);
344
+
345
+ T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8));
346
+ sgr_reduce = Op::apply(qgr_reduce, sgr_reduce);
347
+
348
+ reduced_vals[i] = Op::apply(reduced_vals[i], sgr_reduce);
349
+ }
350
+ }
351
+
352
+ template <typename Op, typename T>
353
+ METAL_FUNC static constexpr void row_bin_op(
354
+ thread dtype_frag_t<T>& inp_vals,
355
+ thread T* row_vals) {
356
+ STEEL_PRAGMA_UNROLL
357
+ for (short i = 0; i < kElemRows; i++) {
358
+ STEEL_PRAGMA_UNROLL
359
+ for (short j = 0; j < kElemCols; j++) {
360
+ inp_vals[i * kElemCols + j] =
361
+ Op::apply(inp_vals[i * kElemCols + j], row_vals[i]);
362
+ }
363
+ }
364
+ }
365
+ };
366
+
367
+ template <
368
+ typename T,
369
+ short kRows_,
370
+ short kCols_,
371
+ typename NAXFrag_ = BaseNAXFrag>
372
+ struct NAXSubTile {
373
+ using NAXFrag_t = NAXFrag_;
374
+ STEEL_CONST short kRows = kRows_;
375
+ STEEL_CONST short kCols = kCols_;
376
+
377
+ STEEL_CONST short kFragRows = NAXFrag_t::kFragRows;
378
+ STEEL_CONST short kFragCols = NAXFrag_t::kFragCols;
379
+ STEEL_CONST short kElemsPerFrag = NAXFrag_t::kElemsPerFrag;
380
+
381
+ STEEL_CONST short kSubTileRows = kRows / kFragRows;
382
+ STEEL_CONST short kSubTileCols = kCols / kFragCols;
383
+
384
+ STEEL_CONST short kNumFrags = kSubTileRows * kSubTileCols;
385
+ STEEL_CONST short kElemsPerSubTile = kNumFrags * kElemsPerFrag;
386
+
387
+ STEEL_CONST int kRowsPerThread = kSubTileRows * NAXFrag_t::kElemRows;
388
+ STEEL_CONST int kColsPerThread = kSubTileCols * NAXFrag_t::kElemCols;
389
+
390
+ STEEL_CONST short kFragThrRows = NAXFrag_t::kElemRows;
391
+ STEEL_CONST short kFragThrCols = NAXFrag_t::kElemCols;
392
+ STEEL_CONST short kFragRowsJump = NAXFrag_t::kElemRowsJump;
393
+
394
+ using frag_type = typename NAXFrag_t::template dtype_frag_t<T>;
395
+
396
+ frag_type val_frags[kNumFrags];
397
+
398
+ METAL_FUNC constexpr void clear() {
399
+ STEEL_PRAGMA_UNROLL
400
+ for (short i = 0; i < kNumFrags; ++i) {
401
+ val_frags[i] = frag_type(0);
402
+ }
403
+ }
404
+
405
+ METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) {
406
+ return val_frags[i * kSubTileCols + j];
407
+ }
408
+
409
+ METAL_FUNC constexpr const thread frag_type& frag_at(
410
+ const short i,
411
+ const short j) const {
412
+ return val_frags[i * kSubTileCols + j];
413
+ }
414
+
415
+ template <int i, int j>
416
+ METAL_FUNC constexpr thread frag_type& frag_at() {
417
+ return val_frags[i * kSubTileCols + j];
418
+ }
419
+
420
+ template <int i, int j>
421
+ METAL_FUNC constexpr const thread frag_type& frag_at() const {
422
+ return val_frags[i * kSubTileCols + j];
423
+ }
424
+
425
+ METAL_FUNC thread T* elems() {
426
+ return reinterpret_cast<thread T*>(val_frags);
427
+ }
428
+
429
+ METAL_FUNC const thread T* elems() const {
430
+ return reinterpret_cast<const thread T*>(val_frags);
431
+ }
432
+
433
+ template <typename Op>
434
+ METAL_FUNC void row_reduce(thread metal::vec<T, kRowsPerThread>& vals) const {
435
+ thread T* vptr = (thread T*)(&vals);
436
+ STEEL_PRAGMA_UNROLL
437
+ for (short i = 0; i < kSubTileRows; ++i) {
438
+ STEEL_PRAGMA_UNROLL
439
+ for (short j = 0; j < kSubTileCols; ++j) {
440
+ NAXFrag_t::template row_reduce<Op>(
441
+ frag_at(i, j), &vptr[i * kFragThrRows]);
442
+ }
443
+ }
444
+ }
445
+
446
+ template <typename Op>
447
+ METAL_FUNC void row_bin_op(thread metal::vec<T, kRowsPerThread>& vals) {
448
+ thread T* vptr = (thread T*)(&vals);
449
+ STEEL_PRAGMA_UNROLL
450
+ for (short i = 0; i < kSubTileRows; ++i) {
451
+ STEEL_PRAGMA_UNROLL
452
+ for (short j = 0; j < kSubTileCols; ++j) {
453
+ NAXFrag_t::template row_bin_op<Op>(
454
+ frag_at(i, j), &vptr[i * kFragThrRows]);
455
+ }
456
+ }
457
+ }
458
+
459
+ template <
460
+ typename SrcPtrType,
461
+ typename StrX,
462
+ typename StrY,
463
+ typename OffX = Int<0>,
464
+ typename OffY = Int<0>>
465
+ METAL_FUNC constexpr void load(
466
+ SrcPtrType src,
467
+ StrX str_x,
468
+ StrY str_y,
469
+ OffX off_x = {},
470
+ OffY off_y = {}) {
471
+ STEEL_PRAGMA_UNROLL
472
+ for (short i = 0; i < kSubTileRows; ++i) {
473
+ STEEL_PRAGMA_UNROLL
474
+ for (short j = 0; j < kSubTileCols; ++j) {
475
+ NAXFrag_t::load(
476
+ frag_at(i, j),
477
+ src,
478
+ str_x,
479
+ str_y,
480
+ off_x + i * kFragRows,
481
+ off_y + j * kFragCols);
482
+ }
483
+ }
484
+ }
485
+
486
+ template <
487
+ typename DstPtrType,
488
+ typename StrX,
489
+ typename StrY,
490
+ typename OffX = Int<0>,
491
+ typename OffY = Int<0>>
492
+ METAL_FUNC constexpr void store(
493
+ DstPtrType dst,
494
+ StrX str_x,
495
+ StrY str_y,
496
+ OffX off_x = {},
497
+ OffY off_y = {}) const {
498
+ STEEL_PRAGMA_UNROLL
499
+ for (short i = 0; i < kSubTileRows; ++i) {
500
+ STEEL_PRAGMA_UNROLL
501
+ for (short j = 0; j < kSubTileCols; ++j) {
502
+ NAXFrag_t::store(
503
+ frag_at(i, j),
504
+ dst,
505
+ str_x,
506
+ str_y,
507
+ off_x + i * kFragRows,
508
+ off_y + j * kFragCols);
509
+ }
510
+ }
511
+ }
512
+
513
+ template <
514
+ typename SrcPtrType,
515
+ typename StrX,
516
+ typename StrY,
517
+ typename LimX,
518
+ typename OffX = Int<0>,
519
+ typename OffY = Int<0>>
520
+ METAL_FUNC constexpr void load_rows(
521
+ SrcPtrType src,
522
+ StrX str_x,
523
+ StrY str_y,
524
+ LimX lim_x,
525
+ OffX off_x = {},
526
+ OffY off_y = {}) {
527
+ STEEL_PRAGMA_UNROLL
528
+ for (int i = 0; i < kSubTileRows; ++i) {
529
+ STEEL_PRAGMA_UNROLL
530
+ for (int j = 0; j < kSubTileCols; ++j) {
531
+ NAXFrag_t::load_rows(
532
+ frag_at(i, j),
533
+ src,
534
+ str_x,
535
+ str_y,
536
+ lim_x,
537
+ off_x + (i * kFragRows),
538
+ off_y + (j * kFragCols));
539
+ }
540
+ }
541
+ }
542
+
543
+ template <
544
+ typename SrcPtrType,
545
+ typename StrX,
546
+ typename StrY,
547
+ typename LimX,
548
+ typename LimY,
549
+ typename OffX = Int<0>,
550
+ typename OffY = Int<0>>
551
+ METAL_FUNC constexpr void load_safe(
552
+ SrcPtrType src,
553
+ StrX str_x,
554
+ StrY str_y,
555
+ LimX lim_x,
556
+ LimY lim_y,
557
+ OffX off_x = {},
558
+ OffY off_y = {}) {
559
+ STEEL_PRAGMA_UNROLL
560
+ for (int i = 0; i < kSubTileRows; ++i) {
561
+ STEEL_PRAGMA_UNROLL
562
+ for (int j = 0; j < kSubTileCols; ++j) {
563
+ NAXFrag_t::load_safe(
564
+ frag_at(i, j),
565
+ src,
566
+ str_x,
567
+ str_y,
568
+ lim_x,
569
+ lim_y,
570
+ off_x + (i * kFragRows),
571
+ off_y + (j * kFragCols));
572
+ }
573
+ }
574
+ }
575
+
576
+ template <
577
+ typename DstPtrType,
578
+ typename StrX,
579
+ typename StrY,
580
+ typename LimX,
581
+ typename OffX = Int<0>,
582
+ typename OffY = Int<0>>
583
+ METAL_FUNC constexpr void store_rows(
584
+ DstPtrType dst,
585
+ StrX str_x,
586
+ StrY str_y,
587
+ LimX lim_x,
588
+ OffX off_x = {},
589
+ OffY off_y = {}) const {
590
+ STEEL_PRAGMA_UNROLL
591
+ for (int i = 0; i < kSubTileRows; ++i) {
592
+ STEEL_PRAGMA_UNROLL
593
+ for (int j = 0; j < kSubTileCols; ++j) {
594
+ NAXFrag_t::store_safe(
595
+ frag_at(i, j),
596
+ dst,
597
+ str_x,
598
+ str_y,
599
+ lim_x,
600
+ off_x + (i * kFragRows),
601
+ off_y + (j * kFragCols));
602
+ }
603
+ }
604
+ }
605
+
606
+ template <
607
+ typename DstPtrType,
608
+ typename StrX,
609
+ typename StrY,
610
+ typename LimX,
611
+ typename LimY,
612
+ typename OffX = Int<0>,
613
+ typename OffY = Int<0>>
614
+ METAL_FUNC constexpr void store_safe(
615
+ DstPtrType dst,
616
+ StrX str_x,
617
+ StrY str_y,
618
+ LimX lim_x,
619
+ LimY lim_y,
620
+ OffX off_x = {},
621
+ OffY off_y = {}) const {
622
+ STEEL_PRAGMA_UNROLL
623
+ for (int i = 0; i < kSubTileRows; ++i) {
624
+ STEEL_PRAGMA_UNROLL
625
+ for (int j = 0; j < kSubTileCols; ++j) {
626
+ NAXFrag_t::store_safe(
627
+ frag_at(i, j),
628
+ dst,
629
+ str_x,
630
+ str_y,
631
+ lim_x,
632
+ lim_y,
633
+ off_x + (i * kFragRows),
634
+ off_y + (j * kFragCols));
635
+ }
636
+ }
637
+ }
638
+
639
+ template <
640
+ typename DstPtrType,
641
+ typename StrX,
642
+ typename StrY,
643
+ typename StartX,
644
+ typename StopX,
645
+ typename StartY,
646
+ typename StopY,
647
+ typename OffX = Int<0>,
648
+ typename OffY = Int<0>>
649
+ METAL_FUNC constexpr void store_slice(
650
+ DstPtrType dst,
651
+ StrX str_x,
652
+ StrY str_y,
653
+ StartX start_x,
654
+ StopX stop_x,
655
+ StartY start_y,
656
+ StopY stop_y,
657
+ OffX off_x = Int<0>{},
658
+ OffY off_y = Int<0>{}) const {
659
+ const_for_loop<0, kSubTileRows, 1>([&](auto idx_row) {
660
+ const_for_loop<0, kSubTileCols, 1>([&](auto idx_col) {
661
+ NAXFrag_t::store_slice(
662
+ frag_at<idx_row.value, idx_col.value>(),
663
+ dst,
664
+ str_x,
665
+ str_y,
666
+ start_x,
667
+ stop_x,
668
+ start_y,
669
+ stop_y,
670
+ off_x + idx_row * Int<kFragRows>{},
671
+ off_y + idx_col * Int<kFragCols>{});
672
+ });
673
+ });
674
+ }
675
+ };
676
+
677
+ template <
678
+ short RC,
679
+ short CC,
680
+ short RA,
681
+ short CA,
682
+ short RB,
683
+ short CB,
684
+ typename CType,
685
+ typename AType,
686
+ typename BType,
687
+ bool transpose_a,
688
+ bool transpose_b,
689
+ typename NAXFrag_t = BaseNAXFrag>
690
+ METAL_FUNC void subtile_matmad_nax(
691
+ thread NAXSubTile<CType, RC, CC, NAXFrag_t>& C,
692
+ thread NAXSubTile<AType, RA, CA, NAXFrag_t>& A,
693
+ metal::bool_constant<transpose_a>,
694
+ thread NAXSubTile<BType, RB, CB, NAXFrag_t>& B,
695
+ metal::bool_constant<transpose_b>) {
696
+ // Static checks
697
+ constexpr short FMa = transpose_a ? CA : RA;
698
+ constexpr short FMc = RC;
699
+ static_assert(FMa == FMc, "NAX matmul: M dimensions do not match");
700
+
701
+ constexpr short FNb = transpose_b ? RB : CB;
702
+ constexpr short FNc = CC;
703
+ static_assert(FNb == FNc, "NAX matmul: N dimensions do not match");
704
+
705
+ constexpr short FKa = transpose_a ? RA : CA;
706
+ constexpr short FKb = transpose_b ? CB : RB;
707
+ static_assert(FKa == FKb, "NAX matmul: N dimensions do not match");
708
+
709
+ constexpr short FM = FMc;
710
+ constexpr short FN = FNc;
711
+ constexpr short FK = FKa;
712
+
713
+ constexpr int TM = FM / 16;
714
+ constexpr int TN = FN / 16;
715
+ constexpr int TK = FK / 16;
716
+
717
+ constexpr auto desc = mpp::tensor_ops::matmul2d_descriptor(
718
+ FM,
719
+ FN,
720
+ FK,
721
+ transpose_a,
722
+ transpose_b,
723
+ true,
724
+ mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate);
725
+
726
+ mpp::tensor_ops::matmul2d<desc, metal::execution_simdgroup> gemm_op;
727
+
728
+ auto ct_a =
729
+ gemm_op.template get_left_input_cooperative_tensor<AType, BType, CType>();
730
+ auto ct_b =
731
+ gemm_op
732
+ .template get_right_input_cooperative_tensor<AType, BType, CType>();
733
+ auto ct_c = gemm_op.template get_destination_cooperative_tensor<
734
+ decltype(ct_a),
735
+ decltype(ct_b),
736
+ CType>();
737
+
738
+ STEEL_PRAGMA_UNROLL
739
+ for (short mm = 0; mm < TM; mm++) {
740
+ STEEL_PRAGMA_UNROLL
741
+ for (short kk = 0; kk < TK; kk++) {
742
+ const short fi = transpose_a ? kk : mm;
743
+ const short fj = transpose_a ? mm : kk;
744
+
745
+ STEEL_PRAGMA_UNROLL
746
+ for (short i = 0; i < 8; i++) {
747
+ ct_a[(TK * mm + kk) * 8 + i] = A.frag_at(fi, fj)[i];
748
+ }
749
+ }
750
+ }
751
+
752
+ STEEL_PRAGMA_UNROLL
753
+ for (short nn = 0; nn < TN; nn++) {
754
+ STEEL_PRAGMA_UNROLL
755
+ for (short kk = 0; kk < TK; kk++) {
756
+ const short fi = transpose_b ? nn : kk;
757
+ const short fj = transpose_b ? kk : nn;
758
+
759
+ STEEL_PRAGMA_UNROLL
760
+ for (short i = 0; i < 8; i++) {
761
+ ct_b[(TN * kk + nn) * 8 + i] = B.frag_at(fi, fj)[i];
762
+ }
763
+ }
764
+ }
765
+
766
+ STEEL_PRAGMA_UNROLL
767
+ for (short i = 0; i < ct_c.get_capacity(); i++) {
768
+ ct_c[i] = C.elems()[i];
769
+ }
770
+
771
+ gemm_op.run(ct_a, ct_b, ct_c);
772
+
773
+ STEEL_PRAGMA_UNROLL
774
+ for (short i = 0; i < ct_c.get_capacity(); i++) {
775
+ C.elems()[i] = ct_c[i];
776
+ }
777
+ }
778
+
779
+ template <typename T, short kTileRows_, short kTileCols_, class NAXSubTile_>
780
+ struct NAXTile {
781
+ using NAXSubTile_t = NAXSubTile_;
782
+ using elem_type = T;
783
+ STEEL_CONST short kSubTileRows = NAXSubTile_t::kRows;
784
+ STEEL_CONST short kSubTileCols = NAXSubTile_t::kCols;
785
+ STEEL_CONST short kElemsPerSubTile = NAXSubTile_t::kElemsPerSubTile;
786
+
787
+ STEEL_CONST short kTileRows = kTileRows_;
788
+ STEEL_CONST short kTileCols = kTileCols_;
789
+
790
+ STEEL_CONST short kRows = kTileRows * kSubTileRows;
791
+ STEEL_CONST short kCols = kTileCols * kSubTileCols;
792
+
793
+ STEEL_CONST short kSubTiles = kTileRows * kTileCols;
794
+ STEEL_CONST short kElemsPerTile = kSubTiles * kElemsPerSubTile;
795
+
796
+ STEEL_CONST short kRowsPerThread = kTileRows * NAXSubTile_t::kRowsPerThread;
797
+ STEEL_CONST short kColsPerThread = kTileCols * NAXSubTile_t::kColsPerThread;
798
+
799
+ STEEL_CONST short kSubTileThrRows = NAXSubTile_t::kRowsPerThread;
800
+ STEEL_CONST short kSubTileThrCols = NAXSubTile_t::kColsPerThread;
801
+
802
+ NAXSubTile_t val_subtiles[kSubTiles];
803
+
804
+ METAL_FUNC NAXTile() thread {}
805
+
806
+ METAL_FUNC constexpr void clear() {
807
+ STEEL_PRAGMA_UNROLL
808
+ for (short i = 0; i < kSubTiles; ++i) {
809
+ val_subtiles[i].clear();
810
+ }
811
+ }
812
+
813
+ METAL_FUNC constexpr thread NAXSubTile_t& subtile_at(
814
+ const short i,
815
+ const short j) {
816
+ return val_subtiles[i * kTileCols + j];
817
+ }
818
+
819
+ METAL_FUNC constexpr const thread NAXSubTile_t& subtile_at(
820
+ const short i,
821
+ const short j) const {
822
+ return val_subtiles[i * kTileCols + j];
823
+ }
824
+
825
+ template <int i, int j>
826
+ METAL_FUNC constexpr const thread NAXSubTile_t& subtile_at() const {
827
+ return val_subtiles[i * kTileCols + j];
828
+ }
829
+
830
+ METAL_FUNC thread elem_type* elems() {
831
+ return reinterpret_cast<thread elem_type*>(val_subtiles[0].elems());
832
+ }
833
+
834
+ METAL_FUNC const thread elem_type* elems() const {
835
+ return reinterpret_cast<const thread elem_type*>(val_subtiles[0].elems());
836
+ }
837
+
838
+ template <typename Op>
839
+ METAL_FUNC void row_reduce(thread metal::vec<T, kRowsPerThread>& vals) const {
840
+ auto sub_rows = (thread metal::vec<T, kSubTileThrRows>*)(&vals);
841
+ STEEL_PRAGMA_UNROLL
842
+ for (short i = 0; i < kTileRows; ++i) {
843
+ STEEL_PRAGMA_UNROLL
844
+ for (short j = 0; j < kTileCols; ++j) {
845
+ subtile_at(i, j).template row_reduce<Op>(sub_rows[i]);
846
+ }
847
+ }
848
+ }
849
+
850
+ template <typename Op>
851
+ METAL_FUNC void row_bin_op(thread metal::vec<T, kRowsPerThread>& vals) {
852
+ auto sub_rows = (thread metal::vec<T, kSubTileThrRows>*)(&vals);
853
+ STEEL_PRAGMA_UNROLL
854
+ for (short i = 0; i < kTileRows; ++i) {
855
+ STEEL_PRAGMA_UNROLL
856
+ for (short j = 0; j < kTileCols; ++j) {
857
+ subtile_at(i, j).template row_bin_op<Op>(sub_rows[i]);
858
+ }
859
+ }
860
+ }
861
+
862
+ template <typename U, int str_x, int str_y>
863
+ METAL_FUNC void load(const threadgroup U* src) {
864
+ STEEL_PRAGMA_UNROLL
865
+ for (short i = 0; i < kTileRows; ++i) {
866
+ STEEL_PRAGMA_UNROLL
867
+ for (short j = 0; j < kTileCols; ++j) {
868
+ subtile_at(i, j).load(
869
+ src,
870
+ Int<str_x>{},
871
+ Int<str_y>{},
872
+ i * kSubTileRows,
873
+ j * kSubTileCols);
874
+ }
875
+ }
876
+ }
877
+
878
+ template <typename U, int str_x, int str_y>
879
+ METAL_FUNC void store(threadgroup U* dst) const {
880
+ STEEL_PRAGMA_UNROLL
881
+ for (short i = 0; i < kTileRows; ++i) {
882
+ STEEL_PRAGMA_UNROLL
883
+ for (short j = 0; j < kTileCols; ++j) {
884
+ subtile_at(i, j).store(
885
+ dst,
886
+ Int<str_x>{},
887
+ Int<str_y>{},
888
+ i * kSubTileRows,
889
+ j * kSubTileCols);
890
+ }
891
+ }
892
+ }
893
+
894
+ template <typename U>
895
+ METAL_FUNC void load(const device U* src, const int ld) {
896
+ STEEL_PRAGMA_UNROLL
897
+ for (short i = 0; i < kTileRows; ++i) {
898
+ STEEL_PRAGMA_UNROLL
899
+ for (short j = 0; j < kTileCols; ++j) {
900
+ subtile_at(i, j).load(
901
+ &src[(i * kSubTileRows) * ld + (j * kSubTileCols)], ld, Int<1>{});
902
+ }
903
+ }
904
+ }
905
+
906
+ template <typename U>
907
+ METAL_FUNC void store(device U* dst, const int ld) const {
908
+ STEEL_PRAGMA_UNROLL
909
+ for (short i = 0; i < kTileRows; ++i) {
910
+ STEEL_PRAGMA_UNROLL
911
+ for (short j = 0; j < kTileCols; ++j) {
912
+ subtile_at(i, j).store(
913
+ &dst[(i * kSubTileRows) * ld + (j * kSubTileCols)], ld, Int<1>{});
914
+ }
915
+ }
916
+ }
917
+
918
+ template <typename U>
919
+ METAL_FUNC void
920
+ load_safe(const device U* src, const int ld, const short2 src_tile_dims) {
921
+ STEEL_PRAGMA_UNROLL
922
+ for (int i = 0; i < kTileRows; ++i) {
923
+ STEEL_PRAGMA_UNROLL
924
+ for (int j = 0; j < kTileCols; ++j) {
925
+ subtile_at(i, j).load_safe(
926
+ src,
927
+ ld,
928
+ Int<1>{},
929
+ src_tile_dims.y,
930
+ src_tile_dims.x,
931
+ i * kSubTileRows,
932
+ j * kSubTileCols);
933
+ }
934
+ }
935
+ }
936
+
937
+ template <typename U>
938
+ METAL_FUNC void
939
+ load_rows(const device U* src, const int ld, const short n_rows) {
940
+ STEEL_PRAGMA_UNROLL
941
+ for (int i = 0; i < kTileRows; ++i) {
942
+ STEEL_PRAGMA_UNROLL
943
+ for (int j = 0; j < kTileCols; ++j) {
944
+ subtile_at(i, j).load_rows(
945
+ &src[(i * kSubTileRows) * ld + (j * kSubTileCols)],
946
+ ld,
947
+ Int<1>{},
948
+ n_rows - i * kSubTileRows);
949
+ }
950
+ }
951
+ }
952
+
953
+ template <typename U>
954
+ METAL_FUNC void
955
+ store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const {
956
+ STEEL_PRAGMA_UNROLL
957
+ for (int i = 0; i < kTileRows; ++i) {
958
+ STEEL_PRAGMA_UNROLL
959
+ for (int j = 0; j < kTileCols; ++j) {
960
+ subtile_at(i, j).store_safe(
961
+ dst,
962
+ ld,
963
+ Int<1>{},
964
+ dst_tile_dims.y,
965
+ dst_tile_dims.x,
966
+ i * kSubTileRows,
967
+ j * kSubTileCols);
968
+ }
969
+ }
970
+ }
971
+
972
+ template <typename U>
973
+ METAL_FUNC void store_rows(device U* dst, const int ld, const short n_rows)
974
+ const {
975
+ STEEL_PRAGMA_UNROLL
976
+ for (int i = 0; i < kTileRows; ++i) {
977
+ STEEL_PRAGMA_UNROLL
978
+ for (int j = 0; j < kTileCols; ++j) {
979
+ subtile_at(i, j).store_rows(
980
+ &dst[(i * kSubTileRows) * ld + (j * kSubTileCols)],
981
+ ld,
982
+ Int<1>{},
983
+ n_rows - i * kSubTileRows);
984
+ }
985
+ }
986
+ }
987
+
988
+ template <typename U>
989
+ METAL_FUNC void store_slice(
990
+ device U* dst,
991
+ const int ld,
992
+ const short2 start,
993
+ const short2 stop) const {
994
+ const_for_loop<0, kTileRows, 1>([&](auto idx_row) {
995
+ const_for_loop<0, kTileCols, 1>([&](auto idx_col) {
996
+ subtile_at<idx_row.value, idx_col.value>().store_slice(
997
+ dst,
998
+ ld,
999
+ Int<1>{},
1000
+ start.y,
1001
+ stop.y,
1002
+ start.x,
1003
+ stop.x,
1004
+ idx_row * Int<kSubTileRows>{},
1005
+ idx_col * Int<kSubTileCols>{});
1006
+ });
1007
+ });
1008
+ }
1009
+ };
1010
+
1011
+ template <
1012
+ class CTile,
1013
+ class ATile,
1014
+ class BTile,
1015
+ bool transpose_a,
1016
+ bool transpose_b>
1017
+ METAL_FUNC void tile_matmad_nax(
1018
+ thread CTile& C,
1019
+ thread ATile& A,
1020
+ metal::bool_constant<transpose_a>,
1021
+ thread BTile& B,
1022
+ metal::bool_constant<transpose_b>) {
1023
+ // Static checks
1024
+ constexpr short TMa = transpose_a ? ATile::kTileCols : ATile::kTileRows;
1025
+ constexpr short TMc = CTile::kTileRows;
1026
+ static_assert(TMa == TMc, "NAX tile matmul: M dimensions do not match");
1027
+
1028
+ constexpr short FMa = transpose_a ? ATile::kSubTileCols : ATile::kSubTileRows;
1029
+ constexpr short FMc = CTile::kSubTileRows;
1030
+ static_assert(FMa == FMc, "NAX subtile matmul: M dimensions do not match");
1031
+
1032
+ constexpr short TNb = transpose_b ? BTile::kTileRows : BTile::kTileCols;
1033
+ constexpr short TNc = CTile::kTileCols;
1034
+ static_assert(TNb == TNc, "NAX tile matmul: N dimensions do not match");
1035
+
1036
+ constexpr short FNb = transpose_b ? BTile::kSubTileRows : BTile::kSubTileCols;
1037
+ constexpr short FNc = CTile::kSubTileCols;
1038
+ static_assert(FNb == FNc, "NAX subtile matmul: N dimensions do not match");
1039
+
1040
+ constexpr short TKa = transpose_a ? ATile::kTileRows : ATile::kTileCols;
1041
+ constexpr short TKb = transpose_b ? BTile::kTileCols : BTile::kTileRows;
1042
+ static_assert(TKa == TKb, "NAX tile matmul: K dimensions do not match");
1043
+
1044
+ constexpr short FKa = transpose_a ? ATile::kSubTileRows : ATile::kSubTileCols;
1045
+ constexpr short FKb = transpose_b ? BTile::kSubTileCols : BTile::kSubTileRows;
1046
+ static_assert(FKa == FKb, "NAX subtile matmul: K dimensions do not match");
1047
+
1048
+ constexpr short TM = TMc;
1049
+ constexpr short TN = TNc;
1050
+ constexpr short TK = TKa;
1051
+
1052
+ // Do matmul here
1053
+ STEEL_PRAGMA_UNROLL
1054
+ for (short i = 0; i < TM; ++i) {
1055
+ STEEL_PRAGMA_UNROLL
1056
+ for (short j = 0; j < TN; ++j) {
1057
+ STEEL_PRAGMA_UNROLL
1058
+ for (short k = 0; k < TK; ++k) {
1059
+ const short ra = transpose_a ? k : i;
1060
+ const short ca = transpose_a ? i : k;
1061
+ const short rb = transpose_b ? j : k;
1062
+ const short cb = transpose_b ? k : j;
1063
+
1064
+ subtile_matmad_nax(
1065
+ C.subtile_at(i, j),
1066
+ A.subtile_at(ra, ca),
1067
+ metal::bool_constant<transpose_a>{},
1068
+ B.subtile_at(rb, cb),
1069
+ metal::bool_constant<transpose_b>{});
1070
+ }
1071
+ }
1072
+ }
1073
+ }
1074
+
1075
+ } // namespace steel
1076
+ } // namespace mlx