mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,266 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ using namespace mlx::steel;
4
+
5
+ constant bool segments_contiguous [[function_constant(199)]];
6
+ constant bool align_M [[function_constant(200)]];
7
+ constant bool align_N [[function_constant(201)]];
8
+
9
+ template <
10
+ typename T,
11
+ int BM,
12
+ int BN,
13
+ int BK,
14
+ int WM,
15
+ int WN,
16
+ bool transpose_a,
17
+ bool transpose_b,
18
+ typename AccumType = float>
19
+ [[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void segmented_mm(
20
+ const device T* A [[buffer(0)]],
21
+ const device T* B [[buffer(1)]],
22
+ const device uint32_t* segments [[buffer(2)]],
23
+ device T* C [[buffer(3)]],
24
+ const constant GEMMParams* params [[buffer(4)]],
25
+ uint simd_lane_id [[thread_index_in_simdgroup]],
26
+ uint simd_group_id [[simdgroup_index_in_threadgroup]],
27
+ uint3 tid [[threadgroup_position_in_grid]]) {
28
+ using gemm_kernel = GEMMKernel<
29
+ T,
30
+ T,
31
+ BM,
32
+ BN,
33
+ BK,
34
+ WM,
35
+ WN,
36
+ transpose_a,
37
+ transpose_b,
38
+ true,
39
+ true,
40
+ AccumType>;
41
+
42
+ using loader_a_t = typename gemm_kernel::loader_a_t;
43
+ using loader_b_t = typename gemm_kernel::loader_b_t;
44
+ using mma_t = typename gemm_kernel::mma_t;
45
+
46
+ if (params->tiles_n <= static_cast<int>(tid.x) ||
47
+ params->tiles_m <= static_cast<int>(tid.y)) {
48
+ return;
49
+ }
50
+
51
+ // Prepare threadgroup memory
52
+ threadgroup T As[gemm_kernel::tgp_mem_size_a];
53
+ threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
54
+
55
+ // Find the block in A, B, C
56
+ const int c_row = tid.y * BM;
57
+ const int c_col = tid.x * BN;
58
+ const size_t c_row_long = size_t(c_row);
59
+ const size_t c_col_long = size_t(c_col);
60
+
61
+ // Prepare threadgroup bounds
62
+ const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
63
+ const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
64
+
65
+ // Move the pointers to the output tile
66
+ A += transpose_a ? c_row_long : c_row_long * params->lda;
67
+ B += transpose_b ? c_col_long * params->ldb : c_col_long;
68
+ C += c_row_long * params->ldd + c_col_long;
69
+
70
+ // Move the pointers to the start of the segment
71
+ uint32_t k_start, k_end;
72
+ if (segments_contiguous) {
73
+ k_start = segments[2 * tid.z];
74
+ k_end = segments[2 * tid.z + 1];
75
+ } else {
76
+ // We accept either contiguous (above) or weird strides where the beginning
77
+ // of the next one is the previous one. Basically the last two strides are
78
+ // both 1!
79
+ k_start = segments[tid.z];
80
+ k_end = segments[tid.z + 1];
81
+ }
82
+ A += transpose_a ? k_start * params->lda : k_start;
83
+ B += transpose_b ? k_start : k_start * params->ldb;
84
+ C += tid.z * params->batch_stride_d;
85
+
86
+ // Prepare threadgroup mma operation
87
+ thread mma_t mma_op(simd_group_id, simd_lane_id);
88
+
89
+ // Prepare threadgroup loading operations
90
+ thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
91
+ thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
92
+
93
+ // Matrix level alignment so only check K
94
+ if (align_M && align_N) {
95
+ uint32_t k = k_start + BK;
96
+ for (; k <= k_end; k += BK) {
97
+ threadgroup_barrier(mem_flags::mem_threadgroup);
98
+
99
+ // Load elements into threadgroup
100
+ loader_a.load_unsafe();
101
+ loader_b.load_unsafe();
102
+
103
+ threadgroup_barrier(mem_flags::mem_threadgroup);
104
+
105
+ // Multiply and accumulate threadgroup elements
106
+ mma_op.mma(As, Bs);
107
+
108
+ // Prepare for next iteration
109
+ loader_a.next();
110
+ loader_b.next();
111
+ }
112
+ short k_remain = BK - short(k - k_end);
113
+ const short2 tile_dims_A =
114
+ transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
115
+ const short2 tile_dims_B =
116
+ transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
117
+ if (k_remain > 0) {
118
+ threadgroup_barrier(mem_flags::mem_threadgroup);
119
+ loader_a.load_safe(tile_dims_A);
120
+ loader_b.load_safe(tile_dims_B);
121
+ threadgroup_barrier(mem_flags::mem_threadgroup);
122
+ mma_op.mma(As, Bs);
123
+ }
124
+ mma_op.store_result(C, params->ldd);
125
+ } else {
126
+ // Tile aligned do the same as above
127
+ if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
128
+ uint32_t k = k_start + BK;
129
+ for (; k <= k_end; k += BK) {
130
+ threadgroup_barrier(mem_flags::mem_threadgroup);
131
+
132
+ // Load elements into threadgroup
133
+ loader_a.load_unsafe();
134
+ loader_b.load_unsafe();
135
+
136
+ threadgroup_barrier(mem_flags::mem_threadgroup);
137
+
138
+ // Multiply and accumulate threadgroup elements
139
+ mma_op.mma(As, Bs);
140
+
141
+ // Prepare for next iteration
142
+ loader_a.next();
143
+ loader_b.next();
144
+ }
145
+ short k_remain = BK - short(k - k_end);
146
+ const short2 tile_dims_A =
147
+ transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
148
+ const short2 tile_dims_B =
149
+ transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
150
+ if (k_remain > 0) {
151
+ threadgroup_barrier(mem_flags::mem_threadgroup);
152
+ loader_a.load_safe(tile_dims_A);
153
+ loader_b.load_safe(tile_dims_B);
154
+ threadgroup_barrier(mem_flags::mem_threadgroup);
155
+ mma_op.mma(As, Bs);
156
+ }
157
+ mma_op.store_result(C, params->ldd);
158
+ }
159
+
160
+ // Tile partially aligned check rows
161
+ else if (align_N || tgp_bn == BN) {
162
+ uint32_t k = k_start + BK;
163
+ for (; k <= k_end; k += BK) {
164
+ threadgroup_barrier(mem_flags::mem_threadgroup);
165
+
166
+ // Load elements into threadgroup
167
+ loader_a.load_safe(
168
+ transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm));
169
+ loader_b.load_unsafe();
170
+
171
+ threadgroup_barrier(mem_flags::mem_threadgroup);
172
+
173
+ // Multiply and accumulate threadgroup elements
174
+ mma_op.mma(As, Bs);
175
+
176
+ // Prepare for next iteration
177
+ loader_a.next();
178
+ loader_b.next();
179
+ }
180
+ short k_remain = BK - short(k - k_end);
181
+ const short2 tile_dims_A =
182
+ transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
183
+ const short2 tile_dims_B =
184
+ transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
185
+ if (k_remain > 0) {
186
+ threadgroup_barrier(mem_flags::mem_threadgroup);
187
+ loader_a.load_safe(tile_dims_A);
188
+ loader_b.load_safe(tile_dims_B);
189
+ threadgroup_barrier(mem_flags::mem_threadgroup);
190
+ mma_op.mma(As, Bs);
191
+ }
192
+ mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
193
+ }
194
+
195
+ // Tile partially aligned check cols
196
+ else if (align_M || tgp_bm == BM) {
197
+ uint32_t k = k_start + BK;
198
+ for (; k <= k_end; k += BK) {
199
+ threadgroup_barrier(mem_flags::mem_threadgroup);
200
+
201
+ // Load elements into threadgroup
202
+ loader_a.load_unsafe();
203
+ loader_b.load_safe(
204
+ transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK));
205
+
206
+ threadgroup_barrier(mem_flags::mem_threadgroup);
207
+
208
+ // Multiply and accumulate threadgroup elements
209
+ mma_op.mma(As, Bs);
210
+
211
+ // Prepare for next iteration
212
+ loader_a.next();
213
+ loader_b.next();
214
+ }
215
+ short k_remain = BK - short(k - k_end);
216
+ const short2 tile_dims_A =
217
+ transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
218
+ const short2 tile_dims_B =
219
+ transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
220
+ if (k_remain > 0) {
221
+ threadgroup_barrier(mem_flags::mem_threadgroup);
222
+ loader_a.load_safe(tile_dims_A);
223
+ loader_b.load_safe(tile_dims_B);
224
+ threadgroup_barrier(mem_flags::mem_threadgroup);
225
+ mma_op.mma(As, Bs);
226
+ }
227
+ mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
228
+ }
229
+
230
+ // Nothing aligned so check both rows and cols
231
+ else {
232
+ uint32_t k = k_start + BK;
233
+ for (; k <= k_end; k += BK) {
234
+ threadgroup_barrier(mem_flags::mem_threadgroup);
235
+
236
+ // Load elements into threadgroup
237
+ loader_a.load_safe(
238
+ transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm));
239
+ loader_b.load_safe(
240
+ transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK));
241
+
242
+ threadgroup_barrier(mem_flags::mem_threadgroup);
243
+
244
+ // Multiply and accumulate threadgroup elements
245
+ mma_op.mma(As, Bs);
246
+
247
+ // Prepare for next iteration
248
+ loader_a.next();
249
+ loader_b.next();
250
+ }
251
+ short k_remain = BK - short(k - k_end);
252
+ const short2 tile_dims_A =
253
+ transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
254
+ const short2 tile_dims_B =
255
+ transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
256
+ if (k_remain > 0) {
257
+ threadgroup_barrier(mem_flags::mem_threadgroup);
258
+ loader_a.load_safe(tile_dims_A);
259
+ loader_b.load_safe(tile_dims_B);
260
+ threadgroup_barrier(mem_flags::mem_threadgroup);
261
+ mma_op.mma(As, Bs);
262
+ }
263
+ mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
264
+ }
265
+ }
266
+ }
@@ -0,0 +1,227 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ using namespace mlx::steel;
4
+
5
+ ///////////////////////////////////////////////////////////////////////////////
6
+ // GEMM kernels
7
+ ///////////////////////////////////////////////////////////////////////////////
8
+
9
+ template <
10
+ typename T,
11
+ typename U,
12
+ int BM,
13
+ int BN,
14
+ int BK,
15
+ int WM,
16
+ int WN,
17
+ bool transpose_a,
18
+ bool transpose_b,
19
+ bool MN_aligned,
20
+ bool K_aligned>
21
+ [[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm_splitk(
22
+ const device T* A [[buffer(0)]],
23
+ const device T* B [[buffer(1)]],
24
+ device U* C [[buffer(2)]],
25
+ const constant GEMMSpiltKParams* params [[buffer(3)]],
26
+ uint simd_lane_id [[thread_index_in_simdgroup]],
27
+ uint simd_group_id [[simdgroup_index_in_threadgroup]],
28
+ uint3 tid [[threadgroup_position_in_grid]],
29
+ uint3 lid [[thread_position_in_threadgroup]]) {
30
+ (void)lid;
31
+
32
+ using gemm_kernel = GEMMKernel<
33
+ T,
34
+ U,
35
+ BM,
36
+ BN,
37
+ BK,
38
+ WM,
39
+ WN,
40
+ transpose_a,
41
+ transpose_b,
42
+ MN_aligned,
43
+ K_aligned>;
44
+ using loader_a_t = typename gemm_kernel::loader_a_t;
45
+ using loader_b_t = typename gemm_kernel::loader_b_t;
46
+ using mma_t = typename gemm_kernel::mma_t;
47
+
48
+ threadgroup T As[gemm_kernel::tgp_mem_size_a];
49
+ threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
50
+
51
+ const int tid_x = tid.x;
52
+ const int tid_y = tid.y;
53
+ const int tid_z = tid.z;
54
+
55
+ if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
56
+ return;
57
+ }
58
+
59
+ // Find block in A, B, C
60
+ const int c_row = tid_y * BM;
61
+ const int c_col = tid_x * BN;
62
+ const int k_start = params->split_k_partition_size * tid_z;
63
+
64
+ const size_t c_row_long = size_t(c_row);
65
+ const size_t c_col_long = size_t(c_col);
66
+ const size_t k_start_long = size_t(k_start);
67
+
68
+ A += transpose_a ? (c_row_long + k_start_long * params->lda)
69
+ : (k_start_long + c_row_long * params->lda);
70
+ B += transpose_b ? (k_start_long + c_col_long * params->ldb)
71
+ : (c_col_long + k_start_long * params->ldb);
72
+ C += (size_t(params->split_k_partition_stride) * tid_z) +
73
+ (c_row_long * params->ldc + c_col_long);
74
+
75
+ // Prepare threadgroup loading operations
76
+ thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
77
+ thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
78
+
79
+ // Prepare threadgroup mma operation
80
+ thread mma_t mma_op(simd_group_id, simd_lane_id);
81
+
82
+ int gemm_k_iterations = params->gemm_k_iterations_aligned;
83
+
84
+ short tgp_bm = min(BM, params->M - c_row);
85
+ short tgp_bn = min(BN, params->N - c_col);
86
+ short leftover_bk = params->K % BK;
87
+
88
+ if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
89
+ gemm_kernel::gemm_loop(
90
+ As,
91
+ Bs,
92
+ gemm_k_iterations,
93
+ loader_a,
94
+ loader_b,
95
+ mma_op,
96
+ tgp_bm,
97
+ tgp_bn,
98
+ leftover_bk,
99
+ LoopAlignment<true, true, true>{});
100
+ } else if (tgp_bn == BN) {
101
+ gemm_kernel::gemm_loop(
102
+ As,
103
+ Bs,
104
+ gemm_k_iterations,
105
+ loader_a,
106
+ loader_b,
107
+ mma_op,
108
+ tgp_bm,
109
+ tgp_bn,
110
+ leftover_bk,
111
+ LoopAlignment<false, true, true>{});
112
+ } else if (tgp_bm == BM) {
113
+ gemm_kernel::gemm_loop(
114
+ As,
115
+ Bs,
116
+ gemm_k_iterations,
117
+ loader_a,
118
+ loader_b,
119
+ mma_op,
120
+ tgp_bm,
121
+ tgp_bn,
122
+ leftover_bk,
123
+ LoopAlignment<true, false, true>{});
124
+ } else {
125
+ gemm_kernel::gemm_loop(
126
+ As,
127
+ Bs,
128
+ gemm_k_iterations,
129
+ loader_a,
130
+ loader_b,
131
+ mma_op,
132
+ tgp_bm,
133
+ tgp_bn,
134
+ leftover_bk,
135
+ LoopAlignment<false, false, true>{});
136
+ }
137
+
138
+ threadgroup_barrier(mem_flags::mem_threadgroup);
139
+
140
+ if ((tid_z + 1) == (params->split_k_partitions)) {
141
+ int gemm_k_iter_remaining =
142
+ (params->K - (k_start + params->split_k_partition_size)) / BK;
143
+ if (!K_aligned || gemm_k_iter_remaining > 0)
144
+ gemm_kernel::gemm_loop(
145
+ As,
146
+ Bs,
147
+ gemm_k_iter_remaining,
148
+ loader_a,
149
+ loader_b,
150
+ mma_op,
151
+ tgp_bm,
152
+ tgp_bn,
153
+ leftover_bk,
154
+ LoopAlignment<false, false, K_aligned>{});
155
+ }
156
+
157
+ if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
158
+ mma_op.store_result(C, params->ldc);
159
+ } else {
160
+ mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
161
+ }
162
+ }
163
+
164
+ ///////////////////////////////////////////////////////////////////////////////
165
+ // Split k accumulation kernel
166
+ ///////////////////////////////////////////////////////////////////////////////
167
+
168
+ template <
169
+ typename AccT,
170
+ typename OutT,
171
+ typename Epilogue = TransformNone<OutT, AccT>>
172
+ [[kernel]] void gemm_splitk_accum(
173
+ const device AccT* C_split [[buffer(0)]],
174
+ device OutT* D [[buffer(1)]],
175
+ const constant int& k_partitions [[buffer(2)]],
176
+ const constant int& partition_stride [[buffer(3)]],
177
+ const constant int& ldd [[buffer(4)]],
178
+ uint2 gid [[thread_position_in_grid]]) {
179
+ // Ajust D and C
180
+ D += gid.x + gid.y * size_t(ldd);
181
+ C_split += gid.x + gid.y * size_t(ldd);
182
+
183
+ size_t offset = 0;
184
+ AccT out = 0;
185
+
186
+ for (int i = 0; i < k_partitions; i++) {
187
+ out += C_split[offset];
188
+ offset += partition_stride;
189
+ }
190
+
191
+ // Write output
192
+ D[0] = Epilogue::apply(out);
193
+ }
194
+
195
+ template <
196
+ typename AccT,
197
+ typename OutT,
198
+ typename Epilogue = TransformAxpby<OutT, AccT>>
199
+ [[kernel]] void gemm_splitk_accum_axpby(
200
+ const device AccT* C_split [[buffer(0)]],
201
+ device OutT* D [[buffer(1)]],
202
+ const constant int& k_partitions [[buffer(2)]],
203
+ const constant int& partition_stride [[buffer(3)]],
204
+ const constant int& ldd [[buffer(4)]],
205
+ const device OutT* C [[buffer(5)]],
206
+ const constant int& ldc [[buffer(6)]],
207
+ const constant int& fdc [[buffer(7)]],
208
+ const constant float& alpha [[buffer(8)]],
209
+ const constant float& beta [[buffer(9)]],
210
+ uint2 gid [[thread_position_in_grid]]) {
211
+ // Ajust D and C
212
+ C += gid.x * size_t(fdc) + gid.y * size_t(ldc);
213
+ D += gid.x + gid.y * size_t(ldd);
214
+ C_split += gid.x + gid.y * size_t(ldd);
215
+
216
+ size_t offset = 0;
217
+ AccT out = 0;
218
+
219
+ for (int i = 0; i < k_partitions; i++) {
220
+ out += C_split[offset];
221
+ offset += partition_stride;
222
+ }
223
+
224
+ // Write output
225
+ Epilogue op(alpha, beta);
226
+ D[0] = op.apply(out, *C);
227
+ }
@@ -0,0 +1,137 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/backend/metal/kernels/steel/defines.h"
6
+
7
+ ///////////////////////////////////////////////////////////////////////////////
8
+ // Loading helper
9
+ ///////////////////////////////////////////////////////////////////////////////
10
+
11
+ namespace mlx {
12
+ namespace steel {
13
+
14
+ template <
15
+ typename T,
16
+ short BROWS,
17
+ short BCOLS,
18
+ short dst_ld,
19
+ short reduction_dim,
20
+ short tgp_size,
21
+ short alignment = 1,
22
+ short n_reads = (BCOLS * BROWS) / (tgp_size),
23
+ short TCOLS = BCOLS / n_reads,
24
+ short TROWS = tgp_size / TCOLS>
25
+ struct BlockLoader {
26
+ STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
27
+ STEEL_CONST short vec_size = n_reads;
28
+
29
+ // Leading dimension for src
30
+ const int src_ld;
31
+ const int tile_stride;
32
+
33
+ // Thread location indices
34
+ const short thread_idx;
35
+ const short bi;
36
+ const short bj;
37
+
38
+ // threadgroup and device memory
39
+ threadgroup T* dst;
40
+ const device T* src;
41
+
42
+ struct alignas(alignment * sizeof(T)) ReadVector {
43
+ uint8_t v[sizeof(T) * vec_size];
44
+ };
45
+
46
+ /* Constructor */
47
+ METAL_FUNC BlockLoader(
48
+ const device T* src_,
49
+ const int src_ld_,
50
+ threadgroup T* dst_,
51
+ ushort simd_group_id [[simdgroup_index_in_threadgroup]],
52
+ ushort simd_lane_id [[thread_index_in_simdgroup]])
53
+ : src_ld(src_ld_),
54
+ tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
55
+ thread_idx(simd_group_id * 32 + simd_lane_id),
56
+ bi(thread_idx / TCOLS),
57
+ bj(vec_size * (thread_idx % TCOLS)),
58
+ dst(dst_ + bi * dst_ld + bj),
59
+ src(src_ + bi * src_ld + bj) {}
60
+
61
+ /* Apply operation to threadgroup without bound checking */
62
+ template <typename UnaryOp>
63
+ METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const {
64
+ STEEL_PRAGMA_UNROLL
65
+ for (short i = 0; i < BROWS; i += TROWS) {
66
+ STEEL_PRAGMA_UNROLL
67
+ for (short j = 0; j < vec_size; j++) {
68
+ dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]);
69
+ }
70
+ }
71
+ }
72
+
73
+ /* Load from device memory into threadgroup memory - without bound checking */
74
+ METAL_FUNC void load_unsafe() const {
75
+ STEEL_PRAGMA_UNROLL
76
+ for (short i = 0; i < BROWS; i += TROWS) {
77
+ *((threadgroup ReadVector*)(&dst[i * dst_ld])) =
78
+ *((const device ReadVector*)(&src[i * src_ld]));
79
+ }
80
+ }
81
+
82
+ /* Load from device memory into threadgroup memory - with bound checking */
83
+ METAL_FUNC void load_safe(short2 src_tile_dim) const {
84
+ src_tile_dim = src_tile_dim - short2(bj, bi);
85
+
86
+ // Skip loading if thread has no valid reads
87
+ if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
88
+ STEEL_PRAGMA_UNROLL
89
+ for (short i = 0; i < BROWS; i += TROWS) {
90
+ STEEL_PRAGMA_UNROLL
91
+ for (short j = 0; j < vec_size; j++) {
92
+ dst[i * dst_ld + j] = T(0);
93
+ }
94
+ }
95
+ return;
96
+ }
97
+
98
+ // Use fast thread memory for bound checks
99
+ bool tmp_idx[vec_size];
100
+ T tmp_val[vec_size];
101
+
102
+ STEEL_PRAGMA_UNROLL
103
+ for (short i = 0; i < BROWS; i += TROWS) {
104
+ // Make sure tmp_idx only contains valid indices
105
+ STEEL_PRAGMA_UNROLL
106
+ for (short j = 0; j < vec_size; j++) {
107
+ tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
108
+ }
109
+
110
+ // Read valid indices into tmp_val
111
+ STEEL_PRAGMA_UNROLL
112
+ for (short j = 0; j < vec_size; j++) {
113
+ tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
114
+ }
115
+
116
+ // Zero out unneeded values
117
+ STEEL_PRAGMA_UNROLL
118
+ for (short j = 0; j < vec_size; j++) {
119
+ tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
120
+ }
121
+
122
+ // Copy values to threadgroup memory
123
+ STEEL_PRAGMA_UNROLL
124
+ for (short j = 0; j < vec_size; j++) {
125
+ dst[i * dst_ld + j] = tmp_val[j];
126
+ }
127
+ }
128
+ }
129
+
130
+ /* Iteration helper */
131
+ METAL_FUNC void next() {
132
+ src += tile_stride;
133
+ }
134
+ };
135
+
136
+ } // namespace steel
137
+ } // namespace mlx