mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,624 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#include <metal_common>
|
|
4
|
+
|
|
5
|
+
#include "mlx/backend/metal/kernels/fft/radix.h"
|
|
6
|
+
|
|
7
|
+
/* FFT helpers for reading and writing from/to device memory.
|
|
8
|
+
|
|
9
|
+
For many sizes, GPU FFTs are memory bandwidth bound so
|
|
10
|
+
read/write performance is important.
|
|
11
|
+
|
|
12
|
+
Where possible, we read 128 bits sequentially in each thread,
|
|
13
|
+
coalesced with accesses from adjacent threads for optimal performance.
|
|
14
|
+
|
|
15
|
+
We implement specialized reading/writing for:
|
|
16
|
+
- FFT
|
|
17
|
+
- RFFT
|
|
18
|
+
- IRFFT
|
|
19
|
+
|
|
20
|
+
Each with support for:
|
|
21
|
+
- Contiguous reads
|
|
22
|
+
- Padded reads
|
|
23
|
+
- Strided reads
|
|
24
|
+
*/
|
|
25
|
+
|
|
26
|
+
#define MAX_RADIX 13
|
|
27
|
+
|
|
28
|
+
using namespace metal;
|
|
29
|
+
|
|
30
|
+
template <
|
|
31
|
+
typename in_T,
|
|
32
|
+
typename out_T,
|
|
33
|
+
int step = 0,
|
|
34
|
+
bool four_step_real = false>
|
|
35
|
+
struct ReadWriter {
|
|
36
|
+
const device in_T* in;
|
|
37
|
+
threadgroup float2* buf;
|
|
38
|
+
device out_T* out;
|
|
39
|
+
int n;
|
|
40
|
+
int batch_size;
|
|
41
|
+
int elems_per_thread;
|
|
42
|
+
uint3 elem;
|
|
43
|
+
uint3 grid;
|
|
44
|
+
int threads_per_tg;
|
|
45
|
+
bool inv;
|
|
46
|
+
|
|
47
|
+
// Used for strided access
|
|
48
|
+
int strided_device_idx = 0;
|
|
49
|
+
int strided_shared_idx = 0;
|
|
50
|
+
|
|
51
|
+
METAL_FUNC ReadWriter(
|
|
52
|
+
const device in_T* in_,
|
|
53
|
+
threadgroup float2* buf_,
|
|
54
|
+
device out_T* out_,
|
|
55
|
+
const short n_,
|
|
56
|
+
const int batch_size_,
|
|
57
|
+
const short elems_per_thread_,
|
|
58
|
+
const uint3 elem_,
|
|
59
|
+
const uint3 grid_,
|
|
60
|
+
const bool inv_)
|
|
61
|
+
: in(in_),
|
|
62
|
+
buf(buf_),
|
|
63
|
+
out(out_),
|
|
64
|
+
n(n_),
|
|
65
|
+
batch_size(batch_size_),
|
|
66
|
+
elems_per_thread(elems_per_thread_),
|
|
67
|
+
elem(elem_),
|
|
68
|
+
grid(grid_),
|
|
69
|
+
inv(inv_) {
|
|
70
|
+
// Account for padding on last threadgroup
|
|
71
|
+
threads_per_tg = elem.x == grid.x - 1
|
|
72
|
+
? (batch_size - (grid.x - 1) * grid.y) * grid.z
|
|
73
|
+
: grid.y * grid.z;
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
// ifft(x) = 1/n * conj(fft(conj(x)))
|
|
77
|
+
METAL_FUNC float2 post_in(float2 elem) const {
|
|
78
|
+
return inv ? float2(elem.x, -elem.y) : elem;
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
// Handle float case for generic RFFT alg
|
|
82
|
+
METAL_FUNC float2 post_in(float elem) const {
|
|
83
|
+
return float2(elem, 0);
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
METAL_FUNC float2 pre_out(float2 elem) const {
|
|
87
|
+
return inv ? float2(elem.x / n, -elem.y / n) : elem;
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
METAL_FUNC float2 pre_out(float2 elem, int length) const {
|
|
91
|
+
return inv ? float2(elem.x / length, -elem.y / length) : elem;
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
METAL_FUNC bool out_of_bounds() const {
|
|
95
|
+
// Account for possible extra threadgroups
|
|
96
|
+
int grid_index = elem.x * grid.y + elem.y;
|
|
97
|
+
return grid_index >= batch_size;
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
METAL_FUNC void load() const {
|
|
101
|
+
size_t batch_idx = size_t(elem.x * grid.y) * n;
|
|
102
|
+
short tg_idx = elem.y * grid.z + elem.z;
|
|
103
|
+
short max_index = grid.y * n - 2;
|
|
104
|
+
|
|
105
|
+
// 2 complex64s = 128 bits
|
|
106
|
+
constexpr int read_width = 2;
|
|
107
|
+
for (short e = 0; e < (elems_per_thread / read_width); e++) {
|
|
108
|
+
short index = read_width * tg_idx + read_width * threads_per_tg * e;
|
|
109
|
+
index = metal::min(index, max_index);
|
|
110
|
+
// vectorized reads
|
|
111
|
+
buf[index] = post_in(in[batch_idx + index]);
|
|
112
|
+
buf[index + 1] = post_in(in[batch_idx + index + 1]);
|
|
113
|
+
}
|
|
114
|
+
max_index += 1;
|
|
115
|
+
if (elems_per_thread % 2 != 0) {
|
|
116
|
+
short index = tg_idx +
|
|
117
|
+
read_width * threads_per_tg * (elems_per_thread / read_width);
|
|
118
|
+
index = metal::min(index, max_index);
|
|
119
|
+
buf[index] = post_in(in[batch_idx + index]);
|
|
120
|
+
}
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
METAL_FUNC void write() const {
|
|
124
|
+
size_t batch_idx = size_t(elem.x * grid.y) * n;
|
|
125
|
+
short tg_idx = elem.y * grid.z + elem.z;
|
|
126
|
+
short max_index = grid.y * n - 2;
|
|
127
|
+
|
|
128
|
+
constexpr int read_width = 2;
|
|
129
|
+
for (short e = 0; e < (elems_per_thread / read_width); e++) {
|
|
130
|
+
short index = read_width * tg_idx + read_width * threads_per_tg * e;
|
|
131
|
+
index = metal::min(index, max_index);
|
|
132
|
+
// vectorized reads
|
|
133
|
+
out[batch_idx + index] = pre_out(buf[index]);
|
|
134
|
+
out[batch_idx + index + 1] = pre_out(buf[index + 1]);
|
|
135
|
+
}
|
|
136
|
+
max_index += 1;
|
|
137
|
+
if (elems_per_thread % 2 != 0) {
|
|
138
|
+
short index = tg_idx +
|
|
139
|
+
read_width * threads_per_tg * (elems_per_thread / read_width);
|
|
140
|
+
index = metal::min(index, max_index);
|
|
141
|
+
out[batch_idx + index] = pre_out(buf[index]);
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
// Padded IO for Bluestein's algorithm
|
|
146
|
+
METAL_FUNC void load_padded(int length, const device float2* w_k) const {
|
|
147
|
+
size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length;
|
|
148
|
+
int fft_idx = elem.z;
|
|
149
|
+
int m = grid.z;
|
|
150
|
+
|
|
151
|
+
threadgroup float2* seq_buf = buf + elem.y * n;
|
|
152
|
+
for (int e = 0; e < elems_per_thread; e++) {
|
|
153
|
+
int index = metal::min(fft_idx + e * m, n - 1);
|
|
154
|
+
if (index < length) {
|
|
155
|
+
float2 elem = post_in(in[batch_idx + index]);
|
|
156
|
+
seq_buf[index] = complex_mul(elem, w_k[index]);
|
|
157
|
+
} else {
|
|
158
|
+
seq_buf[index] = 0.0;
|
|
159
|
+
}
|
|
160
|
+
}
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
METAL_FUNC void write_padded(int length, const device float2* w_k) const {
|
|
164
|
+
size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length;
|
|
165
|
+
int fft_idx = elem.z;
|
|
166
|
+
int m = grid.z;
|
|
167
|
+
float2 inv_factor = {1.0f / n, -1.0f / n};
|
|
168
|
+
|
|
169
|
+
threadgroup float2* seq_buf = buf + elem.y * n;
|
|
170
|
+
for (int e = 0; e < elems_per_thread; e++) {
|
|
171
|
+
int index = metal::min(fft_idx + e * m, n - 1);
|
|
172
|
+
if (index < length) {
|
|
173
|
+
float2 elem = seq_buf[index + length - 1] * inv_factor;
|
|
174
|
+
out[batch_idx + index] = pre_out(complex_mul(elem, w_k[index]), length);
|
|
175
|
+
}
|
|
176
|
+
}
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
// Strided IO for four step FFT
|
|
180
|
+
METAL_FUNC void compute_strided_indices(int stride, int overall_n) {
|
|
181
|
+
// Use the batch threadgroup dimension to coalesce memory accesses:
|
|
182
|
+
// e.g. stride = 12
|
|
183
|
+
// device | shared mem
|
|
184
|
+
// 0 1 2 3 | 0 12 - -
|
|
185
|
+
// - - - - | 1 13 - -
|
|
186
|
+
// - - - - | 2 14 - -
|
|
187
|
+
// 12 13 14 15 | 3 15 - -
|
|
188
|
+
int coalesce_width = grid.y;
|
|
189
|
+
int tg_idx = elem.y * grid.z + elem.z;
|
|
190
|
+
int outer_batch_size = stride / coalesce_width;
|
|
191
|
+
|
|
192
|
+
int strided_batch_idx = (elem.x % outer_batch_size) * coalesce_width +
|
|
193
|
+
overall_n * (elem.x / outer_batch_size);
|
|
194
|
+
strided_device_idx = strided_batch_idx +
|
|
195
|
+
tg_idx / coalesce_width * elems_per_thread * stride +
|
|
196
|
+
tg_idx % coalesce_width;
|
|
197
|
+
strided_shared_idx = (tg_idx % coalesce_width) * n +
|
|
198
|
+
tg_idx / coalesce_width * elems_per_thread;
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
// Four Step FFT First Step
|
|
202
|
+
METAL_FUNC void load_strided(int stride, int overall_n) {
|
|
203
|
+
compute_strided_indices(stride, overall_n);
|
|
204
|
+
for (int e = 0; e < elems_per_thread; e++) {
|
|
205
|
+
buf[strided_shared_idx + e] =
|
|
206
|
+
post_in(in[strided_device_idx + e * stride]);
|
|
207
|
+
}
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
METAL_FUNC void write_strided(int stride, int overall_n) {
|
|
211
|
+
for (int e = 0; e < elems_per_thread; e++) {
|
|
212
|
+
float2 output = buf[strided_shared_idx + e];
|
|
213
|
+
int combined_idx = (strided_device_idx + e * stride) % overall_n;
|
|
214
|
+
int ij = (combined_idx / stride) * (combined_idx % stride);
|
|
215
|
+
// Apply four step twiddles at end of first step
|
|
216
|
+
float2 twiddle = get_twiddle(ij, overall_n);
|
|
217
|
+
out[strided_device_idx + e * stride] = complex_mul(output, twiddle);
|
|
218
|
+
}
|
|
219
|
+
}
|
|
220
|
+
};
|
|
221
|
+
|
|
222
|
+
// Four Step FFT Second Step
|
|
223
|
+
template <>
|
|
224
|
+
METAL_FUNC void ReadWriter<float2, float2, /*step=*/1>::load_strided(
|
|
225
|
+
int stride,
|
|
226
|
+
int overall_n) {
|
|
227
|
+
// Silence compiler warnings
|
|
228
|
+
(void)stride;
|
|
229
|
+
(void)overall_n;
|
|
230
|
+
// Don't invert between steps
|
|
231
|
+
bool default_inv = inv;
|
|
232
|
+
inv = false;
|
|
233
|
+
load();
|
|
234
|
+
inv = default_inv;
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
template <>
|
|
238
|
+
METAL_FUNC void ReadWriter<float2, float2, /*step=*/1>::write_strided(
|
|
239
|
+
int stride,
|
|
240
|
+
int overall_n) {
|
|
241
|
+
compute_strided_indices(stride, overall_n);
|
|
242
|
+
for (int e = 0; e < elems_per_thread; e++) {
|
|
243
|
+
float2 output = buf[strided_shared_idx + e];
|
|
244
|
+
out[strided_device_idx + e * stride] = pre_out(output, overall_n);
|
|
245
|
+
}
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
// For RFFT, we interleave batches of two real sequences into one complex one:
|
|
249
|
+
//
|
|
250
|
+
// z_k = x_k + j.y_k
|
|
251
|
+
// X_k = (Z_k + Z_(N-k)*) / 2
|
|
252
|
+
// Y_k = -j * ((Z_k - Z_(N-k)*) / 2)
|
|
253
|
+
//
|
|
254
|
+
// This roughly doubles the throughput over the regular FFT.
|
|
255
|
+
template <>
|
|
256
|
+
METAL_FUNC bool ReadWriter<float, float2>::out_of_bounds() const {
|
|
257
|
+
int grid_index = elem.x * grid.y + elem.y;
|
|
258
|
+
// We pack two sequences into one for RFFTs
|
|
259
|
+
return grid_index * 2 >= batch_size;
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
template <>
|
|
263
|
+
METAL_FUNC void ReadWriter<float, float2>::load() const {
|
|
264
|
+
size_t batch_idx = size_t(elem.x * grid.y) * n * 2 + elem.y * n * 2;
|
|
265
|
+
threadgroup float2* seq_buf = buf + elem.y * n;
|
|
266
|
+
|
|
267
|
+
// No out of bounds accesses on odd batch sizes
|
|
268
|
+
int grid_index = elem.x * grid.y + elem.y;
|
|
269
|
+
short next_in =
|
|
270
|
+
batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n;
|
|
271
|
+
|
|
272
|
+
short m = grid.z;
|
|
273
|
+
short fft_idx = elem.z;
|
|
274
|
+
|
|
275
|
+
for (int e = 0; e < elems_per_thread; e++) {
|
|
276
|
+
int index = metal::min(fft_idx + e * m, n - 1);
|
|
277
|
+
seq_buf[index].x = in[batch_idx + index];
|
|
278
|
+
seq_buf[index].y = in[batch_idx + index + next_in];
|
|
279
|
+
}
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
template <>
|
|
283
|
+
METAL_FUNC void ReadWriter<float, float2>::write() const {
|
|
284
|
+
short n_over_2 = (n / 2) + 1;
|
|
285
|
+
|
|
286
|
+
size_t batch_idx =
|
|
287
|
+
size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2;
|
|
288
|
+
threadgroup float2* seq_buf = buf + elem.y * n;
|
|
289
|
+
|
|
290
|
+
int grid_index = elem.x * grid.y + elem.y;
|
|
291
|
+
short next_out =
|
|
292
|
+
batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n_over_2;
|
|
293
|
+
|
|
294
|
+
float2 conj = {1, -1};
|
|
295
|
+
float2 minus_j = {0, -1};
|
|
296
|
+
|
|
297
|
+
short m = grid.z;
|
|
298
|
+
short fft_idx = elem.z;
|
|
299
|
+
|
|
300
|
+
for (int e = 0; e < elems_per_thread / 2 + 1; e++) {
|
|
301
|
+
int index = metal::min(fft_idx + e * m, n_over_2 - 1);
|
|
302
|
+
// x_0 = z_0.real
|
|
303
|
+
// y_0 = z_0.imag
|
|
304
|
+
if (index == 0) {
|
|
305
|
+
out[batch_idx + index] = {seq_buf[index].x, 0};
|
|
306
|
+
out[batch_idx + index + next_out] = {seq_buf[index].y, 0};
|
|
307
|
+
} else {
|
|
308
|
+
float2 x_k = seq_buf[index];
|
|
309
|
+
float2 x_n_minus_k = seq_buf[n - index] * conj;
|
|
310
|
+
out[batch_idx + index] = (x_k + x_n_minus_k) / 2;
|
|
311
|
+
out[batch_idx + index + next_out] =
|
|
312
|
+
complex_mul(((x_k - x_n_minus_k) / 2), minus_j);
|
|
313
|
+
}
|
|
314
|
+
}
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
template <>
|
|
318
|
+
METAL_FUNC void ReadWriter<float, float2>::load_padded(
|
|
319
|
+
int length,
|
|
320
|
+
const device float2* w_k) const {
|
|
321
|
+
size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2;
|
|
322
|
+
threadgroup float2* seq_buf = buf + elem.y * n;
|
|
323
|
+
|
|
324
|
+
// No out of bounds accesses on odd batch sizes
|
|
325
|
+
int grid_index = elem.x * grid.y + elem.y;
|
|
326
|
+
short next_in =
|
|
327
|
+
batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : length;
|
|
328
|
+
|
|
329
|
+
short m = grid.z;
|
|
330
|
+
short fft_idx = elem.z;
|
|
331
|
+
|
|
332
|
+
for (int e = 0; e < elems_per_thread; e++) {
|
|
333
|
+
int index = metal::min(fft_idx + e * m, n - 1);
|
|
334
|
+
if (index < length) {
|
|
335
|
+
float2 elem =
|
|
336
|
+
float2(in[batch_idx + index], in[batch_idx + index + next_in]);
|
|
337
|
+
seq_buf[index] = complex_mul(elem, w_k[index]);
|
|
338
|
+
} else {
|
|
339
|
+
seq_buf[index] = 0;
|
|
340
|
+
}
|
|
341
|
+
}
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
template <>
|
|
345
|
+
METAL_FUNC void ReadWriter<float, float2>::write_padded(
|
|
346
|
+
int length,
|
|
347
|
+
const device float2* w_k) const {
|
|
348
|
+
int length_over_2 = (length / 2) + 1;
|
|
349
|
+
size_t batch_idx =
|
|
350
|
+
size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2;
|
|
351
|
+
threadgroup float2* seq_buf = buf + elem.y * n + length - 1;
|
|
352
|
+
|
|
353
|
+
int grid_index = elem.x * grid.y + elem.y;
|
|
354
|
+
short next_out = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1
|
|
355
|
+
? 0
|
|
356
|
+
: length_over_2;
|
|
357
|
+
|
|
358
|
+
float2 conj = {1, -1};
|
|
359
|
+
float2 inv_factor = {1.0f / n, -1.0f / n};
|
|
360
|
+
float2 minus_j = {0, -1};
|
|
361
|
+
|
|
362
|
+
short m = grid.z;
|
|
363
|
+
short fft_idx = elem.z;
|
|
364
|
+
|
|
365
|
+
for (int e = 0; e < elems_per_thread / 2 + 1; e++) {
|
|
366
|
+
int index = metal::min(fft_idx + e * m, length_over_2 - 1);
|
|
367
|
+
// x_0 = z_0.real
|
|
368
|
+
// y_0 = z_0.imag
|
|
369
|
+
if (index == 0) {
|
|
370
|
+
float2 elem = complex_mul(w_k[index], seq_buf[index] * inv_factor);
|
|
371
|
+
out[batch_idx + index] = float2(elem.x, 0);
|
|
372
|
+
out[batch_idx + index + next_out] = float2(elem.y, 0);
|
|
373
|
+
} else {
|
|
374
|
+
float2 x_k = complex_mul(w_k[index], seq_buf[index] * inv_factor);
|
|
375
|
+
float2 x_n_minus_k = complex_mul(
|
|
376
|
+
w_k[length - index], seq_buf[length - index] * inv_factor);
|
|
377
|
+
x_n_minus_k *= conj;
|
|
378
|
+
// w_k should happen before this extraction
|
|
379
|
+
out[batch_idx + index] = (x_k + x_n_minus_k) / 2;
|
|
380
|
+
out[batch_idx + index + next_out] =
|
|
381
|
+
complex_mul(((x_k - x_n_minus_k) / 2), minus_j);
|
|
382
|
+
}
|
|
383
|
+
}
|
|
384
|
+
}
|
|
385
|
+
|
|
386
|
+
// For IRFFT, we do the opposite
|
|
387
|
+
//
|
|
388
|
+
// Z_k = X_k + j.Y_k
|
|
389
|
+
// x_k = Re(Z_k)
|
|
390
|
+
// Y_k = Imag(Z_k)
|
|
391
|
+
template <>
|
|
392
|
+
METAL_FUNC bool ReadWriter<float2, float>::out_of_bounds() const {
|
|
393
|
+
int grid_index = elem.x * grid.y + elem.y;
|
|
394
|
+
// We pack two sequences into one for IRFFTs
|
|
395
|
+
return grid_index * 2 >= batch_size;
|
|
396
|
+
}
|
|
397
|
+
|
|
398
|
+
template <>
|
|
399
|
+
METAL_FUNC void ReadWriter<float2, float>::load() const {
|
|
400
|
+
short n_over_2 = (n / 2) + 1;
|
|
401
|
+
size_t batch_idx =
|
|
402
|
+
size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2;
|
|
403
|
+
threadgroup float2* seq_buf = buf + elem.y * n;
|
|
404
|
+
|
|
405
|
+
// No out of bounds accesses on odd batch sizes
|
|
406
|
+
int grid_index = elem.x * grid.y + elem.y;
|
|
407
|
+
short next_in =
|
|
408
|
+
batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n_over_2;
|
|
409
|
+
|
|
410
|
+
short m = grid.z;
|
|
411
|
+
short fft_idx = elem.z;
|
|
412
|
+
|
|
413
|
+
float2 conj = {1, -1};
|
|
414
|
+
float2 plus_j = {0, 1};
|
|
415
|
+
|
|
416
|
+
for (int t = 0; t < elems_per_thread / 2 + 1; t++) {
|
|
417
|
+
int index = metal::min(fft_idx + t * m, n_over_2 - 1);
|
|
418
|
+
float2 x = in[batch_idx + index];
|
|
419
|
+
float2 y = in[batch_idx + index + next_in];
|
|
420
|
+
// NumPy forces first input to be real
|
|
421
|
+
bool first_val = index == 0;
|
|
422
|
+
// NumPy forces last input on even irffts to be real
|
|
423
|
+
bool last_val = n % 2 == 0 && index == n_over_2 - 1;
|
|
424
|
+
if (first_val || last_val) {
|
|
425
|
+
x = float2(x.x, 0);
|
|
426
|
+
y = float2(y.x, 0);
|
|
427
|
+
}
|
|
428
|
+
seq_buf[index] = x + complex_mul(y, plus_j);
|
|
429
|
+
seq_buf[index].y = -seq_buf[index].y;
|
|
430
|
+
if (index > 0 && !last_val) {
|
|
431
|
+
seq_buf[n - index] = (x * conj) + complex_mul(y * conj, plus_j);
|
|
432
|
+
seq_buf[n - index].y = -seq_buf[n - index].y;
|
|
433
|
+
}
|
|
434
|
+
}
|
|
435
|
+
}
|
|
436
|
+
|
|
437
|
+
template <>
|
|
438
|
+
METAL_FUNC void ReadWriter<float2, float>::write() const {
|
|
439
|
+
int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2;
|
|
440
|
+
threadgroup float2* seq_buf = buf + elem.y * n;
|
|
441
|
+
|
|
442
|
+
int grid_index = elem.x * grid.y + elem.y;
|
|
443
|
+
short next_out =
|
|
444
|
+
batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n;
|
|
445
|
+
|
|
446
|
+
short m = grid.z;
|
|
447
|
+
short fft_idx = elem.z;
|
|
448
|
+
|
|
449
|
+
for (int e = 0; e < elems_per_thread; e++) {
|
|
450
|
+
int index = metal::min(fft_idx + e * m, n - 1);
|
|
451
|
+
out[batch_idx + index] = seq_buf[index].x / n;
|
|
452
|
+
out[batch_idx + index + next_out] = seq_buf[index].y / -n;
|
|
453
|
+
}
|
|
454
|
+
}
|
|
455
|
+
|
|
456
|
+
template <>
|
|
457
|
+
METAL_FUNC void ReadWriter<float2, float>::load_padded(
|
|
458
|
+
int length,
|
|
459
|
+
const device float2* w_k) const {
|
|
460
|
+
int n_over_2 = (n / 2) + 1;
|
|
461
|
+
int length_over_2 = (length / 2) + 1;
|
|
462
|
+
|
|
463
|
+
size_t batch_idx =
|
|
464
|
+
size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2;
|
|
465
|
+
threadgroup float2* seq_buf = buf + elem.y * n;
|
|
466
|
+
|
|
467
|
+
// No out of bounds accesses on odd batch sizes
|
|
468
|
+
int grid_index = elem.x * grid.y + elem.y;
|
|
469
|
+
short next_in = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1
|
|
470
|
+
? 0
|
|
471
|
+
: length_over_2;
|
|
472
|
+
|
|
473
|
+
short m = grid.z;
|
|
474
|
+
short fft_idx = elem.z;
|
|
475
|
+
|
|
476
|
+
float2 conj = {1, -1};
|
|
477
|
+
float2 plus_j = {0, 1};
|
|
478
|
+
|
|
479
|
+
for (int t = 0; t < elems_per_thread / 2 + 1; t++) {
|
|
480
|
+
int index = metal::min(fft_idx + t * m, n_over_2 - 1);
|
|
481
|
+
float2 x = in[batch_idx + index];
|
|
482
|
+
float2 y = in[batch_idx + index + next_in];
|
|
483
|
+
if (index < length_over_2) {
|
|
484
|
+
bool last_val = length % 2 == 0 && index == length_over_2 - 1;
|
|
485
|
+
if (last_val) {
|
|
486
|
+
x = float2(x.x, 0);
|
|
487
|
+
y = float2(y.x, 0);
|
|
488
|
+
}
|
|
489
|
+
float2 elem1 = x + complex_mul(y, plus_j);
|
|
490
|
+
seq_buf[index] = complex_mul(elem1 * conj, w_k[index]);
|
|
491
|
+
if (index > 0 && !last_val) {
|
|
492
|
+
float2 elem2 = (x * conj) + complex_mul(y * conj, plus_j);
|
|
493
|
+
seq_buf[length - index] =
|
|
494
|
+
complex_mul(elem2 * conj, w_k[length - index]);
|
|
495
|
+
}
|
|
496
|
+
} else {
|
|
497
|
+
short pad_index = metal::min(length + (index - length_over_2) * 2, n - 2);
|
|
498
|
+
seq_buf[pad_index] = 0;
|
|
499
|
+
seq_buf[pad_index + 1] = 0;
|
|
500
|
+
}
|
|
501
|
+
}
|
|
502
|
+
}
|
|
503
|
+
|
|
504
|
+
template <>
|
|
505
|
+
METAL_FUNC void ReadWriter<float2, float>::write_padded(
|
|
506
|
+
int length,
|
|
507
|
+
const device float2* w_k) const {
|
|
508
|
+
size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2;
|
|
509
|
+
threadgroup float2* seq_buf = buf + elem.y * n + length - 1;
|
|
510
|
+
|
|
511
|
+
int grid_index = elem.x * grid.y + elem.y;
|
|
512
|
+
short next_out =
|
|
513
|
+
batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : length;
|
|
514
|
+
|
|
515
|
+
short m = grid.z;
|
|
516
|
+
short fft_idx = elem.z;
|
|
517
|
+
|
|
518
|
+
float2 inv_factor = {1.0f / n, -1.0f / n};
|
|
519
|
+
for (int e = 0; e < elems_per_thread; e++) {
|
|
520
|
+
int index = fft_idx + e * m;
|
|
521
|
+
if (index < length) {
|
|
522
|
+
float2 output = complex_mul(seq_buf[index] * inv_factor, w_k[index]);
|
|
523
|
+
out[batch_idx + index] = output.x / length;
|
|
524
|
+
out[batch_idx + index + next_out] = output.y / -length;
|
|
525
|
+
}
|
|
526
|
+
}
|
|
527
|
+
}
|
|
528
|
+
|
|
529
|
+
// Four Step RFFT
|
|
530
|
+
template <>
|
|
531
|
+
METAL_FUNC void
|
|
532
|
+
ReadWriter<float2, float2, /*step=*/1, /*real=*/true>::load_strided(
|
|
533
|
+
int stride,
|
|
534
|
+
int overall_n) {
|
|
535
|
+
// Silence compiler warnings
|
|
536
|
+
(void)stride;
|
|
537
|
+
(void)overall_n;
|
|
538
|
+
// Don't invert between steps
|
|
539
|
+
bool default_inv = inv;
|
|
540
|
+
inv = false;
|
|
541
|
+
load();
|
|
542
|
+
inv = default_inv;
|
|
543
|
+
}
|
|
544
|
+
|
|
545
|
+
template <>
|
|
546
|
+
METAL_FUNC void
|
|
547
|
+
ReadWriter<float2, float2, /*step=*/1, /*real=*/true>::write_strided(
|
|
548
|
+
int stride,
|
|
549
|
+
int overall_n) {
|
|
550
|
+
int overall_n_over_2 = overall_n / 2 + 1;
|
|
551
|
+
int coalesce_width = grid.y;
|
|
552
|
+
int tg_idx = elem.y * grid.z + elem.z;
|
|
553
|
+
int outer_batch_size = stride / coalesce_width;
|
|
554
|
+
|
|
555
|
+
int strided_batch_idx = (elem.x % outer_batch_size) * coalesce_width +
|
|
556
|
+
overall_n_over_2 * (elem.x / outer_batch_size);
|
|
557
|
+
strided_device_idx = strided_batch_idx +
|
|
558
|
+
tg_idx / coalesce_width * elems_per_thread / 2 * stride +
|
|
559
|
+
tg_idx % coalesce_width;
|
|
560
|
+
strided_shared_idx = (tg_idx % coalesce_width) * n +
|
|
561
|
+
tg_idx / coalesce_width * elems_per_thread / 2;
|
|
562
|
+
for (int e = 0; e < elems_per_thread / 2; e++) {
|
|
563
|
+
float2 output = buf[strided_shared_idx + e];
|
|
564
|
+
out[strided_device_idx + e * stride] = output;
|
|
565
|
+
}
|
|
566
|
+
|
|
567
|
+
// Add on n/2 + 1 element
|
|
568
|
+
if (tg_idx == 0 && elem.x % outer_batch_size == 0) {
|
|
569
|
+
out[strided_batch_idx + overall_n / 2] = buf[n / 2];
|
|
570
|
+
}
|
|
571
|
+
}
|
|
572
|
+
|
|
573
|
+
// Four Step IRFFT
|
|
574
|
+
template <>
|
|
575
|
+
METAL_FUNC void
|
|
576
|
+
ReadWriter<float2, float2, /*step=*/0, /*real=*/true>::load_strided(
|
|
577
|
+
int stride,
|
|
578
|
+
int overall_n) {
|
|
579
|
+
int overall_n_over_2 = overall_n / 2 + 1;
|
|
580
|
+
auto conj = float2(1, -1);
|
|
581
|
+
|
|
582
|
+
compute_strided_indices(stride, overall_n);
|
|
583
|
+
// Translate indices in terms of N - k
|
|
584
|
+
for (int e = 0; e < elems_per_thread; e++) {
|
|
585
|
+
int device_idx = strided_device_idx + e * stride;
|
|
586
|
+
int overall_batch = device_idx / overall_n;
|
|
587
|
+
int overall_index = device_idx % overall_n;
|
|
588
|
+
if (overall_index < overall_n_over_2) {
|
|
589
|
+
device_idx -= overall_batch * (overall_n - overall_n_over_2);
|
|
590
|
+
buf[strided_shared_idx + e] = in[device_idx] * conj;
|
|
591
|
+
} else {
|
|
592
|
+
int conj_idx = overall_n - overall_index;
|
|
593
|
+
device_idx = overall_batch * overall_n_over_2 + conj_idx;
|
|
594
|
+
buf[strided_shared_idx + e] = in[device_idx];
|
|
595
|
+
}
|
|
596
|
+
}
|
|
597
|
+
}
|
|
598
|
+
|
|
599
|
+
template <>
|
|
600
|
+
METAL_FUNC void
|
|
601
|
+
ReadWriter<float2, float, /*step=*/1, /*real=*/true>::load_strided(
|
|
602
|
+
int stride,
|
|
603
|
+
int overall_n) {
|
|
604
|
+
// Silence compiler warnings
|
|
605
|
+
(void)stride;
|
|
606
|
+
(void)overall_n;
|
|
607
|
+
bool default_inv = inv;
|
|
608
|
+
inv = false;
|
|
609
|
+
load();
|
|
610
|
+
inv = default_inv;
|
|
611
|
+
}
|
|
612
|
+
|
|
613
|
+
template <>
|
|
614
|
+
METAL_FUNC void
|
|
615
|
+
ReadWriter<float2, float, /*step=*/1, /*real=*/true>::write_strided(
|
|
616
|
+
int stride,
|
|
617
|
+
int overall_n) {
|
|
618
|
+
compute_strided_indices(stride, overall_n);
|
|
619
|
+
|
|
620
|
+
for (int e = 0; e < elems_per_thread; e++) {
|
|
621
|
+
out[strided_device_idx + e * stride] =
|
|
622
|
+
pre_out(buf[strided_shared_idx + e], overall_n).x;
|
|
623
|
+
}
|
|
624
|
+
}
|