mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,346 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ using namespace mlx::steel;
4
+
5
+ ///////////////////////////////////////////////////////////////////////////////
6
+ // GEMM kernels
7
+ ///////////////////////////////////////////////////////////////////////////////
8
+
9
+ constant bool has_batch [[function_constant(10)]];
10
+
11
+ constant bool use_out_source [[function_constant(100)]];
12
+ constant bool do_axpby [[function_constant(110)]];
13
+
14
+ constant bool align_M [[function_constant(200)]];
15
+ constant bool align_N [[function_constant(201)]];
16
+ constant bool align_K [[function_constant(202)]];
17
+
18
+ // clang-format off
19
+ template <
20
+ typename T,
21
+ int BM,
22
+ int BN,
23
+ int BK,
24
+ int WM,
25
+ int WN,
26
+ bool transpose_a,
27
+ bool transpose_b,
28
+ typename AccumType = float>
29
+ [[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm(
30
+ const device T* A [[buffer(0)]],
31
+ const device T* B [[buffer(1)]],
32
+ const device T* C [[buffer(2), function_constant(use_out_source)]],
33
+ device T* D [[buffer(3)]],
34
+ const constant GEMMParams* params [[buffer(4)]],
35
+ const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
36
+ const constant int* batch_shape [[buffer(6), function_constant(has_batch)]],
37
+ const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]],
38
+ uint simd_lane_id [[thread_index_in_simdgroup]],
39
+ uint simd_group_id [[simdgroup_index_in_threadgroup]],
40
+ uint3 tid [[threadgroup_position_in_grid]],
41
+ uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
42
+ // Pacifying compiler
43
+ (void)lid;
44
+
45
+ using gemm_kernel = GEMMKernel<
46
+ T,
47
+ T,
48
+ BM,
49
+ BN,
50
+ BK,
51
+ WM,
52
+ WN,
53
+ transpose_a,
54
+ transpose_b,
55
+ true,
56
+ true,
57
+ AccumType>;
58
+
59
+ using loader_a_t = typename gemm_kernel::loader_a_t;
60
+ using loader_b_t = typename gemm_kernel::loader_b_t;
61
+ using mma_t = typename gemm_kernel::mma_t;
62
+
63
+ // Find block
64
+ const int tid_y = ((tid.y) << params->swizzle_log) +
65
+ ((tid.x) & ((1 << params->swizzle_log) - 1));
66
+ const int tid_x = (tid.x) >> params->swizzle_log;
67
+
68
+ // Exit early if out of bounds
69
+ if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
70
+ return;
71
+ }
72
+
73
+ // Adjust for batch
74
+ if (has_batch) {
75
+ const constant auto* A_bstrides = batch_strides;
76
+ const constant auto* B_bstrides = batch_strides + params->batch_ndim;
77
+
78
+ ulong2 batch_offsets = elem_to_loc_broadcast(
79
+ tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
80
+
81
+ A += batch_offsets.x;
82
+ B += batch_offsets.y;
83
+
84
+ if (use_out_source) {
85
+ const constant auto* C_bstrides = B_bstrides + params->batch_ndim;
86
+ C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
87
+ }
88
+ } else {
89
+ A += params->batch_stride_a * tid.z;
90
+ B += params->batch_stride_b * tid.z;
91
+
92
+ if (use_out_source) {
93
+ C += addmm_params->batch_stride_c * tid.z;
94
+ }
95
+ }
96
+
97
+ D += params->batch_stride_d * tid.z;
98
+
99
+ // Prepare threadgroup memory
100
+ threadgroup T As[gemm_kernel::tgp_mem_size_a];
101
+ threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
102
+
103
+ threadgroup_barrier(mem_flags::mem_none);
104
+
105
+ // Find block in A, B, C
106
+ const int c_row = tid_y * BM;
107
+ const int c_col = tid_x * BN;
108
+ const size_t c_row_long = size_t(c_row);
109
+ const size_t c_col_long = size_t(c_col);
110
+
111
+ A += transpose_a ? c_row_long : c_row_long * params->lda;
112
+ B += transpose_b ? c_col_long * params->ldb : c_col_long;
113
+ D += c_row_long * params->ldd + c_col_long;
114
+
115
+ if (use_out_source) {
116
+ C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc;
117
+ }
118
+
119
+ // Prepare threadgroup mma operation
120
+ thread mma_t mma_op(simd_group_id, simd_lane_id);
121
+
122
+ // Prepare threadgroup loading operations
123
+ thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
124
+ thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
125
+
126
+ // Prepare threadgroup bounds
127
+ const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
128
+ const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
129
+
130
+ // Prepare iterations
131
+ int gemm_k_iterations = params->gemm_k_iterations_aligned;
132
+
133
+ // Do unaligned K iterations first
134
+ if (!align_K) {
135
+ const int k_last = params->gemm_k_iterations_aligned * BK;
136
+ const int k_remain = params->K - k_last;
137
+ const size_t k_jump_a =
138
+ transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
139
+ const size_t k_jump_b =
140
+ transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
141
+
142
+ // Move loader source ahead to end
143
+ loader_a.src += k_jump_a;
144
+ loader_b.src += k_jump_b;
145
+
146
+ // Load tile
147
+ const short2 tile_dims_A =
148
+ transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
149
+ const short2 tile_dims_B =
150
+ transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
151
+
152
+ loader_a.load_safe(tile_dims_A);
153
+ loader_b.load_safe(tile_dims_B);
154
+
155
+ threadgroup_barrier(mem_flags::mem_threadgroup);
156
+
157
+ // Do matmul
158
+ mma_op.mma(As, Bs);
159
+
160
+ // Reset source back to start
161
+ loader_a.src -= k_jump_a;
162
+ loader_b.src -= k_jump_b;
163
+ }
164
+
165
+ const TransformAdd<AccumType, AccumType> epilogue_op_add(
166
+ addmm_params->alpha, addmm_params->beta);
167
+ const TransformAxpby<AccumType, AccumType> epilogue_op_axpby(
168
+ addmm_params->alpha, addmm_params->beta);
169
+
170
+ ///////////////////////////////////////////////////////////////////////////////
171
+ // MNK aligned loop
172
+ if (align_M && align_N) {
173
+ // Do gemm
174
+ for (int k = 0; k < gemm_k_iterations; k++) {
175
+ threadgroup_barrier(mem_flags::mem_threadgroup);
176
+ // Load elements into threadgroup
177
+ loader_a.load_unsafe();
178
+ loader_b.load_unsafe();
179
+
180
+ threadgroup_barrier(mem_flags::mem_threadgroup);
181
+
182
+ // Multiply and accumulate threadgroup elements
183
+ mma_op.mma(As, Bs);
184
+
185
+ // Prepare for next iteration
186
+ loader_a.next();
187
+ loader_b.next();
188
+ }
189
+
190
+ threadgroup_barrier(mem_flags::mem_none);
191
+
192
+ // Do epilogue
193
+ if (use_out_source) {
194
+ if (do_axpby) {
195
+ mma_op.apply_epilogue(
196
+ C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);
197
+ } else {
198
+ mma_op.apply_epilogue(
199
+ C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);
200
+ }
201
+ }
202
+
203
+ // Store results to device memory
204
+ return mma_op.store_result(D, params->ldd);
205
+
206
+ }
207
+ ///////////////////////////////////////////////////////////////////////////////
208
+ // MN unaligned loop
209
+ else { // Loop over K - unaligned case
210
+ const int leftover_bk = 0;
211
+
212
+ if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
213
+ // Do gemm
214
+ gemm_kernel::gemm_loop(
215
+ As,
216
+ Bs,
217
+ gemm_k_iterations,
218
+ loader_a,
219
+ loader_b,
220
+ mma_op,
221
+ tgp_bm,
222
+ tgp_bn,
223
+ leftover_bk,
224
+ LoopAlignment<true, true, true>{});
225
+
226
+ // Do epilogue
227
+ if (use_out_source) {
228
+ if (do_axpby) {
229
+ mma_op.apply_epilogue(
230
+ C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);
231
+ } else {
232
+ mma_op.apply_epilogue(
233
+ C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);
234
+ }
235
+ }
236
+
237
+ // Store results to device memory
238
+ return mma_op.store_result(D, params->ldd);
239
+
240
+ } else if (align_N || tgp_bn == BN) {
241
+ gemm_kernel::gemm_loop(
242
+ As,
243
+ Bs,
244
+ gemm_k_iterations,
245
+ loader_a,
246
+ loader_b,
247
+ mma_op,
248
+ tgp_bm,
249
+ tgp_bn,
250
+ leftover_bk,
251
+ LoopAlignment<false, true, true>{});
252
+
253
+ // Do epilogue
254
+ if (use_out_source) {
255
+ if (do_axpby) {
256
+ mma_op.apply_epilogue_safe(
257
+ C,
258
+ addmm_params->ldc,
259
+ addmm_params->fdc,
260
+ short2(tgp_bn, tgp_bm),
261
+ epilogue_op_axpby);
262
+ } else {
263
+ mma_op.apply_epilogue_safe(
264
+ C,
265
+ addmm_params->ldc,
266
+ addmm_params->fdc,
267
+ short2(tgp_bn, tgp_bm),
268
+ epilogue_op_add);
269
+ }
270
+ }
271
+
272
+ // Store results to device memory
273
+ return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
274
+
275
+ } else if (align_M || tgp_bm == BM) {
276
+ gemm_kernel::gemm_loop(
277
+ As,
278
+ Bs,
279
+ gemm_k_iterations,
280
+ loader_a,
281
+ loader_b,
282
+ mma_op,
283
+ tgp_bm,
284
+ tgp_bn,
285
+ leftover_bk,
286
+ LoopAlignment<true, false, true>{});
287
+
288
+ // Do epilogue
289
+ if (use_out_source) {
290
+ if (do_axpby) {
291
+ mma_op.apply_epilogue_safe(
292
+ C,
293
+ addmm_params->ldc,
294
+ addmm_params->fdc,
295
+ short2(tgp_bn, tgp_bm),
296
+ epilogue_op_axpby);
297
+ } else {
298
+ mma_op.apply_epilogue_safe(
299
+ C,
300
+ addmm_params->ldc,
301
+ addmm_params->fdc,
302
+ short2(tgp_bn, tgp_bm),
303
+ epilogue_op_add);
304
+ }
305
+ }
306
+
307
+ // Store results to device memory
308
+ return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
309
+
310
+ } else {
311
+ gemm_kernel::gemm_loop(
312
+ As,
313
+ Bs,
314
+ gemm_k_iterations,
315
+ loader_a,
316
+ loader_b,
317
+ mma_op,
318
+ tgp_bm,
319
+ tgp_bn,
320
+ leftover_bk,
321
+ LoopAlignment<false, false, true>{});
322
+
323
+ // Do epilogue
324
+ if (use_out_source) {
325
+ if (do_axpby) {
326
+ mma_op.apply_epilogue_safe(
327
+ C,
328
+ addmm_params->ldc,
329
+ addmm_params->fdc,
330
+ short2(tgp_bn, tgp_bm),
331
+ epilogue_op_axpby);
332
+ } else {
333
+ mma_op.apply_epilogue_safe(
334
+ C,
335
+ addmm_params->ldc,
336
+ addmm_params->fdc,
337
+ short2(tgp_bn, tgp_bm),
338
+ epilogue_op_add);
339
+ }
340
+ }
341
+
342
+ // Store results to device memory
343
+ return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
344
+ }
345
+ }
346
+ }
@@ -0,0 +1,207 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ using namespace mlx::steel;
4
+
5
+ constant bool has_batch [[function_constant(10)]];
6
+
7
+ constant bool use_out_source [[function_constant(100)]];
8
+ constant bool do_axpby [[function_constant(110)]];
9
+
10
+ constant bool align_M [[function_constant(200)]];
11
+ constant bool align_N [[function_constant(201)]];
12
+ constant bool align_K [[function_constant(202)]];
13
+
14
+ // clang-format off
15
+ template <
16
+ bool kAlignedM,
17
+ bool kAlignedN,
18
+ typename NAXTile_t,
19
+ typename T>
20
+ void gemm_epilogue(
21
+ thread NAXTile_t& Dtile,
22
+ const device T* C,
23
+ const constant GEMMParams* params,
24
+ const constant GEMMAddMMParams* addmm_params,
25
+ const short sgp_sm,
26
+ const short sgp_sn) { // clang-format on
27
+
28
+ (void)params;
29
+
30
+ constexpr short UM = NAXTile_t::kSubTileRows;
31
+ constexpr short UN = NAXTile_t::kSubTileCols;
32
+ using CSubTile = NAXSubTile<T, UM, UN>;
33
+
34
+ using V = typename NAXTile_t::elem_type;
35
+
36
+ constexpr short TM = NAXTile_t::kTileRows;
37
+ constexpr short TN = NAXTile_t::kTileCols;
38
+ constexpr short kElemsPerSubTile = NAXTile_t::kElemsPerSubTile;
39
+
40
+ STEEL_PRAGMA_UNROLL
41
+ for (short mm = 0; mm < TM; mm++) {
42
+ STEEL_PRAGMA_UNROLL
43
+ for (short nn = 0; nn < TN; nn++) {
44
+ const short m = mm * UM;
45
+ const short n = nn * UN;
46
+
47
+ CSubTile CTile;
48
+
49
+ if constexpr (kAlignedM && kAlignedN) {
50
+ CTile.load(C, addmm_params->ldc, addmm_params->fdc, m, n);
51
+ } else {
52
+ CTile.load_safe(
53
+ C, addmm_params->ldc, addmm_params->fdc, sgp_sm, sgp_sn, m, n);
54
+ }
55
+
56
+ auto delems = Dtile.subtile_at(mm, nn).elems();
57
+ auto celems = CTile.elems();
58
+
59
+ STEEL_PRAGMA_UNROLL
60
+ for (short i = 0; i < kElemsPerSubTile; i++) {
61
+ if (do_axpby) {
62
+ delems[i] = addmm_params->alpha * delems[i] +
63
+ addmm_params->beta * static_cast<V>(celems[i]);
64
+ } else {
65
+ delems[i] += static_cast<V>(celems[i]);
66
+ }
67
+ }
68
+ }
69
+ }
70
+ }
71
+
72
+ // clang-format off
73
+ template <
74
+ typename T,
75
+ int BM,
76
+ int BN,
77
+ int BK,
78
+ int WM,
79
+ int WN,
80
+ bool transpose_a,
81
+ bool transpose_b,
82
+ typename AccumType = float>
83
+ [[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm(
84
+ const device T* A [[buffer(0)]],
85
+ const device T* B [[buffer(1)]],
86
+ const device T* C [[buffer(2), function_constant(use_out_source)]],
87
+ device T* D [[buffer(3)]],
88
+ const constant GEMMParams* params [[buffer(4)]],
89
+ const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
90
+ const constant int* batch_shape [[buffer(6), function_constant(has_batch)]],
91
+ const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]],
92
+ uint simd_group_id [[simdgroup_index_in_threadgroup]],
93
+ uint3 tid [[threadgroup_position_in_grid]]) { // clang-format on
94
+ // Find block
95
+ const int tid_y = ((tid.y) << params->swizzle_log) +
96
+ ((tid.x) & ((1 << params->swizzle_log) - 1));
97
+ const int tid_x = (tid.x) >> params->swizzle_log;
98
+
99
+ // Exit early if out of bounds
100
+ if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
101
+ return;
102
+ }
103
+
104
+ // Adjust for batch
105
+ if (has_batch) {
106
+ const constant auto* A_bstrides = batch_strides;
107
+ const constant auto* B_bstrides = batch_strides + params->batch_ndim;
108
+
109
+ ulong2 batch_offsets = elem_to_loc_broadcast(
110
+ tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
111
+
112
+ A += batch_offsets.x;
113
+ B += batch_offsets.y;
114
+
115
+ if (use_out_source) {
116
+ const constant auto* C_bstrides = B_bstrides + params->batch_ndim;
117
+ C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
118
+ }
119
+ } else {
120
+ A += params->batch_stride_a * tid.z;
121
+ B += params->batch_stride_b * tid.z;
122
+
123
+ if (use_out_source) {
124
+ C += addmm_params->batch_stride_c * tid.z;
125
+ }
126
+ }
127
+
128
+ D += params->batch_stride_d * tid.z;
129
+
130
+ // Prepare threadgroup memory
131
+ threadgroup_barrier(mem_flags::mem_none);
132
+
133
+ // Find block in A, B, C
134
+ const int c_row = tid_y * BM;
135
+ const int c_col = tid_x * BN;
136
+ const size_t c_row_long = size_t(c_row);
137
+ const size_t c_col_long = size_t(c_col);
138
+
139
+ A += transpose_a ? c_row_long : c_row_long * params->lda;
140
+ B += transpose_b ? c_col_long * params->ldb : c_col_long;
141
+ D += c_row_long * params->ldd + c_col_long;
142
+
143
+ if (use_out_source) {
144
+ C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc;
145
+ }
146
+
147
+ constexpr short UM = 16;
148
+ constexpr short UN = 32;
149
+ constexpr short UK = 16;
150
+ constexpr short SM = BM / WM;
151
+ constexpr short SN = BN / WN;
152
+ constexpr short SK = 32;
153
+
154
+ constexpr short TM = SM / UM;
155
+ constexpr short TN = SN / UN;
156
+
157
+ const short tm = SM * (simd_group_id / WN);
158
+ const short tn = SN * (simd_group_id % WN);
159
+
160
+ const short sgp_sm = align_M ? SM : min(SM, short(params->M - (c_row + tm)));
161
+ const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM);
162
+
163
+ const short sgp_sn = align_N ? SN : min(SN, short(params->N - (c_col + tn)));
164
+ const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN);
165
+
166
+ A += transpose_a ? tm : (tm * params->lda);
167
+ B += transpose_b ? (tn * params->ldb) : tn;
168
+ D += tm * params->ldd + tn;
169
+
170
+ if (use_out_source) {
171
+ C += tm * addmm_params->ldc + tn * addmm_params->fdc;
172
+ }
173
+
174
+ using DSubTile = NAXSubTile<AccumType, UM, UN>;
175
+ NAXTile<AccumType, TM, TN, DSubTile> Dtile;
176
+
177
+ dispatch_bool(align_K, [&](auto kAlignedK) {
178
+ dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) {
179
+ dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) {
180
+ Dtile = gemm_loop<
181
+ T,
182
+ SM,
183
+ SN,
184
+ SK,
185
+ BK,
186
+ transpose_a,
187
+ transpose_b,
188
+ kAlignedM.value,
189
+ kAlignedN.value,
190
+ kAlignedK.value,
191
+ UM,
192
+ UN,
193
+ UK,
194
+ AccumType>(A, B, params, sgp_sm, sgp_sn);
195
+ if (use_out_source) {
196
+ gemm_epilogue<kAlignedM.value, kAlignedN.value>(
197
+ Dtile, C, params, addmm_params, sgp_sm, sgp_sn);
198
+ }
199
+ if constexpr (kAlignedM && kAlignedN) {
200
+ Dtile.store(D, int(params->ldd));
201
+ } else {
202
+ Dtile.store_safe(D, int(params->ldd), short2(sgp_sn, sgp_sm));
203
+ }
204
+ });
205
+ });
206
+ });
207
+ }