mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,415 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #include <metal_simdgroup>
4
+
5
+ using namespace metal;
6
+
7
+ constant bool has_mask [[function_constant(20)]];
8
+ constant bool query_transposed [[function_constant(21)]];
9
+ constant bool do_causal [[function_constant(22)]];
10
+ constant bool bool_mask [[function_constant(23)]];
11
+ constant bool float_mask [[function_constant(24)]];
12
+ constant bool has_sinks [[function_constant(25)]];
13
+
14
+ template <typename T, int D, int V = D>
15
+ [[kernel]] void sdpa_vector(
16
+ const device T* queries [[buffer(0)]],
17
+ const device T* keys [[buffer(1)]],
18
+ const device T* values [[buffer(2)]],
19
+ device T* out [[buffer(3)]],
20
+ const constant int& gqa_factor [[buffer(4)]],
21
+ const constant int& N [[buffer(5)]],
22
+ const constant size_t& k_head_stride [[buffer(6)]],
23
+ const constant size_t& k_seq_stride [[buffer(7)]],
24
+ const constant size_t& v_head_stride [[buffer(8)]],
25
+ const constant size_t& v_seq_stride [[buffer(9)]],
26
+ const constant float& scale [[buffer(10)]],
27
+ const device bool* bmask [[buffer(11), function_constant(bool_mask)]],
28
+ const device T* fmask [[buffer(12), function_constant(float_mask)]],
29
+ const constant int& mask_kv_seq_stride
30
+ [[buffer(13), function_constant(has_mask)]],
31
+ const constant int& mask_q_seq_stride
32
+ [[buffer(14), function_constant(has_mask)]],
33
+ const constant int& mask_head_stride
34
+ [[buffer(15), function_constant(has_mask)]],
35
+ const device T* sinks [[buffer(16), function_constant(has_sinks)]],
36
+ const constant int& num_q_heads
37
+ [[buffer(17), function_constant(has_sinks)]],
38
+ uint3 tid [[threadgroup_position_in_grid]],
39
+ uint3 tpg [[threadgroups_per_grid]],
40
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
41
+ uint simd_lid [[thread_index_in_simdgroup]]) {
42
+ constexpr int BN = 32;
43
+ constexpr int BD = 32;
44
+ constexpr int qk_per_thread = D / BD;
45
+ constexpr int v_per_thread = V / BD;
46
+ int inner_k_stride = BN * int(k_seq_stride);
47
+ int inner_v_stride = BN * int(v_seq_stride);
48
+
49
+ typedef float U;
50
+
51
+ thread U q[qk_per_thread];
52
+ thread U k[qk_per_thread];
53
+ thread U o[v_per_thread];
54
+
55
+ threadgroup U outputs[BN * BD];
56
+ threadgroup U max_scores[BN];
57
+ threadgroup U sum_exp_scores[BN];
58
+
59
+ // Adjust positions
60
+ const int q_batch_head_idx = tid.x;
61
+ const int q_seq_idx = tid.y;
62
+ const int kv_head_idx = q_batch_head_idx / gqa_factor;
63
+ const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx;
64
+ const int q_offset =
65
+ query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset;
66
+ queries += q_offset * D + simd_lid * qk_per_thread;
67
+ keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride +
68
+ simd_lid * qk_per_thread;
69
+ values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride +
70
+ simd_lid * v_per_thread;
71
+ if (bool_mask) {
72
+ bmask += q_batch_head_idx * mask_head_stride +
73
+ simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride;
74
+ }
75
+ if (float_mask) {
76
+ fmask += q_batch_head_idx * mask_head_stride +
77
+ simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride;
78
+ }
79
+
80
+ out += o_offset * V + simd_gid * v_per_thread;
81
+
82
+ // Read the query and 0 the output accumulator
83
+ for (int i = 0; i < qk_per_thread; i++) {
84
+ q[i] = static_cast<U>(scale) * queries[i];
85
+ }
86
+ for (int i = 0; i < v_per_thread; i++) {
87
+ o[i] = 0;
88
+ }
89
+
90
+ U max_score = Limits<U>::finite_min;
91
+ U sum_exp_score = 0;
92
+ if (has_sinks && simd_gid == 0) {
93
+ max_score = static_cast<U>(sinks[q_batch_head_idx % num_q_heads]);
94
+ sum_exp_score = 1;
95
+ }
96
+
97
+ // For each key
98
+ for (int i = simd_gid; i < N; i += BN) {
99
+ bool use_key = true;
100
+ if (do_causal) {
101
+ use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
102
+ } else if (bool_mask) {
103
+ use_key = bmask[0];
104
+ } else if (float_mask) {
105
+ use_key = (fmask[0] >= Limits<T>::finite_min);
106
+ }
107
+ if (use_key) {
108
+ // Read the key
109
+ for (int j = 0; j < qk_per_thread; j++) {
110
+ k[j] = keys[j];
111
+ }
112
+
113
+ // Compute the i-th score
114
+ U score = 0;
115
+ for (int j = 0; j < qk_per_thread; j++) {
116
+ score += q[j] * k[j];
117
+ }
118
+ score = simd_sum(score);
119
+ if (float_mask) {
120
+ score += static_cast<U>(fmask[0]);
121
+ }
122
+
123
+ // Update the accumulators
124
+ U new_max = max(max_score, score);
125
+ U factor = fast::exp(max_score - new_max);
126
+ U exp_score = fast::exp(score - new_max);
127
+
128
+ max_score = new_max;
129
+ sum_exp_score = sum_exp_score * factor + exp_score;
130
+
131
+ // Update the output accumulator
132
+ for (int j = 0; j < v_per_thread; j++) {
133
+ o[j] = o[j] * factor + exp_score * values[j];
134
+ }
135
+ }
136
+
137
+ // Move the pointers to the next kv
138
+ keys += inner_k_stride;
139
+ values += inner_v_stride;
140
+ if (bool_mask) {
141
+ bmask += BN * mask_kv_seq_stride;
142
+ }
143
+ if (float_mask) {
144
+ fmask += BN * mask_kv_seq_stride;
145
+ }
146
+ }
147
+
148
+ // Each thread has a partial part of the output so we need to combine them.
149
+
150
+ // First let's communicate the max and sum_exp
151
+ if (simd_lid == 0) {
152
+ max_scores[simd_gid] = max_score;
153
+ sum_exp_scores[simd_gid] = sum_exp_score;
154
+ }
155
+ threadgroup_barrier(mem_flags::mem_threadgroup);
156
+ max_score = max_scores[simd_lid];
157
+ U new_max = simd_max(max_score);
158
+ U factor = fast::exp(max_score - new_max);
159
+ sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor);
160
+
161
+ // Now we need to aggregate all the outputs
162
+ for (int i = 0; i < v_per_thread; i++) {
163
+ outputs[simd_lid * BD + simd_gid] = o[i];
164
+ threadgroup_barrier(mem_flags::mem_threadgroup);
165
+ o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor);
166
+ o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score);
167
+ threadgroup_barrier(mem_flags::mem_threadgroup);
168
+ }
169
+
170
+ // And write the output
171
+ if (simd_lid == 0) {
172
+ for (int i = 0; i < v_per_thread; i++) {
173
+ out[i] = static_cast<T>(o[i]);
174
+ }
175
+ }
176
+ }
177
+
178
+ template <typename T, int D, int V = D>
179
+ [[kernel]] void sdpa_vector_2pass_1(
180
+ const device T* queries [[buffer(0)]],
181
+ const device T* keys [[buffer(1)]],
182
+ const device T* values [[buffer(2)]],
183
+ device float* out [[buffer(3)]],
184
+ device float* sums [[buffer(4)]],
185
+ device float* maxs [[buffer(5)]],
186
+ const constant int& gqa_factor [[buffer(6)]],
187
+ const constant int& N [[buffer(7)]],
188
+ const constant size_t& k_head_stride [[buffer(8)]],
189
+ const constant size_t& k_seq_stride [[buffer(9)]],
190
+ const constant size_t& v_head_stride [[buffer(10)]],
191
+ const constant size_t& v_seq_stride [[buffer(11)]],
192
+ const constant float& scale [[buffer(12)]],
193
+ const device bool* bmask [[buffer(13), function_constant(bool_mask)]],
194
+ const device T* fmask [[buffer(14), function_constant(float_mask)]],
195
+ const constant int& mask_kv_seq_stride
196
+ [[buffer(15), function_constant(has_mask)]],
197
+ const constant int& mask_q_seq_stride
198
+ [[buffer(16), function_constant(has_mask)]],
199
+ const constant int& mask_head_stride
200
+ [[buffer(17), function_constant(has_mask)]],
201
+ const device T* sinks [[buffer(18), function_constant(has_sinks)]],
202
+ const constant int& num_q_heads
203
+ [[buffer(19), function_constant(has_sinks)]],
204
+ uint3 tid [[threadgroup_position_in_grid]],
205
+ uint3 tpg [[threadgroups_per_grid]],
206
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
207
+ uint simd_lid [[thread_index_in_simdgroup]]) {
208
+ constexpr int BN = 8;
209
+ constexpr int BD = 32;
210
+ constexpr int qk_per_thread = D / BD;
211
+ constexpr int v_per_thread = V / BD;
212
+ int inner_k_stride = BN * int(k_seq_stride);
213
+ int inner_v_stride = BN * int(v_seq_stride);
214
+ constexpr int blocks = 32;
215
+
216
+ typedef float U;
217
+
218
+ thread U q[qk_per_thread];
219
+ thread U k[qk_per_thread];
220
+ thread U o[v_per_thread];
221
+
222
+ threadgroup U outputs[BN * BD];
223
+ threadgroup U max_scores[BN];
224
+ threadgroup U sum_exp_scores[BN];
225
+
226
+ // Adjust positions
227
+ const int block_idx = tid.z;
228
+ const int q_batch_head_idx = tid.x;
229
+ const int q_seq_idx = tid.y;
230
+ const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx;
231
+ const int q_offset =
232
+ query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset;
233
+ const int kv_head_idx = q_batch_head_idx / gqa_factor;
234
+
235
+ queries += q_offset * D + simd_lid * qk_per_thread;
236
+ keys += kv_head_idx * k_head_stride +
237
+ (block_idx * BN + simd_gid) * k_seq_stride + simd_lid * qk_per_thread;
238
+ values += kv_head_idx * v_head_stride +
239
+ (block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread;
240
+ out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread;
241
+ if (bool_mask) {
242
+ bmask += q_batch_head_idx * mask_head_stride +
243
+ (block_idx * BN + simd_gid) * mask_kv_seq_stride +
244
+ q_seq_idx * mask_q_seq_stride;
245
+ }
246
+ if (float_mask) {
247
+ fmask += q_batch_head_idx * mask_head_stride +
248
+ (block_idx * BN + simd_gid) * mask_kv_seq_stride +
249
+ q_seq_idx * mask_q_seq_stride;
250
+ }
251
+ sums += o_offset * blocks + block_idx;
252
+ maxs += o_offset * blocks + block_idx;
253
+
254
+ // Read the query and 0 the output accumulator
255
+ for (int i = 0; i < qk_per_thread; i++) {
256
+ q[i] = static_cast<U>(scale) * queries[i];
257
+ }
258
+ for (int i = 0; i < v_per_thread; i++) {
259
+ o[i] = 0;
260
+ }
261
+
262
+ U max_score = Limits<U>::finite_min;
263
+ U sum_exp_score = 0;
264
+ if (has_sinks && block_idx == 0 && simd_gid == 0) {
265
+ int q_head_idx = q_batch_head_idx % num_q_heads;
266
+ max_score = static_cast<U>(sinks[q_head_idx]);
267
+ sum_exp_score = 1;
268
+ }
269
+
270
+ // For each key
271
+ for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
272
+ bool use_key = true;
273
+ if (do_causal) {
274
+ use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
275
+ } else if (bool_mask) {
276
+ use_key = bmask[0];
277
+ } else if (float_mask) {
278
+ use_key = (fmask[0] >= Limits<T>::finite_min);
279
+ }
280
+ if (use_key) {
281
+ // Read the key
282
+ for (int i = 0; i < qk_per_thread; i++) {
283
+ k[i] = keys[i];
284
+ }
285
+
286
+ // Compute the i-th score
287
+ U score = 0;
288
+ for (int i = 0; i < qk_per_thread; i++) {
289
+ score += q[i] * k[i];
290
+ }
291
+ score = simd_sum(score);
292
+
293
+ if (float_mask) {
294
+ score += fmask[0];
295
+ }
296
+
297
+ // Update the accumulators
298
+ U new_max = max(max_score, score);
299
+ U factor = fast::exp(max_score - new_max);
300
+ U exp_score = fast::exp(score - new_max);
301
+
302
+ max_score = new_max;
303
+ sum_exp_score = sum_exp_score * factor + exp_score;
304
+
305
+ // Update the output accumulator
306
+ for (int i = 0; i < v_per_thread; i++) {
307
+ o[i] = o[i] * factor + exp_score * values[i];
308
+ }
309
+ }
310
+
311
+ // Move the pointers to the next kv
312
+ keys += blocks * inner_k_stride;
313
+ values += blocks * inner_v_stride;
314
+ if (bool_mask) {
315
+ bmask += BN * blocks * mask_kv_seq_stride;
316
+ }
317
+ if (float_mask) {
318
+ fmask += BN * blocks * mask_kv_seq_stride;
319
+ }
320
+ }
321
+
322
+ // Each thread has a partial part of the output so we need to combine them.
323
+
324
+ // First let's communicate the max and sum_exp
325
+ if (simd_lid == 0) {
326
+ max_scores[simd_gid] = max_score;
327
+ sum_exp_scores[simd_gid] = sum_exp_score;
328
+ }
329
+ threadgroup_barrier(mem_flags::mem_threadgroup);
330
+ max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9;
331
+ U new_max = simd_max(max_score);
332
+ U factor = fast::exp(max_score - new_max);
333
+ sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0;
334
+ sum_exp_score = simd_sum(sum_exp_score * factor);
335
+
336
+ // Write the sum and new max
337
+ if (simd_gid == 0) {
338
+ sums[0] = sum_exp_score;
339
+ maxs[0] = new_max;
340
+ }
341
+
342
+ // Now we need to aggregate all the outputs
343
+ for (int i = 0; i < v_per_thread; i++) {
344
+ outputs[simd_lid * BN + simd_gid] =
345
+ o[i] * fast::exp(max_scores[simd_gid] - new_max);
346
+ threadgroup_barrier(mem_flags::mem_threadgroup);
347
+
348
+ // And write the output
349
+ if (simd_gid == 0) {
350
+ U output = outputs[simd_lid * BN];
351
+ for (int j = 1; j < BN; j++) {
352
+ output += outputs[simd_lid * BN + j];
353
+ }
354
+ out[i] = static_cast<T>(output);
355
+ }
356
+ threadgroup_barrier(mem_flags::mem_threadgroup);
357
+ }
358
+ }
359
+
360
+ template <typename T, int D>
361
+ [[kernel]] void sdpa_vector_2pass_2(
362
+ const device float* partials [[buffer(0)]],
363
+ const device float* sums [[buffer(1)]],
364
+ const device float* maxs [[buffer(2)]],
365
+ device T* out [[buffer(3)]],
366
+ uint3 tid [[threadgroup_position_in_grid]],
367
+ uint3 tpg [[threadgroups_per_grid]],
368
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
369
+ uint simd_lid [[thread_index_in_simdgroup]]) {
370
+ constexpr int BN = 32;
371
+ constexpr int BD = 32;
372
+ constexpr int elem_per_thread = D / BD;
373
+ constexpr int blocks = 32;
374
+
375
+ typedef float U;
376
+
377
+ thread U o[elem_per_thread];
378
+ threadgroup U outputs[BN * BD];
379
+
380
+ // Adjust positions
381
+ const int head_idx = tid.x;
382
+ const int q_seq_idx = tid.y;
383
+ const int q_offset = head_idx * tpg.y + q_seq_idx;
384
+ ;
385
+ partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
386
+ sums += q_offset * blocks;
387
+ maxs += q_offset * blocks;
388
+ out += q_offset * D + simd_gid * elem_per_thread;
389
+
390
+ // First everybody reads the max and sum_exp
391
+ U max_score = maxs[simd_lid];
392
+ U new_max = simd_max(max_score);
393
+ U factor = fast::exp(max_score - new_max);
394
+ U sum_exp_score = simd_sum(sums[simd_lid] * factor);
395
+
396
+ // Now read the block into registers and then use shared memory to transpose
397
+ // it
398
+ for (int i = 0; i < elem_per_thread; i++) {
399
+ o[i] = partials[i];
400
+ }
401
+ for (int i = 0; i < elem_per_thread; i++) {
402
+ outputs[simd_lid * BD + simd_gid] = o[i];
403
+ threadgroup_barrier(mem_flags::mem_threadgroup);
404
+ o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor);
405
+ o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score);
406
+ threadgroup_barrier(mem_flags::mem_threadgroup);
407
+ }
408
+
409
+ // And write the output
410
+ if (simd_lid == 0) {
411
+ for (int i = 0; i < elem_per_thread; i++) {
412
+ out[i] = static_cast<T>(o[i]);
413
+ }
414
+ }
415
+ }
@@ -0,0 +1,190 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ template <typename T>
4
+ inline T softmax_exp(T x) {
5
+ // Softmax doesn't need high precision exponential cause x is gonna be in
6
+ // (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).
7
+ return fast::exp(x);
8
+ }
9
+
10
+ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
11
+ [[kernel]] void softmax_single_row(
12
+ const device T* in,
13
+ device T* out,
14
+ constant int& axis_size,
15
+ uint gid [[threadgroup_position_in_grid]],
16
+ uint _lid [[thread_position_in_threadgroup]],
17
+ uint simd_lane_id [[thread_index_in_simdgroup]],
18
+ uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
19
+ int lid = _lid;
20
+
21
+ constexpr int SIMD_SIZE = 32;
22
+
23
+ threadgroup AccT local_max[SIMD_SIZE];
24
+ threadgroup AccT local_normalizer[SIMD_SIZE];
25
+
26
+ AccT ld[N_READS];
27
+
28
+ in += gid * size_t(axis_size) + lid * N_READS;
29
+ if (lid * N_READS + N_READS <= axis_size) {
30
+ for (int i = 0; i < N_READS; i++) {
31
+ ld[i] = AccT(in[i]);
32
+ }
33
+ } else {
34
+ for (int i = 0; i < N_READS; i++) {
35
+ ld[i] =
36
+ ((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits<AccT>::min;
37
+ }
38
+ }
39
+ if (simd_group_id == 0) {
40
+ local_max[simd_lane_id] = Limits<AccT>::min;
41
+ local_normalizer[simd_lane_id] = 0;
42
+ }
43
+ threadgroup_barrier(mem_flags::mem_threadgroup);
44
+
45
+ // Get the max
46
+ AccT maxval = Limits<AccT>::finite_min;
47
+ for (int i = 0; i < N_READS; i++) {
48
+ maxval = (maxval < ld[i]) ? ld[i] : maxval;
49
+ }
50
+ maxval = simd_max(maxval);
51
+ if (simd_lane_id == 0) {
52
+ local_max[simd_group_id] = maxval;
53
+ }
54
+ threadgroup_barrier(mem_flags::mem_threadgroup);
55
+ if (simd_group_id == 0) {
56
+ maxval = simd_max(local_max[simd_lane_id]);
57
+ if (simd_lane_id == 0) {
58
+ local_max[0] = maxval;
59
+ }
60
+ }
61
+ threadgroup_barrier(mem_flags::mem_threadgroup);
62
+ maxval = local_max[0];
63
+
64
+ // Compute exp(x_i - maxval) and store the partial sums in local_normalizer
65
+ AccT normalizer = 0;
66
+ for (int i = 0; i < N_READS; i++) {
67
+ AccT exp_x = softmax_exp(ld[i] - maxval);
68
+ ld[i] = exp_x;
69
+ normalizer += exp_x;
70
+ }
71
+ normalizer = simd_sum(normalizer);
72
+ if (simd_lane_id == 0) {
73
+ local_normalizer[simd_group_id] = normalizer;
74
+ }
75
+ threadgroup_barrier(mem_flags::mem_threadgroup);
76
+ if (simd_group_id == 0) {
77
+ normalizer = simd_sum(local_normalizer[simd_lane_id]);
78
+ if (simd_lane_id == 0) {
79
+ local_normalizer[0] = normalizer;
80
+ }
81
+ }
82
+ threadgroup_barrier(mem_flags::mem_threadgroup);
83
+ normalizer = 1 / local_normalizer[0];
84
+
85
+ // Normalize and write to the output
86
+ out += gid * size_t(axis_size) + lid * N_READS;
87
+ if (lid * N_READS + N_READS <= axis_size) {
88
+ for (int i = 0; i < N_READS; i++) {
89
+ out[i] = T(ld[i] * normalizer);
90
+ }
91
+ } else {
92
+ for (int i = 0; i < N_READS; i++) {
93
+ if ((lid * N_READS + i) < axis_size) {
94
+ out[i] = T(ld[i] * normalizer);
95
+ }
96
+ }
97
+ }
98
+ }
99
+
100
+ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
101
+ [[kernel]] void softmax_looped(
102
+ const device T* in,
103
+ device T* out,
104
+ constant int& axis_size,
105
+ uint gid [[threadgroup_position_in_grid]],
106
+ uint lid [[thread_position_in_threadgroup]],
107
+ uint lsize [[threads_per_threadgroup]],
108
+ uint simd_lane_id [[thread_index_in_simdgroup]],
109
+ uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
110
+ in += gid * size_t(axis_size);
111
+
112
+ constexpr int SIMD_SIZE = 32;
113
+
114
+ threadgroup AccT local_max[SIMD_SIZE];
115
+ threadgroup AccT local_normalizer[SIMD_SIZE];
116
+
117
+ // Get the max and the normalizer in one go
118
+ AccT prevmax;
119
+ AccT maxval = Limits<AccT>::finite_min;
120
+ AccT normalizer = 0;
121
+ for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
122
+ r++) {
123
+ int offset = r * lsize * N_READS + lid * N_READS;
124
+ AccT vals[N_READS];
125
+ if (offset + N_READS <= axis_size) {
126
+ for (int i = 0; i < N_READS; i++) {
127
+ vals[i] = AccT(in[offset + i]);
128
+ }
129
+ } else {
130
+ for (int i = 0; i < N_READS; i++) {
131
+ vals[i] =
132
+ (offset + i < axis_size) ? AccT(in[offset + i]) : Limits<AccT>::min;
133
+ }
134
+ }
135
+ prevmax = maxval;
136
+ for (int i = 0; i < N_READS; i++) {
137
+ maxval = (maxval < vals[i]) ? vals[i] : maxval;
138
+ }
139
+ normalizer *= softmax_exp(prevmax - maxval);
140
+ for (int i = 0; i < N_READS; i++) {
141
+ normalizer += softmax_exp(vals[i] - maxval);
142
+ }
143
+ }
144
+ // Now we got partial normalizer of N_READS * ceildiv(axis_size, N_READS *
145
+ // lsize) parts. We need to combine them.
146
+ // 1. We start by finding the max across simd groups
147
+ // 2. We then change the partial normalizers to account for a possible
148
+ // change in max
149
+ // 3. We sum all normalizers
150
+ prevmax = maxval;
151
+ maxval = simd_max(maxval);
152
+ normalizer *= softmax_exp(prevmax - maxval);
153
+ normalizer = simd_sum(normalizer);
154
+
155
+ // Now the normalizer and max value is correct for each simdgroup. We write
156
+ // them shared memory and combine them.
157
+ prevmax = maxval;
158
+ if (simd_lane_id == 0) {
159
+ local_max[simd_group_id] = maxval;
160
+ }
161
+ threadgroup_barrier(mem_flags::mem_threadgroup);
162
+ maxval = simd_max(local_max[simd_lane_id]);
163
+ normalizer *= softmax_exp(prevmax - maxval);
164
+ if (simd_lane_id == 0) {
165
+ local_normalizer[simd_group_id] = normalizer;
166
+ }
167
+ threadgroup_barrier(mem_flags::mem_threadgroup);
168
+ normalizer = simd_sum(local_normalizer[simd_lane_id]);
169
+ normalizer = 1 / normalizer;
170
+
171
+ // Finally given the normalizer and max value we can directly write the
172
+ // softmax output
173
+ out += gid * size_t(axis_size);
174
+ for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
175
+ r++) {
176
+ int offset = r * lsize * N_READS + lid * N_READS;
177
+ if (offset + N_READS <= axis_size) {
178
+ for (int i = 0; i < N_READS; i++) {
179
+ out[offset + i] = T(softmax_exp(in[offset + i] - maxval) * normalizer);
180
+ }
181
+ } else {
182
+ for (int i = 0; i < N_READS; i++) {
183
+ if (offset + i < axis_size) {
184
+ out[offset + i] =
185
+ T(softmax_exp(in[offset + i] - maxval) * normalizer);
186
+ }
187
+ }
188
+ }
189
+ }
190
+ }