mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,1146 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <metal_simdgroup>
6
+ #include <metal_simdgroup_matrix>
7
+ #include <metal_stdlib>
8
+
9
+ #include "mlx/backend/metal/kernels/steel/defines.h"
10
+ #include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
11
+ #include "mlx/backend/metal/kernels/steel/utils/integral_constant.h"
12
+
13
+ using namespace metal;
14
+
15
+ ///////////////////////////////////////////////////////////////////////////////
16
+ // MMA helper
17
+ ///////////////////////////////////////////////////////////////////////////////
18
+
19
+ namespace mlx {
20
+ namespace steel {
21
+
22
+ template <typename T, int kFragRows_, int kFragCols_>
23
+ struct BaseMMAFrag {
24
+ static_assert(
25
+ kFragRows_ == 8,
26
+ "Only 8 x 8 fragment matrices are currently supported");
27
+ static_assert(
28
+ kFragCols_ == 8,
29
+ "Only 8 x 8 fragment matrices are currently supported");
30
+ };
31
+
32
+ template <typename T>
33
+ struct BaseMMAFrag<T, 8, 8> {
34
+ STEEL_CONST int kFragRows = 8;
35
+ STEEL_CONST int kFragCols = 8;
36
+
37
+ STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32;
38
+
39
+ STEEL_CONST int kElemRows = 1;
40
+ STEEL_CONST int kElemCols = 2;
41
+
42
+ static_assert(
43
+ kElemRows * kElemCols == kElemsPerFrag,
44
+ "MMAFrag shape is not consistent with MMAFrag size");
45
+
46
+ typedef metal::simdgroup_matrix<T, kFragRows, kFragCols> mat_type;
47
+ typedef metal::vec<T, kElemsPerFrag> frag_type;
48
+
49
+ METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id
50
+ [[thread_index_in_simdgroup]]) {
51
+ const short qid = simd_lane_id / 4;
52
+ const short fm = (qid & 4) + ((simd_lane_id / 2) % 4);
53
+ const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
54
+ return short2{fn, fm};
55
+ }
56
+
57
+ template <typename SrcPtrType, typename StrX, typename StrY>
58
+ METAL_FUNC static constexpr void
59
+ load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) {
60
+ STEEL_PRAGMA_UNROLL
61
+ for (short i = 0; i < kElemRows; i++) {
62
+ STEEL_PRAGMA_UNROLL
63
+ for (short j = 0; j < kElemCols; j++) {
64
+ dst[i * kElemCols + j] = static_cast<T>(src[i * str_x + j * str_y]);
65
+ }
66
+ }
67
+ }
68
+
69
+ template <
70
+ typename SrcPtrType,
71
+ typename StrX,
72
+ typename StrY,
73
+ typename LimX,
74
+ typename LimY,
75
+ typename OffX,
76
+ typename OffY>
77
+ METAL_FUNC static constexpr void load_safe(
78
+ thread frag_type& dst,
79
+ SrcPtrType src,
80
+ StrX str_x,
81
+ StrY str_y,
82
+ LimX lim_x,
83
+ LimY lim_y,
84
+ OffX off_x = Int<0>{},
85
+ OffY off_y = Int<0>{}) {
86
+ STEEL_PRAGMA_UNROLL
87
+ for (short i = 0; i < kElemRows; i++) {
88
+ STEEL_PRAGMA_UNROLL
89
+ for (short j = 0; j < kElemCols; j++) {
90
+ if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
91
+ dst[i * kElemCols + j] =
92
+ static_cast<T>(src[(off_x + i) * str_x + (off_x + j) * str_y]);
93
+ } else {
94
+ dst[i * kElemCols + j] = T(0);
95
+ }
96
+ }
97
+ }
98
+ }
99
+
100
+ template <typename DstPtrType, typename StrX, typename StrY>
101
+ METAL_FUNC static constexpr void
102
+ store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) {
103
+ using U = pointer_element_t<DstPtrType>;
104
+
105
+ STEEL_PRAGMA_UNROLL
106
+ for (short i = 0; i < kElemRows; i++) {
107
+ STEEL_PRAGMA_UNROLL
108
+ for (short j = 0; j < kElemCols; j++) {
109
+ dst[i * str_x + j * str_y] = static_cast<U>(src[i * kElemCols + j]);
110
+ }
111
+ }
112
+ }
113
+
114
+ template <
115
+ typename DstPtrType,
116
+ typename StrX,
117
+ typename StrY,
118
+ typename LimX,
119
+ typename LimY,
120
+ typename OffX,
121
+ typename OffY>
122
+ METAL_FUNC static constexpr void store_safe(
123
+ const thread frag_type& src,
124
+ DstPtrType dst,
125
+ StrX str_x,
126
+ StrY str_y,
127
+ LimX lim_x,
128
+ LimY lim_y,
129
+ OffX off_x = Int<0>{},
130
+ OffY off_y = Int<0>{}) {
131
+ using U = pointer_element_t<DstPtrType>;
132
+
133
+ STEEL_PRAGMA_UNROLL
134
+ for (short i = 0; i < kElemRows; i++) {
135
+ STEEL_PRAGMA_UNROLL
136
+ for (short j = 0; j < kElemCols; j++) {
137
+ if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
138
+ dst[(off_x + i) * str_x + (off_y + j) * str_y] =
139
+ static_cast<U>(src[i * kElemCols + j]);
140
+ }
141
+ }
142
+ }
143
+ }
144
+
145
+ template <
146
+ typename DstPtrType,
147
+ typename StrX,
148
+ typename StrY,
149
+ typename StartX,
150
+ typename StopX,
151
+ typename StartY,
152
+ typename StopY,
153
+ typename OffX,
154
+ typename OffY>
155
+ METAL_FUNC static constexpr void store_slice(
156
+ const thread frag_type& src,
157
+ DstPtrType dst,
158
+ StrX str_x,
159
+ StrY str_y,
160
+ StartX start_x,
161
+ StopX stop_x,
162
+ StartY start_y,
163
+ StopY stop_y,
164
+ OffX off_x = Int<0>{},
165
+ OffY off_y = Int<0>{}) {
166
+ using U = pointer_element_t<DstPtrType>;
167
+
168
+ STEEL_PRAGMA_UNROLL
169
+ for (short i = 0; i < kElemRows; i++) {
170
+ STEEL_PRAGMA_UNROLL
171
+ for (short j = 0; j < kElemCols; j++) {
172
+ if ((off_x + i) < stop_x && (off_x + i) >= start_x &&
173
+ (off_y + j) < stop_y && (off_y + j) >= start_y) {
174
+ dst[(off_x + i) * str_x + (off_y + j) * str_y] =
175
+ static_cast<U>(src[i * kElemCols + j]);
176
+ }
177
+ }
178
+ }
179
+ }
180
+
181
+ METAL_FUNC static constexpr void mma(
182
+ thread frag_type& D,
183
+ thread frag_type& A,
184
+ thread frag_type& B,
185
+ thread frag_type& C) {
186
+ mat_type D_mat;
187
+ mat_type A_mat;
188
+ mat_type B_mat;
189
+ mat_type C_mat;
190
+
191
+ reinterpret_cast<thread frag_type&>(A_mat.thread_elements()) = A;
192
+ reinterpret_cast<thread frag_type&>(B_mat.thread_elements()) = B;
193
+ reinterpret_cast<thread frag_type&>(C_mat.thread_elements()) = C;
194
+
195
+ mma(D_mat, A_mat, B_mat, C_mat);
196
+
197
+ D = reinterpret_cast<thread frag_type&>(D_mat.thread_elements());
198
+ }
199
+
200
+ METAL_FUNC static constexpr void mma(
201
+ thread mat_type& D,
202
+ thread mat_type& A,
203
+ thread mat_type& B,
204
+ thread mat_type& C) {
205
+ simdgroup_multiply_accumulate(D, A, B, C);
206
+ }
207
+ };
208
+
209
+ template <
210
+ typename T,
211
+ int kTileRows_,
212
+ int kTileCols_,
213
+ class MMAFrag_ = BaseMMAFrag<T, 8, 8>>
214
+ struct MMATile {
215
+ using MMAFrag_t = MMAFrag_;
216
+ using elem_type = T;
217
+ STEEL_CONST int kFragRows = MMAFrag_t::kFragRows;
218
+ STEEL_CONST int kFragCols = MMAFrag_t::kFragCols;
219
+ STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag;
220
+
221
+ STEEL_CONST int kTileRows = kTileRows_;
222
+ STEEL_CONST int kTileCols = kTileCols_;
223
+
224
+ STEEL_CONST int kRows = kTileRows * kFragRows;
225
+ STEEL_CONST int kCols = kTileCols * kFragCols;
226
+
227
+ STEEL_CONST int kNumFrags = kTileRows * kTileCols;
228
+ STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag;
229
+
230
+ typedef typename MMAFrag_t::mat_type mat_type;
231
+ typedef typename MMAFrag_t::frag_type frag_type;
232
+
233
+ frag_type val_frags[kNumFrags] = {frag_type(0)};
234
+
235
+ METAL_FUNC MMATile() thread {}
236
+
237
+ METAL_FUNC constexpr void clear() {
238
+ STEEL_PRAGMA_UNROLL
239
+ for (short i = 0; i < kNumFrags; ++i) {
240
+ val_frags[i] = frag_type(0);
241
+ }
242
+ }
243
+
244
+ METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) {
245
+ return val_frags[i * kTileCols + j];
246
+ }
247
+
248
+ METAL_FUNC constexpr const thread frag_type& frag_at(
249
+ const short i,
250
+ const short j) const {
251
+ return val_frags[i * kTileCols + j];
252
+ }
253
+
254
+ METAL_FUNC mat_type mat_at(const short i, const short j) {
255
+ mat_type val_mat;
256
+ STEEL_PRAGMA_UNROLL
257
+ for (short ii = 0; ii < kElemsPerFrag; ++ii) {
258
+ val_mat.thread_elements()[ii] = frag_at(i, j)[ii];
259
+ }
260
+ return val_mat;
261
+ }
262
+
263
+ METAL_FUNC thread elem_type* elems() {
264
+ return reinterpret_cast<thread elem_type*>(val_frags);
265
+ }
266
+
267
+ METAL_FUNC const thread elem_type* elems() const {
268
+ return reinterpret_cast<const thread elem_type*>(val_frags);
269
+ }
270
+
271
+ template <typename U, int w_x, int w_y, int str_x, int str_y>
272
+ METAL_FUNC void load(const threadgroup U* src) {
273
+ STEEL_PRAGMA_UNROLL
274
+ for (short i = 0; i < kTileRows; ++i) {
275
+ STEEL_PRAGMA_UNROLL
276
+ for (short j = 0; j < kTileCols; ++j) {
277
+ MMAFrag_t::load(
278
+ frag_at(i, j),
279
+ &(
280
+ src[(i * kFragRows) * w_x * str_x +
281
+ (j * kFragCols) * w_y * str_y]),
282
+ Int<str_x>{},
283
+ Int<str_y>{});
284
+ }
285
+ }
286
+ }
287
+
288
+ template <typename U, int w_x, int w_y, int str_x, int str_y>
289
+ METAL_FUNC void store(threadgroup U* dst) const {
290
+ STEEL_PRAGMA_UNROLL
291
+ for (short i = 0; i < kTileRows; ++i) {
292
+ STEEL_PRAGMA_UNROLL
293
+ for (short j = 0; j < kTileCols; ++j) {
294
+ MMAFrag_t::store(
295
+ frag_at(i, j),
296
+ &(
297
+ dst[(i * kFragRows) * w_x * str_x +
298
+ (j * kFragCols) * w_y * str_y]),
299
+ Int<str_x>{},
300
+ Int<str_y>{});
301
+ }
302
+ }
303
+ }
304
+
305
+ template <typename U, int w_x, int w_y>
306
+ METAL_FUNC void load(const device U* src, const int ld) {
307
+ STEEL_PRAGMA_UNROLL
308
+ for (short i = 0; i < kTileRows; ++i) {
309
+ STEEL_PRAGMA_UNROLL
310
+ for (short j = 0; j < kTileCols; ++j) {
311
+ MMAFrag_t::load(
312
+ frag_at(i, j),
313
+ &(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
314
+ ld,
315
+ Int<1>{});
316
+ }
317
+ }
318
+ }
319
+
320
+ template <typename U, int w_x, int w_y>
321
+ METAL_FUNC void store(device U* dst, const int ld) const {
322
+ STEEL_PRAGMA_UNROLL
323
+ for (short i = 0; i < kTileRows; ++i) {
324
+ STEEL_PRAGMA_UNROLL
325
+ for (short j = 0; j < kTileCols; ++j) {
326
+ MMAFrag_t::store(
327
+ frag_at(i, j),
328
+ &(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
329
+ ld,
330
+ Int<1>{});
331
+ }
332
+ }
333
+ }
334
+
335
+ template <typename U, int w_x, int w_y>
336
+ METAL_FUNC void
337
+ load_safe(const device U* src, const int ld, const short2 src_tile_dims) {
338
+ STEEL_PRAGMA_UNROLL
339
+ for (int i = 0; i < kTileRows; ++i) {
340
+ STEEL_PRAGMA_UNROLL
341
+ for (int j = 0; j < kTileCols; ++j) {
342
+ MMAFrag_t::load_safe(
343
+ frag_at(i, j),
344
+ src,
345
+ ld,
346
+ Int<1>{},
347
+ src_tile_dims.y,
348
+ src_tile_dims.x,
349
+ (i * kFragRows) * w_x,
350
+ (j * kFragCols) * w_y);
351
+ }
352
+ }
353
+ }
354
+
355
+ template <typename U, int w_x, int w_y>
356
+ METAL_FUNC void
357
+ store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const {
358
+ STEEL_PRAGMA_UNROLL
359
+ for (int i = 0; i < kTileRows; ++i) {
360
+ STEEL_PRAGMA_UNROLL
361
+ for (int j = 0; j < kTileCols; ++j) {
362
+ MMAFrag_t::store_safe(
363
+ frag_at(i, j),
364
+ dst,
365
+ ld,
366
+ Int<1>{},
367
+ dst_tile_dims.y,
368
+ dst_tile_dims.x,
369
+ (i * kFragRows) * w_x,
370
+ (j * kFragCols) * w_y);
371
+ }
372
+ }
373
+ }
374
+
375
+ template <typename U, int w_x, int w_y>
376
+ METAL_FUNC void store_slice(
377
+ device U* dst,
378
+ const int ld,
379
+ const short2 start,
380
+ const short2 stop) const {
381
+ STEEL_PRAGMA_UNROLL
382
+ for (int i = 0; i < kTileRows; ++i) {
383
+ STEEL_PRAGMA_UNROLL
384
+ for (int j = 0; j < kTileCols; ++j) {
385
+ MMAFrag_t::store_slice(
386
+ frag_at(i, j),
387
+ dst,
388
+ ld,
389
+ Int<1>{},
390
+ start.y,
391
+ stop.y,
392
+ start.x,
393
+ stop.x,
394
+ (i * kFragRows) * w_x,
395
+ (j * kFragCols) * w_y);
396
+ }
397
+ }
398
+ }
399
+ };
400
+
401
+ template <typename T, typename U, int M, int N, int K>
402
+ METAL_FUNC void tile_matmad(
403
+ thread MMATile<T, M, N>& D,
404
+ thread MMATile<U, M, K>& A,
405
+ thread MMATile<U, K, N>& B,
406
+ thread MMATile<T, M, N>& C) {
407
+ STEEL_PRAGMA_UNROLL
408
+ for (short m = 0; m < M; ++m) {
409
+ STEEL_PRAGMA_UNROLL
410
+ for (short n = 0; n < N; ++n) {
411
+ short n_serp = (m % 2) ? (N - 1 - n) : n;
412
+ STEEL_PRAGMA_UNROLL
413
+ for (short k = 0; k < K; ++k) {
414
+ MMATile<T, M, N>::MMAFrag_t::mma(
415
+ D.frag_at(m, n_serp),
416
+ A.frag_at(m, k),
417
+ B.frag_at(k, n_serp),
418
+ C.frag_at(m, n_serp));
419
+ }
420
+ }
421
+ }
422
+ }
423
+
424
+ template <typename InT>
425
+ struct TransformNone<complex64_t, InT> {
426
+ static METAL_FUNC complex64_t apply(complex64_t x) {
427
+ return x;
428
+ }
429
+ static METAL_FUNC complex64_t apply(complex64_t x, complex64_t) {
430
+ return x;
431
+ }
432
+ };
433
+
434
+ template <
435
+ typename T,
436
+ typename U,
437
+ int BM,
438
+ int BN,
439
+ int BK,
440
+ int WM,
441
+ int WN,
442
+ bool transpose_a,
443
+ bool transpose_b,
444
+ short lda_tgp,
445
+ short ldb_tgp,
446
+ typename AccumType = float,
447
+ typename Epilogue = TransformNone<U, AccumType>>
448
+ struct BlockMMA {
449
+ // MMAFrag size
450
+ STEEL_CONST short kFragSize = 8;
451
+ using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
452
+
453
+ // Warp tile simdgroup matrix strides along M
454
+ STEEL_CONST short TM_stride = kFragSize * WM;
455
+ // Warp tile simdgroup matrix strides along M
456
+ STEEL_CONST short TN_stride = kFragSize * WN;
457
+
458
+ // Warp tile size along M
459
+ STEEL_CONST short TM = BM / (kFragSize * WM);
460
+ // Warp tile size along N
461
+ STEEL_CONST short TN = BN / (kFragSize * WN);
462
+
463
+ // Threadgroup A strides
464
+ STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M
465
+ STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K
466
+
467
+ // Threadgroup B strides
468
+ STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K
469
+ STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N
470
+
471
+ // Threadgroup strides along K
472
+ STEEL_CONST short tile_stride_a = kFragSize * A_str_k;
473
+ STEEL_CONST short tile_stride_b = kFragSize * B_str_k;
474
+
475
+ // Simdgroup matrices
476
+ MMATile<AccumType, TM, 1, MMAFrag_acc_t> Atile;
477
+ MMATile<AccumType, 1, TN, MMAFrag_acc_t> Btile;
478
+ MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile;
479
+
480
+ // Offsets within threadgroup
481
+ short sm;
482
+ short sn;
483
+
484
+ short As_offset;
485
+ short Bs_offset;
486
+
487
+ /* Constructor */
488
+ METAL_FUNC BlockMMA(
489
+ ushort simd_group_id [[simdgroup_index_in_threadgroup]],
490
+ ushort simd_lane_id [[thread_index_in_simdgroup]]) {
491
+ // Determine thread position in simdgroup matrix
492
+ short tm = kFragSize * (simd_group_id / WN);
493
+ short tn = kFragSize * (simd_group_id % WN);
494
+
495
+ short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
496
+ sm = simd_coord.y;
497
+ sn = simd_coord.x;
498
+
499
+ // Determine thread and simdgroup offset
500
+ As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K
501
+ Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N
502
+
503
+ sm += tm;
504
+ sn += tn;
505
+ }
506
+
507
+ /* (BM, BK) X (BK, BN) multiply accumulate function */
508
+ METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
509
+ // Adjust for simdgroup and thread location
510
+ As += As_offset;
511
+ Bs += Bs_offset;
512
+
513
+ // Iterate over BK in blocks of kFragSize
514
+ STEEL_PRAGMA_UNROLL
515
+ for (short kk = 0; kk < BK; kk += kFragSize) {
516
+ simdgroup_barrier(mem_flags::mem_none);
517
+
518
+ Atile.template load<T, WM, 1, A_str_m, A_str_k>(As);
519
+
520
+ simdgroup_barrier(mem_flags::mem_none);
521
+
522
+ Btile.template load<T, 1, WN, B_str_k, B_str_n>(Bs);
523
+
524
+ simdgroup_barrier(mem_flags::mem_none);
525
+
526
+ tile_matmad(Ctile, Atile, Btile, Ctile);
527
+
528
+ // Progress to next simdgroup tile
529
+ As += tile_stride_a;
530
+ Bs += tile_stride_b;
531
+ }
532
+ }
533
+
534
+ /* Store results from simdgroup_matrix results into device memory */
535
+ METAL_FUNC void store_result(device U* D, const int ldd) {
536
+ // Apply epilogue
537
+ STEEL_PRAGMA_UNROLL
538
+ for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
539
+ Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
540
+ }
541
+
542
+ // Adjust for simdgroup and thread location
543
+ D += sm * ldd + sn;
544
+
545
+ Ctile.template store<U, WM, WN>(D, ldd);
546
+ }
547
+
548
+ METAL_FUNC void
549
+ store_result_slice(device U* D, const int ldd, short2 start, short2 stop) {
550
+ // Apply epilogue
551
+ STEEL_PRAGMA_UNROLL
552
+ for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
553
+ Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
554
+ }
555
+
556
+ D += sm * ldd + sn;
557
+ start -= short2(sn, sm);
558
+ stop -= short2(sn, sm);
559
+
560
+ // TODO: Check the start as well
561
+ if (stop.y <= 0 || stop.x <= 0) {
562
+ return;
563
+ }
564
+
565
+ Ctile.template store_slice<U, WM, WN>(D, ldd, start, stop);
566
+ }
567
+
568
+ METAL_FUNC void
569
+ store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
570
+ // Apply epilogue
571
+ STEEL_PRAGMA_UNROLL
572
+ for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
573
+ Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
574
+ }
575
+
576
+ // Adjust for simdgroup and thread location
577
+ D += sm * ldd + sn;
578
+ dst_tile_dims -= short2(sn, sm);
579
+
580
+ if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
581
+ return;
582
+
583
+ Ctile.template store_safe<U, WM, WN>(D, ldd, dst_tile_dims);
584
+ }
585
+
586
+ /* Apply epilogue */
587
+ template <typename UnaryEpilogue>
588
+ METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) {
589
+ // Loop over all simdgroup tiles
590
+ STEEL_PRAGMA_UNROLL
591
+ for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
592
+ Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]);
593
+ }
594
+ }
595
+
596
+ /* Apply epilogue */
597
+ template <typename BinaryEpilogue>
598
+ METAL_FUNC void apply_epilogue(
599
+ const device U* C,
600
+ const int ldc,
601
+ const int fdc,
602
+ thread const BinaryEpilogue& epilogue_op) {
603
+ // Adjust for simdgroup and thread location
604
+ C += (sm)*ldc + (sn)*fdc;
605
+
606
+ // Loop over all simdgroup tiles
607
+ STEEL_PRAGMA_UNROLL
608
+ for (short i = 0; i < TM; i++) {
609
+ STEEL_PRAGMA_UNROLL
610
+ for (short j = 0; j < TN; j++) {
611
+ // Get accumulated result and associated offset in C
612
+ thread auto& accum = Ctile.frag_at(i, j);
613
+ int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
614
+
615
+ // Apply epilogue
616
+ STEEL_PRAGMA_UNROLL
617
+ for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) {
618
+ accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
619
+ }
620
+ }
621
+ }
622
+ }
623
+
624
+ /* Apply epilogue */
625
+ template <typename BinaryEpilogue>
626
+ METAL_FUNC void apply_epilogue_safe(
627
+ const device U* C,
628
+ const int ldc,
629
+ const int fdc,
630
+ short2 dst_tile_dims,
631
+ thread const BinaryEpilogue& epilogue_op) {
632
+ // Adjust for simdgroup and thread location
633
+ C += (sm)*ldc + (sn)*fdc;
634
+ dst_tile_dims -= short2(sn, sm);
635
+
636
+ if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
637
+ return;
638
+
639
+ // Loop over all simdgroup tiles
640
+ STEEL_PRAGMA_UNROLL
641
+ for (short i = 0; i < TM; i++) {
642
+ STEEL_PRAGMA_UNROLL
643
+ for (short j = 0; j < TN; j++) {
644
+ // Get accumulated result and associated offset in C
645
+ thread auto& accum = Ctile.frag_at(i, j);
646
+ int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
647
+
648
+ constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
649
+
650
+ // Read C
651
+ U c_elems[kelems] = {0};
652
+
653
+ STEEL_PRAGMA_UNROLL
654
+ for (short k = 0; k < kelems; k++) {
655
+ if ((j * TN_stride + k) < dst_tile_dims.x) {
656
+ c_elems[k] = C[offset_c + k * fdc];
657
+ }
658
+ }
659
+
660
+ // Apply epilogue
661
+ STEEL_PRAGMA_UNROLL
662
+ for (short k = 0; k < kelems; k++) {
663
+ accum[k] = epilogue_op.apply(accum[k], c_elems[k]);
664
+ }
665
+ }
666
+ }
667
+ }
668
+
669
+ /* Store results from simdgroup_matrix results into device memory */
670
+ METAL_FUNC void store_result(
671
+ device U* D,
672
+ const int ldd,
673
+ const device U* C,
674
+ const int ldc,
675
+ const int fdc,
676
+ thread const Epilogue& epilogue_op) const {
677
+ // Adjust for simdgroup and thread location
678
+ C += (sm)*ldc + (sn)*fdc;
679
+ D += (sm)*ldd + sn;
680
+
681
+ constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
682
+
683
+ // Loop over all simdgroup tiles
684
+ STEEL_PRAGMA_UNROLL
685
+ for (short i = 0; i < TM; i++) {
686
+ STEEL_PRAGMA_UNROLL
687
+ for (short j = 0; j < TN; j++) {
688
+ // Get accumulated result and associated offset in C
689
+ thread const auto& accum = Ctile.frag_at(i, j);
690
+ int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
691
+ int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
692
+
693
+ // Apply epilogue
694
+ STEEL_PRAGMA_UNROLL
695
+ for (short k = 0; k < kelems; k++) {
696
+ D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
697
+ }
698
+ }
699
+ }
700
+ }
701
+
702
+ METAL_FUNC void store_result_safe(
703
+ device U* D,
704
+ const int ldd,
705
+ const device U* C,
706
+ const int ldc,
707
+ const int fdc,
708
+ short2 dst_tile_dims,
709
+ thread const Epilogue& epilogue_op) const {
710
+ // Adjust for simdgroup and thread location
711
+ C += (sm)*ldc + (sn)*fdc;
712
+ D += (sm)*ldd + sn;
713
+ dst_tile_dims -= short2(sn, sm);
714
+
715
+ if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
716
+ return;
717
+
718
+ constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
719
+
720
+ STEEL_PRAGMA_UNROLL
721
+ for (int i = 0; i < TM; i++) {
722
+ if (i * TM_stride < dst_tile_dims.y) {
723
+ STEEL_PRAGMA_UNROLL
724
+ for (int j = 0; j < TN; j++) {
725
+ // Get accumulated result and associated offset in C
726
+ thread const auto& accum = Ctile.frag_at(i, j);
727
+ int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
728
+ int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
729
+
730
+ // Apply epilogue
731
+ STEEL_PRAGMA_UNROLL
732
+ for (short k = 0; k < kelems; k++) {
733
+ if ((j * TN_stride + k) < dst_tile_dims.x) {
734
+ D[offset_d + k] =
735
+ epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
736
+ }
737
+ }
738
+ }
739
+ }
740
+ }
741
+ }
742
+ };
743
+
744
+ template <
745
+ typename U,
746
+ int BM,
747
+ int BN,
748
+ int BK,
749
+ int WM,
750
+ int WN,
751
+ bool transpose_a,
752
+ bool transpose_b,
753
+ short lda_tgp,
754
+ short ldb_tgp,
755
+ typename AccumType,
756
+ typename Epilogue>
757
+ struct BlockMMA<
758
+ complex64_t,
759
+ U,
760
+ BM,
761
+ BN,
762
+ BK,
763
+ WM,
764
+ WN,
765
+ transpose_a,
766
+ transpose_b,
767
+ lda_tgp,
768
+ ldb_tgp,
769
+ AccumType,
770
+ Epilogue> {
771
+ static_assert(
772
+ metal::is_same_v<AccumType, float>,
773
+ "BlockMMA<complex64_t,...> expects float accumulators");
774
+ static_assert(
775
+ metal::is_same_v<U, complex64_t>,
776
+ "For complex BlockMMA, U must be complex64_t; use a different epilogue for projections");
777
+ // MMAFrag size
778
+ STEEL_CONST short kFragSize = 8;
779
+ using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
780
+
781
+ // Warp tile simdgroup matrix strides along M
782
+ STEEL_CONST short TM_stride = kFragSize * WM;
783
+ // Warp tile simdgroup matrix strides along M
784
+ STEEL_CONST short TN_stride = kFragSize * WN;
785
+
786
+ // Warp tile size along M
787
+ STEEL_CONST short TM = BM / (kFragSize * WM);
788
+ // Warp tile size along N
789
+ STEEL_CONST short TN = BN / (kFragSize * WN);
790
+
791
+ // Threadgroup A strides
792
+ STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M
793
+ STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K
794
+
795
+ // Threadgroup B strides
796
+ STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K
797
+ STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N
798
+
799
+ // Threadgroup strides along K
800
+ STEEL_CONST short tile_stride_a = kFragSize * A_str_k;
801
+ STEEL_CONST short tile_stride_b = kFragSize * B_str_k;
802
+
803
+ // When indexing complex as float[2]
804
+ STEEL_CONST short A_str_m_f = A_str_m * 2;
805
+ STEEL_CONST short A_str_k_f = A_str_k * 2;
806
+ STEEL_CONST short B_str_k_f = B_str_k * 2;
807
+ STEEL_CONST short B_str_n_f = B_str_n * 2;
808
+ STEEL_CONST short tile_stride_a_f = tile_stride_a * 2;
809
+ STEEL_CONST short tile_stride_b_f = tile_stride_b * 2;
810
+
811
+ // Accumulators (real/imag)
812
+ MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile_r;
813
+ MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile_i;
814
+
815
+ // Offsets within threadgroup
816
+ short sm, sn;
817
+ short As_offset, Bs_offset;
818
+
819
+ /* Constructor */
820
+ METAL_FUNC BlockMMA(
821
+ ushort simd_group_id [[simdgroup_index_in_threadgroup]],
822
+ ushort simd_lane_id [[thread_index_in_simdgroup]]) {
823
+ // Determine thread position in simdgroup matrix
824
+ short tm = kFragSize * (simd_group_id / WN);
825
+ short tn = kFragSize * (simd_group_id % WN);
826
+
827
+ short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
828
+ sm = simd_coord.y;
829
+ sn = simd_coord.x;
830
+
831
+ // Determine thread and simdgroup offset
832
+ As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // (M,K)
833
+ Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // (K,N)
834
+
835
+ sm += tm;
836
+ sn += tn;
837
+ }
838
+
839
+ /* Karatsuba MMA: 3 real MMAs per K-chunk */
840
+ METAL_FUNC void mma(
841
+ const threadgroup complex64_t* As,
842
+ const threadgroup complex64_t* Bs) {
843
+ // Adjust for simdgroup and thread location
844
+ As += As_offset;
845
+ Bs += Bs_offset;
846
+ threadgroup const float* As_f =
847
+ reinterpret_cast<threadgroup const float*>(As);
848
+ threadgroup const float* Bs_f =
849
+ reinterpret_cast<threadgroup const float*>(Bs);
850
+
851
+ // Iterate over BK in blocks of kFragSize
852
+ STEEL_PRAGMA_UNROLL
853
+ for (short kk = 0; kk < BK; kk += kFragSize) {
854
+ simdgroup_barrier(mem_flags::mem_none);
855
+
856
+ MMATile<AccumType, TM, 1, MMAFrag_acc_t> Ar, Ai;
857
+ Ar.template load<float, WM, 1, A_str_m_f, A_str_k_f>(As_f + 0);
858
+ Ai.template load<float, WM, 1, A_str_m_f, A_str_k_f>(As_f + 1);
859
+
860
+ simdgroup_barrier(mem_flags::mem_none);
861
+
862
+ MMATile<AccumType, 1, TN, MMAFrag_acc_t> Br, Bi;
863
+ Br.template load<float, 1, WN, B_str_k_f, B_str_n_f>(Bs_f + 0);
864
+ Bi.template load<float, 1, WN, B_str_k_f, B_str_n_f>(Bs_f + 1);
865
+
866
+ simdgroup_barrier(mem_flags::mem_none);
867
+
868
+ // P = Ar*Br ; Q = Ai*Bi ; R = (Ar+Ai)*(Br+Bi)
869
+ MMATile<AccumType, TM, TN, MMAFrag_acc_t> P, Q, R;
870
+
871
+ tile_matmad(P, Ar, Br, P);
872
+ tile_matmad(Q, Ai, Bi, Q);
873
+
874
+ STEEL_PRAGMA_UNROLL
875
+ for (short i = 0; i < decltype(Ar)::kElemsPerTile; ++i)
876
+ Ar.elems()[i] += Ai.elems()[i];
877
+ STEEL_PRAGMA_UNROLL
878
+ for (short i = 0; i < decltype(Br)::kElemsPerTile; ++i)
879
+ Br.elems()[i] += Bi.elems()[i];
880
+
881
+ tile_matmad(R, Ar, Br, R);
882
+
883
+ // C_r += P - Q ; C_i -= Q
884
+ STEEL_PRAGMA_UNROLL
885
+ for (short i = 0; i < decltype(Ctile_r)::kElemsPerTile; ++i) {
886
+ const auto p = P.elems()[i];
887
+ const auto q = Q.elems()[i];
888
+ const auto r = R.elems()[i];
889
+ Ctile_r.elems()[i] += (p - q);
890
+ Ctile_i.elems()[i] += (r - p - q);
891
+ }
892
+
893
+ // Progress to next simdgroup tile
894
+ As_f += tile_stride_a_f;
895
+ Bs_f += tile_stride_b_f;
896
+ }
897
+ }
898
+
899
+ /* Store results from simdgroup_matrix results into device memory */
900
+ METAL_FUNC void store_result(device U* D, const int ldd) {
901
+ // Adjust for simdgroup and thread location
902
+ D += sm * ldd + sn;
903
+
904
+ STEEL_PRAGMA_UNROLL
905
+ for (short i = 0; i < TM; i++) {
906
+ STEEL_PRAGMA_UNROLL
907
+ for (short j = 0; j < TN; j++) {
908
+ thread const auto& r = Ctile_r.frag_at(i, j);
909
+ thread const auto& im = Ctile_i.frag_at(i, j);
910
+ int off = (i * TM_stride) * ldd + (j * TN_stride);
911
+ STEEL_PRAGMA_UNROLL
912
+ for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) {
913
+ D[off + k] = Epilogue::apply(complex64_t(r[k], im[k]));
914
+ }
915
+ }
916
+ }
917
+ }
918
+
919
+ METAL_FUNC void
920
+ store_result_slice(device U* D, const int ldd, short2 start, short2 stop) {
921
+ D += sm * ldd + sn;
922
+ start -= short2(sn, sm);
923
+ stop -= short2(sn, sm);
924
+
925
+ if (stop.y <= 0 || stop.x <= 0)
926
+ return;
927
+
928
+ STEEL_PRAGMA_UNROLL
929
+ for (short i = 0; i < TM; ++i) {
930
+ const int row = i * TM_stride;
931
+ if (row >= start.y && row < stop.y) {
932
+ STEEL_PRAGMA_UNROLL
933
+ for (short j = 0; j < TN; ++j) {
934
+ const int off = row * ldd + (j * TN_stride);
935
+ thread const auto& r = Ctile_r.frag_at(i, j);
936
+ thread const auto& im = Ctile_i.frag_at(i, j);
937
+
938
+ STEEL_PRAGMA_UNROLL
939
+ for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; ++k) {
940
+ const int col = j * TN_stride + k;
941
+ if (col >= start.x && col < stop.x) {
942
+ D[off + k] = Epilogue::apply(complex64_t(r[k], im[k]));
943
+ }
944
+ }
945
+ }
946
+ }
947
+ }
948
+ }
949
+
950
+ METAL_FUNC void
951
+ store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
952
+ D += sm * ldd + sn;
953
+ dst_tile_dims -= short2(sn, sm);
954
+ if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
955
+ return;
956
+ STEEL_PRAGMA_UNROLL
957
+ for (short i = 0; i < TM; i++) {
958
+ if (i * TM_stride < dst_tile_dims.y) {
959
+ STEEL_PRAGMA_UNROLL
960
+ for (short j = 0; j < TN; j++) {
961
+ int off = (i * TM_stride) * ldd + (j * TN_stride);
962
+ thread const auto& r = Ctile_r.frag_at(i, j);
963
+ thread const auto& im = Ctile_i.frag_at(i, j);
964
+ STEEL_PRAGMA_UNROLL
965
+ for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) {
966
+ if ((j * TN_stride + k) < dst_tile_dims.x) {
967
+ D[off + k] = Epilogue::apply(complex64_t(r[k], im[k]));
968
+ }
969
+ }
970
+ }
971
+ }
972
+ }
973
+ }
974
+
975
+ /* Apply epilogue */
976
+ template <typename UnaryEpilogue>
977
+ METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) {
978
+ STEEL_PRAGMA_UNROLL
979
+ for (short i = 0; i < decltype(Ctile_r)::kElemsPerTile; i++) {
980
+ complex64_t out = epilogue_op.apply(
981
+ complex64_t(Ctile_r.elems()[i], Ctile_i.elems()[i]));
982
+ Ctile_r.elems()[i] = out.real;
983
+ Ctile_i.elems()[i] = out.imag;
984
+ }
985
+ }
986
+
987
+ /* Apply epilogue */
988
+ template <typename BinaryEpilogue>
989
+ METAL_FUNC void apply_epilogue(
990
+ const device U* C,
991
+ const int ldc,
992
+ const int fdc,
993
+ thread const BinaryEpilogue& epilogue_op) {
994
+ // Adjust for simdgroup and thread location
995
+ C += (sm)*ldc + (sn)*fdc;
996
+
997
+ // Loop over all simdgroup tiles
998
+ STEEL_PRAGMA_UNROLL
999
+ for (short i = 0; i < TM; i++) {
1000
+ STEEL_PRAGMA_UNROLL
1001
+ for (short j = 0; j < TN; j++) {
1002
+ // Get accumulated result and associated offset in Cr, Ci
1003
+ thread auto& r = Ctile_r.frag_at(i, j);
1004
+ thread auto& im = Ctile_i.frag_at(i, j);
1005
+ int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
1006
+
1007
+ STEEL_PRAGMA_UNROLL
1008
+ for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) {
1009
+ complex64_t out = epilogue_op.apply(
1010
+ complex64_t(r[k], im[k]), C[offset_c + k * fdc]);
1011
+ r[k] = out.real;
1012
+ im[k] = out.imag;
1013
+ }
1014
+ }
1015
+ }
1016
+ }
1017
+
1018
+ /* Apply epilogue */
1019
+ template <typename BinaryEpilogue>
1020
+ METAL_FUNC void apply_epilogue_safe(
1021
+ const device U* C,
1022
+ const int ldc,
1023
+ const int fdc,
1024
+ short2 dst_tile_dims,
1025
+ thread const BinaryEpilogue& epilogue_op) {
1026
+ // Adjust for simdgroup and thread location
1027
+ C += (sm)*ldc + (sn)*fdc;
1028
+ dst_tile_dims -= short2(sn, sm);
1029
+
1030
+ if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
1031
+ return;
1032
+
1033
+ // Loop over all simdgroup tiles
1034
+ STEEL_PRAGMA_UNROLL
1035
+ for (short i = 0; i < TM; i++) {
1036
+ STEEL_PRAGMA_UNROLL
1037
+ for (short j = 0; j < TN; j++) {
1038
+ // Get accumulated result and associated offset in Cr, Ci
1039
+ thread auto& r = Ctile_r.frag_at(i, j);
1040
+ thread auto& im = Ctile_i.frag_at(i, j);
1041
+ int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
1042
+
1043
+ constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag;
1044
+ complex64_t tmp[kelems];
1045
+
1046
+ STEEL_PRAGMA_UNROLL
1047
+ for (short k = 0; k < kelems; k++) {
1048
+ if ((j * TN_stride + k) < dst_tile_dims.x &&
1049
+ (i * TM_stride) < dst_tile_dims.y) {
1050
+ tmp[k] = C[offset_c + k * fdc];
1051
+ } else {
1052
+ tmp[k] = complex64_t(0.0f, 0.0f);
1053
+ }
1054
+ }
1055
+
1056
+ // Apply epilogue
1057
+ STEEL_PRAGMA_UNROLL
1058
+ for (short k = 0; k < kelems; k++) {
1059
+ complex64_t out = epilogue_op.apply(complex64_t(r[k], im[k]), tmp[k]);
1060
+ r[k] = out.real;
1061
+ im[k] = out.imag;
1062
+ }
1063
+ }
1064
+ }
1065
+ }
1066
+
1067
+ /* Store results from simdgroup_matrix results into device memory */
1068
+ METAL_FUNC void store_result(
1069
+ device U* D,
1070
+ const int ldd,
1071
+ const device U* C,
1072
+ const int ldc,
1073
+ const int fdc,
1074
+ thread const Epilogue& epilogue_op) const {
1075
+ // Adjust for simdgroup and thread location
1076
+ C += (sm)*ldc + (sn)*fdc;
1077
+ D += (sm)*ldd + sn;
1078
+
1079
+ constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag;
1080
+
1081
+ // Loop over all simdgroup tiles
1082
+ STEEL_PRAGMA_UNROLL
1083
+ for (short i = 0; i < TM; i++) {
1084
+ STEEL_PRAGMA_UNROLL
1085
+ for (short j = 0; j < TN; j++) {
1086
+ // Get accumulated result and associated offset in Cr, Ci
1087
+ thread const auto& r = Ctile_r.frag_at(i, j);
1088
+ thread const auto& im = Ctile_i.frag_at(i, j);
1089
+ int off_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
1090
+ int off_d = (i * TM_stride) * ldd + (j * TN_stride);
1091
+
1092
+ // Apply epilogue
1093
+ STEEL_PRAGMA_UNROLL
1094
+ for (short k = 0; k < kelems; k++) {
1095
+ D[off_d + k] =
1096
+ epilogue_op.apply(complex64_t(r[k], im[k]), C[off_c + k * fdc]);
1097
+ }
1098
+ }
1099
+ }
1100
+ }
1101
+
1102
+ METAL_FUNC void store_result_safe(
1103
+ device U* D,
1104
+ const int ldd,
1105
+ const device U* C,
1106
+ const int ldc,
1107
+ const int fdc,
1108
+ short2 dst_tile_dims,
1109
+ thread const Epilogue& epilogue_op) const {
1110
+ // Adjust for simdgroup and thread location
1111
+ C += (sm)*ldc + (sn)*fdc;
1112
+ D += (sm)*ldd + sn;
1113
+ dst_tile_dims -= short2(sn, sm);
1114
+
1115
+ if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
1116
+ return;
1117
+
1118
+ constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag;
1119
+
1120
+ STEEL_PRAGMA_UNROLL
1121
+ for (int i = 0; i < TM; i++) {
1122
+ if (i * TM_stride < dst_tile_dims.y) {
1123
+ STEEL_PRAGMA_UNROLL
1124
+ for (int j = 0; j < TN; j++) {
1125
+ // Get accumulated result and associated offset in Cr, Ci
1126
+ thread const auto& r = Ctile_r.frag_at(i, j);
1127
+ thread const auto& im = Ctile_i.frag_at(i, j);
1128
+ int off_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
1129
+ int off_d = (i * TM_stride) * ldd + (j * TN_stride);
1130
+
1131
+ // Apply epilogue
1132
+ STEEL_PRAGMA_UNROLL
1133
+ for (short k = 0; k < kelems; k++) {
1134
+ if ((j * TN_stride + k) < dst_tile_dims.x) {
1135
+ D[off_d + k] = epilogue_op.apply(
1136
+ complex64_t(r[k], im[k]), C[off_c + k * fdc]);
1137
+ }
1138
+ }
1139
+ }
1140
+ }
1141
+ }
1142
+ }
1143
+ };
1144
+
1145
+ } // namespace steel
1146
+ } // namespace mlx