mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,296 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/backend/metal/kernels/steel/attn/loader.h"
6
+ #include "mlx/backend/metal/kernels/steel/attn/mma.h"
7
+ #include "mlx/backend/metal/kernels/steel/attn/params.h"
8
+ #include "mlx/backend/metal/kernels/steel/attn/transforms.h"
9
+ #include "mlx/backend/metal/kernels/steel/gemm/params.h"
10
+ #include "mlx/backend/metal/kernels/steel/utils.h"
11
+
12
+ using namespace metal;
13
+
14
+ ///////////////////////////////////////////////////////////////////////////////
15
+ // GEMM kernel class
16
+ ///////////////////////////////////////////////////////////////////////////////
17
+
18
+ namespace mlx {
19
+ namespace steel {
20
+
21
+ template <bool M_aligned, bool N_aligned, bool K_aligned>
22
+ struct LoopAlignment {};
23
+
24
+ template <
25
+ typename T,
26
+ typename U,
27
+ int BM,
28
+ int BN,
29
+ int BK,
30
+ int WM,
31
+ int WN,
32
+ bool transpose_a,
33
+ bool transpose_b,
34
+ bool MN_aligned,
35
+ bool K_aligned,
36
+ typename AccumType = typename AccumHelper<T>::accum_type,
37
+ typename Epilogue = TransformNone<U, AccumType>>
38
+ struct GEMMKernel {
39
+ STEEL_CONST short tgp_padding_a = 16 / sizeof(T);
40
+ STEEL_CONST short tgp_padding_b = 16 / sizeof(T);
41
+ STEEL_CONST short tgp_mem_size_a =
42
+ transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
43
+ STEEL_CONST short tgp_mem_size_b =
44
+ transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
45
+ STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
46
+
47
+ STEEL_CONST short tgp_size = WM * WN * 32;
48
+
49
+ using loader_a_t = BlockLoader<
50
+ T,
51
+ transpose_a ? BK : BM,
52
+ transpose_a ? BM : BK,
53
+ transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
54
+ !transpose_a,
55
+ tgp_size>;
56
+ using loader_b_t = BlockLoader<
57
+ T,
58
+ transpose_b ? BN : BK,
59
+ transpose_b ? BK : BN,
60
+ transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
61
+ transpose_b,
62
+ tgp_size>;
63
+ using mma_t = BlockMMA<
64
+ T,
65
+ U,
66
+ BM,
67
+ BN,
68
+ BK,
69
+ WM,
70
+ WN,
71
+ transpose_a,
72
+ transpose_b,
73
+ transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
74
+ transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
75
+ AccumType,
76
+ Epilogue>;
77
+
78
+ /* Main kernel function */
79
+ template <bool M_aligned, bool N_aligned, bool K_aligned_>
80
+ static METAL_FUNC void gemm_loop(
81
+ threadgroup T* As [[threadgroup(0)]],
82
+ threadgroup T* Bs [[threadgroup(1)]],
83
+ const int gemm_k_iterations,
84
+ thread loader_a_t& loader_a,
85
+ thread loader_b_t& loader_b,
86
+ thread mma_t& mma_op,
87
+ thread const short& tgp_bm,
88
+ thread const short& tgp_bn,
89
+ thread const short& lbk,
90
+ LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) {
91
+ // Appease the compiler
92
+ (void)l;
93
+
94
+ short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
95
+
96
+ short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
97
+
98
+ for (int k = 0; k < gemm_k_iterations; k++) {
99
+ threadgroup_barrier(mem_flags::mem_threadgroup);
100
+ // Load elements into threadgroup
101
+ if (M_aligned) {
102
+ loader_a.load_unsafe();
103
+ } else {
104
+ loader_a.load_safe(tile_dims_A);
105
+ }
106
+
107
+ if (N_aligned) {
108
+ loader_b.load_unsafe();
109
+ } else {
110
+ loader_b.load_safe(tile_dims_B);
111
+ }
112
+
113
+ threadgroup_barrier(mem_flags::mem_threadgroup);
114
+
115
+ // Multiply and accumulate threadgroup elements
116
+ mma_op.mma(As, Bs);
117
+
118
+ // Prepare for next iteration
119
+ loader_a.next();
120
+ loader_b.next();
121
+ }
122
+
123
+ if (!K_aligned_) {
124
+ threadgroup_barrier(mem_flags::mem_threadgroup);
125
+
126
+ short2 tile_dims_A_last =
127
+ transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
128
+ short2 tile_dims_B_last =
129
+ transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
130
+
131
+ loader_a.load_safe(tile_dims_A_last);
132
+ loader_b.load_safe(tile_dims_B_last);
133
+
134
+ threadgroup_barrier(mem_flags::mem_threadgroup);
135
+
136
+ mma_op.mma(As, Bs);
137
+ }
138
+ }
139
+
140
+ /* Main kernel function */
141
+ static METAL_FUNC void run(
142
+ const device T* A [[buffer(0)]],
143
+ const device T* B [[buffer(1)]],
144
+ device U* D [[buffer(2)]],
145
+ const constant GEMMParams* params [[buffer(3)]],
146
+ threadgroup T* As [[threadgroup(0)]],
147
+ threadgroup T* Bs [[threadgroup(1)]],
148
+ uint simd_lane_id [[thread_index_in_simdgroup]],
149
+ uint simd_group_id [[simdgroup_index_in_threadgroup]],
150
+ uint3 tid [[threadgroup_position_in_grid]],
151
+ uint3 lid [[thread_position_in_threadgroup]]) {
152
+ // Pacifying compiler
153
+ (void)lid;
154
+
155
+ const int tid_y = ((tid.y) << params->swizzle_log) +
156
+ ((tid.x) & ((1 << params->swizzle_log) - 1));
157
+ const int tid_x = (tid.x) >> params->swizzle_log;
158
+
159
+ if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
160
+ return;
161
+ }
162
+
163
+ threadgroup_barrier(mem_flags::mem_none);
164
+
165
+ // Find block in A, B, C
166
+ const int c_row = tid_y * BM;
167
+ const int c_col = tid_x * BN;
168
+ const size_t c_row_long = size_t(c_row);
169
+ const size_t c_col_long = size_t(c_col);
170
+
171
+ A += transpose_a ? c_row_long : c_row_long * params->lda;
172
+ B += transpose_b ? c_col_long * params->ldb : c_col_long;
173
+ D += c_row_long * params->ldd + c_col_long;
174
+
175
+ // Prepare threadgroup loading operations
176
+ thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
177
+ thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
178
+
179
+ // Prepare threadgroup mma operation
180
+ thread mma_t mma_op(simd_group_id, simd_lane_id);
181
+
182
+ int gemm_k_iterations = params->gemm_k_iterations_aligned;
183
+
184
+ ///////////////////////////////////////////////////////////////////////////////
185
+ // MNK aligned loop
186
+ if (MN_aligned) {
187
+ for (int k = 0; k < gemm_k_iterations; k++) {
188
+ threadgroup_barrier(mem_flags::mem_threadgroup);
189
+ // Load elements into threadgroup
190
+ loader_a.load_unsafe();
191
+ loader_b.load_unsafe();
192
+
193
+ threadgroup_barrier(mem_flags::mem_threadgroup);
194
+
195
+ // Multiply and accumulate threadgroup elements
196
+ mma_op.mma(As, Bs);
197
+
198
+ // Prepare for next iteration
199
+ loader_a.next();
200
+ loader_b.next();
201
+ }
202
+
203
+ threadgroup_barrier(mem_flags::mem_none);
204
+
205
+ // Loop tail
206
+ if (!K_aligned) {
207
+ int lbk = params->K - params->gemm_k_iterations_aligned * BK;
208
+ short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
209
+ short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
210
+
211
+ loader_a.load_safe(tile_dims_A);
212
+ loader_b.load_safe(tile_dims_B);
213
+
214
+ threadgroup_barrier(mem_flags::mem_threadgroup);
215
+
216
+ mma_op.mma(As, Bs);
217
+ }
218
+
219
+ // Store results to device memory
220
+ mma_op.store_result(D, params->ldd);
221
+ return;
222
+
223
+ }
224
+ ///////////////////////////////////////////////////////////////////////////////
225
+ // MN unaligned loop
226
+ else { // Loop over K - unaligned case
227
+ short tgp_bm = min(BM, params->M - c_row);
228
+ short tgp_bn = min(BN, params->N - c_col);
229
+ short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
230
+
231
+ if (tgp_bm == BM && tgp_bn == BN) {
232
+ gemm_loop<true, true, K_aligned>(
233
+ As,
234
+ Bs,
235
+ gemm_k_iterations,
236
+ loader_a,
237
+ loader_b,
238
+ mma_op,
239
+ tgp_bm,
240
+ tgp_bn,
241
+ leftover_bk);
242
+
243
+ mma_op.store_result(D, params->ldd);
244
+ return;
245
+
246
+ } else if (tgp_bn == BN) {
247
+ gemm_loop<false, true, K_aligned>(
248
+ As,
249
+ Bs,
250
+ gemm_k_iterations,
251
+ loader_a,
252
+ loader_b,
253
+ mma_op,
254
+ tgp_bm,
255
+ tgp_bn,
256
+ leftover_bk);
257
+
258
+ mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
259
+ return;
260
+
261
+ } else if (tgp_bm == BM) {
262
+ gemm_loop<true, false, K_aligned>(
263
+ As,
264
+ Bs,
265
+ gemm_k_iterations,
266
+ loader_a,
267
+ loader_b,
268
+ mma_op,
269
+ tgp_bm,
270
+ tgp_bn,
271
+ leftover_bk);
272
+
273
+ mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
274
+ return;
275
+
276
+ } else {
277
+ gemm_loop<false, false, K_aligned>(
278
+ As,
279
+ Bs,
280
+ gemm_k_iterations,
281
+ loader_a,
282
+ loader_b,
283
+ mma_op,
284
+ tgp_bm,
285
+ tgp_bn,
286
+ leftover_bk);
287
+
288
+ mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
289
+ return;
290
+ }
291
+ }
292
+ }
293
+ };
294
+
295
+ } // namespace steel
296
+ } // namespace mlx