mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,481 @@
1
+ // Copyright © 2024-25 Apple Inc.
2
+
3
+ #include "mlx/backend/metal/kernels/steel/attn/nax.h"
4
+ #include "mlx/backend/metal/kernels/steel/attn/params.h"
5
+ #include "mlx/backend/metal/kernels/steel/attn/transforms.h"
6
+ #include "mlx/backend/metal/kernels/steel/utils.h"
7
+
8
+ using namespace mlx::steel;
9
+
10
+ ///////////////////////////////////////////////////////////////////////////////
11
+ // GEMM kernels
12
+ ///////////////////////////////////////////////////////////////////////////////
13
+
14
+ constant bool align_Q [[function_constant(200)]];
15
+ constant bool align_K [[function_constant(201)]];
16
+
17
+ constant bool has_mask [[function_constant(300)]];
18
+ constant bool do_causal [[function_constant(301)]];
19
+ constant bool has_sinks [[function_constant(302)]];
20
+
21
+ template <typename T>
22
+ struct TransformScale {
23
+ T scale;
24
+ METAL_FUNC TransformScale(T scale_) : scale(scale_) {}
25
+
26
+ METAL_FUNC T apply(T x) const {
27
+ return scale * x;
28
+ }
29
+ };
30
+
31
+ struct MaxOp {
32
+ template <typename T>
33
+ METAL_FUNC static constexpr T apply(T x, T y) {
34
+ return metal::max(x, y);
35
+ }
36
+ };
37
+
38
+ struct SumOp {
39
+ template <typename T>
40
+ METAL_FUNC static constexpr T apply(T x, T y) {
41
+ return x + y;
42
+ }
43
+ };
44
+
45
+ struct MulOp {
46
+ template <typename T>
47
+ METAL_FUNC static constexpr T apply(T x, T y) {
48
+ return x * y;
49
+ }
50
+ };
51
+
52
+ struct SubOp {
53
+ template <typename T>
54
+ METAL_FUNC static constexpr T apply(T x, T y) {
55
+ return x - y;
56
+ }
57
+ };
58
+
59
+ struct ExpSubOp {
60
+ template <typename T>
61
+ METAL_FUNC static constexpr T apply(T x, T y) {
62
+ return fast::exp2(x - y);
63
+ }
64
+ };
65
+
66
+ struct DivOp {
67
+ template <typename T>
68
+ METAL_FUNC static constexpr T apply(T x, T y) {
69
+ return x / y;
70
+ }
71
+ };
72
+
73
+ // clang-format off
74
+ template <
75
+ typename T,
76
+ int BQ,
77
+ int BK,
78
+ int BD,
79
+ int WM,
80
+ int WN,
81
+ typename MaskType = float,
82
+ typename AccumType = float>
83
+ [[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention_nax(
84
+ const device T* Q [[buffer(0)]],
85
+ const device T* K [[buffer(1)]],
86
+ const device T* V [[buffer(2)]],
87
+ device T* O [[buffer(3)]],
88
+ const constant AttnParams* params [[buffer(4)]],
89
+ const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]],
90
+ const device MaskType* mask [[buffer(6), function_constant(has_mask)]],
91
+ const device T* sinks [[buffer(7), function_constant(has_sinks)]],
92
+ uint simd_lane_id [[thread_index_in_simdgroup]],
93
+ uint simd_group_id [[simdgroup_index_in_threadgroup]],
94
+ uint3 tid [[threadgroup_position_in_grid]],
95
+ uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
96
+
97
+ // Pacifying compiler
98
+ (void)lid;
99
+ (void)simd_lane_id;
100
+
101
+ // Move to correct block
102
+ ulong3 tidl{tid.x, tid.y, tid.z};
103
+
104
+ Q += tidl.z * params->Q_strides[0] + // Batch
105
+ tidl.y * params->Q_strides[1] + // Head
106
+ tidl.x * BQ * params->Q_strides[2]; // Sequence
107
+
108
+ ulong kv_head_idx = int(tid.y) / params->gqa_factor;
109
+ K += tidl.z * params->K_strides[0] + // Batch
110
+ kv_head_idx * params->K_strides[1]; // Head
111
+
112
+ V += tidl.z * params->V_strides[0] + // Batch
113
+ kv_head_idx * params->V_strides[1]; // Head
114
+
115
+ O += tidl.z * params->O_strides[0] + // Batch
116
+ tidl.y * params->O_strides[1] + // Head
117
+ tidl.x * BQ * params->O_strides[2]; // Sequence
118
+
119
+ if (has_mask) {
120
+ mask += tidl.z * mask_params->M_strides[0] + // Batch
121
+ tidl.y * mask_params->M_strides[1]; // Head
122
+ }
123
+
124
+ const metal::uniform<float> scale2 =
125
+ make_uniform(params->scale) * make_uniform(1.44269504089f);
126
+
127
+ // Prepare MMA tiles
128
+ constexpr short UQ = 16;
129
+ constexpr short UD = 32;
130
+
131
+ constexpr int kNWarps = WM * WN;
132
+ static_assert(
133
+ BQ >= (kNWarps * UQ) && BQ % (kNWarps * UQ) == 0,
134
+ "Each simdgroup must host atleast 1 simdgroup matrix along Q sequence.");
135
+
136
+ // Q seq frags per warp
137
+ constexpr int TQ = BQ / (kNWarps * UQ);
138
+ // HeadDim frags (all warps load the same frags)
139
+ constexpr int TD = BD / UD;
140
+
141
+ static_assert(TQ == 1, "Check TQ");
142
+
143
+ using OSubTile = NAXSubTile<AccumType, UQ, UD>;
144
+ NAXTile<AccumType, TQ, TD, OSubTile> Otile;
145
+
146
+ Otile.clear();
147
+
148
+ // Prepare mma tile offsets
149
+ const short2 simd_coord = OSubTile::NAXFrag_t::get_coord();
150
+ const short sm = simd_coord.y;
151
+ const short sn = simd_coord.x;
152
+ const short tm = UQ * TQ * simd_group_id;
153
+
154
+ Q += (tm + sm) * int(params->Q_strides[2]) + sn;
155
+ K += sm * int(params->K_strides[2]) + sn;
156
+ V += sm * int(params->V_strides[2]) + sn;
157
+
158
+ // Init row reduction variables
159
+ constexpr short kRowsPT = decltype(Otile)::kRowsPerThread;
160
+
161
+ metal::vec<AccumType, kRowsPT> max_score;
162
+ metal::vec<AccumType, kRowsPT> sum_score{0};
163
+
164
+ // Init to -Inf
165
+ STEEL_PRAGMA_UNROLL
166
+ for (short i = 0; i < kRowsPT; ++i) {
167
+ max_score[i] = Limits<AccumType>::finite_min;
168
+ }
169
+
170
+ if (has_sinks) {
171
+ STEEL_PRAGMA_UNROLL
172
+ for (short i = 0; i < kRowsPT; ++i) {
173
+ max_score[i] = M_LOG2E_F * static_cast<AccumType>(sinks[tidl.y]);
174
+ sum_score[i] = 1;
175
+ }
176
+ }
177
+
178
+ int kb_lim = params->NK;
179
+
180
+ if (do_causal) {
181
+ int q_max = (tid.x + 1) * BQ + params->qL_off;
182
+ kb_lim = (q_max + BK - 1) / BK;
183
+ kb_lim = min(params->NK, kb_lim);
184
+ }
185
+
186
+ const bool is_last_bq = int(tid.x) == (params->NQ_aligned);
187
+ // const bool is_last_tq = int(simd_group_id) >= (params->qL_rem / UQ);
188
+ const bool is_last_q = is_last_bq;
189
+
190
+ const short lim_rows_q = params->qL_rem - (tm + sm);
191
+ const short lim_rows_k = params->kL_rem - sm;
192
+
193
+ // Loop over KV seq length
194
+ for (int kb = 0; kb < kb_lim; kb++) {
195
+ const int is_last_k = (kb == (params->NK_aligned));
196
+
197
+ // Do S = Q @ K.T
198
+ constexpr short UDs = 16;
199
+ constexpr short UKs = 32;
200
+
201
+ constexpr short TDs = BD / UDs;
202
+ constexpr short TKs = BK / UKs;
203
+
204
+ using SSubTile = NAXSubTile<AccumType, UQ, UKs>;
205
+ using QSubTile = NAXSubTile<T, UQ, UDs>;
206
+ using KSubTile = NAXSubTile<T, UKs, UDs>;
207
+
208
+ NAXTile<AccumType, TQ, TKs, SSubTile> Stile;
209
+
210
+ Stile.clear();
211
+
212
+ STEEL_PRAGMA_UNROLL
213
+ for (short iq = 0; iq < TQ; iq++) {
214
+ STEEL_PRAGMA_UNROLL
215
+ for (short ik = 0; ik < TKs; ik++) {
216
+ STEEL_PRAGMA_UNROLL
217
+ for (short id = 0; id < TDs; id++) {
218
+ NAXTile<T, 1, 1, QSubTile> Qtile;
219
+ NAXTile<T, 1, 1, KSubTile> Ktile;
220
+
221
+ const int Q_load_off = iq * UQ * int(params->Q_strides[2]) + id * UDs;
222
+ const int K_load_off =
223
+ ik * UKs * int(params->K_strides[2]) + id * UDs;
224
+
225
+ if (!align_Q && is_last_q) {
226
+ // Qtile.load_rows(
227
+ // Q + Q_load_off,
228
+ // int(params->Q_strides[2]),
229
+ // lim_rows_q - iq * UQ);
230
+ Qtile.load_safe(
231
+ Q + Q_load_off,
232
+ int(params->Q_strides[2]),
233
+ short2(BD, lim_rows_q - iq * UQ));
234
+ } else {
235
+ Qtile.load(Q + Q_load_off, int(params->Q_strides[2]));
236
+ }
237
+
238
+ if (!align_K && is_last_k) {
239
+ // Ktile.load_rows(
240
+ // K + K_load_off,
241
+ // int(params->K_strides[2]),
242
+ // lim_rows_k - ik * UKs);
243
+ Ktile.load_safe(
244
+ K + K_load_off,
245
+ int(params->K_strides[2]),
246
+ short2(BD, lim_rows_k - ik * UKs));
247
+ } else {
248
+ Ktile.load(K + K_load_off, int(params->K_strides[2]));
249
+ }
250
+
251
+ subtile_matmad_nax(
252
+ Stile.subtile_at(iq, ik),
253
+ Qtile.subtile_at(0, 0),
254
+ metal::false_type{},
255
+ Ktile.subtile_at(0, 0),
256
+ metal::true_type{});
257
+ }
258
+ }
259
+ }
260
+
261
+ // Scale S
262
+ STEEL_PRAGMA_UNROLL
263
+ for (short ii = 0; ii < decltype(Stile)::kElemsPerTile; ii++) {
264
+ Stile.elems()[ii] *= float(scale2);
265
+ }
266
+
267
+ // Scale and Retile S
268
+ constexpr short UK = 16;
269
+ constexpr short TK = BK / UK;
270
+ using PSubTile = NAXSubTile<AccumType, UQ, UK>;
271
+
272
+ NAXTile<AccumType, TQ, TK, PSubTile> Ptile;
273
+
274
+ STEEL_PRAGMA_UNROLL
275
+ for (short ii = 0; ii < decltype(Stile)::kElemsPerTile; ii++) {
276
+ Ptile.elems()[ii] = Stile.elems()[ii];
277
+ }
278
+
279
+ // Mask out length sequence
280
+ if (!align_K && is_last_k) {
281
+ constexpr auto neg_inf = Limits<AccumType>::finite_min;
282
+
283
+ STEEL_PRAGMA_UNROLL
284
+ for (short iq = 0; iq < TQ; iq++) {
285
+ STEEL_PRAGMA_UNROLL
286
+ for (short ik = 0; ik < TK; ik++) {
287
+ const short col_pos = sn + ik * UK;
288
+
289
+ thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0);
290
+
291
+ STEEL_PRAGMA_UNROLL
292
+ for (short ii = 0; ii < PSubTile::kFragThrRows; ii++) {
293
+ STEEL_PRAGMA_UNROLL
294
+ for (short jj = 0; jj < PSubTile::kFragThrCols; jj++) {
295
+ const auto loc = ii * PSubTile::kFragThrCols + jj;
296
+ fg[loc] = ((col_pos + jj) >= params->kL_rem) ? neg_inf : fg[loc];
297
+ }
298
+ }
299
+ }
300
+ }
301
+ }
302
+
303
+ // Mask out if causal
304
+ if (do_causal && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) {
305
+ constexpr auto neg_inf = Limits<AccumType>::finite_min;
306
+
307
+ const int base_row = tid.x * BQ + params->qL_off + tm;
308
+ const int base_col = kb * BK;
309
+
310
+ STEEL_PRAGMA_UNROLL
311
+ for (short iq = 0; iq < TQ; iq++) {
312
+ STEEL_PRAGMA_UNROLL
313
+ for (short ik = 0; ik < TK; ik++) {
314
+ const short row_pos = base_row + iq * UQ;
315
+ const short col_pos = base_col + ik * UK;
316
+
317
+ thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0);
318
+
319
+ STEEL_PRAGMA_UNROLL
320
+ for (short ii = 0; ii < PSubTile::kFragThrRows; ii++) {
321
+ STEEL_PRAGMA_UNROLL
322
+ for (short jj = 0; jj < PSubTile::kFragThrCols; jj++) {
323
+ const auto r = row_pos + ii * PSubTile::kFragRowsJump + sm;
324
+ const auto c = col_pos + jj + sn;
325
+ const auto loc = ii * PSubTile::kFragThrCols + jj;
326
+ fg[loc] = (r < c) ? neg_inf : fg[loc];
327
+ }
328
+ }
329
+ }
330
+ }
331
+ }
332
+
333
+ // Other masking as needed
334
+ if (has_mask) {
335
+ constexpr auto neg_inf = Limits<AccumType>::finite_min;
336
+
337
+ const int base_row = tid.x * BQ + tm;
338
+ const int base_col = kb * BK;
339
+
340
+ constexpr bool is_bool = is_same_v<MaskType, bool>;
341
+ using melem_t = typename metal::conditional_t<is_bool, bool, AccumType>;
342
+ using MSubTile = NAXSubTile<melem_t, UQ, UK>;
343
+
344
+ STEEL_PRAGMA_UNROLL
345
+ for (short iq = 0; iq < TQ; iq++) {
346
+ STEEL_PRAGMA_UNROLL
347
+ for (short ik = 0; ik < TK; ik++) {
348
+ const short row_pos = base_row + iq * UQ + sm;
349
+ const short col_pos = base_col + ik * UK + sn;
350
+
351
+ MSubTile mfrag;
352
+ mfrag.load_safe(
353
+ mask,
354
+ int64_t(mask_params->M_strides[2]),
355
+ Int<1>{},
356
+ params->qL,
357
+ params->kL,
358
+ row_pos,
359
+ col_pos);
360
+
361
+ thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0);
362
+
363
+ STEEL_PRAGMA_UNROLL
364
+ for (short jj = 0; jj < MSubTile::kElemsPerFrag; jj++) {
365
+ if constexpr (is_bool) {
366
+ fg[jj] = mfrag.elems()[jj] ? fg[jj] : neg_inf;
367
+ } else {
368
+ fg[jj] += M_LOG2E_F * AccumType(mfrag.elems()[jj]);
369
+ }
370
+ }
371
+ }
372
+ }
373
+ }
374
+
375
+ // Do softmax
376
+
377
+ // Temp variables
378
+ metal::vec<AccumType, kRowsPT> new_max;
379
+ metal::vec<AccumType, kRowsPT> factor;
380
+ STEEL_PRAGMA_UNROLL
381
+ for (short i = 0; i < kRowsPT; ++i) {
382
+ new_max[i] = max_score[i];
383
+ }
384
+
385
+ // Row max
386
+ Ptile.template row_reduce<MaxOp>(new_max);
387
+
388
+ // exp(Si - rowmax(Si))
389
+ Ptile.template row_bin_op<ExpSubOp>(new_max);
390
+
391
+ // Factor exp(rowmax(Si) - rowmax(Si-1))
392
+ STEEL_PRAGMA_UNROLL
393
+ for (short i = 0; i < kRowsPT; ++i) {
394
+ factor[i] = fast::exp2(max_score[i] - new_max[i]);
395
+ max_score[i] = new_max[i];
396
+ }
397
+
398
+ // Row Sum
399
+ STEEL_PRAGMA_UNROLL
400
+ for (short i = 0; i < kRowsPT; ++i) {
401
+ sum_score[i] = sum_score[i] * factor[i];
402
+ }
403
+
404
+ Ptile.template row_reduce<SumOp>(sum_score);
405
+
406
+ // Update O
407
+ Otile.template row_bin_op<MulOp>(factor);
408
+
409
+ simdgroup_barrier(mem_flags::mem_none);
410
+
411
+ // Do O = P @ V
412
+ STEEL_PRAGMA_UNROLL
413
+ for (short iq = 0; iq < TQ; iq++) {
414
+ STEEL_PRAGMA_UNROLL
415
+ for (short id = 0; id < TD; id++) {
416
+ if constexpr (BD == 128) {
417
+ if (id == 2) {
418
+ threadgroup_barrier(mem_flags::mem_none);
419
+ }
420
+ }
421
+
422
+ STEEL_PRAGMA_UNROLL
423
+ for (short ik = 0; ik < TK; ik++) {
424
+ using VSubTile = NAXSubTile<T, UK, UD>;
425
+ NAXTile<T, 1, 1, VSubTile> Vtile;
426
+
427
+ const int V_load_off = ik * UK * int(params->V_strides[2]) + id * UD;
428
+
429
+ if (!align_K && is_last_k) {
430
+ // Vtile.load_rows(
431
+ // V + V_load_off,
432
+ // int(params->V_strides[2]),
433
+ // lim_rows_k - ik * UK);
434
+ Vtile.load_safe(
435
+ V + V_load_off,
436
+ int(params->V_strides[2]),
437
+ short2(BD, lim_rows_k - ik * UK));
438
+ } else {
439
+ Vtile.load(V + V_load_off, int(params->V_strides[2]));
440
+ }
441
+
442
+ subtile_matmad_nax(
443
+ Otile.subtile_at(iq, id),
444
+ Ptile.subtile_at(iq, ik),
445
+ metal::bool_constant<false>{},
446
+ Vtile.subtile_at(0, 0),
447
+ metal::bool_constant<false>{});
448
+ }
449
+ }
450
+ }
451
+
452
+ // Prepare for next iteration
453
+ K += BK * int(params->K_strides[2]);
454
+ V += BK * int(params->V_strides[2]);
455
+ }
456
+
457
+ // Normalize output
458
+
459
+ threadgroup_barrier(mem_flags::mem_none);
460
+
461
+ metal::vec<AccumType, kRowsPT> rcp;
462
+ STEEL_PRAGMA_UNROLL
463
+ for (short i = 0; i < kRowsPT; ++i) {
464
+ rcp[i] = (1.f / sum_score[i]);
465
+ }
466
+
467
+ Otile.template row_bin_op<MulOp>(rcp);
468
+
469
+ // Store results
470
+ O += (tm + sm) * int(params->O_strides[2]) + sn;
471
+
472
+ if (!align_Q && is_last_q) {
473
+ if (lim_rows_q <= 0)
474
+ return;
475
+
476
+ // Otile.store_rows(O, params->O_strides[2], lim_rows_q);
477
+ Otile.store_safe(O, params->O_strides[2], short2(BD, lim_rows_q));
478
+ } else {
479
+ Otile.store(O, int(params->O_strides[2]));
480
+ }
481
+ }