mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,476 @@
1
+ // Copyright © 2024-25 Apple Inc.
2
+
3
+ #include "mlx/backend/metal/kernels/steel/attn/attn.h"
4
+
5
+ using namespace mlx::steel;
6
+
7
+ ///////////////////////////////////////////////////////////////////////////////
8
+ // GEMM kernels
9
+ ///////////////////////////////////////////////////////////////////////////////
10
+
11
+ constant bool align_Q [[function_constant(200)]];
12
+ constant bool align_K [[function_constant(201)]];
13
+
14
+ constant bool has_mask [[function_constant(300)]];
15
+ constant bool do_causal [[function_constant(301)]];
16
+ constant bool has_sinks [[function_constant(302)]];
17
+
18
+ template <typename T>
19
+ struct TransformScale {
20
+ T scale;
21
+ METAL_FUNC TransformScale(T scale_) : scale(scale_) {}
22
+
23
+ METAL_FUNC T apply(T x) const {
24
+ return scale * x;
25
+ }
26
+ };
27
+
28
+ struct MaxOp {
29
+ template <typename T>
30
+ METAL_FUNC static constexpr T apply(T x, T y) {
31
+ return metal::max(x, y);
32
+ }
33
+ };
34
+
35
+ struct SumOp {
36
+ template <typename T>
37
+ METAL_FUNC static constexpr T apply(T x, T y) {
38
+ return x + y;
39
+ }
40
+ };
41
+
42
+ struct MulOp {
43
+ template <typename T>
44
+ METAL_FUNC static constexpr T apply(T x, T y) {
45
+ return x * y;
46
+ }
47
+ };
48
+
49
+ struct SubOp {
50
+ template <typename T>
51
+ METAL_FUNC static constexpr T apply(T x, T y) {
52
+ return x - y;
53
+ }
54
+ };
55
+
56
+ struct ExpSubOp {
57
+ template <typename T>
58
+ METAL_FUNC static constexpr T apply(T x, T y) {
59
+ return fast::exp2(x - y);
60
+ }
61
+ };
62
+
63
+ struct DivOp {
64
+ template <typename T>
65
+ METAL_FUNC static constexpr T apply(T x, T y) {
66
+ return x / y;
67
+ }
68
+ };
69
+
70
+ // clang-format off
71
+ template <
72
+ typename T,
73
+ int BQ,
74
+ int BK,
75
+ int BD,
76
+ int WM,
77
+ int WN,
78
+ typename MaskType = float,
79
+ typename AccumType = float>
80
+ [[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention(
81
+ const device T* Q [[buffer(0)]],
82
+ const device T* K [[buffer(1)]],
83
+ const device T* V [[buffer(2)]],
84
+ device T* O [[buffer(3)]],
85
+ const constant AttnParams* params [[buffer(4)]],
86
+ const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]],
87
+ const device MaskType* mask [[buffer(6), function_constant(has_mask)]],
88
+ const device T* sinks [[buffer(7), function_constant(has_sinks)]],
89
+ uint simd_lane_id [[thread_index_in_simdgroup]],
90
+ uint simd_group_id [[simdgroup_index_in_threadgroup]],
91
+ uint3 tid [[threadgroup_position_in_grid]],
92
+ uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
93
+
94
+ // Pacifying compiler
95
+ (void)lid;
96
+
97
+ // Move to correct block
98
+ ulong3 tidl{tid.x, tid.y, tid.z};
99
+
100
+ Q += tidl.z * params->Q_strides[0] + // Batch
101
+ tidl.y * params->Q_strides[1] + // Head
102
+ tidl.x * BQ * params->Q_strides[2]; // Sequence
103
+
104
+ ulong kv_head_idx = int(tid.y) / params->gqa_factor;
105
+ K += tidl.z * params->K_strides[0] + // Batch
106
+ kv_head_idx * params->K_strides[1]; // Head
107
+
108
+ V += tidl.z * params->V_strides[0] + // Batch
109
+ kv_head_idx * params->V_strides[1]; // Head
110
+
111
+ O += tidl.z * params->O_strides[0] + // Batch
112
+ tidl.y * params->O_strides[1] + // Head
113
+ tidl.x * BQ * params->O_strides[2]; // Sequence
114
+
115
+ if (has_mask) {
116
+ mask += tidl.z * mask_params->M_strides[0] + // Batch
117
+ tidl.y * mask_params->M_strides[1]; // Head
118
+ }
119
+
120
+ // Prepare threadgroup memory
121
+ constexpr short padQ = 16 / sizeof(T);
122
+ constexpr short padK = 16 / sizeof(T);
123
+ constexpr short padV = 16 / sizeof(T);
124
+
125
+ constexpr short LDQ_tgp = BD + padQ;
126
+ constexpr short LDK_tgp = BK + padK;
127
+ constexpr short LDV_tgp = BD + padV;
128
+
129
+ constexpr short tgp_mem_0 = (BK + padK) * (BD);
130
+ constexpr short tgp_mem_1 = BK * (BD + padV);
131
+ constexpr short tgp_mem_s = tgp_mem_0 > tgp_mem_1 ? tgp_mem_0 : tgp_mem_1;
132
+
133
+ threadgroup T Q_smem[BQ * (BD + padQ)];
134
+ threadgroup T KV_smem[tgp_mem_s];
135
+
136
+ threadgroup T* Qs = Q_smem;
137
+ threadgroup T* Ks = KV_smem;
138
+ threadgroup T* Vs = KV_smem;
139
+
140
+ // Prepare block loaders
141
+ using QBlockLoader = BlockLoaderT<
142
+ /* typename T = */ T,
143
+ /* short BROWS = */ BQ,
144
+ /* short BCOLS = */ BD,
145
+ /* short kDstStrRow = */ LDQ_tgp,
146
+ /* short kDstStrCol = */ 1,
147
+ /* short reduction_dim = */ 1,
148
+ /* short tgp_size = */ WM * WN * 32>;
149
+
150
+ // K is loaded in transposed
151
+ using KBlockLoader = BlockLoaderT<
152
+ /* typename T = */ T,
153
+ /* short BROWS = */ BK,
154
+ /* short BCOLS = */ BD,
155
+ /* short kDstStrRow = */ 1,
156
+ /* short kDstStrCol = */ LDK_tgp,
157
+ /* short reduction_dim = */ 0,
158
+ /* short tgp_size = */ WM * WN * 32>;
159
+
160
+ using VBlockLoader = BlockLoaderT<
161
+ /* typename T = */ T,
162
+ /* short BROWS = */ BK,
163
+ /* short BCOLS = */ BD,
164
+ /* short kDstStrRow = */ LDV_tgp,
165
+ /* short kDstStrCol = */ 1,
166
+ /* short reduction_dim = */ 0,
167
+ /* short tgp_size = */ WM * WN * 32>;
168
+
169
+ QBlockLoader loader_q(
170
+ Q, params->Q_strides[2], Qs, simd_group_id, simd_lane_id);
171
+ KBlockLoader loader_k(
172
+ K, params->K_strides[2], Ks, simd_group_id, simd_lane_id);
173
+ VBlockLoader loader_v(
174
+ V, params->V_strides[2], Vs, simd_group_id, simd_lane_id);
175
+
176
+ TransformScale<T> ts(static_cast<T>(params->scale * M_LOG2E_F));
177
+
178
+ // Prepare MMA tiles
179
+ constexpr short kFragSize = 8; // MMAFrag size
180
+ using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
181
+
182
+ constexpr int kNWarps = WM * WN;
183
+ static_assert(
184
+ BQ >= (kNWarps * kFragSize) && BQ % (kNWarps * kFragSize) == 0,
185
+ "Each simdgroup must host atleast 1 simdgroup matrix along Q sequence.");
186
+
187
+ // Q seq frags per warp
188
+ constexpr int TQ = BQ / (kNWarps * kFragSize);
189
+ // KV sequence frags (all warps load the same frags)
190
+ constexpr int TK = BK / kFragSize;
191
+ // HeadDim frags (all warps load the same frags)
192
+ constexpr int TD = BD / kFragSize;
193
+
194
+ static_assert(TQ == 1, "Check TQ");
195
+
196
+ MMATile<AccumType, TQ, 1, MMAFrag_acc_t> Qtile;
197
+ MMATile<AccumType, 1, TK, MMAFrag_acc_t> Ktile;
198
+ MMATile<AccumType, TQ, TK, MMAFrag_acc_t> Stile;
199
+ MMATile<AccumType, 1, 1, MMAFrag_acc_t> Vtile;
200
+ MMATile<AccumType, TQ, TD, MMAFrag_acc_t> Otile;
201
+
202
+ Otile.clear();
203
+
204
+ // Prepare mma tile offsets
205
+ const short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
206
+ const short sm = simd_coord.y;
207
+ const short sn = simd_coord.x;
208
+ const short tm = kFragSize * TQ * simd_group_id;
209
+
210
+ const short Qs_offset = (tm + sm) * LDQ_tgp + sn;
211
+ const short Ks_offset = sm * LDK_tgp + sn;
212
+ const short Vs_offset = sm * LDV_tgp + sn;
213
+
214
+ constexpr short Qs_tile_stride = kFragSize;
215
+ constexpr short Ks_tile_stride = kFragSize * LDK_tgp;
216
+
217
+ threadgroup_barrier(mem_flags::mem_threadgroup);
218
+
219
+ // Load Q blocks apply scale
220
+ if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
221
+ loader_q.load_safe(short2(BD, params->qL_rem));
222
+ } else {
223
+ loader_q.load_unsafe();
224
+ }
225
+ loader_q.apply_inplace_op(ts);
226
+
227
+ // Init row reduction variables
228
+ constexpr short kRowsPT = decltype(Stile)::kRowsPerThread;
229
+
230
+ AccumType max_score[kRowsPT];
231
+ AccumType sum_score[kRowsPT] = {0};
232
+
233
+ // Init to -Inf
234
+ STEEL_PRAGMA_UNROLL
235
+ for (short i = 0; i < kRowsPT; ++i) {
236
+ max_score[i] = Limits<AccumType>::finite_min;
237
+ }
238
+
239
+ if (has_sinks) {
240
+ STEEL_PRAGMA_UNROLL
241
+ for (short i = 0; i < kRowsPT; ++i) {
242
+ max_score[i] = M_LOG2E_F * static_cast<AccumType>(sinks[tidl.y]);
243
+ sum_score[i] = 1;
244
+ }
245
+ }
246
+
247
+ int kb_lim = params->NK;
248
+
249
+ if (do_causal) {
250
+ int q_max = (tid.x + 1) * BQ + params->qL_off;
251
+ kb_lim = (q_max + BK - 1) / BK;
252
+ kb_lim = min(params->NK, kb_lim);
253
+ }
254
+
255
+ // Loop over KV seq length
256
+ for (int kb = 0; kb < kb_lim; kb++) {
257
+ // Load K block and apply scale
258
+ threadgroup_barrier(mem_flags::mem_threadgroup);
259
+ if (!align_K && kb == (params->NK_aligned)) {
260
+ loader_k.load_safe(short2(BD, params->kL_rem));
261
+ } else {
262
+ loader_k.load_unsafe();
263
+ }
264
+
265
+ // Do S = Q @ K.T
266
+ Stile.clear();
267
+
268
+ threadgroup_barrier(mem_flags::mem_threadgroup);
269
+
270
+ STEEL_PRAGMA_UNROLL
271
+ for (short dd = 0; dd < TD; dd++) {
272
+ simdgroup_barrier(mem_flags::mem_none);
273
+
274
+ Qtile.template load<T, 1, 1, LDQ_tgp, 1>(
275
+ &Qs[Qs_offset + dd * Qs_tile_stride]);
276
+ Ktile.template load<T, 1, 1, LDK_tgp, 1>(
277
+ &Ks[Ks_offset + dd * Ks_tile_stride]);
278
+
279
+ simdgroup_barrier(mem_flags::mem_none);
280
+
281
+ tile_matmad(Stile, Qtile, Ktile, Stile);
282
+ }
283
+
284
+ // Mask out length sequence
285
+ if (!align_K && kb == (params->NK_aligned)) {
286
+ using stile_t = decltype(Stile);
287
+ using selem_t = typename stile_t::elem_type;
288
+ constexpr auto neg_inf = Limits<selem_t>::finite_min;
289
+
290
+ STEEL_PRAGMA_UNROLL
291
+ for (short i = 0; i < stile_t::kTileRows; i++) {
292
+ STEEL_PRAGMA_UNROLL
293
+ for (short j = 0; j < stile_t::kTileCols; j++) {
294
+ short col_pos = sn + (j * stile_t::kFragCols);
295
+ STEEL_PRAGMA_UNROLL
296
+ for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
297
+ if ((col_pos + jj) >= params->kL_rem) {
298
+ Stile.frag_at(i, j)[jj] = neg_inf;
299
+ }
300
+ }
301
+ }
302
+ }
303
+ }
304
+
305
+ // Mask out if causal
306
+ if (do_causal && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) {
307
+ using stile_t = decltype(Stile);
308
+ using selem_t = typename stile_t::elem_type;
309
+ constexpr auto neg_inf = Limits<selem_t>::finite_min;
310
+
311
+ STEEL_PRAGMA_UNROLL
312
+ for (short i = 0; i < stile_t::kTileRows; i++) {
313
+ const int row_pos =
314
+ tid.x * BQ + params->qL_off + tm + sm + (i * stile_t::kFragRows);
315
+ STEEL_PRAGMA_UNROLL
316
+ for (short j = 0; j < stile_t::kTileCols; j++) {
317
+ const int col_pos = kb * BK + sn + (j * stile_t::kFragCols);
318
+ STEEL_PRAGMA_UNROLL
319
+ for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
320
+ if (row_pos < (col_pos + jj)) {
321
+ Stile.frag_at(i, j)[jj] = neg_inf;
322
+ }
323
+ }
324
+ }
325
+ }
326
+ }
327
+
328
+ // Other masking as needed
329
+ if (has_mask) {
330
+ using stile_t = decltype(Stile);
331
+ using selem_t = typename stile_t::elem_type;
332
+ constexpr auto neg_inf = Limits<selem_t>::finite_min;
333
+
334
+ constexpr bool is_bool = is_same_v<MaskType, bool>;
335
+ using melem_t = typename metal::conditional_t<is_bool, bool, selem_t>;
336
+
337
+ using MMAFrag_mask_t = BaseMMAFrag<melem_t, kFragSize, kFragSize>;
338
+ using frag_t = typename MMAFrag_mask_t::frag_type;
339
+
340
+ STEEL_PRAGMA_UNROLL
341
+ for (short i = 0; i < stile_t::kTileRows; i++) {
342
+ const int row_pos = tid.x * BQ + tm + sm + (i * stile_t::kFragRows);
343
+ STEEL_PRAGMA_UNROLL
344
+ for (short j = 0; j < stile_t::kTileCols; j++) {
345
+ const int col_pos = kb * BK + sn + (j * stile_t::kFragCols);
346
+
347
+ frag_t mfrag;
348
+
349
+ MMAFrag_mask_t::load_safe(
350
+ mfrag,
351
+ mask,
352
+ int64_t(mask_params->M_strides[2]),
353
+ Int<1>{},
354
+ params->qL,
355
+ params->kL,
356
+ row_pos,
357
+ col_pos);
358
+
359
+ STEEL_PRAGMA_UNROLL
360
+ for (short jj = 0; jj < stile_t::MMAFrag_t::kElemsPerFrag; jj++) {
361
+ if constexpr (is_bool) {
362
+ Stile.frag_at(i, j)[jj] =
363
+ mfrag[jj] ? Stile.frag_at(i, j)[jj] : neg_inf;
364
+ } else {
365
+ Stile.frag_at(i, j)[jj] += M_LOG2E_F * selem_t(mfrag[jj]);
366
+ }
367
+ }
368
+ }
369
+ }
370
+ }
371
+
372
+ threadgroup_barrier(mem_flags::mem_threadgroup);
373
+
374
+ // Load V blocks
375
+ if (!align_K && kb == (params->NK_aligned)) {
376
+ loader_v.load_safe(short2(BD, params->kL_rem));
377
+ } else {
378
+ loader_v.load_unsafe();
379
+ }
380
+
381
+ // Do softmax
382
+
383
+ // Temp variables
384
+ AccumType new_max[kRowsPT];
385
+ AccumType factor[kRowsPT];
386
+ STEEL_PRAGMA_UNROLL
387
+ for (short i = 0; i < kRowsPT; ++i) {
388
+ new_max[i] = max_score[i];
389
+ }
390
+
391
+ // Row max
392
+ Stile.template row_reduce<MaxOp>(new_max);
393
+
394
+ // exp(Si - rowmax(Si))
395
+ Stile.template row_bin_op<ExpSubOp>(new_max);
396
+
397
+ // Factor exp(rowmax(Si) - rowmax(Si-1))
398
+ STEEL_PRAGMA_UNROLL
399
+ for (short i = 0; i < kRowsPT; ++i) {
400
+ factor[i] = fast::exp2(max_score[i] - new_max[i]);
401
+ }
402
+
403
+ // Save max for next iteration
404
+ STEEL_PRAGMA_UNROLL
405
+ for (short i = 0; i < kRowsPT; ++i) {
406
+ max_score[i] = new_max[i];
407
+ }
408
+
409
+ // Row Sum
410
+ AccumType sum_score_tmp[kRowsPT] = {0};
411
+ Stile.template row_reduce<SumOp>(sum_score_tmp);
412
+
413
+ // Update norm
414
+ STEEL_PRAGMA_UNROLL
415
+ for (short i = 0; i < kRowsPT; ++i) {
416
+ sum_score[i] = sum_score[i] * factor[i] + sum_score_tmp[i];
417
+ }
418
+
419
+ // Update O
420
+ Otile.template row_bin_op<MulOp>(factor);
421
+
422
+ // Load V into registers
423
+ threadgroup_barrier(mem_flags::mem_threadgroup);
424
+
425
+ STEEL_PRAGMA_UNROLL
426
+ for (short iq = 0; iq < TQ; iq++) {
427
+ STEEL_PRAGMA_UNROLL
428
+ for (short id = 0; id < TD; id++) {
429
+ STEEL_PRAGMA_UNROLL
430
+ for (short ik = 0; ik < TK; ik++) {
431
+ if constexpr (BD == 128) {
432
+ simdgroup_barrier(mem_flags::mem_none);
433
+ }
434
+
435
+ const short kk = ik * kFragSize;
436
+ const short dd = id * kFragSize;
437
+
438
+ Vtile.template load<T, 1, 1, LDV_tgp, 1>(
439
+ &Vs[Vs_offset + kk * LDV_tgp + dd]);
440
+
441
+ if constexpr (BD == 128) {
442
+ simdgroup_barrier(mem_flags::mem_none);
443
+ }
444
+
445
+ MMAFrag_acc_t::mma(
446
+ Otile.frag_at(iq, id),
447
+ Stile.frag_at(iq, ik),
448
+ Vtile.frag_at(0, 0),
449
+ Otile.frag_at(iq, id));
450
+ }
451
+ }
452
+ }
453
+
454
+ // Prepare for next iteration
455
+ loader_k.next();
456
+ loader_v.next();
457
+ }
458
+
459
+ // Normalize output
460
+ Otile.template row_bin_op<DivOp>(sum_score);
461
+ threadgroup_barrier(mem_flags::mem_none);
462
+
463
+ // Store results
464
+ O += (tm + sm) * params->O_strides[2] + sn;
465
+
466
+ if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
467
+ auto dst_tile_dims = short2(BD - sn, params->qL_rem - (tm + sm));
468
+
469
+ if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
470
+ return;
471
+
472
+ Otile.template store_safe<T, 1, 1>(O, params->O_strides[2], dst_tile_dims);
473
+ } else {
474
+ Otile.template store<T, 1, 1>(O, params->O_strides[2]);
475
+ }
476
+ }