mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,719 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #include "mlx/backend/metal/kernels/steel/defines.h"
4
+ using namespace metal;
5
+ using namespace mlx::steel;
6
+
7
+ ///////////////////////////////////////////////////////////////////////////////
8
+ // GEMM kernels
9
+ ///////////////////////////////////////////////////////////////////////////////
10
+
11
+ struct _NoMask {
12
+ char x;
13
+
14
+ constexpr METAL_FUNC operator bool() {
15
+ return true;
16
+ }
17
+ constexpr METAL_FUNC operator bool() const threadgroup {
18
+ return true;
19
+ }
20
+ constexpr METAL_FUNC operator bool() const device {
21
+ return true;
22
+ }
23
+ constexpr METAL_FUNC operator bool() const constant {
24
+ return true;
25
+ }
26
+ };
27
+
28
+ template <typename OutT, typename InT = OutT>
29
+ struct ScaleOp {
30
+ OutT scale;
31
+
32
+ METAL_FUNC OutT apply(InT x) const {
33
+ return static_cast<OutT>(x) * scale;
34
+ }
35
+ };
36
+
37
+ typedef struct _NoMask nomask_t;
38
+
39
+ template <
40
+ typename T,
41
+ typename out_mask_t,
42
+ typename op_mask_t,
43
+ int BM,
44
+ int BN,
45
+ int BK,
46
+ int WM,
47
+ int WN,
48
+ bool transpose_a,
49
+ bool transpose_b,
50
+ bool MN_aligned,
51
+ bool K_aligned>
52
+ [[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
53
+ block_masked_gemm(
54
+ const device T* A [[buffer(0)]],
55
+ const device T* B [[buffer(1)]],
56
+ device T* D [[buffer(3)]],
57
+ const constant GEMMParams* params [[buffer(4)]],
58
+ const constant int* batch_shape [[buffer(6)]],
59
+ const constant int64_t* batch_strides [[buffer(7)]],
60
+ const device out_mask_t* out_mask [[buffer(10)]],
61
+ const device op_mask_t* lhs_mask [[buffer(11)]],
62
+ const device op_mask_t* rhs_mask [[buffer(12)]],
63
+ const constant int* mask_strides [[buffer(13)]],
64
+ uint simd_lane_id [[thread_index_in_simdgroup]],
65
+ uint simd_group_id [[simdgroup_index_in_threadgroup]],
66
+ uint3 tid [[threadgroup_position_in_grid]],
67
+ uint3 lid [[thread_position_in_threadgroup]]) {
68
+ // Appease the compiler
69
+ (void)lid;
70
+
71
+ static_assert(
72
+ BM == BN,
73
+ "block_masked_gemm must have the same block M and block N size");
74
+ static_assert(BM % BK == 0, "block_masked_gemm must have BM % BK == 0");
75
+
76
+ constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
77
+ constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
78
+
79
+ constexpr bool has_mul_operand_mask =
80
+ has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
81
+ constexpr bool has_mul_output_mask =
82
+ has_output_mask && !metal::is_same_v<out_mask_t, bool>;
83
+
84
+ constexpr short k_mask_factor = short(BM / BK);
85
+
86
+ using gemm_kernel = GEMMKernel<
87
+ T,
88
+ T,
89
+ BM,
90
+ BN,
91
+ BK,
92
+ WM,
93
+ WN,
94
+ transpose_a,
95
+ transpose_b,
96
+ MN_aligned,
97
+ K_aligned>;
98
+
99
+ const int tid_y = ((tid.y) << params->swizzle_log) +
100
+ ((tid.x) & ((1 << params->swizzle_log) - 1));
101
+ const int tid_x = (tid.x) >> params->swizzle_log;
102
+
103
+ if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
104
+ return;
105
+ }
106
+
107
+ const constant auto* mask_batch_strides =
108
+ batch_strides + 2 * params->batch_ndim;
109
+
110
+ if (params->batch_ndim > 1) {
111
+ if (has_output_mask) {
112
+ out_mask += elem_to_loc(
113
+ tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
114
+
115
+ mask_batch_strides += params->batch_ndim;
116
+ }
117
+
118
+ if (has_operand_mask) {
119
+ const constant auto* mask_strides_lhs = mask_batch_strides;
120
+ const constant auto* mask_strides_rhs =
121
+ mask_strides_lhs + params->batch_ndim;
122
+
123
+ ulong2 batch_offsets = elem_to_loc_broadcast(
124
+ tid.z,
125
+ batch_shape,
126
+ mask_strides_lhs,
127
+ mask_strides_rhs,
128
+ params->batch_ndim);
129
+
130
+ lhs_mask += batch_offsets.x;
131
+ rhs_mask += batch_offsets.y;
132
+ }
133
+ } else {
134
+ if (has_output_mask) {
135
+ out_mask += tid.z * mask_batch_strides[0];
136
+ mask_batch_strides += params->batch_ndim;
137
+ }
138
+
139
+ if (has_operand_mask) {
140
+ lhs_mask += tid.z * mask_batch_strides[0];
141
+ rhs_mask += tid.z * mask_batch_strides[params->batch_ndim];
142
+ }
143
+ }
144
+
145
+ // Adjust for batch
146
+ if (params->batch_ndim > 1) {
147
+ const constant auto* A_bstrides = batch_strides;
148
+ const constant auto* B_bstrides = batch_strides + params->batch_ndim;
149
+
150
+ ulong2 batch_offsets = elem_to_loc_broadcast(
151
+ tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
152
+
153
+ A += batch_offsets.x;
154
+ B += batch_offsets.y;
155
+
156
+ } else {
157
+ A += params->batch_stride_a * tid.z;
158
+ B += params->batch_stride_b * tid.z;
159
+ }
160
+
161
+ D += params->batch_stride_d * tid.z;
162
+
163
+ // Find block in A, B, C
164
+ const int c_row = tid_y * BM;
165
+ const int c_col = tid_x * BN;
166
+ const size_t c_row_long = size_t(c_row);
167
+ const size_t c_col_long = size_t(c_col);
168
+
169
+ A += transpose_a ? c_row_long : c_row_long * params->lda;
170
+ B += transpose_b ? c_col_long * params->ldb : c_col_long;
171
+ D += c_row_long * params->ldd + c_col_long;
172
+
173
+ const constant int* out_mask_strides = mask_strides;
174
+ const constant int* lhs_mask_strides =
175
+ mask_strides + (has_output_mask ? 2 : 0);
176
+ const constant int* rhs_mask_strides =
177
+ lhs_mask_strides + (has_operand_mask ? 2 : 0);
178
+
179
+ const int out_mask_offset = !has_output_mask
180
+ ? 0
181
+ : tid_y * out_mask_strides[1] + tid_x * out_mask_strides[0];
182
+ int lhs_mask_offset = !has_operand_mask ? 0 : tid_y * lhs_mask_strides[1];
183
+ int rhs_mask_offset = !has_operand_mask ? 0 : tid_x * rhs_mask_strides[0];
184
+ const int lhs_mask_step = !has_operand_mask ? 0 : lhs_mask_strides[0];
185
+ const int rhs_mask_step = !has_operand_mask ? 0 : rhs_mask_strides[1];
186
+ short k_factor_cnt = k_mask_factor;
187
+
188
+ ScaleOp<float> out_mask_op;
189
+ ScaleOp<T> lhs_mask_op;
190
+ ScaleOp<T> rhs_mask_op;
191
+
192
+ if (has_output_mask) {
193
+ auto mask_out = out_mask[out_mask_offset];
194
+
195
+ if (has_mul_output_mask) {
196
+ out_mask_op.scale = float(mask_out);
197
+ }
198
+
199
+ // Write zeros and return
200
+ if (!mask_out) {
201
+ constexpr short tgp_size = WM * WN * 32;
202
+ constexpr short vec_size = 4;
203
+
204
+ // Tile threads in threadgroup
205
+ constexpr short TN = BN / vec_size;
206
+ constexpr short TM = tgp_size / TN;
207
+
208
+ const short thread_idx = simd_group_id * 32 + simd_lane_id;
209
+ const short bi = thread_idx / TN;
210
+ const short bj = vec_size * (thread_idx % TN);
211
+
212
+ D += bi * params->ldd + bj;
213
+
214
+ short tgp_bm = min(BM, params->M - c_row);
215
+ short tgp_bn = min(BN, params->N - c_col);
216
+
217
+ if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
218
+ for (short ti = 0; ti < BM; ti += TM) {
219
+ STEEL_PRAGMA_UNROLL
220
+ for (short j = 0; j < vec_size; j++) {
221
+ D[ti * params->ldd + j] = T(0.);
222
+ }
223
+ }
224
+ } else {
225
+ short jmax = tgp_bn - bj;
226
+ jmax = jmax < vec_size ? jmax : vec_size;
227
+ for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) {
228
+ for (short j = 0; j < jmax; j++) {
229
+ D[ti * params->ldd + j] = T(0.);
230
+ }
231
+ }
232
+ }
233
+
234
+ return;
235
+ }
236
+ }
237
+
238
+ threadgroup_barrier(mem_flags::mem_none);
239
+
240
+ // Prepare threadgroup mma operation
241
+ thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id);
242
+
243
+ threadgroup T As[gemm_kernel::tgp_mem_size_a];
244
+ threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
245
+
246
+ // Prepare threadgroup loading operations
247
+ thread typename gemm_kernel::loader_a_t loader_a(
248
+ A, params->lda, As, simd_group_id, simd_lane_id);
249
+ thread typename gemm_kernel::loader_b_t loader_b(
250
+ B, params->ldb, Bs, simd_group_id, simd_lane_id);
251
+
252
+ // Prepare threadgroup bounds
253
+ const short tgp_bm =
254
+ MN_aligned ? short(BM) : short(min(BM, params->M - c_row));
255
+ const short tgp_bn =
256
+ MN_aligned ? short(BN) : short(min(BN, params->N - c_col));
257
+
258
+ int gemm_k_iterations = params->gemm_k_iterations_aligned;
259
+
260
+ ///////////////////////////////////////////////////////////////////////////////
261
+ // Do unaligned K iterations first
262
+ if (!K_aligned) {
263
+ const int k_last = params->gemm_k_iterations_aligned * BK;
264
+ const int mask_idx_last = k_last / BM;
265
+
266
+ if (!has_operand_mask ||
267
+ (bool(lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step]) &&
268
+ bool(rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step]))) {
269
+ if (has_mul_operand_mask) {
270
+ lhs_mask_op.scale =
271
+ lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step];
272
+ rhs_mask_op.scale =
273
+ rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step];
274
+ }
275
+
276
+ // Move loader source ahead to end
277
+ const int k_remain = params->K - k_last;
278
+ const size_t k_jump_a =
279
+ transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
280
+ const size_t k_jump_b =
281
+ transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
282
+
283
+ loader_a.src += k_jump_a;
284
+ loader_b.src += k_jump_b;
285
+
286
+ // Load tile
287
+ const short2 tile_dims_A =
288
+ transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
289
+ const short2 tile_dims_B =
290
+ transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
291
+
292
+ loader_a.load_safe(tile_dims_A);
293
+ loader_b.load_safe(tile_dims_B);
294
+
295
+ if (has_mul_operand_mask) {
296
+ loader_a.apply_inplace_op(lhs_mask_op);
297
+ loader_b.apply_inplace_op(rhs_mask_op);
298
+ }
299
+
300
+ threadgroup_barrier(mem_flags::mem_threadgroup);
301
+
302
+ // Do matmul
303
+ mma_op.mma(As, Bs);
304
+
305
+ // Reset source back to start
306
+ loader_a.src -= k_jump_a;
307
+ loader_b.src -= k_jump_b;
308
+ }
309
+ }
310
+
311
+ ///////////////////////////////////////////////////////////////////////////////
312
+ // MNK aligned loop
313
+ if (MN_aligned) {
314
+ for (; gemm_k_iterations > 0; gemm_k_iterations--) {
315
+ threadgroup_barrier(mem_flags::mem_threadgroup);
316
+
317
+ if (!has_operand_mask ||
318
+ (bool(lhs_mask[lhs_mask_offset]) &&
319
+ bool(rhs_mask[rhs_mask_offset]))) {
320
+ if (has_mul_operand_mask) {
321
+ lhs_mask_op.scale = lhs_mask[lhs_mask_offset];
322
+ rhs_mask_op.scale = rhs_mask[rhs_mask_offset];
323
+ }
324
+
325
+ // Load elements into threadgroup
326
+ loader_a.load_unsafe();
327
+ loader_b.load_unsafe();
328
+
329
+ if (has_mul_operand_mask) {
330
+ loader_a.apply_inplace_op(lhs_mask_op);
331
+ loader_b.apply_inplace_op(rhs_mask_op);
332
+ }
333
+
334
+ threadgroup_barrier(mem_flags::mem_threadgroup);
335
+
336
+ // Multiply and accumulate threadgroup elements
337
+ mma_op.mma(As, Bs);
338
+ }
339
+
340
+ // Prepare for next iteration
341
+ loader_a.next();
342
+ loader_b.next();
343
+
344
+ k_factor_cnt--;
345
+ lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0;
346
+ rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0;
347
+ k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt;
348
+ }
349
+
350
+ if (has_mul_output_mask) {
351
+ mma_op.apply_epilogue(out_mask_op);
352
+ }
353
+
354
+ // Store results to device memory
355
+ mma_op.store_result(D, params->ldd);
356
+ return;
357
+
358
+ }
359
+ ///////////////////////////////////////////////////////////////////////////////
360
+ // MN unaligned loop
361
+ else {
362
+ const bool M_aligned = (tgp_bm == BM);
363
+ const bool N_aligned = (tgp_bn == BN);
364
+
365
+ const short2 tile_dims_A =
366
+ transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
367
+ const short2 tile_dims_B =
368
+ transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
369
+
370
+ for (; gemm_k_iterations > 0; gemm_k_iterations--) {
371
+ threadgroup_barrier(mem_flags::mem_threadgroup);
372
+ if (!has_operand_mask ||
373
+ (bool(lhs_mask[lhs_mask_offset]) &&
374
+ bool(rhs_mask[rhs_mask_offset]))) {
375
+ if (has_mul_operand_mask) {
376
+ lhs_mask_op.scale = lhs_mask[lhs_mask_offset];
377
+ rhs_mask_op.scale = rhs_mask[rhs_mask_offset];
378
+ }
379
+
380
+ // Load elements into threadgroup
381
+ if (M_aligned) {
382
+ loader_a.load_unsafe();
383
+ } else {
384
+ loader_a.load_safe(tile_dims_A);
385
+ }
386
+
387
+ if (N_aligned) {
388
+ loader_b.load_unsafe();
389
+ } else {
390
+ loader_b.load_safe(tile_dims_B);
391
+ }
392
+
393
+ if (has_mul_operand_mask) {
394
+ loader_a.apply_inplace_op(lhs_mask_op);
395
+ loader_b.apply_inplace_op(rhs_mask_op);
396
+ }
397
+
398
+ threadgroup_barrier(mem_flags::mem_threadgroup);
399
+
400
+ // Multiply and accumulate threadgroup elements
401
+ mma_op.mma(As, Bs);
402
+ }
403
+
404
+ // Prepare for next iteration
405
+ loader_a.next();
406
+ loader_b.next();
407
+
408
+ k_factor_cnt--;
409
+ lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0;
410
+ rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0;
411
+ k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt;
412
+ }
413
+
414
+ if (has_mul_output_mask) {
415
+ mma_op.apply_epilogue(out_mask_op);
416
+ }
417
+
418
+ if (M_aligned && N_aligned) {
419
+ mma_op.store_result(D, params->ldd);
420
+ } else {
421
+ mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
422
+ }
423
+ }
424
+ }
425
+
426
+ template <
427
+ typename T,
428
+ int BM,
429
+ int BN,
430
+ int BK,
431
+ int WM,
432
+ int WN,
433
+ bool transpose_a,
434
+ bool transpose_b,
435
+ bool MN_aligned,
436
+ bool K_aligned,
437
+ bool has_operand_mask = false>
438
+ [[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
439
+ block_masked_gemm(
440
+ const device T* A [[buffer(0)]],
441
+ const device T* B [[buffer(1)]],
442
+ device T* D [[buffer(3)]],
443
+ const constant GEMMParams* params [[buffer(4)]],
444
+ const constant int* batch_shape [[buffer(6)]],
445
+ const constant int64_t* batch_strides [[buffer(7)]],
446
+ const device bool* out_mask [[buffer(10)]],
447
+ const device bool* lhs_mask [[buffer(11)]],
448
+ const device bool* rhs_mask [[buffer(12)]],
449
+ const constant int* mask_strides [[buffer(13)]],
450
+ uint simd_lane_id [[thread_index_in_simdgroup]],
451
+ uint simd_group_id [[simdgroup_index_in_threadgroup]],
452
+ uint3 tid [[threadgroup_position_in_grid]],
453
+ uint3 lid [[thread_position_in_threadgroup]]) {
454
+ // Appease the compiler
455
+ (void)lid;
456
+
457
+ using gemm_kernel = GEMMKernel<
458
+ T,
459
+ T,
460
+ BM,
461
+ BN,
462
+ BK,
463
+ WM,
464
+ WN,
465
+ transpose_a,
466
+ transpose_b,
467
+ MN_aligned,
468
+ K_aligned>;
469
+
470
+ const int tid_y = ((tid.y) << params->swizzle_log) +
471
+ ((tid.x) & ((1 << params->swizzle_log) - 1));
472
+ const int tid_x = (tid.x) >> params->swizzle_log;
473
+
474
+ if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
475
+ return;
476
+ }
477
+
478
+ if (params->batch_ndim > 1) {
479
+ const constant auto* mask_batch_strides =
480
+ batch_strides + 2 * params->batch_ndim;
481
+ out_mask +=
482
+ elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
483
+
484
+ if (has_operand_mask) {
485
+ const constant auto* mask_strides_lhs =
486
+ mask_batch_strides + params->batch_ndim;
487
+ const constant auto* mask_strides_rhs =
488
+ mask_strides_lhs + params->batch_ndim;
489
+
490
+ ulong2 batch_offsets = elem_to_loc_broadcast(
491
+ tid.z,
492
+ batch_shape,
493
+ mask_strides_lhs,
494
+ mask_strides_rhs,
495
+ params->batch_ndim);
496
+
497
+ lhs_mask += batch_offsets.x;
498
+ rhs_mask += batch_offsets.y;
499
+ }
500
+ } else {
501
+ out_mask += tid.z * batch_strides[2 * params->batch_ndim];
502
+ if (has_operand_mask) {
503
+ lhs_mask += tid.z * batch_strides[3 * params->batch_ndim];
504
+ rhs_mask += tid.z * batch_strides[4 * params->batch_ndim];
505
+ }
506
+ }
507
+
508
+ // Adjust for batch
509
+ if (params->batch_ndim > 1) {
510
+ const constant auto* A_bstrides = batch_strides;
511
+ const constant auto* B_bstrides = batch_strides + params->batch_ndim;
512
+
513
+ ulong2 batch_offsets = elem_to_loc_broadcast(
514
+ tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
515
+
516
+ A += batch_offsets.x;
517
+ B += batch_offsets.y;
518
+
519
+ } else {
520
+ A += params->batch_stride_a * tid.z;
521
+ B += params->batch_stride_b * tid.z;
522
+ }
523
+
524
+ D += params->batch_stride_d * tid.z;
525
+
526
+ // Find block in A, B, C
527
+ const int c_row = tid_y * BM;
528
+ const int c_col = tid_x * BN;
529
+ const size_t c_row_long = size_t(c_row);
530
+ const size_t c_col_long = size_t(c_col);
531
+
532
+ A += transpose_a ? c_row_long : c_row_long * params->lda;
533
+ B += transpose_b ? c_col_long * params->ldb : c_col_long;
534
+ D += c_row_long * params->ldd + c_col_long;
535
+
536
+ bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]];
537
+
538
+ // Write zeros and return
539
+ if (!mask_out) {
540
+ constexpr short tgp_size = WM * WN * 32;
541
+ constexpr short vec_size = 4;
542
+
543
+ // Tile threads in threadgroup
544
+ constexpr short TN = BN / vec_size;
545
+ constexpr short TM = tgp_size / TN;
546
+
547
+ const short thread_idx = simd_group_id * 32 + simd_lane_id;
548
+ const short bi = thread_idx / TN;
549
+ const short bj = vec_size * (thread_idx % TN);
550
+
551
+ D += bi * params->ldd + bj;
552
+
553
+ short tgp_bm = min(BM, params->M - c_row);
554
+ short tgp_bn = min(BN, params->N - c_col);
555
+
556
+ if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
557
+ for (short ti = 0; ti < BM; ti += TM) {
558
+ STEEL_PRAGMA_UNROLL
559
+ for (short j = 0; j < vec_size; j++) {
560
+ D[ti * params->ldd + j] = T(0.);
561
+ }
562
+ }
563
+ } else {
564
+ short jmax = tgp_bn - bj;
565
+ jmax = jmax < vec_size ? jmax : vec_size;
566
+ for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) {
567
+ for (short j = 0; j < jmax; j++) {
568
+ D[ti * params->ldd + j] = T(0.);
569
+ }
570
+ }
571
+ }
572
+
573
+ return;
574
+ }
575
+
576
+ threadgroup_barrier(mem_flags::mem_none);
577
+
578
+ // Prepare threadgroup mma operation
579
+ thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id);
580
+
581
+ int gemm_k_iterations = params->gemm_k_iterations_aligned;
582
+
583
+ threadgroup T As[gemm_kernel::tgp_mem_size_a];
584
+ threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
585
+
586
+ // Prepare threadgroup loading operations
587
+ thread typename gemm_kernel::loader_a_t loader_a(
588
+ A, params->lda, As, simd_group_id, simd_lane_id);
589
+ thread typename gemm_kernel::loader_b_t loader_b(
590
+ B, params->ldb, Bs, simd_group_id, simd_lane_id);
591
+
592
+ ///////////////////////////////////////////////////////////////////////////////
593
+ // MNK aligned loop
594
+ if (MN_aligned) {
595
+ for (int k = 0; k < gemm_k_iterations; k++) {
596
+ threadgroup_barrier(mem_flags::mem_threadgroup);
597
+
598
+ if (!has_operand_mask ||
599
+ (lhs_mask
600
+ [tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
601
+ rhs_mask
602
+ [((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
603
+ // Load elements into threadgroup
604
+ loader_a.load_unsafe();
605
+ loader_b.load_unsafe();
606
+
607
+ threadgroup_barrier(mem_flags::mem_threadgroup);
608
+
609
+ // Multiply and accumulate threadgroup elements
610
+ mma_op.mma(As, Bs);
611
+ }
612
+
613
+ // Prepare for next iteration
614
+ loader_a.next();
615
+ loader_b.next();
616
+ }
617
+
618
+ threadgroup_barrier(mem_flags::mem_none);
619
+
620
+ // Loop tail
621
+ if (!K_aligned) {
622
+ if (!has_operand_mask ||
623
+ (lhs_mask
624
+ [tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
625
+ rhs_mask
626
+ [(params->K / BM) * mask_strides[5] +
627
+ tid_x * mask_strides[4]])) {
628
+ int lbk = params->K - params->gemm_k_iterations_aligned * BK;
629
+ short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
630
+ short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
631
+
632
+ loader_a.load_safe(tile_dims_A);
633
+ loader_b.load_safe(tile_dims_B);
634
+
635
+ threadgroup_barrier(mem_flags::mem_threadgroup);
636
+
637
+ mma_op.mma(As, Bs);
638
+ }
639
+ }
640
+
641
+ // Store results to device memory
642
+ mma_op.store_result(D, params->ldd);
643
+ return;
644
+
645
+ }
646
+ ///////////////////////////////////////////////////////////////////////////////
647
+ // MN unaligned loop
648
+ else { // Loop over K - unaligned case
649
+ short tgp_bm = min(BM, params->M - c_row);
650
+ short tgp_bn = min(BN, params->N - c_col);
651
+ short lbk = params->K - params->gemm_k_iterations_aligned * BK;
652
+
653
+ bool M_aligned = (tgp_bm == BM);
654
+ bool N_aligned = (tgp_bn == BN);
655
+
656
+ short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
657
+ short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
658
+
659
+ for (int k = 0; k < gemm_k_iterations; k++) {
660
+ threadgroup_barrier(mem_flags::mem_threadgroup);
661
+ if (!has_operand_mask ||
662
+ (lhs_mask
663
+ [tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
664
+ rhs_mask
665
+ [((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
666
+ // Load elements into threadgroup
667
+ if (M_aligned) {
668
+ loader_a.load_unsafe();
669
+ } else {
670
+ loader_a.load_safe(tile_dims_A);
671
+ }
672
+
673
+ if (N_aligned) {
674
+ loader_b.load_unsafe();
675
+ } else {
676
+ loader_b.load_safe(tile_dims_B);
677
+ }
678
+
679
+ threadgroup_barrier(mem_flags::mem_threadgroup);
680
+
681
+ // Multiply and accumulate threadgroup elements
682
+ mma_op.mma(As, Bs);
683
+ }
684
+
685
+ // Prepare for next iteration
686
+ loader_a.next();
687
+ loader_b.next();
688
+ }
689
+
690
+ if (!K_aligned) {
691
+ threadgroup_barrier(mem_flags::mem_threadgroup);
692
+
693
+ if (!has_operand_mask ||
694
+ (lhs_mask
695
+ [tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
696
+ rhs_mask
697
+ [(params->K / BM) * mask_strides[5] +
698
+ tid_x * mask_strides[4]])) {
699
+ short2 tile_dims_A_last =
700
+ transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
701
+ short2 tile_dims_B_last =
702
+ transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
703
+
704
+ loader_a.load_safe(tile_dims_A_last);
705
+ loader_b.load_safe(tile_dims_B_last);
706
+
707
+ threadgroup_barrier(mem_flags::mem_threadgroup);
708
+
709
+ mma_op.mma(As, Bs);
710
+ }
711
+ }
712
+
713
+ if (M_aligned && N_aligned) {
714
+ mma_op.store_result(D, params->ldd);
715
+ } else {
716
+ mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
717
+ }
718
+ }
719
+ }