mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,624 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #include <metal_common>
4
+
5
+ #include "mlx/backend/metal/kernels/fft/radix.h"
6
+
7
+ /* FFT helpers for reading and writing from/to device memory.
8
+
9
+ For many sizes, GPU FFTs are memory bandwidth bound so
10
+ read/write performance is important.
11
+
12
+ Where possible, we read 128 bits sequentially in each thread,
13
+ coalesced with accesses from adjacent threads for optimal performance.
14
+
15
+ We implement specialized reading/writing for:
16
+ - FFT
17
+ - RFFT
18
+ - IRFFT
19
+
20
+ Each with support for:
21
+ - Contiguous reads
22
+ - Padded reads
23
+ - Strided reads
24
+ */
25
+
26
+ #define MAX_RADIX 13
27
+
28
+ using namespace metal;
29
+
30
+ template <
31
+ typename in_T,
32
+ typename out_T,
33
+ int step = 0,
34
+ bool four_step_real = false>
35
+ struct ReadWriter {
36
+ const device in_T* in;
37
+ threadgroup float2* buf;
38
+ device out_T* out;
39
+ int n;
40
+ int batch_size;
41
+ int elems_per_thread;
42
+ uint3 elem;
43
+ uint3 grid;
44
+ int threads_per_tg;
45
+ bool inv;
46
+
47
+ // Used for strided access
48
+ int strided_device_idx = 0;
49
+ int strided_shared_idx = 0;
50
+
51
+ METAL_FUNC ReadWriter(
52
+ const device in_T* in_,
53
+ threadgroup float2* buf_,
54
+ device out_T* out_,
55
+ const short n_,
56
+ const int batch_size_,
57
+ const short elems_per_thread_,
58
+ const uint3 elem_,
59
+ const uint3 grid_,
60
+ const bool inv_)
61
+ : in(in_),
62
+ buf(buf_),
63
+ out(out_),
64
+ n(n_),
65
+ batch_size(batch_size_),
66
+ elems_per_thread(elems_per_thread_),
67
+ elem(elem_),
68
+ grid(grid_),
69
+ inv(inv_) {
70
+ // Account for padding on last threadgroup
71
+ threads_per_tg = elem.x == grid.x - 1
72
+ ? (batch_size - (grid.x - 1) * grid.y) * grid.z
73
+ : grid.y * grid.z;
74
+ }
75
+
76
+ // ifft(x) = 1/n * conj(fft(conj(x)))
77
+ METAL_FUNC float2 post_in(float2 elem) const {
78
+ return inv ? float2(elem.x, -elem.y) : elem;
79
+ }
80
+
81
+ // Handle float case for generic RFFT alg
82
+ METAL_FUNC float2 post_in(float elem) const {
83
+ return float2(elem, 0);
84
+ }
85
+
86
+ METAL_FUNC float2 pre_out(float2 elem) const {
87
+ return inv ? float2(elem.x / n, -elem.y / n) : elem;
88
+ }
89
+
90
+ METAL_FUNC float2 pre_out(float2 elem, int length) const {
91
+ return inv ? float2(elem.x / length, -elem.y / length) : elem;
92
+ }
93
+
94
+ METAL_FUNC bool out_of_bounds() const {
95
+ // Account for possible extra threadgroups
96
+ int grid_index = elem.x * grid.y + elem.y;
97
+ return grid_index >= batch_size;
98
+ }
99
+
100
+ METAL_FUNC void load() const {
101
+ size_t batch_idx = size_t(elem.x * grid.y) * n;
102
+ short tg_idx = elem.y * grid.z + elem.z;
103
+ short max_index = grid.y * n - 2;
104
+
105
+ // 2 complex64s = 128 bits
106
+ constexpr int read_width = 2;
107
+ for (short e = 0; e < (elems_per_thread / read_width); e++) {
108
+ short index = read_width * tg_idx + read_width * threads_per_tg * e;
109
+ index = metal::min(index, max_index);
110
+ // vectorized reads
111
+ buf[index] = post_in(in[batch_idx + index]);
112
+ buf[index + 1] = post_in(in[batch_idx + index + 1]);
113
+ }
114
+ max_index += 1;
115
+ if (elems_per_thread % 2 != 0) {
116
+ short index = tg_idx +
117
+ read_width * threads_per_tg * (elems_per_thread / read_width);
118
+ index = metal::min(index, max_index);
119
+ buf[index] = post_in(in[batch_idx + index]);
120
+ }
121
+ }
122
+
123
+ METAL_FUNC void write() const {
124
+ size_t batch_idx = size_t(elem.x * grid.y) * n;
125
+ short tg_idx = elem.y * grid.z + elem.z;
126
+ short max_index = grid.y * n - 2;
127
+
128
+ constexpr int read_width = 2;
129
+ for (short e = 0; e < (elems_per_thread / read_width); e++) {
130
+ short index = read_width * tg_idx + read_width * threads_per_tg * e;
131
+ index = metal::min(index, max_index);
132
+ // vectorized reads
133
+ out[batch_idx + index] = pre_out(buf[index]);
134
+ out[batch_idx + index + 1] = pre_out(buf[index + 1]);
135
+ }
136
+ max_index += 1;
137
+ if (elems_per_thread % 2 != 0) {
138
+ short index = tg_idx +
139
+ read_width * threads_per_tg * (elems_per_thread / read_width);
140
+ index = metal::min(index, max_index);
141
+ out[batch_idx + index] = pre_out(buf[index]);
142
+ }
143
+ }
144
+
145
+ // Padded IO for Bluestein's algorithm
146
+ METAL_FUNC void load_padded(int length, const device float2* w_k) const {
147
+ size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length;
148
+ int fft_idx = elem.z;
149
+ int m = grid.z;
150
+
151
+ threadgroup float2* seq_buf = buf + elem.y * n;
152
+ for (int e = 0; e < elems_per_thread; e++) {
153
+ int index = metal::min(fft_idx + e * m, n - 1);
154
+ if (index < length) {
155
+ float2 elem = post_in(in[batch_idx + index]);
156
+ seq_buf[index] = complex_mul(elem, w_k[index]);
157
+ } else {
158
+ seq_buf[index] = 0.0;
159
+ }
160
+ }
161
+ }
162
+
163
+ METAL_FUNC void write_padded(int length, const device float2* w_k) const {
164
+ size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length;
165
+ int fft_idx = elem.z;
166
+ int m = grid.z;
167
+ float2 inv_factor = {1.0f / n, -1.0f / n};
168
+
169
+ threadgroup float2* seq_buf = buf + elem.y * n;
170
+ for (int e = 0; e < elems_per_thread; e++) {
171
+ int index = metal::min(fft_idx + e * m, n - 1);
172
+ if (index < length) {
173
+ float2 elem = seq_buf[index + length - 1] * inv_factor;
174
+ out[batch_idx + index] = pre_out(complex_mul(elem, w_k[index]), length);
175
+ }
176
+ }
177
+ }
178
+
179
+ // Strided IO for four step FFT
180
+ METAL_FUNC void compute_strided_indices(int stride, int overall_n) {
181
+ // Use the batch threadgroup dimension to coalesce memory accesses:
182
+ // e.g. stride = 12
183
+ // device | shared mem
184
+ // 0 1 2 3 | 0 12 - -
185
+ // - - - - | 1 13 - -
186
+ // - - - - | 2 14 - -
187
+ // 12 13 14 15 | 3 15 - -
188
+ int coalesce_width = grid.y;
189
+ int tg_idx = elem.y * grid.z + elem.z;
190
+ int outer_batch_size = stride / coalesce_width;
191
+
192
+ int strided_batch_idx = (elem.x % outer_batch_size) * coalesce_width +
193
+ overall_n * (elem.x / outer_batch_size);
194
+ strided_device_idx = strided_batch_idx +
195
+ tg_idx / coalesce_width * elems_per_thread * stride +
196
+ tg_idx % coalesce_width;
197
+ strided_shared_idx = (tg_idx % coalesce_width) * n +
198
+ tg_idx / coalesce_width * elems_per_thread;
199
+ }
200
+
201
+ // Four Step FFT First Step
202
+ METAL_FUNC void load_strided(int stride, int overall_n) {
203
+ compute_strided_indices(stride, overall_n);
204
+ for (int e = 0; e < elems_per_thread; e++) {
205
+ buf[strided_shared_idx + e] =
206
+ post_in(in[strided_device_idx + e * stride]);
207
+ }
208
+ }
209
+
210
+ METAL_FUNC void write_strided(int stride, int overall_n) {
211
+ for (int e = 0; e < elems_per_thread; e++) {
212
+ float2 output = buf[strided_shared_idx + e];
213
+ int combined_idx = (strided_device_idx + e * stride) % overall_n;
214
+ int ij = (combined_idx / stride) * (combined_idx % stride);
215
+ // Apply four step twiddles at end of first step
216
+ float2 twiddle = get_twiddle(ij, overall_n);
217
+ out[strided_device_idx + e * stride] = complex_mul(output, twiddle);
218
+ }
219
+ }
220
+ };
221
+
222
+ // Four Step FFT Second Step
223
+ template <>
224
+ METAL_FUNC void ReadWriter<float2, float2, /*step=*/1>::load_strided(
225
+ int stride,
226
+ int overall_n) {
227
+ // Silence compiler warnings
228
+ (void)stride;
229
+ (void)overall_n;
230
+ // Don't invert between steps
231
+ bool default_inv = inv;
232
+ inv = false;
233
+ load();
234
+ inv = default_inv;
235
+ }
236
+
237
+ template <>
238
+ METAL_FUNC void ReadWriter<float2, float2, /*step=*/1>::write_strided(
239
+ int stride,
240
+ int overall_n) {
241
+ compute_strided_indices(stride, overall_n);
242
+ for (int e = 0; e < elems_per_thread; e++) {
243
+ float2 output = buf[strided_shared_idx + e];
244
+ out[strided_device_idx + e * stride] = pre_out(output, overall_n);
245
+ }
246
+ }
247
+
248
+ // For RFFT, we interleave batches of two real sequences into one complex one:
249
+ //
250
+ // z_k = x_k + j.y_k
251
+ // X_k = (Z_k + Z_(N-k)*) / 2
252
+ // Y_k = -j * ((Z_k - Z_(N-k)*) / 2)
253
+ //
254
+ // This roughly doubles the throughput over the regular FFT.
255
+ template <>
256
+ METAL_FUNC bool ReadWriter<float, float2>::out_of_bounds() const {
257
+ int grid_index = elem.x * grid.y + elem.y;
258
+ // We pack two sequences into one for RFFTs
259
+ return grid_index * 2 >= batch_size;
260
+ }
261
+
262
+ template <>
263
+ METAL_FUNC void ReadWriter<float, float2>::load() const {
264
+ size_t batch_idx = size_t(elem.x * grid.y) * n * 2 + elem.y * n * 2;
265
+ threadgroup float2* seq_buf = buf + elem.y * n;
266
+
267
+ // No out of bounds accesses on odd batch sizes
268
+ int grid_index = elem.x * grid.y + elem.y;
269
+ short next_in =
270
+ batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n;
271
+
272
+ short m = grid.z;
273
+ short fft_idx = elem.z;
274
+
275
+ for (int e = 0; e < elems_per_thread; e++) {
276
+ int index = metal::min(fft_idx + e * m, n - 1);
277
+ seq_buf[index].x = in[batch_idx + index];
278
+ seq_buf[index].y = in[batch_idx + index + next_in];
279
+ }
280
+ }
281
+
282
+ template <>
283
+ METAL_FUNC void ReadWriter<float, float2>::write() const {
284
+ short n_over_2 = (n / 2) + 1;
285
+
286
+ size_t batch_idx =
287
+ size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2;
288
+ threadgroup float2* seq_buf = buf + elem.y * n;
289
+
290
+ int grid_index = elem.x * grid.y + elem.y;
291
+ short next_out =
292
+ batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n_over_2;
293
+
294
+ float2 conj = {1, -1};
295
+ float2 minus_j = {0, -1};
296
+
297
+ short m = grid.z;
298
+ short fft_idx = elem.z;
299
+
300
+ for (int e = 0; e < elems_per_thread / 2 + 1; e++) {
301
+ int index = metal::min(fft_idx + e * m, n_over_2 - 1);
302
+ // x_0 = z_0.real
303
+ // y_0 = z_0.imag
304
+ if (index == 0) {
305
+ out[batch_idx + index] = {seq_buf[index].x, 0};
306
+ out[batch_idx + index + next_out] = {seq_buf[index].y, 0};
307
+ } else {
308
+ float2 x_k = seq_buf[index];
309
+ float2 x_n_minus_k = seq_buf[n - index] * conj;
310
+ out[batch_idx + index] = (x_k + x_n_minus_k) / 2;
311
+ out[batch_idx + index + next_out] =
312
+ complex_mul(((x_k - x_n_minus_k) / 2), minus_j);
313
+ }
314
+ }
315
+ }
316
+
317
+ template <>
318
+ METAL_FUNC void ReadWriter<float, float2>::load_padded(
319
+ int length,
320
+ const device float2* w_k) const {
321
+ size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2;
322
+ threadgroup float2* seq_buf = buf + elem.y * n;
323
+
324
+ // No out of bounds accesses on odd batch sizes
325
+ int grid_index = elem.x * grid.y + elem.y;
326
+ short next_in =
327
+ batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : length;
328
+
329
+ short m = grid.z;
330
+ short fft_idx = elem.z;
331
+
332
+ for (int e = 0; e < elems_per_thread; e++) {
333
+ int index = metal::min(fft_idx + e * m, n - 1);
334
+ if (index < length) {
335
+ float2 elem =
336
+ float2(in[batch_idx + index], in[batch_idx + index + next_in]);
337
+ seq_buf[index] = complex_mul(elem, w_k[index]);
338
+ } else {
339
+ seq_buf[index] = 0;
340
+ }
341
+ }
342
+ }
343
+
344
+ template <>
345
+ METAL_FUNC void ReadWriter<float, float2>::write_padded(
346
+ int length,
347
+ const device float2* w_k) const {
348
+ int length_over_2 = (length / 2) + 1;
349
+ size_t batch_idx =
350
+ size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2;
351
+ threadgroup float2* seq_buf = buf + elem.y * n + length - 1;
352
+
353
+ int grid_index = elem.x * grid.y + elem.y;
354
+ short next_out = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1
355
+ ? 0
356
+ : length_over_2;
357
+
358
+ float2 conj = {1, -1};
359
+ float2 inv_factor = {1.0f / n, -1.0f / n};
360
+ float2 minus_j = {0, -1};
361
+
362
+ short m = grid.z;
363
+ short fft_idx = elem.z;
364
+
365
+ for (int e = 0; e < elems_per_thread / 2 + 1; e++) {
366
+ int index = metal::min(fft_idx + e * m, length_over_2 - 1);
367
+ // x_0 = z_0.real
368
+ // y_0 = z_0.imag
369
+ if (index == 0) {
370
+ float2 elem = complex_mul(w_k[index], seq_buf[index] * inv_factor);
371
+ out[batch_idx + index] = float2(elem.x, 0);
372
+ out[batch_idx + index + next_out] = float2(elem.y, 0);
373
+ } else {
374
+ float2 x_k = complex_mul(w_k[index], seq_buf[index] * inv_factor);
375
+ float2 x_n_minus_k = complex_mul(
376
+ w_k[length - index], seq_buf[length - index] * inv_factor);
377
+ x_n_minus_k *= conj;
378
+ // w_k should happen before this extraction
379
+ out[batch_idx + index] = (x_k + x_n_minus_k) / 2;
380
+ out[batch_idx + index + next_out] =
381
+ complex_mul(((x_k - x_n_minus_k) / 2), minus_j);
382
+ }
383
+ }
384
+ }
385
+
386
+ // For IRFFT, we do the opposite
387
+ //
388
+ // Z_k = X_k + j.Y_k
389
+ // x_k = Re(Z_k)
390
+ // Y_k = Imag(Z_k)
391
+ template <>
392
+ METAL_FUNC bool ReadWriter<float2, float>::out_of_bounds() const {
393
+ int grid_index = elem.x * grid.y + elem.y;
394
+ // We pack two sequences into one for IRFFTs
395
+ return grid_index * 2 >= batch_size;
396
+ }
397
+
398
+ template <>
399
+ METAL_FUNC void ReadWriter<float2, float>::load() const {
400
+ short n_over_2 = (n / 2) + 1;
401
+ size_t batch_idx =
402
+ size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2;
403
+ threadgroup float2* seq_buf = buf + elem.y * n;
404
+
405
+ // No out of bounds accesses on odd batch sizes
406
+ int grid_index = elem.x * grid.y + elem.y;
407
+ short next_in =
408
+ batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n_over_2;
409
+
410
+ short m = grid.z;
411
+ short fft_idx = elem.z;
412
+
413
+ float2 conj = {1, -1};
414
+ float2 plus_j = {0, 1};
415
+
416
+ for (int t = 0; t < elems_per_thread / 2 + 1; t++) {
417
+ int index = metal::min(fft_idx + t * m, n_over_2 - 1);
418
+ float2 x = in[batch_idx + index];
419
+ float2 y = in[batch_idx + index + next_in];
420
+ // NumPy forces first input to be real
421
+ bool first_val = index == 0;
422
+ // NumPy forces last input on even irffts to be real
423
+ bool last_val = n % 2 == 0 && index == n_over_2 - 1;
424
+ if (first_val || last_val) {
425
+ x = float2(x.x, 0);
426
+ y = float2(y.x, 0);
427
+ }
428
+ seq_buf[index] = x + complex_mul(y, plus_j);
429
+ seq_buf[index].y = -seq_buf[index].y;
430
+ if (index > 0 && !last_val) {
431
+ seq_buf[n - index] = (x * conj) + complex_mul(y * conj, plus_j);
432
+ seq_buf[n - index].y = -seq_buf[n - index].y;
433
+ }
434
+ }
435
+ }
436
+
437
+ template <>
438
+ METAL_FUNC void ReadWriter<float2, float>::write() const {
439
+ int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2;
440
+ threadgroup float2* seq_buf = buf + elem.y * n;
441
+
442
+ int grid_index = elem.x * grid.y + elem.y;
443
+ short next_out =
444
+ batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n;
445
+
446
+ short m = grid.z;
447
+ short fft_idx = elem.z;
448
+
449
+ for (int e = 0; e < elems_per_thread; e++) {
450
+ int index = metal::min(fft_idx + e * m, n - 1);
451
+ out[batch_idx + index] = seq_buf[index].x / n;
452
+ out[batch_idx + index + next_out] = seq_buf[index].y / -n;
453
+ }
454
+ }
455
+
456
+ template <>
457
+ METAL_FUNC void ReadWriter<float2, float>::load_padded(
458
+ int length,
459
+ const device float2* w_k) const {
460
+ int n_over_2 = (n / 2) + 1;
461
+ int length_over_2 = (length / 2) + 1;
462
+
463
+ size_t batch_idx =
464
+ size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2;
465
+ threadgroup float2* seq_buf = buf + elem.y * n;
466
+
467
+ // No out of bounds accesses on odd batch sizes
468
+ int grid_index = elem.x * grid.y + elem.y;
469
+ short next_in = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1
470
+ ? 0
471
+ : length_over_2;
472
+
473
+ short m = grid.z;
474
+ short fft_idx = elem.z;
475
+
476
+ float2 conj = {1, -1};
477
+ float2 plus_j = {0, 1};
478
+
479
+ for (int t = 0; t < elems_per_thread / 2 + 1; t++) {
480
+ int index = metal::min(fft_idx + t * m, n_over_2 - 1);
481
+ float2 x = in[batch_idx + index];
482
+ float2 y = in[batch_idx + index + next_in];
483
+ if (index < length_over_2) {
484
+ bool last_val = length % 2 == 0 && index == length_over_2 - 1;
485
+ if (last_val) {
486
+ x = float2(x.x, 0);
487
+ y = float2(y.x, 0);
488
+ }
489
+ float2 elem1 = x + complex_mul(y, plus_j);
490
+ seq_buf[index] = complex_mul(elem1 * conj, w_k[index]);
491
+ if (index > 0 && !last_val) {
492
+ float2 elem2 = (x * conj) + complex_mul(y * conj, plus_j);
493
+ seq_buf[length - index] =
494
+ complex_mul(elem2 * conj, w_k[length - index]);
495
+ }
496
+ } else {
497
+ short pad_index = metal::min(length + (index - length_over_2) * 2, n - 2);
498
+ seq_buf[pad_index] = 0;
499
+ seq_buf[pad_index + 1] = 0;
500
+ }
501
+ }
502
+ }
503
+
504
+ template <>
505
+ METAL_FUNC void ReadWriter<float2, float>::write_padded(
506
+ int length,
507
+ const device float2* w_k) const {
508
+ size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2;
509
+ threadgroup float2* seq_buf = buf + elem.y * n + length - 1;
510
+
511
+ int grid_index = elem.x * grid.y + elem.y;
512
+ short next_out =
513
+ batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : length;
514
+
515
+ short m = grid.z;
516
+ short fft_idx = elem.z;
517
+
518
+ float2 inv_factor = {1.0f / n, -1.0f / n};
519
+ for (int e = 0; e < elems_per_thread; e++) {
520
+ int index = fft_idx + e * m;
521
+ if (index < length) {
522
+ float2 output = complex_mul(seq_buf[index] * inv_factor, w_k[index]);
523
+ out[batch_idx + index] = output.x / length;
524
+ out[batch_idx + index + next_out] = output.y / -length;
525
+ }
526
+ }
527
+ }
528
+
529
+ // Four Step RFFT
530
+ template <>
531
+ METAL_FUNC void
532
+ ReadWriter<float2, float2, /*step=*/1, /*real=*/true>::load_strided(
533
+ int stride,
534
+ int overall_n) {
535
+ // Silence compiler warnings
536
+ (void)stride;
537
+ (void)overall_n;
538
+ // Don't invert between steps
539
+ bool default_inv = inv;
540
+ inv = false;
541
+ load();
542
+ inv = default_inv;
543
+ }
544
+
545
+ template <>
546
+ METAL_FUNC void
547
+ ReadWriter<float2, float2, /*step=*/1, /*real=*/true>::write_strided(
548
+ int stride,
549
+ int overall_n) {
550
+ int overall_n_over_2 = overall_n / 2 + 1;
551
+ int coalesce_width = grid.y;
552
+ int tg_idx = elem.y * grid.z + elem.z;
553
+ int outer_batch_size = stride / coalesce_width;
554
+
555
+ int strided_batch_idx = (elem.x % outer_batch_size) * coalesce_width +
556
+ overall_n_over_2 * (elem.x / outer_batch_size);
557
+ strided_device_idx = strided_batch_idx +
558
+ tg_idx / coalesce_width * elems_per_thread / 2 * stride +
559
+ tg_idx % coalesce_width;
560
+ strided_shared_idx = (tg_idx % coalesce_width) * n +
561
+ tg_idx / coalesce_width * elems_per_thread / 2;
562
+ for (int e = 0; e < elems_per_thread / 2; e++) {
563
+ float2 output = buf[strided_shared_idx + e];
564
+ out[strided_device_idx + e * stride] = output;
565
+ }
566
+
567
+ // Add on n/2 + 1 element
568
+ if (tg_idx == 0 && elem.x % outer_batch_size == 0) {
569
+ out[strided_batch_idx + overall_n / 2] = buf[n / 2];
570
+ }
571
+ }
572
+
573
+ // Four Step IRFFT
574
+ template <>
575
+ METAL_FUNC void
576
+ ReadWriter<float2, float2, /*step=*/0, /*real=*/true>::load_strided(
577
+ int stride,
578
+ int overall_n) {
579
+ int overall_n_over_2 = overall_n / 2 + 1;
580
+ auto conj = float2(1, -1);
581
+
582
+ compute_strided_indices(stride, overall_n);
583
+ // Translate indices in terms of N - k
584
+ for (int e = 0; e < elems_per_thread; e++) {
585
+ int device_idx = strided_device_idx + e * stride;
586
+ int overall_batch = device_idx / overall_n;
587
+ int overall_index = device_idx % overall_n;
588
+ if (overall_index < overall_n_over_2) {
589
+ device_idx -= overall_batch * (overall_n - overall_n_over_2);
590
+ buf[strided_shared_idx + e] = in[device_idx] * conj;
591
+ } else {
592
+ int conj_idx = overall_n - overall_index;
593
+ device_idx = overall_batch * overall_n_over_2 + conj_idx;
594
+ buf[strided_shared_idx + e] = in[device_idx];
595
+ }
596
+ }
597
+ }
598
+
599
+ template <>
600
+ METAL_FUNC void
601
+ ReadWriter<float2, float, /*step=*/1, /*real=*/true>::load_strided(
602
+ int stride,
603
+ int overall_n) {
604
+ // Silence compiler warnings
605
+ (void)stride;
606
+ (void)overall_n;
607
+ bool default_inv = inv;
608
+ inv = false;
609
+ load();
610
+ inv = default_inv;
611
+ }
612
+
613
+ template <>
614
+ METAL_FUNC void
615
+ ReadWriter<float2, float, /*step=*/1, /*real=*/true>::write_strided(
616
+ int stride,
617
+ int overall_n) {
618
+ compute_strided_indices(stride, overall_n);
619
+
620
+ for (int e = 0; e < elems_per_thread; e++) {
621
+ out[strided_device_idx + e * stride] =
622
+ pre_out(buf[strided_shared_idx + e], overall_n).x;
623
+ }
624
+ }