mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,182 @@
1
+ // Copyright © 2024 Apple Inc.
2
+ #include <metal_common>
3
+ #include <metal_compute>
4
+
5
+ #include "mlx/backend/metal/kernels/steel/defines.h"
6
+
7
+ using namespace metal;
8
+
9
+ // Thread local Hadamard transform for 2^R
10
+ template <short R>
11
+ METAL_FUNC void radix_func(thread float* x) {
12
+ constexpr short logR = __builtin_ctz(R);
13
+ short h = 1;
14
+ STEEL_PRAGMA_UNROLL
15
+ for (short s = 0; s < logR; s++) {
16
+ STEEL_PRAGMA_UNROLL
17
+ for (short i = 0; i < R / 2; i++) {
18
+ short k = i & (h - 1);
19
+ short j = ((i - k) << 1) + k;
20
+ float a = x[j];
21
+ float b = x[j + h];
22
+ x[j] = a + b;
23
+ x[j + h] = a - b;
24
+ }
25
+ h <<= 1;
26
+ }
27
+ }
28
+
29
+ template <typename T, int N, int max_radix, int read_width, int stride = 1>
30
+ [[kernel]] void hadamard_n(
31
+ const device T* in [[buffer(0)]],
32
+ device T* out [[buffer(1)]],
33
+ constant const float& scale,
34
+ uint3 elem [[thread_position_in_grid]],
35
+ uint3 grid [[threads_per_grid]]) {
36
+ // Compute a Hadamard transform of size N = 2^k
37
+ //
38
+ // Equivalent to:
39
+ // from scipy.linalg import hadamard
40
+ // y = hadamard(len(x)) @ x
41
+
42
+ constexpr short num_threads = N / max_radix;
43
+ constexpr short logN = __builtin_ctz(N);
44
+ constexpr short logR = __builtin_ctz(max_radix);
45
+ constexpr short num_steps = logN / logR;
46
+ constexpr short logFinal = logN % logR;
47
+ constexpr short final_radix = 1 << (logFinal);
48
+
49
+ int batch_idx = elem.y * N * stride + elem.z;
50
+ short i = elem.x;
51
+
52
+ threadgroup T buf[N];
53
+
54
+ // Read values from device
55
+ if (stride == 1) {
56
+ STEEL_PRAGMA_UNROLL
57
+ for (short j = 0; j < max_radix / read_width; j++) {
58
+ short index = j * read_width * num_threads + i * read_width;
59
+ STEEL_PRAGMA_UNROLL
60
+ for (short r = 0; r < read_width; r++) {
61
+ buf[index + r] = in[batch_idx + index + r];
62
+ }
63
+ }
64
+ } else {
65
+ STEEL_PRAGMA_UNROLL
66
+ for (short j = 0; j < max_radix; j++) {
67
+ buf[j * num_threads + i] = in[batch_idx + (j * num_threads + i) * stride];
68
+ }
69
+ }
70
+
71
+ threadgroup_barrier(mem_flags::mem_threadgroup);
72
+
73
+ float x[max_radix];
74
+ short h = 1;
75
+
76
+ STEEL_PRAGMA_UNROLL
77
+ for (short s = 0; s < num_steps; s++) {
78
+ short k = i & (h - 1);
79
+ short j = ((i - k) << logR) + k;
80
+
81
+ STEEL_PRAGMA_UNROLL
82
+ for (short r = 0; r < max_radix; r++) {
83
+ x[r] = buf[j + h * r];
84
+ }
85
+
86
+ radix_func<max_radix>(x);
87
+
88
+ STEEL_PRAGMA_UNROLL
89
+ for (short r = 0; r < max_radix; r++) {
90
+ buf[j + h * r] = T(x[r]);
91
+ }
92
+
93
+ h <<= logR;
94
+ threadgroup_barrier(mem_flags::mem_threadgroup);
95
+ }
96
+
97
+ // Do the final radix
98
+ // e.g. max_radix = 16
99
+ // N = 1024 = 16 * 16 * 4
100
+ if (final_radix > 1) {
101
+ // Each thread does multiple butterflies
102
+ STEEL_PRAGMA_UNROLL
103
+ for (int t = 0; t < max_radix / final_radix; t++) {
104
+ short index = i + t * num_threads;
105
+ short k = index & (h - 1);
106
+ short j = ((index - k) << logFinal) + k;
107
+ STEEL_PRAGMA_UNROLL
108
+ for (short r = 0; r < final_radix; r++) {
109
+ x[r] = buf[j + h * r];
110
+ }
111
+
112
+ radix_func<final_radix>(x);
113
+
114
+ STEEL_PRAGMA_UNROLL
115
+ for (short r = 0; r < final_radix; r++) {
116
+ buf[j + h * r] = T(x[r]);
117
+ }
118
+ }
119
+ threadgroup_barrier(mem_flags::mem_threadgroup);
120
+ }
121
+
122
+ // Write values to device
123
+ if (stride == 1) {
124
+ STEEL_PRAGMA_UNROLL
125
+ for (short j = 0; j < max_radix / read_width; j++) {
126
+ short index = j * read_width * num_threads + i * read_width;
127
+ STEEL_PRAGMA_UNROLL
128
+ for (short r = 0; r < read_width; r++) {
129
+ out[batch_idx + index + r] = T(buf[index + r] * scale);
130
+ }
131
+ }
132
+ } else {
133
+ STEEL_PRAGMA_UNROLL
134
+ for (short j = 0; j < max_radix; j++) {
135
+ out[batch_idx + (j * num_threads + i) * stride] =
136
+ buf[j * num_threads + i];
137
+ }
138
+ }
139
+ }
140
+
141
+ template <typename T, int N, int M, int read_width>
142
+ [[kernel]] void hadamard_m(
143
+ const device T* in [[buffer(0)]],
144
+ device T* out [[buffer(1)]],
145
+ constant const float& scale,
146
+ uint3 elem [[thread_position_in_grid]],
147
+ uint3 grid [[threads_per_grid]]) {
148
+ // Compute a Hadamard transform of size M
149
+ // using a naive O(M^2) codelet.
150
+ //
151
+ // This kernel is the second stage in the computation
152
+ // of a Hadamard transform of size M*N where N = 2^k.
153
+
154
+ int index = elem.x * grid.y + elem.y;
155
+ short i = index % (N / read_width);
156
+ int batch_idx = index / (N / read_width) * M * N;
157
+
158
+ float x[read_width][M];
159
+ STEEL_PRAGMA_UNROLL
160
+ for (short c = 0; c < M; c++) {
161
+ STEEL_PRAGMA_UNROLL
162
+ for (short r = 0; r < read_width; r++) {
163
+ x[r][c] = in[batch_idx + c * N + i * read_width + r];
164
+ }
165
+ }
166
+
167
+ STEEL_PRAGMA_UNROLL
168
+ for (short r = 0; r < read_width; r++) {
169
+ // This function is JIT compiled for M
170
+ // using the Hadamard matrix strings in `metal/hadamard.cpp`
171
+ hadamard_radix_m(x[r]);
172
+ }
173
+
174
+ // Write back to device
175
+ STEEL_PRAGMA_UNROLL
176
+ for (short c = 0; c < M; c++) {
177
+ STEEL_PRAGMA_UNROLL
178
+ for (short r = 0; r < read_width; r++) {
179
+ out[batch_idx + c * N + i * read_width + r] = T(x[r][c] * scale);
180
+ }
181
+ }
182
+ }
@@ -0,0 +1,51 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/backend/metal/kernels/indexing/indexing.h"
6
+
7
+ template <typename T, typename IdxT, int NIDX, int IDX_NDIM, typename LocT>
8
+ METAL_FUNC void gather_impl(
9
+ const device T* src [[buffer(0)]],
10
+ device T* out [[buffer(1)]],
11
+ const constant int* src_shape [[buffer(2)]],
12
+ const constant int64_t* src_strides [[buffer(3)]],
13
+ const constant size_t& src_ndim [[buffer(4)]],
14
+ const constant int* slice_sizes [[buffer(5)]],
15
+ const constant int* axes [[buffer(6)]],
16
+ const thread Indices<IdxT, NIDX>& indices,
17
+ uint3 index [[thread_position_in_grid]],
18
+ uint3 grid_dim [[threads_per_grid]]) {
19
+ LocT src_idx = 0;
20
+ for (int i = 0; i < NIDX; ++i) {
21
+ LocT idx_loc;
22
+ if (IDX_NDIM == 0) {
23
+ idx_loc = 0;
24
+ } else if (IDX_NDIM == 1) {
25
+ idx_loc = index.x * static_cast<LocT>(indices.strides[indices.ndim * i]);
26
+ } else {
27
+ idx_loc = index.x * static_cast<LocT>(indices.strides[indices.ndim * i]);
28
+ idx_loc += indices.row_contiguous[i]
29
+ ? index.y
30
+ : elem_to_loc<LocT>(
31
+ index.y,
32
+ &indices.shapes[indices.ndim * i + 1],
33
+ &indices.strides[indices.ndim * i + 1],
34
+ indices.ndim - 1);
35
+ }
36
+ auto ax = axes[i];
37
+ auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
38
+ src_idx += static_cast<LocT>(idx_val) * static_cast<LocT>(src_strides[ax]);
39
+ }
40
+
41
+ auto src_offset =
42
+ elem_to_loc<LocT>(index.z, slice_sizes, src_strides, src_ndim);
43
+
44
+ LocT out_idx = index.z;
45
+ if (IDX_NDIM == 1) {
46
+ out_idx += static_cast<LocT>(grid_dim.z) * index.x;
47
+ } else if (IDX_NDIM >= 2) {
48
+ out_idx += grid_dim.z * (index.x * static_cast<LocT>(grid_dim.y) + index.y);
49
+ }
50
+ out[out_idx] = src[src_offset + src_idx];
51
+ }
@@ -0,0 +1,44 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ template <typename T, typename IdxT, typename LocT, bool SrcC, bool IdxC>
6
+ [[kernel]] void gather_axis(
7
+ const device T* src [[buffer(0)]],
8
+ const device IdxT* indices [[buffer(1)]],
9
+ device T* out [[buffer(2)]],
10
+ const constant int* shape [[buffer(3)]],
11
+ const constant int64_t* src_strides [[buffer(4)]],
12
+ const constant int64_t* idx_strides [[buffer(5)]],
13
+ const constant size_t& ndim [[buffer(6)]],
14
+ const constant int& axis [[buffer(7)]],
15
+ const constant int& axis_size [[buffer(8)]],
16
+ const constant size_t& src_ax_stride [[buffer(9)]],
17
+ const constant size_t& idx_ax_stride [[buffer(10)]],
18
+ uint3 index [[thread_position_in_grid]],
19
+ uint3 grid_dim [[threads_per_grid]]) {
20
+ LocT elem_idx = index.z * static_cast<LocT>(grid_dim.x);
21
+ LocT out_idx = elem_idx * grid_dim.y + index.x;
22
+
23
+ LocT idx_loc = index.y * static_cast<LocT>(idx_ax_stride);
24
+ if (IdxC) {
25
+ idx_loc += out_idx;
26
+ } else {
27
+ idx_loc += elem_to_loc<LocT>(elem_idx + index.x, shape, idx_strides, ndim);
28
+ }
29
+
30
+ auto idx_val = indices[idx_loc];
31
+ if (is_signed_v<IdxT>) {
32
+ idx_val = (idx_val < 0) ? idx_val + axis_size : idx_val;
33
+ }
34
+
35
+ LocT src_idx = idx_val * static_cast<LocT>(src_ax_stride);
36
+ if (SrcC) {
37
+ src_idx += elem_idx * axis_size + index.x;
38
+ } else {
39
+ src_idx += elem_to_loc<LocT>(elem_idx + index.x, shape, src_strides, ndim);
40
+ }
41
+
42
+ out_idx += index.y * static_cast<LocT>(grid_dim.x);
43
+ out[out_idx] = src[src_idx];
44
+ }
@@ -0,0 +1,24 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/backend/metal/kernels/indexing/indexing.h"
6
+
7
+ template <typename T, typename IdxT, typename LocT, int N>
8
+ [[kernel]] void gather_front(
9
+ const device T* src,
10
+ const device IdxT* indices,
11
+ device T* out,
12
+ const constant int64_t& stride,
13
+ const constant int& size,
14
+ uint2 index [[thread_position_in_grid]],
15
+ uint2 grid_dim [[threads_per_grid]]) {
16
+ auto idx = offset_neg_idx(indices[index.y], size);
17
+ LocT src_idx = static_cast<LocT>(stride) * idx;
18
+ LocT out_idx = static_cast<LocT>(stride) * index.y;
19
+
20
+ int s_idx = N * index.x;
21
+ for (int i = 0; i < N && s_idx < stride; ++i, ++s_idx) {
22
+ out[out_idx + s_idx] = src[src_idx + s_idx];
23
+ }
24
+ }
@@ -0,0 +1,23 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <metal_stdlib>
6
+
7
+ template <typename IdxT, int NIDX>
8
+ struct Indices {
9
+ const array<const device IdxT*, NIDX> buffers;
10
+ const constant int* shapes;
11
+ const constant int64_t* strides;
12
+ const constant bool* row_contiguous;
13
+ const int ndim;
14
+ };
15
+
16
+ template <typename IdxT>
17
+ METAL_FUNC size_t offset_neg_idx(IdxT idx, int size) {
18
+ if (is_unsigned_v<IdxT>) {
19
+ return idx;
20
+ } else {
21
+ return (idx < 0) ? idx + size : idx;
22
+ }
23
+ }
@@ -0,0 +1,38 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ template <typename T, bool src_contiguous>
6
+ [[kernel]] void masked_assign_impl(
7
+ const device bool* mask [[buffer(0)]],
8
+ const device uint* scatter_offsets [[buffer(1)]],
9
+ const device T* src [[buffer(2)]],
10
+ device T* out [[buffer(3)]],
11
+ const constant int* src_shapes [[buffer(4)]],
12
+ const constant int64_t* src_strides [[buffer(5)]],
13
+ const constant int& src_ndim [[buffer(6)]],
14
+ const constant int64_t& src_batch_size [[buffer(7)]],
15
+ const constant int64_t& mask_batch_size [[buffer(8)]],
16
+ uint idx [[thread_position_in_grid]]) {
17
+ const bool mask_value = mask[idx];
18
+ if (!mask_value) {
19
+ return;
20
+ }
21
+
22
+ const uint src_index = scatter_offsets[idx];
23
+ if (src_index >= src_batch_size) {
24
+ return;
25
+ }
26
+
27
+ const uint batch_idx = idx / mask_batch_size;
28
+
29
+ if (src_contiguous) {
30
+ out[idx] = src[batch_idx * src_batch_size + src_index];
31
+ } else {
32
+ out[idx] = src[elem_to_loc<uint>(
33
+ batch_idx * src_batch_size + src_index,
34
+ src_shapes,
35
+ src_strides,
36
+ src_ndim)];
37
+ }
38
+ }
@@ -0,0 +1,59 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/backend/metal/kernels/indexing/indexing.h"
6
+
7
+ template <
8
+ typename T,
9
+ typename IdxT,
10
+ typename Op,
11
+ int NIDX,
12
+ bool UPD_ROW_CONTIG,
13
+ int NWORK,
14
+ typename LocT>
15
+ METAL_FUNC void scatter_impl(
16
+ const device T* updates,
17
+ device mlx_atomic<T>* out,
18
+ const constant int* upd_shape,
19
+ const constant int64_t* upd_strides,
20
+ const constant size_t& upd_ndim,
21
+ const constant size_t& upd_size,
22
+ const constant int* out_shape,
23
+ const constant int64_t* out_strides,
24
+ const constant size_t& out_ndim,
25
+ const constant int* axes,
26
+ const constant size_t& idx_size,
27
+ const thread Indices<IdxT, NIDX>& indices,
28
+ uint2 gid [[thread_position_in_grid]]) {
29
+ Op op;
30
+
31
+ auto ind_idx = gid.y * NWORK;
32
+ LocT out_offset = 0;
33
+ if (upd_size > 1) {
34
+ out_offset = elem_to_loc<LocT>(
35
+ gid.x, upd_shape + indices.ndim, out_strides, out_ndim);
36
+ }
37
+
38
+ for (int j = 0; j < NWORK && ind_idx < idx_size; ++j, ind_idx++) {
39
+ LocT out_idx = out_offset;
40
+ for (int i = 0; i < NIDX; ++i) {
41
+ auto idx_loc = indices.row_contiguous[i]
42
+ ? ind_idx
43
+ : elem_to_loc<LocT>(
44
+ ind_idx,
45
+ &indices.shapes[indices.ndim * i],
46
+ &indices.strides[indices.ndim * i],
47
+ indices.ndim);
48
+ auto ax = axes[i];
49
+ auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
50
+ out_idx +=
51
+ static_cast<LocT>(idx_val) * static_cast<LocT>(out_strides[ax]);
52
+ }
53
+ auto upd_idx = ind_idx * static_cast<LocT>(upd_size) + gid.x;
54
+ if constexpr (!UPD_ROW_CONTIG) {
55
+ upd_idx = elem_to_loc<LocT>(upd_idx, upd_shape, upd_strides, upd_ndim);
56
+ }
57
+ op.atomic_update(out, updates[upd_idx], out_idx);
58
+ }
59
+ }
@@ -0,0 +1,52 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ template <
6
+ typename T,
7
+ typename IdxT,
8
+ typename LocT,
9
+ typename Op,
10
+ bool UpdC,
11
+ bool IdxC>
12
+ [[kernel]] void scatter_axis(
13
+ const device T* upd [[buffer(0)]],
14
+ const device IdxT* indices [[buffer(1)]],
15
+ device mlx_atomic<T>* out [[buffer(2)]],
16
+ const constant int* shape [[buffer(3)]],
17
+ const constant int64_t* upd_strides [[buffer(4)]],
18
+ const constant int64_t* idx_strides [[buffer(5)]],
19
+ const constant size_t& ndim [[buffer(6)]],
20
+ const constant int& axis [[buffer(7)]],
21
+ const constant int& out_axis_size [[buffer(8)]],
22
+ const constant size_t& upd_ax_stride [[buffer(9)]],
23
+ const constant size_t& idx_ax_stride [[buffer(10)]],
24
+ uint3 index [[thread_position_in_grid]],
25
+ uint3 grid_dim [[threads_per_grid]]) {
26
+ Op op;
27
+
28
+ LocT elem_idx = index.z * static_cast<LocT>(grid_dim.x);
29
+
30
+ LocT idx_loc = index.y * static_cast<LocT>(idx_ax_stride);
31
+ if (IdxC) {
32
+ idx_loc += elem_idx * grid_dim.y + index.x;
33
+ } else {
34
+ idx_loc += elem_to_loc<LocT>(elem_idx + index.x, shape, idx_strides, ndim);
35
+ }
36
+
37
+ auto idx_val = indices[idx_loc];
38
+ if (is_signed_v<IdxT>) {
39
+ idx_val = (idx_val < 0) ? idx_val + out_axis_size : idx_val;
40
+ }
41
+
42
+ LocT upd_idx = index.y * static_cast<LocT>(upd_ax_stride);
43
+ if (UpdC) {
44
+ upd_idx += elem_idx * grid_dim.y + index.x;
45
+ } else {
46
+ upd_idx += elem_to_loc<LocT>(elem_idx + index.x, shape, upd_strides, ndim);
47
+ }
48
+
49
+ LocT out_idx = elem_idx * static_cast<LocT>(out_axis_size) +
50
+ idx_val * grid_dim.x + index.x;
51
+ op.atomic_update(out, upd[upd_idx], out_idx);
52
+ }
@@ -0,0 +1,140 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ template <typename T, typename AccT = float, int N_READS = 4>
4
+ [[kernel]] void logsumexp(
5
+ const device T* in,
6
+ device T* out,
7
+ constant int& axis_size,
8
+ uint gid [[threadgroup_position_in_grid]],
9
+ uint _lid [[thread_position_in_threadgroup]],
10
+ uint simd_lane_id [[thread_index_in_simdgroup]],
11
+ uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
12
+ int lid = _lid;
13
+
14
+ constexpr int SIMD_SIZE = 32;
15
+
16
+ threadgroup AccT local_max[SIMD_SIZE];
17
+ threadgroup AccT local_normalizer[SIMD_SIZE];
18
+
19
+ AccT ld[N_READS];
20
+
21
+ in += gid * size_t(axis_size) + lid * N_READS;
22
+ if (lid * N_READS + N_READS <= axis_size) {
23
+ for (int i = 0; i < N_READS; i++) {
24
+ ld[i] = AccT(in[i]);
25
+ }
26
+ } else {
27
+ for (int i = 0; i < N_READS; i++) {
28
+ ld[i] =
29
+ ((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits<AccT>::min;
30
+ }
31
+ }
32
+ if (simd_group_id == 0) {
33
+ local_max[simd_lane_id] = Limits<AccT>::min;
34
+ local_normalizer[simd_lane_id] = 0;
35
+ }
36
+ threadgroup_barrier(mem_flags::mem_threadgroup);
37
+
38
+ // Get the max
39
+ AccT maxval = Limits<AccT>::finite_min;
40
+ for (int i = 0; i < N_READS; i++) {
41
+ maxval = (maxval < ld[i]) ? ld[i] : maxval;
42
+ }
43
+ maxval = simd_max(maxval);
44
+ if (simd_lane_id == 0) {
45
+ local_max[simd_group_id] = maxval;
46
+ }
47
+ threadgroup_barrier(mem_flags::mem_threadgroup);
48
+ if (simd_group_id == 0) {
49
+ maxval = simd_max(local_max[simd_lane_id]);
50
+ if (simd_lane_id == 0) {
51
+ local_max[0] = maxval;
52
+ }
53
+ }
54
+ threadgroup_barrier(mem_flags::mem_threadgroup);
55
+ maxval = local_max[0];
56
+
57
+ // Compute exp(x_i - maxval) and store the partial sums in local_normalizer
58
+ AccT normalizer = 0;
59
+ for (int i = 0; i < N_READS; i++) {
60
+ normalizer += fast::exp(ld[i] - maxval);
61
+ }
62
+ normalizer = simd_sum(normalizer);
63
+ if (simd_lane_id == 0) {
64
+ local_normalizer[simd_group_id] = normalizer;
65
+ }
66
+ threadgroup_barrier(mem_flags::mem_threadgroup);
67
+ if (simd_group_id == 0) {
68
+ normalizer = simd_sum(local_normalizer[simd_lane_id]);
69
+ if (simd_lane_id == 0) {
70
+ out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
71
+ }
72
+ }
73
+ }
74
+
75
+ template <typename T, typename AccT = float, int N_READS = 4>
76
+ [[kernel]] void logsumexp_looped(
77
+ const device T* in,
78
+ device T* out,
79
+ constant int& axis_size,
80
+ uint gid [[threadgroup_position_in_grid]],
81
+ uint lid [[thread_position_in_threadgroup]],
82
+ uint lsize [[threads_per_threadgroup]],
83
+ uint simd_lane_id [[thread_index_in_simdgroup]],
84
+ uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
85
+ in += gid * size_t(axis_size);
86
+
87
+ constexpr int SIMD_SIZE = 32;
88
+
89
+ threadgroup AccT local_max[SIMD_SIZE];
90
+ threadgroup AccT local_normalizer[SIMD_SIZE];
91
+
92
+ // Get the max and the normalizer in one go
93
+ AccT prevmax;
94
+ AccT maxval = Limits<AccT>::finite_min;
95
+ AccT normalizer = 0;
96
+ for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
97
+ r++) {
98
+ int offset = r * lsize * N_READS + lid * N_READS;
99
+ AccT vals[N_READS];
100
+ if (offset + N_READS <= axis_size) {
101
+ for (int i = 0; i < N_READS; i++) {
102
+ vals[i] = AccT(in[offset + i]);
103
+ }
104
+ } else {
105
+ for (int i = 0; i < N_READS; i++) {
106
+ vals[i] =
107
+ (offset + i < axis_size) ? AccT(in[offset + i]) : Limits<AccT>::min;
108
+ }
109
+ }
110
+ prevmax = maxval;
111
+ for (int i = 0; i < N_READS; i++) {
112
+ maxval = (maxval < vals[i]) ? vals[i] : maxval;
113
+ }
114
+ normalizer *= fast::exp(prevmax - maxval);
115
+ for (int i = 0; i < N_READS; i++) {
116
+ normalizer += fast::exp(vals[i] - maxval);
117
+ }
118
+ }
119
+ prevmax = maxval;
120
+ maxval = simd_max(maxval);
121
+ normalizer *= fast::exp(prevmax - maxval);
122
+ normalizer = simd_sum(normalizer);
123
+
124
+ prevmax = maxval;
125
+ if (simd_lane_id == 0) {
126
+ local_max[simd_group_id] = maxval;
127
+ }
128
+ threadgroup_barrier(mem_flags::mem_threadgroup);
129
+ maxval = simd_max(local_max[simd_lane_id]);
130
+ normalizer *= fast::exp(prevmax - maxval);
131
+ if (simd_lane_id == 0) {
132
+ local_normalizer[simd_group_id] = normalizer;
133
+ }
134
+ threadgroup_barrier(mem_flags::mem_threadgroup);
135
+ normalizer = simd_sum(local_normalizer[simd_lane_id]);
136
+
137
+ if (lid == 0) {
138
+ out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
139
+ }
140
+ }