mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,486 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ // Metal FFT using Stockham's algorithm
4
+ //
5
+ // References:
6
+ // - VkFFT (https://github.com/DTolm/VkFFT)
7
+ // - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html)
8
+
9
+ #include <metal_common>
10
+
11
+ #include "mlx/backend/metal/kernels/fft/radix.h"
12
+ #include "mlx/backend/metal/kernels/fft/readwrite.h"
13
+ #include "mlx/backend/metal/kernels/steel/defines.h"
14
+
15
+ using namespace metal;
16
+
17
+ #define MAX_RADIX 13
18
+ // Reached when elems_per_thread_ = 6, max_radix = 13
19
+ // and some threads have to do 3 radix 6s requiring 18 float2s.
20
+ #define MAX_OUTPUT_SIZE 18
21
+
22
+ // Specialize for a particular value of N at runtime
23
+ STEEL_CONST bool inv_ [[function_constant(0)]];
24
+ STEEL_CONST bool is_power_of_2_ [[function_constant(1)]];
25
+ STEEL_CONST int elems_per_thread_ [[function_constant(2)]];
26
+ // rader_m = n / rader_n
27
+ STEEL_CONST int rader_m_ [[function_constant(3)]];
28
+ // Stockham steps
29
+ STEEL_CONST int radix_13_steps_ [[function_constant(4)]];
30
+ STEEL_CONST int radix_11_steps_ [[function_constant(5)]];
31
+ STEEL_CONST int radix_8_steps_ [[function_constant(6)]];
32
+ STEEL_CONST int radix_7_steps_ [[function_constant(7)]];
33
+ STEEL_CONST int radix_6_steps_ [[function_constant(8)]];
34
+ STEEL_CONST int radix_5_steps_ [[function_constant(9)]];
35
+ STEEL_CONST int radix_4_steps_ [[function_constant(10)]];
36
+ STEEL_CONST int radix_3_steps_ [[function_constant(11)]];
37
+ STEEL_CONST int radix_2_steps_ [[function_constant(12)]];
38
+ // Rader steps
39
+ STEEL_CONST int rader_13_steps_ [[function_constant(13)]];
40
+ STEEL_CONST int rader_11_steps_ [[function_constant(14)]];
41
+ STEEL_CONST int rader_8_steps_ [[function_constant(15)]];
42
+ STEEL_CONST int rader_7_steps_ [[function_constant(16)]];
43
+ STEEL_CONST int rader_6_steps_ [[function_constant(17)]];
44
+ STEEL_CONST int rader_5_steps_ [[function_constant(18)]];
45
+ STEEL_CONST int rader_4_steps_ [[function_constant(19)]];
46
+ STEEL_CONST int rader_3_steps_ [[function_constant(20)]];
47
+ STEEL_CONST int rader_2_steps_ [[function_constant(21)]];
48
+
49
+ // See "radix.h" for radix codelets
50
+ typedef void (*RadixFunc)(thread float2*, thread float2*);
51
+
52
+ // Perform a single radix n butterfly with appropriate twiddles
53
+ template <int radix, RadixFunc radix_func>
54
+ METAL_FUNC void radix_butterfly(
55
+ int i,
56
+ int p,
57
+ thread float2* x,
58
+ thread short* indices,
59
+ thread float2* y) {
60
+ // i: the index in the overall DFT that we're processing.
61
+ // p: the size of the DFTs we're merging at this step.
62
+ // m: how many threads are working on this DFT.
63
+ int k, j;
64
+
65
+ // Use faster bitwise operations when working with powers of two
66
+ constexpr bool radix_p_2 = (radix & (radix - 1)) == 0;
67
+ if (radix_p_2 && is_power_of_2_) {
68
+ constexpr short power = __builtin_ctz(radix);
69
+ k = i & (p - 1);
70
+ j = ((i - k) << power) + k;
71
+ } else {
72
+ k = i % p;
73
+ j = (i / p) * radix * p + k;
74
+ }
75
+
76
+ // Apply twiddles
77
+ if (p > 1) {
78
+ float2 twiddle_1 = get_twiddle(k, radix * p);
79
+ float2 twiddle = twiddle_1;
80
+ x[1] = complex_mul(x[1], twiddle);
81
+
82
+ STEEL_PRAGMA_UNROLL
83
+ for (int t = 2; t < radix; t++) {
84
+ twiddle = complex_mul(twiddle, twiddle_1);
85
+ x[t] = complex_mul(x[t], twiddle);
86
+ }
87
+ }
88
+
89
+ radix_func(x, y);
90
+
91
+ STEEL_PRAGMA_UNROLL
92
+ for (int t = 0; t < radix; t++) {
93
+ indices[t] = j + t * p;
94
+ }
95
+ }
96
+
97
+ // Perform all the radix steps required for a
98
+ // particular radix size n.
99
+ template <int radix, RadixFunc radix_func>
100
+ METAL_FUNC void radix_n_steps(
101
+ int i,
102
+ thread int* p,
103
+ int m,
104
+ int n,
105
+ int num_steps,
106
+ thread float2* inputs,
107
+ thread short* indices,
108
+ thread float2* values,
109
+ threadgroup float2* buf) {
110
+ int m_r = n / radix;
111
+ // When combining different sized radices, we have to do
112
+ // multiple butterflies in a single thread.
113
+ // E.g. n = 28 = 4 * 7
114
+ // 4 threads, 7 elems_per_thread
115
+ // All threads do 1 radix7 butterfly.
116
+ // 3 threads do 2 radix4 butterflies.
117
+ // 1 thread does 1 radix4 butterfly.
118
+ int max_radices_per_thread = (elems_per_thread_ + radix - 1) / radix;
119
+
120
+ int index = 0;
121
+ int r_index = 0;
122
+ for (int s = 0; s < num_steps; s++) {
123
+ for (int t = 0; t < max_radices_per_thread; t++) {
124
+ index = i + t * m;
125
+ if (index < m_r) {
126
+ for (int r = 0; r < radix; r++) {
127
+ inputs[r] = buf[index + r * m_r];
128
+ }
129
+ radix_butterfly<radix, radix_func>(
130
+ index, *p, inputs, indices + t * radix, values + t * radix);
131
+ }
132
+ }
133
+
134
+ // Wait until all threads have read their inputs into thread local mem
135
+ threadgroup_barrier(mem_flags::mem_threadgroup);
136
+
137
+ for (int t = 0; t < max_radices_per_thread; t++) {
138
+ index = i + t * m;
139
+ if (index < m_r) {
140
+ for (int r = 0; r < radix; r++) {
141
+ r_index = t * radix + r;
142
+ buf[indices[r_index]] = values[r_index];
143
+ }
144
+ }
145
+ }
146
+
147
+ // Wait until all threads have written back to threadgroup mem
148
+ threadgroup_barrier(mem_flags::mem_threadgroup);
149
+ *p *= radix;
150
+ }
151
+ }
152
+
153
+ #define RADIX_STEP(radix, radix_func, num_steps) \
154
+ radix_n_steps<radix, radix_func>( \
155
+ fft_idx, p, m, n, num_steps, inputs, indices, values, buf);
156
+
157
+ template <bool rader = false>
158
+ METAL_FUNC void
159
+ perform_fft(int fft_idx, thread int* p, int m, int n, threadgroup float2* buf) {
160
+ float2 inputs[MAX_RADIX];
161
+ short indices[MAX_OUTPUT_SIZE];
162
+ float2 values[MAX_OUTPUT_SIZE];
163
+
164
+ RADIX_STEP(2, radix2, rader ? rader_2_steps_ : radix_2_steps_);
165
+ RADIX_STEP(3, radix3, rader ? rader_3_steps_ : radix_3_steps_);
166
+ RADIX_STEP(4, radix4, rader ? rader_4_steps_ : radix_4_steps_);
167
+ RADIX_STEP(5, radix5, rader ? rader_5_steps_ : radix_5_steps_);
168
+ RADIX_STEP(6, radix6, rader ? rader_6_steps_ : radix_6_steps_);
169
+ RADIX_STEP(7, radix7, rader ? rader_7_steps_ : radix_7_steps_);
170
+ RADIX_STEP(8, radix8, rader ? rader_8_steps_ : radix_8_steps_);
171
+ RADIX_STEP(11, radix11, rader ? rader_11_steps_ : radix_11_steps_);
172
+ RADIX_STEP(13, radix13, rader ? rader_13_steps_ : radix_13_steps_);
173
+ }
174
+
175
+ // Each FFT is computed entirely in shared GPU memory.
176
+ //
177
+ // N is decomposed into radix-n DFTs:
178
+ // e.g. 128 = 2 * 4 * 4 * 4
179
+ template <int tg_mem_size, typename in_T, typename out_T>
180
+ [[kernel]] void fft(
181
+ const device in_T* in [[buffer(0)]],
182
+ device out_T* out [[buffer(1)]],
183
+ constant const int& n,
184
+ constant const int& batch_size,
185
+ uint3 elem [[thread_position_in_grid]],
186
+ uint3 grid [[threads_per_grid]]) {
187
+ threadgroup float2 shared_in[tg_mem_size];
188
+
189
+ thread ReadWriter<in_T, out_T> read_writer = ReadWriter<in_T, out_T>(
190
+ in,
191
+ &shared_in[0],
192
+ out,
193
+ n,
194
+ batch_size,
195
+ elems_per_thread_,
196
+ elem,
197
+ grid,
198
+ inv_);
199
+
200
+ if (read_writer.out_of_bounds()) {
201
+ return;
202
+ };
203
+ read_writer.load();
204
+
205
+ threadgroup_barrier(mem_flags::mem_threadgroup);
206
+
207
+ int p = 1;
208
+ int fft_idx = elem.z; // Thread index in DFT
209
+ int m = grid.z; // Threads per DFT
210
+ int tg_idx = elem.y * n; // Index of this DFT in threadgroup
211
+ threadgroup float2* buf = &shared_in[tg_idx];
212
+
213
+ perform_fft(fft_idx, &p, m, n, buf);
214
+
215
+ read_writer.write();
216
+ }
217
+
218
+ template <int tg_mem_size, typename in_T, typename out_T>
219
+ [[kernel]] void rader_fft(
220
+ const device in_T* in [[buffer(0)]],
221
+ device out_T* out [[buffer(1)]],
222
+ const device float2* raders_b_q [[buffer(2)]],
223
+ const device short* raders_g_q [[buffer(3)]],
224
+ const device short* raders_g_minus_q [[buffer(4)]],
225
+ constant const int& n,
226
+ constant const int& batch_size,
227
+ constant const int& rader_n,
228
+ uint3 elem [[thread_position_in_grid]],
229
+ uint3 grid [[threads_per_grid]]) {
230
+ // Use Rader's algorithm to compute fast FFTs
231
+ // when a prime factor `p` of `n` is greater than 13 but
232
+ // has `p - 1` Stockham decomposable into to prime factors <= 13.
233
+ //
234
+ // E.g. n = 102
235
+ // = 2 * 3 * 17
236
+ // . = 2 * 3 * RADER(16)
237
+ // . = 2 * 3 * RADER(4 * 4)
238
+ //
239
+ // In numpy:
240
+ // x_perm = x[g_q]
241
+ // y = np.fft.fft(x_perm) * b_q
242
+ // z = np.fft.ifft(y) + x[0]
243
+ // out = z[g_minus_q]
244
+ // out[0] = x[1:].sum()
245
+ //
246
+ // Where the g_q and g_minus_q are permutations formed
247
+ // by the group under multiplicative modulo N using the
248
+ // primitive root of N and b_q is a constant.
249
+ // See https://en.wikipedia.org/wiki/Rader%27s_FFT_algorithm
250
+ //
251
+ // Rader's uses fewer operations than Bluestein's and so
252
+ // is more accurate. It's also faster in most cases.
253
+ threadgroup float2 shared_in[tg_mem_size];
254
+
255
+ thread ReadWriter<in_T, out_T> read_writer = ReadWriter<in_T, out_T>(
256
+ in,
257
+ &shared_in[0],
258
+ out,
259
+ n,
260
+ batch_size,
261
+ elems_per_thread_,
262
+ elem,
263
+ grid,
264
+ inv_);
265
+
266
+ if (read_writer.out_of_bounds()) {
267
+ return;
268
+ };
269
+ read_writer.load();
270
+
271
+ threadgroup_barrier(mem_flags::mem_threadgroup);
272
+
273
+ // The number of the threads we're using for each DFT
274
+ int m = grid.z;
275
+
276
+ int fft_idx = elem.z;
277
+ int tg_idx = elem.y * n;
278
+ threadgroup float2* buf = &shared_in[tg_idx];
279
+
280
+ // rader_m = n / rader_n;
281
+ int rader_m = rader_m_;
282
+
283
+ // We have to load two x_0s for each thread since sometimes
284
+ // elems_per_thread_ crosses a boundary.
285
+ // E.g. with n = 34, rader_n = 17, elems_per_thread_ = 4
286
+ // 0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4 5 5 5 5 6 6 6 6 7 7 7 7 8 8
287
+ // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
288
+ short x_0_index =
289
+ metal::min(fft_idx * elems_per_thread_ / (rader_n - 1), rader_m - 1);
290
+ float2 x_0[2] = {buf[x_0_index], buf[x_0_index + 1]};
291
+
292
+ // Do the Rader permutation in shared memory
293
+ float2 temp[MAX_RADIX];
294
+ int max_index = n - rader_m - 1;
295
+ for (int e = 0; e < elems_per_thread_; e++) {
296
+ short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
297
+ short g_q = raders_g_q[index / rader_m];
298
+ temp[e] = buf[rader_m + (g_q - 1) * rader_m + index % rader_m];
299
+ }
300
+
301
+ threadgroup_barrier(mem_flags::mem_threadgroup);
302
+
303
+ for (int e = 0; e < elems_per_thread_; e++) {
304
+ short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
305
+ buf[index + rader_m] = temp[e];
306
+ }
307
+
308
+ threadgroup_barrier(mem_flags::mem_threadgroup);
309
+
310
+ // Rader FFT on x[rader_m:]
311
+ int p = 1;
312
+ perform_fft</*rader=*/true>(fft_idx, &p, m, n - rader_m, buf + rader_m);
313
+
314
+ // x_1 + ... + x_n is computed for us in the first FFT step so
315
+ // we save it in the first rader_m indices of the array for later.
316
+ int x_sum_index = metal::min(fft_idx, rader_m - 1);
317
+ buf[x_sum_index] = buf[rader_m + x_sum_index * (rader_n - 1)];
318
+
319
+ float2 inv = {1.0f, -1.0f};
320
+ for (int e = 0; e < elems_per_thread_; e++) {
321
+ short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
322
+ short interleaved_index =
323
+ index / rader_m + (index % rader_m) * (rader_n - 1);
324
+ temp[e] = complex_mul(
325
+ buf[rader_m + interleaved_index],
326
+ raders_b_q[interleaved_index % (rader_n - 1)]);
327
+ }
328
+
329
+ threadgroup_barrier(mem_flags::mem_threadgroup);
330
+
331
+ for (int e = 0; e < elems_per_thread_; e++) {
332
+ short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
333
+ buf[rader_m + index] = temp[e] * inv;
334
+ }
335
+
336
+ threadgroup_barrier(mem_flags::mem_threadgroup);
337
+
338
+ // Rader IFFT on x[rader_m:]
339
+ p = 1;
340
+ perform_fft</*rader=*/true>(fft_idx, &p, m, n - rader_m, buf + rader_m);
341
+
342
+ float2 rader_inv_factor = {1.0f / (rader_n - 1), -1.0f / (rader_n - 1)};
343
+
344
+ for (int e = 0; e < elems_per_thread_; e++) {
345
+ short index = metal::min(fft_idx * elems_per_thread_ + e, n - rader_m - 1);
346
+ short diff_index = index / (rader_n - 1) - x_0_index;
347
+ temp[e] = buf[rader_m + index] * rader_inv_factor + x_0[diff_index];
348
+ }
349
+
350
+ // Use the sum of elements that was computed in the first FFT
351
+ float2 x_sum = buf[x_0_index] + x_0[0];
352
+
353
+ threadgroup_barrier(mem_flags::mem_threadgroup);
354
+
355
+ for (int e = 0; e < elems_per_thread_; e++) {
356
+ short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
357
+ short g_q_index = index % (rader_n - 1);
358
+ short g_q = raders_g_minus_q[g_q_index];
359
+ short out_index = index - g_q_index + g_q + (index / (rader_n - 1));
360
+ buf[out_index] = temp[e];
361
+ }
362
+
363
+ buf[x_0_index * rader_n] = x_sum;
364
+
365
+ threadgroup_barrier(mem_flags::mem_threadgroup);
366
+
367
+ p = rader_n;
368
+ perform_fft(fft_idx, &p, m, n, buf);
369
+
370
+ read_writer.write();
371
+ }
372
+
373
+ template <int tg_mem_size, typename in_T, typename out_T>
374
+ [[kernel]] void bluestein_fft(
375
+ const device in_T* in [[buffer(0)]],
376
+ device out_T* out [[buffer(1)]],
377
+ const device float2* w_q [[buffer(2)]],
378
+ const device float2* w_k [[buffer(3)]],
379
+ constant const int& length,
380
+ constant const int& n,
381
+ constant const int& batch_size,
382
+ uint3 elem [[thread_position_in_grid]],
383
+ uint3 grid [[threads_per_grid]]) {
384
+ // Computes arbitrary length FFTs with Bluestein's algorithm
385
+ //
386
+ // In numpy:
387
+ // bluestein_n = next_power_of_2(2*n - 1)
388
+ // out = w_k * np.fft.ifft(np.fft.fft(w_k * in, bluestein_n) * w_q)
389
+ //
390
+ // Where w_k and w_q are precomputed on CPU in high precision as:
391
+ // w_k = np.exp(-1j * np.pi / n * (np.arange(-n + 1, n) ** 2))
392
+ // w_q = np.fft.fft(1/w_k[-n:])
393
+ threadgroup float2 shared_in[tg_mem_size];
394
+
395
+ thread ReadWriter<in_T, out_T> read_writer = ReadWriter<in_T, out_T>(
396
+ in,
397
+ &shared_in[0],
398
+ out,
399
+ n,
400
+ batch_size,
401
+ elems_per_thread_,
402
+ elem,
403
+ grid,
404
+ inv_);
405
+
406
+ if (read_writer.out_of_bounds()) {
407
+ return;
408
+ };
409
+ read_writer.load_padded(length, w_k);
410
+
411
+ threadgroup_barrier(mem_flags::mem_threadgroup);
412
+
413
+ int p = 1;
414
+ int fft_idx = elem.z; // Thread index in DFT
415
+ int m = grid.z; // Threads per DFT
416
+ int tg_idx = elem.y * n; // Index of this DFT in threadgroup
417
+ threadgroup float2* buf = &shared_in[tg_idx];
418
+
419
+ // fft
420
+ perform_fft(fft_idx, &p, m, n, buf);
421
+
422
+ float2 inv = float2(1.0f, -1.0f);
423
+ for (int t = 0; t < elems_per_thread_; t++) {
424
+ int index = fft_idx + t * m;
425
+ buf[index] = complex_mul(buf[index], w_q[index]) * inv;
426
+ }
427
+
428
+ threadgroup_barrier(mem_flags::mem_threadgroup);
429
+
430
+ // ifft
431
+ p = 1;
432
+ perform_fft(fft_idx, &p, m, n, buf);
433
+
434
+ read_writer.write_padded(length, w_k);
435
+ }
436
+
437
+ template <
438
+ int tg_mem_size,
439
+ typename in_T,
440
+ typename out_T,
441
+ int step,
442
+ bool real = false>
443
+ [[kernel]] void four_step_fft(
444
+ const device in_T* in [[buffer(0)]],
445
+ device out_T* out [[buffer(1)]],
446
+ constant const int& n1,
447
+ constant const int& n2,
448
+ constant const int& batch_size,
449
+ uint3 elem [[thread_position_in_grid]],
450
+ uint3 grid [[threads_per_grid]]) {
451
+ // Fast four step FFT implementation for powers of 2.
452
+ int overall_n = n1 * n2;
453
+ int n = step == 0 ? n1 : n2;
454
+ int stride = step == 0 ? n2 : n1;
455
+
456
+ // The number of the threads we're using for each DFT
457
+ int m = grid.z;
458
+ int fft_idx = elem.z;
459
+
460
+ threadgroup float2 shared_in[tg_mem_size];
461
+ threadgroup float2* buf = &shared_in[elem.y * n];
462
+
463
+ using read_writer_t = ReadWriter<in_T, out_T, step, real>;
464
+ read_writer_t read_writer = read_writer_t(
465
+ in,
466
+ &shared_in[0],
467
+ out,
468
+ n,
469
+ batch_size,
470
+ elems_per_thread_,
471
+ elem,
472
+ grid,
473
+ inv_);
474
+
475
+ if (read_writer.out_of_bounds()) {
476
+ return;
477
+ };
478
+ read_writer.load_strided(stride, overall_n);
479
+
480
+ threadgroup_barrier(mem_flags::mem_threadgroup);
481
+
482
+ int p = 1;
483
+ perform_fft(fft_idx, &p, m, n, buf);
484
+
485
+ read_writer.write_strided(stride, overall_n);
486
+ }
@@ -0,0 +1,59 @@
1
+ #pragma once
2
+
3
+ constexpr constant static float FP4_LUT[16] = {
4
+ +0.0f,
5
+ +0.5f,
6
+ +1.0f,
7
+ +1.5f,
8
+ +2.0f,
9
+ +3.0f,
10
+ +4.0f,
11
+ +6.0f,
12
+ -0.0f,
13
+ -0.5f,
14
+ -1.0f,
15
+ -1.5f,
16
+ -2.0f,
17
+ -3.0f,
18
+ -4.0f,
19
+ -6.0f};
20
+
21
+ struct fp4_e2m1 {
22
+ fp4_e2m1(float x) {
23
+ if (metal::isnan(x)) {
24
+ bits = 0x7;
25
+ return;
26
+ }
27
+
28
+ const uint8_t sign_bit = (metal::signbit(x)) ? 0x8 : 0x0;
29
+ x = metal::abs(x);
30
+
31
+ if (x > 5.0f) {
32
+ bits = 0x7;
33
+ } else if (x >= 3.5f) {
34
+ bits = 0x6;
35
+ } else if (x > 2.5f) {
36
+ bits = 0x5;
37
+ } else if (x >= 1.75f) {
38
+ bits = 0x4;
39
+ } else if (x > 1.25f) {
40
+ bits = 0x3;
41
+ } else if (x >= 0.75f) {
42
+ bits = 0x2;
43
+ } else if (x > 0.25f) {
44
+ bits = 0x1;
45
+ } else {
46
+ bits = 0x0;
47
+ }
48
+ bits |= sign_bit;
49
+ }
50
+
51
+ operator float() {
52
+ half converted = as_type<half>(ushort((bits & 7) << 9));
53
+ converted *= 16384.0;
54
+ converted = bits & 8 ? -converted : converted;
55
+ return converted;
56
+ }
57
+
58
+ uint8_t bits;
59
+ };
@@ -0,0 +1,82 @@
1
+ #pragma once
2
+
3
+ struct fp8_e4m3 {
4
+ template <typename T>
5
+ fp8_e4m3(T f) {
6
+ // From PyTorch
7
+ // https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L148
8
+ uint32_t fp8_max = 543 << 21;
9
+ uint32_t denorm_mask = 141 << 23;
10
+ uint32_t f_bits = as_type<uint32_t>(static_cast<float>(f));
11
+ uint32_t sign = f_bits & 0x80000000;
12
+ f_bits ^= sign;
13
+ if (f_bits >= fp8_max) {
14
+ // Default behavior saturates to min/max
15
+ bits = 0x7E;
16
+ } else {
17
+ if (f_bits < (121 << 23)) {
18
+ f_bits = as_type<uint32_t>(
19
+ as_type<float>(f_bits) + as_type<float>(denorm_mask));
20
+ bits = static_cast<uint8_t>(f_bits - denorm_mask);
21
+ } else {
22
+ // resulting mantissa is odd
23
+ uint8_t mant_odd = (f_bits >> 20) & 1;
24
+ f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF;
25
+ f_bits += mant_odd;
26
+ bits = static_cast<uint8_t>(f_bits >> 20);
27
+ }
28
+ }
29
+ bits |= static_cast<uint8_t>(sign >> 24);
30
+ }
31
+
32
+ operator float() {
33
+ // From PyTorch:
34
+ // https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L46
35
+ uint32_t w = static_cast<uint32_t>(bits) << 24;
36
+ uint32_t sign = w & 0x80000000;
37
+ uint32_t nonsign = w & 0x7FFFFFFF;
38
+
39
+ uint32_t renorm_shift = metal::clz(nonsign);
40
+ renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
41
+
42
+ int32_t inf_nan_mask =
43
+ (static_cast<int32_t>(nonsign + 0x01000000) >> 8) & 0x7F800000;
44
+ int32_t zero_mask = static_cast<int32_t>(nonsign - 1) >> 31;
45
+ uint32_t result = sign |
46
+ ((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
47
+ inf_nan_mask) &
48
+ ~zero_mask);
49
+ return as_type<float>(result);
50
+ }
51
+
52
+ uint8_t bits;
53
+ };
54
+
55
+ struct fp8_e8m0 {
56
+ fp8_e8m0(float x) {
57
+ if (!metal::isfinite(x)) {
58
+ bits = 0xFF;
59
+ return;
60
+ }
61
+ if (x < 0.0f) {
62
+ bits = 0x00;
63
+ return;
64
+ }
65
+ float le = metal::log2(x);
66
+ int n = int(metal::round(le));
67
+
68
+ n = n < -127 ? -127 : n;
69
+ n = n > 127 ? 127 : n;
70
+ bits = static_cast<uint8_t>(n + 127);
71
+ }
72
+
73
+ operator bfloat16_t() {
74
+ uint16_t out = (bits == 0 ? 0x40 : (static_cast<uint16_t>(bits) << 7));
75
+ return as_type<bfloat16_t>(out);
76
+ }
77
+ operator float() {
78
+ return static_cast<float>(this->operator bfloat16_t());
79
+ }
80
+
81
+ uint8_t bits;
82
+ };