mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,328 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ /* Radix kernels
4
+
5
+ We provide optimized, single threaded Radix codelets
6
+ for n=2,3,4,5,6,7,8,10,11,12,13.
7
+
8
+ For n=2,3,4,5,6 we hand write the codelets.
9
+ For n=8,10,12 we combine smaller codelets.
10
+ For n=7,11,13 we use Rader's algorithm which decomposes
11
+ them into (n-1)=6,10,12 codelets. */
12
+
13
+ #pragma once
14
+
15
+ #include <metal_common>
16
+ #include <metal_math>
17
+ #include <metal_stdlib>
18
+
19
+ METAL_FUNC float2 complex_mul(float2 a, float2 b) {
20
+ return float2(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x);
21
+ }
22
+
23
+ // Complex mul followed by conjugate
24
+ METAL_FUNC float2 complex_mul_conj(float2 a, float2 b) {
25
+ return float2(a.x * b.x - a.y * b.y, -a.x * b.y - a.y * b.x);
26
+ }
27
+
28
+ // Compute an FFT twiddle factor
29
+ METAL_FUNC float2 get_twiddle(int k, int p) {
30
+ float theta = -2.0f * k * M_PI_F / p;
31
+
32
+ float2 twiddle = {metal::fast::cos(theta), metal::fast::sin(theta)};
33
+ return twiddle;
34
+ }
35
+
36
+ METAL_FUNC void radix2(thread float2* x, thread float2* y) {
37
+ y[0] = x[0] + x[1];
38
+ y[1] = x[0] - x[1];
39
+ }
40
+
41
+ METAL_FUNC void radix3(thread float2* x, thread float2* y) {
42
+ float pi_2_3 = -0.8660254037844387;
43
+
44
+ float2 a_1 = x[1] + x[2];
45
+ float2 a_2 = x[1] - x[2];
46
+
47
+ y[0] = x[0] + a_1;
48
+ float2 b_1 = x[0] - 0.5 * a_1;
49
+ float2 b_2 = pi_2_3 * a_2;
50
+
51
+ float2 b_2_j = {-b_2.y, b_2.x};
52
+ y[1] = b_1 + b_2_j;
53
+ y[2] = b_1 - b_2_j;
54
+ }
55
+
56
+ METAL_FUNC void radix4(thread float2* x, thread float2* y) {
57
+ float2 z_0 = x[0] + x[2];
58
+ float2 z_1 = x[0] - x[2];
59
+ float2 z_2 = x[1] + x[3];
60
+ float2 z_3 = x[1] - x[3];
61
+ float2 z_3_i = {z_3.y, -z_3.x};
62
+
63
+ y[0] = z_0 + z_2;
64
+ y[1] = z_1 + z_3_i;
65
+ y[2] = z_0 - z_2;
66
+ y[3] = z_1 - z_3_i;
67
+ }
68
+
69
+ METAL_FUNC void radix5(thread float2* x, thread float2* y) {
70
+ float2 root_5_4 = 0.5590169943749475;
71
+ float2 sin_2pi_5 = 0.9510565162951535;
72
+ float2 sin_1pi_5 = 0.5877852522924731;
73
+
74
+ float2 a_1 = x[1] + x[4];
75
+ float2 a_2 = x[2] + x[3];
76
+ float2 a_3 = x[1] - x[4];
77
+ float2 a_4 = x[2] - x[3];
78
+
79
+ float2 a_5 = a_1 + a_2;
80
+ float2 a_6 = root_5_4 * (a_1 - a_2);
81
+ float2 a_7 = x[0] - a_5 / 4;
82
+ float2 a_8 = a_7 + a_6;
83
+ float2 a_9 = a_7 - a_6;
84
+ float2 a_10 = sin_2pi_5 * a_3 + sin_1pi_5 * a_4;
85
+ float2 a_11 = sin_1pi_5 * a_3 - sin_2pi_5 * a_4;
86
+ float2 a_10_j = {a_10.y, -a_10.x};
87
+ float2 a_11_j = {a_11.y, -a_11.x};
88
+
89
+ y[0] = x[0] + a_5;
90
+ y[1] = a_8 + a_10_j;
91
+ y[2] = a_9 + a_11_j;
92
+ y[3] = a_9 - a_11_j;
93
+ y[4] = a_8 - a_10_j;
94
+ }
95
+
96
+ METAL_FUNC void radix6(thread float2* x, thread float2* y) {
97
+ float sin_pi_3 = 0.8660254037844387;
98
+ float2 a_1 = x[2] + x[4];
99
+ float2 a_2 = x[0] - a_1 / 2;
100
+ float2 a_3 = sin_pi_3 * (x[2] - x[4]);
101
+ float2 a_4 = x[5] + x[1];
102
+ float2 a_5 = x[3] - a_4 / 2;
103
+ float2 a_6 = sin_pi_3 * (x[5] - x[1]);
104
+ float2 a_7 = x[0] + a_1;
105
+
106
+ float2 a_3_i = {a_3.y, -a_3.x};
107
+ float2 a_6_i = {a_6.y, -a_6.x};
108
+ float2 a_8 = a_2 + a_3_i;
109
+ float2 a_9 = a_2 - a_3_i;
110
+ float2 a_10 = x[3] + a_4;
111
+ float2 a_11 = a_5 + a_6_i;
112
+ float2 a_12 = a_5 - a_6_i;
113
+
114
+ y[0] = a_7 + a_10;
115
+ y[1] = a_8 - a_11;
116
+ y[2] = a_9 + a_12;
117
+ y[3] = a_7 - a_10;
118
+ y[4] = a_8 + a_11;
119
+ y[5] = a_9 - a_12;
120
+ }
121
+
122
+ METAL_FUNC void radix7(thread float2* x, thread float2* y) {
123
+ // Rader's algorithm
124
+ float2 inv = {1 / 6.0, -1 / 6.0};
125
+
126
+ // fft
127
+ float2 in1[6] = {x[1], x[3], x[2], x[6], x[4], x[5]};
128
+ radix6(in1, y + 1);
129
+
130
+ y[0] = y[1] + x[0];
131
+
132
+ // b_q
133
+ y[1] = complex_mul_conj(y[1], float2(-1, 0));
134
+ y[2] = complex_mul_conj(y[2], float2(2.44013336, -1.02261879));
135
+ y[3] = complex_mul_conj(y[3], float2(2.37046941, -1.17510629));
136
+ y[4] = complex_mul_conj(y[4], float2(0, -2.64575131));
137
+ y[5] = complex_mul_conj(y[5], float2(2.37046941, 1.17510629));
138
+ y[6] = complex_mul_conj(y[6], float2(-2.44013336, -1.02261879));
139
+
140
+ // ifft
141
+ radix6(y + 1, x + 1);
142
+
143
+ y[1] = x[1] * inv + x[0];
144
+ y[5] = x[2] * inv + x[0];
145
+ y[4] = x[3] * inv + x[0];
146
+ y[6] = x[4] * inv + x[0];
147
+ y[2] = x[5] * inv + x[0];
148
+ y[3] = x[6] * inv + x[0];
149
+ }
150
+
151
+ METAL_FUNC void radix8(thread float2* x, thread float2* y) {
152
+ float cos_pi_4 = 0.7071067811865476;
153
+ float2 w_0 = {cos_pi_4, -cos_pi_4};
154
+ float2 w_1 = {-cos_pi_4, -cos_pi_4};
155
+ float2 temp[8] = {x[0], x[2], x[4], x[6], x[1], x[3], x[5], x[7]};
156
+ radix4(temp, x);
157
+ radix4(temp + 4, x + 4);
158
+
159
+ y[0] = x[0] + x[4];
160
+ y[4] = x[0] - x[4];
161
+ float2 x_5 = complex_mul(x[5], w_0);
162
+ y[1] = x[1] + x_5;
163
+ y[5] = x[1] - x_5;
164
+ float2 x_6 = {x[6].y, -x[6].x};
165
+ y[2] = x[2] + x_6;
166
+ y[6] = x[2] - x_6;
167
+ float2 x_7 = complex_mul(x[7], w_1);
168
+ y[3] = x[3] + x_7;
169
+ y[7] = x[3] - x_7;
170
+ }
171
+
172
+ template <bool raders_perm>
173
+ METAL_FUNC void radix10(thread float2* x, thread float2* y) {
174
+ float2 w[4];
175
+ w[0] = {0.8090169943749475, -0.5877852522924731};
176
+ w[1] = {0.30901699437494745, -0.9510565162951535};
177
+ w[2] = {-w[1].x, w[1].y};
178
+ w[3] = {-w[0].x, w[0].y};
179
+
180
+ if (raders_perm) {
181
+ float2 temp[10] = {
182
+ x[0], x[3], x[4], x[8], x[2], x[1], x[7], x[9], x[6], x[5]};
183
+ radix5(temp, x);
184
+ radix5(temp + 5, x + 5);
185
+ } else {
186
+ float2 temp[10] = {
187
+ x[0], x[2], x[4], x[6], x[8], x[1], x[3], x[5], x[7], x[9]};
188
+ radix5(temp, x);
189
+ radix5(temp + 5, x + 5);
190
+ }
191
+
192
+ y[0] = x[0] + x[5];
193
+ y[5] = x[0] - x[5];
194
+ for (int t = 1; t < 5; t++) {
195
+ float2 a = complex_mul(x[t + 5], w[t - 1]);
196
+ y[t] = x[t] + a;
197
+ y[t + 5] = x[t] - a;
198
+ }
199
+ }
200
+
201
+ METAL_FUNC void radix11(thread float2* x, thread float2* y) {
202
+ // Raders Algorithm
203
+ float2 inv = {1 / 10.0, -1 / 10.0};
204
+
205
+ // fft
206
+ radix10<true>(x + 1, y + 1);
207
+
208
+ y[0] = y[1] + x[0];
209
+
210
+ // b_q
211
+ y[1] = complex_mul_conj(y[1], float2(-1, 0));
212
+ y[2] = complex_mul_conj(y[2], float2(0.955301878, -3.17606649));
213
+ y[3] = complex_mul_conj(y[3], float2(2.63610556, 2.01269656));
214
+ y[4] = complex_mul_conj(y[4], float2(2.54127802, 2.13117479));
215
+ y[5] = complex_mul_conj(y[5], float2(2.07016210, 2.59122150));
216
+ y[6] = complex_mul_conj(y[6], float2(0, -3.31662479));
217
+ y[7] = complex_mul_conj(y[7], float2(2.07016210, -2.59122150));
218
+ y[8] = complex_mul_conj(y[8], float2(-2.54127802, 2.13117479));
219
+ y[9] = complex_mul_conj(y[9], float2(2.63610556, -2.01269656));
220
+ y[10] = complex_mul_conj(y[10], float2(-0.955301878, -3.17606649));
221
+
222
+ // ifft
223
+ radix10<false>(y + 1, x + 1);
224
+
225
+ y[1] = x[1] * inv + x[0];
226
+ y[6] = x[2] * inv + x[0];
227
+ y[3] = x[3] * inv + x[0];
228
+ y[7] = x[4] * inv + x[0];
229
+ y[9] = x[5] * inv + x[0];
230
+ y[10] = x[6] * inv + x[0];
231
+ y[5] = x[7] * inv + x[0];
232
+ y[8] = x[8] * inv + x[0];
233
+ y[4] = x[9] * inv + x[0];
234
+ y[2] = x[10] * inv + x[0];
235
+ }
236
+
237
+ template <bool raders_perm>
238
+ METAL_FUNC void radix12(thread float2* x, thread float2* y) {
239
+ float2 w[6];
240
+ float sin_pi_3 = 0.8660254037844387;
241
+ w[0] = {sin_pi_3, -0.5};
242
+ w[1] = {0.5, -sin_pi_3};
243
+ w[2] = {0, -1};
244
+ w[3] = {-0.5, -sin_pi_3};
245
+ w[4] = {-sin_pi_3, -0.5};
246
+
247
+ if (raders_perm) {
248
+ float2 temp[12] = {
249
+ x[0],
250
+ x[3],
251
+ x[2],
252
+ x[11],
253
+ x[8],
254
+ x[9],
255
+ x[1],
256
+ x[7],
257
+ x[5],
258
+ x[10],
259
+ x[4],
260
+ x[6]};
261
+ radix6(temp, x);
262
+ radix6(temp + 6, x + 6);
263
+ } else {
264
+ float2 temp[12] = {
265
+ x[0],
266
+ x[2],
267
+ x[4],
268
+ x[6],
269
+ x[8],
270
+ x[10],
271
+ x[1],
272
+ x[3],
273
+ x[5],
274
+ x[7],
275
+ x[9],
276
+ x[11]};
277
+ radix6(temp, x);
278
+ radix6(temp + 6, x + 6);
279
+ }
280
+
281
+ y[0] = x[0] + x[6];
282
+ y[6] = x[0] - x[6];
283
+ for (int t = 1; t < 6; t++) {
284
+ float2 a = complex_mul(x[t + 6], w[t - 1]);
285
+ y[t] = x[t] + a;
286
+ y[t + 6] = x[t] - a;
287
+ }
288
+ }
289
+
290
+ METAL_FUNC void radix13(thread float2* x, thread float2* y) {
291
+ // Raders Algorithm
292
+ float2 inv = {1 / 12.0, -1 / 12.0};
293
+
294
+ // fft
295
+ radix12<true>(x + 1, y + 1);
296
+
297
+ y[0] = y[1] + x[0];
298
+
299
+ // b_q
300
+ y[1] = complex_mul_conj(y[1], float2(-1, 0));
301
+ y[2] = complex_mul_conj(y[2], float2(3.07497206, -1.88269669));
302
+ y[3] = complex_mul_conj(y[3], float2(3.09912468, 1.84266823));
303
+ y[4] = complex_mul_conj(y[4], float2(3.45084438, -1.04483161));
304
+ y[5] = complex_mul_conj(y[5], float2(0.91083583, 3.48860690));
305
+ y[6] = complex_mul_conj(y[6], float2(-3.60286363, 0.139189267));
306
+ y[7] = complex_mul_conj(y[7], float2(3.60555128, 0));
307
+ y[8] = complex_mul_conj(y[8], float2(3.60286363, 0.139189267));
308
+ y[9] = complex_mul_conj(y[9], float2(0.91083583, -3.48860690));
309
+ y[10] = complex_mul_conj(y[10], float2(-3.45084438, -1.04483161));
310
+ y[11] = complex_mul_conj(y[11], float2(3.09912468, -1.84266823));
311
+ y[12] = complex_mul_conj(y[12], float2(-3.07497206, -1.88269669));
312
+
313
+ // ifft
314
+ radix12<false>(y + 1, x + 1);
315
+
316
+ y[1] = x[1] * inv + x[0];
317
+ y[7] = x[2] * inv + x[0];
318
+ y[10] = x[3] * inv + x[0];
319
+ y[5] = x[4] * inv + x[0];
320
+ y[9] = x[5] * inv + x[0];
321
+ y[11] = x[6] * inv + x[0];
322
+ y[12] = x[7] * inv + x[0];
323
+ y[6] = x[8] * inv + x[0];
324
+ y[3] = x[9] * inv + x[0];
325
+ y[8] = x[10] * inv + x[0];
326
+ y[4] = x[11] * inv + x[0];
327
+ y[2] = x[12] * inv + x[0];
328
+ }