mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,459 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ using namespace mlx::steel;
4
+
5
+ constant bool has_batch [[function_constant(10)]];
6
+ constant bool align_M [[function_constant(200)]];
7
+ constant bool align_N [[function_constant(201)]];
8
+ constant bool align_K [[function_constant(202)]];
9
+
10
+ template <
11
+ typename T,
12
+ int BM,
13
+ int BN,
14
+ int BK,
15
+ int WM,
16
+ int WN,
17
+ bool transpose_a,
18
+ bool transpose_b,
19
+ typename AccumType = float>
20
+ [[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gather_mm_rhs(
21
+ const device T* A [[buffer(0)]],
22
+ const device T* B [[buffer(1)]],
23
+ const device uint32_t* rhs_indices [[buffer(2)]],
24
+ device T* C [[buffer(3)]],
25
+ const constant GEMMParams* params [[buffer(4)]],
26
+ uint simd_lane_id [[thread_index_in_simdgroup]],
27
+ uint simd_group_id [[simdgroup_index_in_threadgroup]],
28
+ uint3 tid [[threadgroup_position_in_grid]]) {
29
+ using gemm_kernel = GEMMKernel<
30
+ T,
31
+ T,
32
+ BM,
33
+ BN,
34
+ BK,
35
+ WM,
36
+ WN,
37
+ transpose_a,
38
+ transpose_b,
39
+ true,
40
+ true,
41
+ AccumType>;
42
+
43
+ using loader_a_t = typename gemm_kernel::loader_a_t;
44
+ using loader_b_t = typename gemm_kernel::loader_b_t;
45
+ using mma_t = typename gemm_kernel::mma_t;
46
+
47
+ if (params->tiles_n <= static_cast<int>(tid.x) ||
48
+ params->tiles_m <= static_cast<int>(tid.y)) {
49
+ return;
50
+ }
51
+
52
+ // Prepare threadgroup memory
53
+ threadgroup T As[gemm_kernel::tgp_mem_size_a];
54
+ threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
55
+
56
+ // Find the block in A, B, C
57
+ const int c_row = tid.y * BM;
58
+ const int c_col = tid.x * BN;
59
+ const size_t c_row_long = size_t(c_row);
60
+ const size_t c_col_long = size_t(c_col);
61
+
62
+ // Prepare threadgroup bounds
63
+ const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
64
+ const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
65
+
66
+ A += transpose_a ? c_row_long : c_row_long * params->lda;
67
+ B += transpose_b ? c_col_long * params->ldb : c_col_long;
68
+ C += c_row_long * params->ldd + c_col_long;
69
+
70
+ // Do as many matmuls as necessary
71
+ uint32_t index;
72
+ short offset;
73
+ uint32_t index_next = rhs_indices[c_row];
74
+ short offset_next = 0;
75
+ int n = 0;
76
+ while (n < tgp_bm) {
77
+ n++;
78
+ offset = offset_next;
79
+ index = index_next;
80
+ offset_next = tgp_bm;
81
+ for (; n < tgp_bm; n++) {
82
+ if (rhs_indices[c_row + n] != index) {
83
+ offset_next = n;
84
+ index_next = rhs_indices[c_row + n];
85
+ break;
86
+ }
87
+ }
88
+ threadgroup_barrier(mem_flags::mem_none);
89
+
90
+ // Prepare threadgroup mma operation
91
+ thread mma_t mma_op(simd_group_id, simd_lane_id);
92
+
93
+ // Prepare threadgroup loading operations
94
+ thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
95
+ thread loader_b_t loader_b(
96
+ B + index * params->batch_stride_b,
97
+ params->ldb,
98
+ Bs,
99
+ simd_group_id,
100
+ simd_lane_id);
101
+
102
+ // Prepare iterations
103
+ const int gemm_k_iterations = params->gemm_k_iterations_aligned;
104
+
105
+ // Do unaligned K iterations first
106
+ if (!align_K) {
107
+ const int k_last = params->gemm_k_iterations_aligned * BK;
108
+ const int k_remain = params->K - k_last;
109
+ const size_t k_jump_a =
110
+ transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
111
+ const size_t k_jump_b =
112
+ transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
113
+
114
+ // Move loader source ahead to end
115
+ loader_a.src += k_jump_a;
116
+ loader_b.src += k_jump_b;
117
+
118
+ // Load tile
119
+ const short2 tile_dims_A =
120
+ transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
121
+ const short2 tile_dims_B =
122
+ transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
123
+
124
+ loader_a.load_safe(tile_dims_A);
125
+ loader_b.load_safe(tile_dims_B);
126
+
127
+ threadgroup_barrier(mem_flags::mem_threadgroup);
128
+
129
+ // Do matmul
130
+ mma_op.mma(As, Bs);
131
+
132
+ // Reset source back to start
133
+ loader_a.src -= k_jump_a;
134
+ loader_b.src -= k_jump_b;
135
+ }
136
+
137
+ // Matrix level aligned never check
138
+ if (align_M && align_N) {
139
+ for (int k = 0; k < gemm_k_iterations; k++) {
140
+ threadgroup_barrier(mem_flags::mem_threadgroup);
141
+
142
+ // Load elements into threadgroup
143
+ loader_a.load_unsafe();
144
+ loader_b.load_unsafe();
145
+
146
+ threadgroup_barrier(mem_flags::mem_threadgroup);
147
+
148
+ // Multiply and accumulate threadgroup elements
149
+ mma_op.mma(As, Bs);
150
+
151
+ // Prepare for next iteration
152
+ loader_a.next();
153
+ loader_b.next();
154
+ }
155
+
156
+ // Store results to device memory
157
+ if (offset_next - offset == BM) {
158
+ mma_op.store_result(C, params->ldd);
159
+ } else {
160
+ mma_op.store_result_slice(
161
+ C, params->ldd, short2(0, offset), short2(BN, offset_next));
162
+ }
163
+ } else {
164
+ const short lbk = 0;
165
+
166
+ // Tile aligned don't check
167
+ if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
168
+ gemm_kernel::gemm_loop(
169
+ As,
170
+ Bs,
171
+ gemm_k_iterations,
172
+ loader_a,
173
+ loader_b,
174
+ mma_op,
175
+ tgp_bm,
176
+ tgp_bn,
177
+ lbk,
178
+ LoopAlignment<true, true, true>{});
179
+ if (offset_next - offset == BM) {
180
+ mma_op.store_result(C, params->ldd);
181
+ } else {
182
+ mma_op.store_result_slice(
183
+ C, params->ldd, short2(0, offset), short2(BN, offset_next));
184
+ }
185
+ }
186
+
187
+ // Tile partially aligned check rows
188
+ else if (align_N || tgp_bn == BN) {
189
+ gemm_kernel::gemm_loop(
190
+ As,
191
+ Bs,
192
+ gemm_k_iterations,
193
+ loader_a,
194
+ loader_b,
195
+ mma_op,
196
+ tgp_bm,
197
+ tgp_bn,
198
+ lbk,
199
+ LoopAlignment<false, true, true>{});
200
+ mma_op.store_result_slice(
201
+ C, params->ldd, short2(0, offset), short2(BN, offset_next));
202
+ }
203
+
204
+ // Tile partially aligned check cols
205
+ else if (align_M || tgp_bm == BM) {
206
+ gemm_kernel::gemm_loop(
207
+ As,
208
+ Bs,
209
+ gemm_k_iterations,
210
+ loader_a,
211
+ loader_b,
212
+ mma_op,
213
+ tgp_bm,
214
+ tgp_bn,
215
+ lbk,
216
+ LoopAlignment<true, false, true>{});
217
+ mma_op.store_result_slice(
218
+ C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next));
219
+ }
220
+
221
+ // Nothing aligned so check both rows and cols
222
+ else {
223
+ gemm_kernel::gemm_loop(
224
+ As,
225
+ Bs,
226
+ gemm_k_iterations,
227
+ loader_a,
228
+ loader_b,
229
+ mma_op,
230
+ tgp_bm,
231
+ tgp_bn,
232
+ lbk,
233
+ LoopAlignment<false, false, true>{});
234
+ mma_op.store_result_slice(
235
+ C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next));
236
+ }
237
+ }
238
+ }
239
+ }
240
+
241
+ template <
242
+ typename T,
243
+ int BM,
244
+ int BN,
245
+ int BK,
246
+ int WM,
247
+ int WN,
248
+ bool transpose_a,
249
+ bool transpose_b,
250
+ typename AccumType = float>
251
+ [[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gather_mm(
252
+ const device T* A [[buffer(0)]],
253
+ const device T* B [[buffer(1)]],
254
+ const device uint32_t* lhs_indices [[buffer(2)]],
255
+ const device uint32_t* rhs_indices [[buffer(3)]],
256
+ device T* C [[buffer(4)]],
257
+ const constant GEMMParams* params [[buffer(5)]],
258
+ const constant int* indices_shape [[buffer(6)]],
259
+ const constant int64_t* lhs_strides [[buffer(7)]],
260
+ const constant int64_t* rhs_strides [[buffer(8)]],
261
+ const constant int& batch_ndim_a [[buffer(9)]],
262
+ const constant int* batch_shape_a [[buffer(10)]],
263
+ const constant int64_t* batch_strides_a [[buffer(11)]],
264
+ const constant int& batch_ndim_b [[buffer(12)]],
265
+ const constant int* batch_shape_b [[buffer(13)]],
266
+ const constant int64_t* batch_strides_b [[buffer(14)]],
267
+ uint simd_lane_id [[thread_index_in_simdgroup]],
268
+ uint simd_group_id [[simdgroup_index_in_threadgroup]],
269
+ uint3 tid [[threadgroup_position_in_grid]]) {
270
+ using gemm_kernel = GEMMKernel<
271
+ T,
272
+ T,
273
+ BM,
274
+ BN,
275
+ BK,
276
+ WM,
277
+ WN,
278
+ transpose_a,
279
+ transpose_b,
280
+ true,
281
+ true,
282
+ AccumType>;
283
+
284
+ using loader_a_t = typename gemm_kernel::loader_a_t;
285
+ using loader_b_t = typename gemm_kernel::loader_b_t;
286
+ using mma_t = typename gemm_kernel::mma_t;
287
+
288
+ if (params->tiles_n <= static_cast<int>(tid.x) ||
289
+ params->tiles_m <= static_cast<int>(tid.y)) {
290
+ return;
291
+ }
292
+
293
+ // Move A and B to the locations pointed by lhs_indices and rhs_indices.
294
+ uint32_t indx_A, indx_B;
295
+ if (has_batch) {
296
+ ulong2 indices_offsets = elem_to_loc_broadcast(
297
+ tid.z, indices_shape, lhs_strides, rhs_strides, params->batch_ndim);
298
+ indx_A = lhs_indices[indices_offsets.x];
299
+ indx_B = rhs_indices[indices_offsets.y];
300
+ } else {
301
+ indx_A = lhs_indices[params->batch_stride_a * tid.z];
302
+ indx_B = rhs_indices[params->batch_stride_b * tid.z];
303
+ }
304
+ A += elem_to_loc(indx_A, batch_shape_a, batch_strides_a, batch_ndim_a);
305
+ B += elem_to_loc(indx_B, batch_shape_b, batch_strides_b, batch_ndim_b);
306
+ C += params->batch_stride_d * tid.z;
307
+
308
+ // Prepare threadgroup memory
309
+ threadgroup T As[gemm_kernel::tgp_mem_size_a];
310
+ threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
311
+
312
+ // Just make sure everybody's finished with the indexing math above.
313
+ threadgroup_barrier(mem_flags::mem_none);
314
+
315
+ // Find block in A, B, C
316
+ const int c_row = tid.y * BM;
317
+ const int c_col = tid.x * BN;
318
+ const size_t c_row_long = size_t(c_row);
319
+ const size_t c_col_long = size_t(c_col);
320
+
321
+ A += transpose_a ? c_row_long : c_row_long * params->lda;
322
+ B += transpose_b ? c_col_long * params->ldb : c_col_long;
323
+ C += c_row_long * params->ldd + c_col_long;
324
+
325
+ // Prepare threadgroup mma operation
326
+ thread mma_t mma_op(simd_group_id, simd_lane_id);
327
+
328
+ // Prepare threadgroup loading operations
329
+ thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
330
+ thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
331
+
332
+ // Prepare threadgroup bounds
333
+ const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
334
+ const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
335
+
336
+ // Prepare iterations
337
+ int gemm_k_iterations = params->gemm_k_iterations_aligned;
338
+
339
+ // Do unaligned K iterations first
340
+ if (!align_K) {
341
+ const int k_last = params->gemm_k_iterations_aligned * BK;
342
+ const int k_remain = params->K - k_last;
343
+ const size_t k_jump_a =
344
+ transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
345
+ const size_t k_jump_b =
346
+ transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
347
+
348
+ // Move loader source ahead to end
349
+ loader_a.src += k_jump_a;
350
+ loader_b.src += k_jump_b;
351
+
352
+ // Load tile
353
+ const short2 tile_dims_A =
354
+ transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
355
+ const short2 tile_dims_B =
356
+ transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
357
+
358
+ loader_a.load_safe(tile_dims_A);
359
+ loader_b.load_safe(tile_dims_B);
360
+
361
+ threadgroup_barrier(mem_flags::mem_threadgroup);
362
+
363
+ // Do matmul
364
+ mma_op.mma(As, Bs);
365
+
366
+ // Reset source back to start
367
+ loader_a.src -= k_jump_a;
368
+ loader_b.src -= k_jump_b;
369
+ }
370
+
371
+ // Matrix level aligned never check
372
+ if (align_M && align_N) {
373
+ for (int k = 0; k < gemm_k_iterations; k++) {
374
+ threadgroup_barrier(mem_flags::mem_threadgroup);
375
+
376
+ // Load elements into threadgroup
377
+ loader_a.load_unsafe();
378
+ loader_b.load_unsafe();
379
+
380
+ threadgroup_barrier(mem_flags::mem_threadgroup);
381
+
382
+ // Multiply and accumulate threadgroup elements
383
+ mma_op.mma(As, Bs);
384
+
385
+ // Prepare for next iteration
386
+ loader_a.next();
387
+ loader_b.next();
388
+ }
389
+
390
+ // Store results to device memory
391
+ mma_op.store_result(C, params->ldd);
392
+ } else {
393
+ const short lbk = 0;
394
+
395
+ // Tile aligned don't check
396
+ if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
397
+ gemm_kernel::gemm_loop(
398
+ As,
399
+ Bs,
400
+ gemm_k_iterations,
401
+ loader_a,
402
+ loader_b,
403
+ mma_op,
404
+ tgp_bm,
405
+ tgp_bn,
406
+ lbk,
407
+ LoopAlignment<true, true, true>{});
408
+ mma_op.store_result(C, params->ldd);
409
+ }
410
+
411
+ // Tile partially aligned check rows
412
+ else if (align_N || tgp_bn == BN) {
413
+ gemm_kernel::gemm_loop(
414
+ As,
415
+ Bs,
416
+ gemm_k_iterations,
417
+ loader_a,
418
+ loader_b,
419
+ mma_op,
420
+ tgp_bm,
421
+ tgp_bn,
422
+ lbk,
423
+ LoopAlignment<false, true, true>{});
424
+ mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
425
+ }
426
+
427
+ // Tile partially aligned check cols
428
+ else if (align_M || tgp_bm == BM) {
429
+ gemm_kernel::gemm_loop(
430
+ As,
431
+ Bs,
432
+ gemm_k_iterations,
433
+ loader_a,
434
+ loader_b,
435
+ mma_op,
436
+ tgp_bm,
437
+ tgp_bn,
438
+ lbk,
439
+ LoopAlignment<true, false, true>{});
440
+ mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
441
+ }
442
+
443
+ // Nothing aligned so check both rows and cols
444
+ else {
445
+ gemm_kernel::gemm_loop(
446
+ As,
447
+ Bs,
448
+ gemm_k_iterations,
449
+ loader_a,
450
+ loader_b,
451
+ mma_op,
452
+ tgp_bm,
453
+ tgp_bn,
454
+ lbk,
455
+ LoopAlignment<false, false, true>{});
456
+ mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
457
+ }
458
+ }
459
+ }
@@ -0,0 +1,132 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ using namespace mlx::steel;
4
+
5
+ constant bool align_M [[function_constant(200)]];
6
+ constant bool align_N [[function_constant(201)]];
7
+ constant bool align_K [[function_constant(202)]];
8
+
9
+ template <
10
+ typename T,
11
+ int BM,
12
+ int BN,
13
+ int BK,
14
+ int WM,
15
+ int WN,
16
+ bool transpose_a,
17
+ bool transpose_b,
18
+ typename AccumType = float>
19
+ [[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
20
+ gather_mm_rhs_nax(
21
+ const device T* A [[buffer(0)]],
22
+ const device T* B [[buffer(1)]],
23
+ const device uint32_t* rhs_indices [[buffer(2)]],
24
+ device T* C [[buffer(3)]],
25
+ const constant GEMMParams* params [[buffer(4)]],
26
+ uint simd_group_id [[simdgroup_index_in_threadgroup]],
27
+ uint3 tid [[threadgroup_position_in_grid]]) {
28
+ constexpr short UM = 16;
29
+ constexpr short UN = 32;
30
+ constexpr short UK = 16;
31
+ constexpr short SM = BM / WM;
32
+ constexpr short SN = BN / WN;
33
+ constexpr short SK = 32;
34
+ constexpr short TM = SM / UM;
35
+ constexpr short TN = SN / UN;
36
+
37
+ if (params->tiles_n <= static_cast<int>(tid.x) ||
38
+ params->tiles_m <= static_cast<int>(tid.y)) {
39
+ return;
40
+ }
41
+
42
+ // Find the block in A, B, C
43
+ const int c_row = tid.y * BM;
44
+ const int c_col = tid.x * BN;
45
+ const size_t c_row_long = size_t(c_row);
46
+ const size_t c_col_long = size_t(c_col);
47
+
48
+ A += transpose_a ? c_row_long : c_row_long * params->lda;
49
+ B += transpose_b ? c_col_long * params->ldb : c_col_long;
50
+ C += c_row_long * params->ldd + c_col_long;
51
+ rhs_indices += c_row;
52
+
53
+ const short tm = SM * (simd_group_id / WN);
54
+ const short tn = SN * (simd_group_id % WN);
55
+
56
+ const short sgp_sm = align_M ? SM : min(SM, short(params->M - (c_row + tm)));
57
+ const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM);
58
+
59
+ const short sgp_sn = align_N ? SN : min(SN, short(params->N - (c_col + tn)));
60
+ const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN);
61
+
62
+ A += transpose_a ? tm : (tm * params->lda);
63
+ B += transpose_b ? (tn * params->ldb) : tn;
64
+ C += tm * params->ldd + tn;
65
+ rhs_indices += tm;
66
+
67
+ // Do as many matmuls as necessary
68
+ uint32_t index;
69
+ short offset;
70
+ uint32_t index_next = rhs_indices[0];
71
+ short offset_next = 0;
72
+ int n = 0;
73
+ while (n < sgp_sm) {
74
+ n++;
75
+ offset = offset_next;
76
+ index = index_next;
77
+ offset_next = sgp_sm;
78
+ for (; n < sgp_sm; n++) {
79
+ if (rhs_indices[n] != index) {
80
+ offset_next = n;
81
+ index_next = rhs_indices[n];
82
+ break;
83
+ }
84
+ }
85
+ threadgroup_barrier(mem_flags::mem_none);
86
+
87
+ using DSubTile = NAXSubTile<AccumType, UM, UN>;
88
+ NAXTile<AccumType, TM, TN, DSubTile> Ctile;
89
+
90
+ dispatch_bool(align_K, [&](auto kAlignedK) {
91
+ dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) {
92
+ dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) {
93
+ auto do_gemm = gemm_loop<
94
+ T,
95
+ SM,
96
+ SN,
97
+ SK,
98
+ BK,
99
+ transpose_a,
100
+ transpose_b,
101
+ kAlignedM.value,
102
+ kAlignedN.value,
103
+ kAlignedK.value,
104
+ UM,
105
+ UN,
106
+ UK,
107
+ AccumType>;
108
+ Ctile = do_gemm(
109
+ A, B + index * params->batch_stride_b, params, sgp_sm, sgp_sn);
110
+
111
+ if constexpr (kAlignedN.value) {
112
+ if (offset_next - offset == SM) {
113
+ Ctile.store(C, int(params->ldd));
114
+ } else {
115
+ Ctile.store_slice(
116
+ C,
117
+ int(params->ldd),
118
+ short2(0, offset),
119
+ short2(SN, offset_next));
120
+ }
121
+ } else {
122
+ Ctile.store_slice(
123
+ C,
124
+ int(params->ldd),
125
+ short2(0, offset),
126
+ short2(sgp_sn, offset_next));
127
+ }
128
+ });
129
+ });
130
+ });
131
+ }
132
+ }