mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,486 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
// Metal FFT using Stockham's algorithm
|
|
4
|
+
//
|
|
5
|
+
// References:
|
|
6
|
+
// - VkFFT (https://github.com/DTolm/VkFFT)
|
|
7
|
+
// - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html)
|
|
8
|
+
|
|
9
|
+
#include <metal_common>
|
|
10
|
+
|
|
11
|
+
#include "mlx/backend/metal/kernels/fft/radix.h"
|
|
12
|
+
#include "mlx/backend/metal/kernels/fft/readwrite.h"
|
|
13
|
+
#include "mlx/backend/metal/kernels/steel/defines.h"
|
|
14
|
+
|
|
15
|
+
using namespace metal;
|
|
16
|
+
|
|
17
|
+
#define MAX_RADIX 13
|
|
18
|
+
// Reached when elems_per_thread_ = 6, max_radix = 13
|
|
19
|
+
// and some threads have to do 3 radix 6s requiring 18 float2s.
|
|
20
|
+
#define MAX_OUTPUT_SIZE 18
|
|
21
|
+
|
|
22
|
+
// Specialize for a particular value of N at runtime
|
|
23
|
+
STEEL_CONST bool inv_ [[function_constant(0)]];
|
|
24
|
+
STEEL_CONST bool is_power_of_2_ [[function_constant(1)]];
|
|
25
|
+
STEEL_CONST int elems_per_thread_ [[function_constant(2)]];
|
|
26
|
+
// rader_m = n / rader_n
|
|
27
|
+
STEEL_CONST int rader_m_ [[function_constant(3)]];
|
|
28
|
+
// Stockham steps
|
|
29
|
+
STEEL_CONST int radix_13_steps_ [[function_constant(4)]];
|
|
30
|
+
STEEL_CONST int radix_11_steps_ [[function_constant(5)]];
|
|
31
|
+
STEEL_CONST int radix_8_steps_ [[function_constant(6)]];
|
|
32
|
+
STEEL_CONST int radix_7_steps_ [[function_constant(7)]];
|
|
33
|
+
STEEL_CONST int radix_6_steps_ [[function_constant(8)]];
|
|
34
|
+
STEEL_CONST int radix_5_steps_ [[function_constant(9)]];
|
|
35
|
+
STEEL_CONST int radix_4_steps_ [[function_constant(10)]];
|
|
36
|
+
STEEL_CONST int radix_3_steps_ [[function_constant(11)]];
|
|
37
|
+
STEEL_CONST int radix_2_steps_ [[function_constant(12)]];
|
|
38
|
+
// Rader steps
|
|
39
|
+
STEEL_CONST int rader_13_steps_ [[function_constant(13)]];
|
|
40
|
+
STEEL_CONST int rader_11_steps_ [[function_constant(14)]];
|
|
41
|
+
STEEL_CONST int rader_8_steps_ [[function_constant(15)]];
|
|
42
|
+
STEEL_CONST int rader_7_steps_ [[function_constant(16)]];
|
|
43
|
+
STEEL_CONST int rader_6_steps_ [[function_constant(17)]];
|
|
44
|
+
STEEL_CONST int rader_5_steps_ [[function_constant(18)]];
|
|
45
|
+
STEEL_CONST int rader_4_steps_ [[function_constant(19)]];
|
|
46
|
+
STEEL_CONST int rader_3_steps_ [[function_constant(20)]];
|
|
47
|
+
STEEL_CONST int rader_2_steps_ [[function_constant(21)]];
|
|
48
|
+
|
|
49
|
+
// See "radix.h" for radix codelets
|
|
50
|
+
typedef void (*RadixFunc)(thread float2*, thread float2*);
|
|
51
|
+
|
|
52
|
+
// Perform a single radix n butterfly with appropriate twiddles
|
|
53
|
+
template <int radix, RadixFunc radix_func>
|
|
54
|
+
METAL_FUNC void radix_butterfly(
|
|
55
|
+
int i,
|
|
56
|
+
int p,
|
|
57
|
+
thread float2* x,
|
|
58
|
+
thread short* indices,
|
|
59
|
+
thread float2* y) {
|
|
60
|
+
// i: the index in the overall DFT that we're processing.
|
|
61
|
+
// p: the size of the DFTs we're merging at this step.
|
|
62
|
+
// m: how many threads are working on this DFT.
|
|
63
|
+
int k, j;
|
|
64
|
+
|
|
65
|
+
// Use faster bitwise operations when working with powers of two
|
|
66
|
+
constexpr bool radix_p_2 = (radix & (radix - 1)) == 0;
|
|
67
|
+
if (radix_p_2 && is_power_of_2_) {
|
|
68
|
+
constexpr short power = __builtin_ctz(radix);
|
|
69
|
+
k = i & (p - 1);
|
|
70
|
+
j = ((i - k) << power) + k;
|
|
71
|
+
} else {
|
|
72
|
+
k = i % p;
|
|
73
|
+
j = (i / p) * radix * p + k;
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
// Apply twiddles
|
|
77
|
+
if (p > 1) {
|
|
78
|
+
float2 twiddle_1 = get_twiddle(k, radix * p);
|
|
79
|
+
float2 twiddle = twiddle_1;
|
|
80
|
+
x[1] = complex_mul(x[1], twiddle);
|
|
81
|
+
|
|
82
|
+
STEEL_PRAGMA_UNROLL
|
|
83
|
+
for (int t = 2; t < radix; t++) {
|
|
84
|
+
twiddle = complex_mul(twiddle, twiddle_1);
|
|
85
|
+
x[t] = complex_mul(x[t], twiddle);
|
|
86
|
+
}
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
radix_func(x, y);
|
|
90
|
+
|
|
91
|
+
STEEL_PRAGMA_UNROLL
|
|
92
|
+
for (int t = 0; t < radix; t++) {
|
|
93
|
+
indices[t] = j + t * p;
|
|
94
|
+
}
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
// Perform all the radix steps required for a
|
|
98
|
+
// particular radix size n.
|
|
99
|
+
template <int radix, RadixFunc radix_func>
|
|
100
|
+
METAL_FUNC void radix_n_steps(
|
|
101
|
+
int i,
|
|
102
|
+
thread int* p,
|
|
103
|
+
int m,
|
|
104
|
+
int n,
|
|
105
|
+
int num_steps,
|
|
106
|
+
thread float2* inputs,
|
|
107
|
+
thread short* indices,
|
|
108
|
+
thread float2* values,
|
|
109
|
+
threadgroup float2* buf) {
|
|
110
|
+
int m_r = n / radix;
|
|
111
|
+
// When combining different sized radices, we have to do
|
|
112
|
+
// multiple butterflies in a single thread.
|
|
113
|
+
// E.g. n = 28 = 4 * 7
|
|
114
|
+
// 4 threads, 7 elems_per_thread
|
|
115
|
+
// All threads do 1 radix7 butterfly.
|
|
116
|
+
// 3 threads do 2 radix4 butterflies.
|
|
117
|
+
// 1 thread does 1 radix4 butterfly.
|
|
118
|
+
int max_radices_per_thread = (elems_per_thread_ + radix - 1) / radix;
|
|
119
|
+
|
|
120
|
+
int index = 0;
|
|
121
|
+
int r_index = 0;
|
|
122
|
+
for (int s = 0; s < num_steps; s++) {
|
|
123
|
+
for (int t = 0; t < max_radices_per_thread; t++) {
|
|
124
|
+
index = i + t * m;
|
|
125
|
+
if (index < m_r) {
|
|
126
|
+
for (int r = 0; r < radix; r++) {
|
|
127
|
+
inputs[r] = buf[index + r * m_r];
|
|
128
|
+
}
|
|
129
|
+
radix_butterfly<radix, radix_func>(
|
|
130
|
+
index, *p, inputs, indices + t * radix, values + t * radix);
|
|
131
|
+
}
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
// Wait until all threads have read their inputs into thread local mem
|
|
135
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
136
|
+
|
|
137
|
+
for (int t = 0; t < max_radices_per_thread; t++) {
|
|
138
|
+
index = i + t * m;
|
|
139
|
+
if (index < m_r) {
|
|
140
|
+
for (int r = 0; r < radix; r++) {
|
|
141
|
+
r_index = t * radix + r;
|
|
142
|
+
buf[indices[r_index]] = values[r_index];
|
|
143
|
+
}
|
|
144
|
+
}
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
// Wait until all threads have written back to threadgroup mem
|
|
148
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
149
|
+
*p *= radix;
|
|
150
|
+
}
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
#define RADIX_STEP(radix, radix_func, num_steps) \
|
|
154
|
+
radix_n_steps<radix, radix_func>( \
|
|
155
|
+
fft_idx, p, m, n, num_steps, inputs, indices, values, buf);
|
|
156
|
+
|
|
157
|
+
template <bool rader = false>
|
|
158
|
+
METAL_FUNC void
|
|
159
|
+
perform_fft(int fft_idx, thread int* p, int m, int n, threadgroup float2* buf) {
|
|
160
|
+
float2 inputs[MAX_RADIX];
|
|
161
|
+
short indices[MAX_OUTPUT_SIZE];
|
|
162
|
+
float2 values[MAX_OUTPUT_SIZE];
|
|
163
|
+
|
|
164
|
+
RADIX_STEP(2, radix2, rader ? rader_2_steps_ : radix_2_steps_);
|
|
165
|
+
RADIX_STEP(3, radix3, rader ? rader_3_steps_ : radix_3_steps_);
|
|
166
|
+
RADIX_STEP(4, radix4, rader ? rader_4_steps_ : radix_4_steps_);
|
|
167
|
+
RADIX_STEP(5, radix5, rader ? rader_5_steps_ : radix_5_steps_);
|
|
168
|
+
RADIX_STEP(6, radix6, rader ? rader_6_steps_ : radix_6_steps_);
|
|
169
|
+
RADIX_STEP(7, radix7, rader ? rader_7_steps_ : radix_7_steps_);
|
|
170
|
+
RADIX_STEP(8, radix8, rader ? rader_8_steps_ : radix_8_steps_);
|
|
171
|
+
RADIX_STEP(11, radix11, rader ? rader_11_steps_ : radix_11_steps_);
|
|
172
|
+
RADIX_STEP(13, radix13, rader ? rader_13_steps_ : radix_13_steps_);
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
// Each FFT is computed entirely in shared GPU memory.
|
|
176
|
+
//
|
|
177
|
+
// N is decomposed into radix-n DFTs:
|
|
178
|
+
// e.g. 128 = 2 * 4 * 4 * 4
|
|
179
|
+
template <int tg_mem_size, typename in_T, typename out_T>
|
|
180
|
+
[[kernel]] void fft(
|
|
181
|
+
const device in_T* in [[buffer(0)]],
|
|
182
|
+
device out_T* out [[buffer(1)]],
|
|
183
|
+
constant const int& n,
|
|
184
|
+
constant const int& batch_size,
|
|
185
|
+
uint3 elem [[thread_position_in_grid]],
|
|
186
|
+
uint3 grid [[threads_per_grid]]) {
|
|
187
|
+
threadgroup float2 shared_in[tg_mem_size];
|
|
188
|
+
|
|
189
|
+
thread ReadWriter<in_T, out_T> read_writer = ReadWriter<in_T, out_T>(
|
|
190
|
+
in,
|
|
191
|
+
&shared_in[0],
|
|
192
|
+
out,
|
|
193
|
+
n,
|
|
194
|
+
batch_size,
|
|
195
|
+
elems_per_thread_,
|
|
196
|
+
elem,
|
|
197
|
+
grid,
|
|
198
|
+
inv_);
|
|
199
|
+
|
|
200
|
+
if (read_writer.out_of_bounds()) {
|
|
201
|
+
return;
|
|
202
|
+
};
|
|
203
|
+
read_writer.load();
|
|
204
|
+
|
|
205
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
206
|
+
|
|
207
|
+
int p = 1;
|
|
208
|
+
int fft_idx = elem.z; // Thread index in DFT
|
|
209
|
+
int m = grid.z; // Threads per DFT
|
|
210
|
+
int tg_idx = elem.y * n; // Index of this DFT in threadgroup
|
|
211
|
+
threadgroup float2* buf = &shared_in[tg_idx];
|
|
212
|
+
|
|
213
|
+
perform_fft(fft_idx, &p, m, n, buf);
|
|
214
|
+
|
|
215
|
+
read_writer.write();
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
template <int tg_mem_size, typename in_T, typename out_T>
|
|
219
|
+
[[kernel]] void rader_fft(
|
|
220
|
+
const device in_T* in [[buffer(0)]],
|
|
221
|
+
device out_T* out [[buffer(1)]],
|
|
222
|
+
const device float2* raders_b_q [[buffer(2)]],
|
|
223
|
+
const device short* raders_g_q [[buffer(3)]],
|
|
224
|
+
const device short* raders_g_minus_q [[buffer(4)]],
|
|
225
|
+
constant const int& n,
|
|
226
|
+
constant const int& batch_size,
|
|
227
|
+
constant const int& rader_n,
|
|
228
|
+
uint3 elem [[thread_position_in_grid]],
|
|
229
|
+
uint3 grid [[threads_per_grid]]) {
|
|
230
|
+
// Use Rader's algorithm to compute fast FFTs
|
|
231
|
+
// when a prime factor `p` of `n` is greater than 13 but
|
|
232
|
+
// has `p - 1` Stockham decomposable into to prime factors <= 13.
|
|
233
|
+
//
|
|
234
|
+
// E.g. n = 102
|
|
235
|
+
// = 2 * 3 * 17
|
|
236
|
+
// . = 2 * 3 * RADER(16)
|
|
237
|
+
// . = 2 * 3 * RADER(4 * 4)
|
|
238
|
+
//
|
|
239
|
+
// In numpy:
|
|
240
|
+
// x_perm = x[g_q]
|
|
241
|
+
// y = np.fft.fft(x_perm) * b_q
|
|
242
|
+
// z = np.fft.ifft(y) + x[0]
|
|
243
|
+
// out = z[g_minus_q]
|
|
244
|
+
// out[0] = x[1:].sum()
|
|
245
|
+
//
|
|
246
|
+
// Where the g_q and g_minus_q are permutations formed
|
|
247
|
+
// by the group under multiplicative modulo N using the
|
|
248
|
+
// primitive root of N and b_q is a constant.
|
|
249
|
+
// See https://en.wikipedia.org/wiki/Rader%27s_FFT_algorithm
|
|
250
|
+
//
|
|
251
|
+
// Rader's uses fewer operations than Bluestein's and so
|
|
252
|
+
// is more accurate. It's also faster in most cases.
|
|
253
|
+
threadgroup float2 shared_in[tg_mem_size];
|
|
254
|
+
|
|
255
|
+
thread ReadWriter<in_T, out_T> read_writer = ReadWriter<in_T, out_T>(
|
|
256
|
+
in,
|
|
257
|
+
&shared_in[0],
|
|
258
|
+
out,
|
|
259
|
+
n,
|
|
260
|
+
batch_size,
|
|
261
|
+
elems_per_thread_,
|
|
262
|
+
elem,
|
|
263
|
+
grid,
|
|
264
|
+
inv_);
|
|
265
|
+
|
|
266
|
+
if (read_writer.out_of_bounds()) {
|
|
267
|
+
return;
|
|
268
|
+
};
|
|
269
|
+
read_writer.load();
|
|
270
|
+
|
|
271
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
272
|
+
|
|
273
|
+
// The number of the threads we're using for each DFT
|
|
274
|
+
int m = grid.z;
|
|
275
|
+
|
|
276
|
+
int fft_idx = elem.z;
|
|
277
|
+
int tg_idx = elem.y * n;
|
|
278
|
+
threadgroup float2* buf = &shared_in[tg_idx];
|
|
279
|
+
|
|
280
|
+
// rader_m = n / rader_n;
|
|
281
|
+
int rader_m = rader_m_;
|
|
282
|
+
|
|
283
|
+
// We have to load two x_0s for each thread since sometimes
|
|
284
|
+
// elems_per_thread_ crosses a boundary.
|
|
285
|
+
// E.g. with n = 34, rader_n = 17, elems_per_thread_ = 4
|
|
286
|
+
// 0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4 5 5 5 5 6 6 6 6 7 7 7 7 8 8
|
|
287
|
+
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
|
|
288
|
+
short x_0_index =
|
|
289
|
+
metal::min(fft_idx * elems_per_thread_ / (rader_n - 1), rader_m - 1);
|
|
290
|
+
float2 x_0[2] = {buf[x_0_index], buf[x_0_index + 1]};
|
|
291
|
+
|
|
292
|
+
// Do the Rader permutation in shared memory
|
|
293
|
+
float2 temp[MAX_RADIX];
|
|
294
|
+
int max_index = n - rader_m - 1;
|
|
295
|
+
for (int e = 0; e < elems_per_thread_; e++) {
|
|
296
|
+
short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
|
|
297
|
+
short g_q = raders_g_q[index / rader_m];
|
|
298
|
+
temp[e] = buf[rader_m + (g_q - 1) * rader_m + index % rader_m];
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
302
|
+
|
|
303
|
+
for (int e = 0; e < elems_per_thread_; e++) {
|
|
304
|
+
short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
|
|
305
|
+
buf[index + rader_m] = temp[e];
|
|
306
|
+
}
|
|
307
|
+
|
|
308
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
309
|
+
|
|
310
|
+
// Rader FFT on x[rader_m:]
|
|
311
|
+
int p = 1;
|
|
312
|
+
perform_fft</*rader=*/true>(fft_idx, &p, m, n - rader_m, buf + rader_m);
|
|
313
|
+
|
|
314
|
+
// x_1 + ... + x_n is computed for us in the first FFT step so
|
|
315
|
+
// we save it in the first rader_m indices of the array for later.
|
|
316
|
+
int x_sum_index = metal::min(fft_idx, rader_m - 1);
|
|
317
|
+
buf[x_sum_index] = buf[rader_m + x_sum_index * (rader_n - 1)];
|
|
318
|
+
|
|
319
|
+
float2 inv = {1.0f, -1.0f};
|
|
320
|
+
for (int e = 0; e < elems_per_thread_; e++) {
|
|
321
|
+
short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
|
|
322
|
+
short interleaved_index =
|
|
323
|
+
index / rader_m + (index % rader_m) * (rader_n - 1);
|
|
324
|
+
temp[e] = complex_mul(
|
|
325
|
+
buf[rader_m + interleaved_index],
|
|
326
|
+
raders_b_q[interleaved_index % (rader_n - 1)]);
|
|
327
|
+
}
|
|
328
|
+
|
|
329
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
330
|
+
|
|
331
|
+
for (int e = 0; e < elems_per_thread_; e++) {
|
|
332
|
+
short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
|
|
333
|
+
buf[rader_m + index] = temp[e] * inv;
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
337
|
+
|
|
338
|
+
// Rader IFFT on x[rader_m:]
|
|
339
|
+
p = 1;
|
|
340
|
+
perform_fft</*rader=*/true>(fft_idx, &p, m, n - rader_m, buf + rader_m);
|
|
341
|
+
|
|
342
|
+
float2 rader_inv_factor = {1.0f / (rader_n - 1), -1.0f / (rader_n - 1)};
|
|
343
|
+
|
|
344
|
+
for (int e = 0; e < elems_per_thread_; e++) {
|
|
345
|
+
short index = metal::min(fft_idx * elems_per_thread_ + e, n - rader_m - 1);
|
|
346
|
+
short diff_index = index / (rader_n - 1) - x_0_index;
|
|
347
|
+
temp[e] = buf[rader_m + index] * rader_inv_factor + x_0[diff_index];
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
// Use the sum of elements that was computed in the first FFT
|
|
351
|
+
float2 x_sum = buf[x_0_index] + x_0[0];
|
|
352
|
+
|
|
353
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
354
|
+
|
|
355
|
+
for (int e = 0; e < elems_per_thread_; e++) {
|
|
356
|
+
short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
|
|
357
|
+
short g_q_index = index % (rader_n - 1);
|
|
358
|
+
short g_q = raders_g_minus_q[g_q_index];
|
|
359
|
+
short out_index = index - g_q_index + g_q + (index / (rader_n - 1));
|
|
360
|
+
buf[out_index] = temp[e];
|
|
361
|
+
}
|
|
362
|
+
|
|
363
|
+
buf[x_0_index * rader_n] = x_sum;
|
|
364
|
+
|
|
365
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
366
|
+
|
|
367
|
+
p = rader_n;
|
|
368
|
+
perform_fft(fft_idx, &p, m, n, buf);
|
|
369
|
+
|
|
370
|
+
read_writer.write();
|
|
371
|
+
}
|
|
372
|
+
|
|
373
|
+
template <int tg_mem_size, typename in_T, typename out_T>
|
|
374
|
+
[[kernel]] void bluestein_fft(
|
|
375
|
+
const device in_T* in [[buffer(0)]],
|
|
376
|
+
device out_T* out [[buffer(1)]],
|
|
377
|
+
const device float2* w_q [[buffer(2)]],
|
|
378
|
+
const device float2* w_k [[buffer(3)]],
|
|
379
|
+
constant const int& length,
|
|
380
|
+
constant const int& n,
|
|
381
|
+
constant const int& batch_size,
|
|
382
|
+
uint3 elem [[thread_position_in_grid]],
|
|
383
|
+
uint3 grid [[threads_per_grid]]) {
|
|
384
|
+
// Computes arbitrary length FFTs with Bluestein's algorithm
|
|
385
|
+
//
|
|
386
|
+
// In numpy:
|
|
387
|
+
// bluestein_n = next_power_of_2(2*n - 1)
|
|
388
|
+
// out = w_k * np.fft.ifft(np.fft.fft(w_k * in, bluestein_n) * w_q)
|
|
389
|
+
//
|
|
390
|
+
// Where w_k and w_q are precomputed on CPU in high precision as:
|
|
391
|
+
// w_k = np.exp(-1j * np.pi / n * (np.arange(-n + 1, n) ** 2))
|
|
392
|
+
// w_q = np.fft.fft(1/w_k[-n:])
|
|
393
|
+
threadgroup float2 shared_in[tg_mem_size];
|
|
394
|
+
|
|
395
|
+
thread ReadWriter<in_T, out_T> read_writer = ReadWriter<in_T, out_T>(
|
|
396
|
+
in,
|
|
397
|
+
&shared_in[0],
|
|
398
|
+
out,
|
|
399
|
+
n,
|
|
400
|
+
batch_size,
|
|
401
|
+
elems_per_thread_,
|
|
402
|
+
elem,
|
|
403
|
+
grid,
|
|
404
|
+
inv_);
|
|
405
|
+
|
|
406
|
+
if (read_writer.out_of_bounds()) {
|
|
407
|
+
return;
|
|
408
|
+
};
|
|
409
|
+
read_writer.load_padded(length, w_k);
|
|
410
|
+
|
|
411
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
412
|
+
|
|
413
|
+
int p = 1;
|
|
414
|
+
int fft_idx = elem.z; // Thread index in DFT
|
|
415
|
+
int m = grid.z; // Threads per DFT
|
|
416
|
+
int tg_idx = elem.y * n; // Index of this DFT in threadgroup
|
|
417
|
+
threadgroup float2* buf = &shared_in[tg_idx];
|
|
418
|
+
|
|
419
|
+
// fft
|
|
420
|
+
perform_fft(fft_idx, &p, m, n, buf);
|
|
421
|
+
|
|
422
|
+
float2 inv = float2(1.0f, -1.0f);
|
|
423
|
+
for (int t = 0; t < elems_per_thread_; t++) {
|
|
424
|
+
int index = fft_idx + t * m;
|
|
425
|
+
buf[index] = complex_mul(buf[index], w_q[index]) * inv;
|
|
426
|
+
}
|
|
427
|
+
|
|
428
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
429
|
+
|
|
430
|
+
// ifft
|
|
431
|
+
p = 1;
|
|
432
|
+
perform_fft(fft_idx, &p, m, n, buf);
|
|
433
|
+
|
|
434
|
+
read_writer.write_padded(length, w_k);
|
|
435
|
+
}
|
|
436
|
+
|
|
437
|
+
template <
|
|
438
|
+
int tg_mem_size,
|
|
439
|
+
typename in_T,
|
|
440
|
+
typename out_T,
|
|
441
|
+
int step,
|
|
442
|
+
bool real = false>
|
|
443
|
+
[[kernel]] void four_step_fft(
|
|
444
|
+
const device in_T* in [[buffer(0)]],
|
|
445
|
+
device out_T* out [[buffer(1)]],
|
|
446
|
+
constant const int& n1,
|
|
447
|
+
constant const int& n2,
|
|
448
|
+
constant const int& batch_size,
|
|
449
|
+
uint3 elem [[thread_position_in_grid]],
|
|
450
|
+
uint3 grid [[threads_per_grid]]) {
|
|
451
|
+
// Fast four step FFT implementation for powers of 2.
|
|
452
|
+
int overall_n = n1 * n2;
|
|
453
|
+
int n = step == 0 ? n1 : n2;
|
|
454
|
+
int stride = step == 0 ? n2 : n1;
|
|
455
|
+
|
|
456
|
+
// The number of the threads we're using for each DFT
|
|
457
|
+
int m = grid.z;
|
|
458
|
+
int fft_idx = elem.z;
|
|
459
|
+
|
|
460
|
+
threadgroup float2 shared_in[tg_mem_size];
|
|
461
|
+
threadgroup float2* buf = &shared_in[elem.y * n];
|
|
462
|
+
|
|
463
|
+
using read_writer_t = ReadWriter<in_T, out_T, step, real>;
|
|
464
|
+
read_writer_t read_writer = read_writer_t(
|
|
465
|
+
in,
|
|
466
|
+
&shared_in[0],
|
|
467
|
+
out,
|
|
468
|
+
n,
|
|
469
|
+
batch_size,
|
|
470
|
+
elems_per_thread_,
|
|
471
|
+
elem,
|
|
472
|
+
grid,
|
|
473
|
+
inv_);
|
|
474
|
+
|
|
475
|
+
if (read_writer.out_of_bounds()) {
|
|
476
|
+
return;
|
|
477
|
+
};
|
|
478
|
+
read_writer.load_strided(stride, overall_n);
|
|
479
|
+
|
|
480
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
481
|
+
|
|
482
|
+
int p = 1;
|
|
483
|
+
perform_fft(fft_idx, &p, m, n, buf);
|
|
484
|
+
|
|
485
|
+
read_writer.write_strided(stride, overall_n);
|
|
486
|
+
}
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
constexpr constant static float FP4_LUT[16] = {
|
|
4
|
+
+0.0f,
|
|
5
|
+
+0.5f,
|
|
6
|
+
+1.0f,
|
|
7
|
+
+1.5f,
|
|
8
|
+
+2.0f,
|
|
9
|
+
+3.0f,
|
|
10
|
+
+4.0f,
|
|
11
|
+
+6.0f,
|
|
12
|
+
-0.0f,
|
|
13
|
+
-0.5f,
|
|
14
|
+
-1.0f,
|
|
15
|
+
-1.5f,
|
|
16
|
+
-2.0f,
|
|
17
|
+
-3.0f,
|
|
18
|
+
-4.0f,
|
|
19
|
+
-6.0f};
|
|
20
|
+
|
|
21
|
+
struct fp4_e2m1 {
|
|
22
|
+
fp4_e2m1(float x) {
|
|
23
|
+
if (metal::isnan(x)) {
|
|
24
|
+
bits = 0x7;
|
|
25
|
+
return;
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
const uint8_t sign_bit = (metal::signbit(x)) ? 0x8 : 0x0;
|
|
29
|
+
x = metal::abs(x);
|
|
30
|
+
|
|
31
|
+
if (x > 5.0f) {
|
|
32
|
+
bits = 0x7;
|
|
33
|
+
} else if (x >= 3.5f) {
|
|
34
|
+
bits = 0x6;
|
|
35
|
+
} else if (x > 2.5f) {
|
|
36
|
+
bits = 0x5;
|
|
37
|
+
} else if (x >= 1.75f) {
|
|
38
|
+
bits = 0x4;
|
|
39
|
+
} else if (x > 1.25f) {
|
|
40
|
+
bits = 0x3;
|
|
41
|
+
} else if (x >= 0.75f) {
|
|
42
|
+
bits = 0x2;
|
|
43
|
+
} else if (x > 0.25f) {
|
|
44
|
+
bits = 0x1;
|
|
45
|
+
} else {
|
|
46
|
+
bits = 0x0;
|
|
47
|
+
}
|
|
48
|
+
bits |= sign_bit;
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
operator float() {
|
|
52
|
+
half converted = as_type<half>(ushort((bits & 7) << 9));
|
|
53
|
+
converted *= 16384.0;
|
|
54
|
+
converted = bits & 8 ? -converted : converted;
|
|
55
|
+
return converted;
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
uint8_t bits;
|
|
59
|
+
};
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
struct fp8_e4m3 {
|
|
4
|
+
template <typename T>
|
|
5
|
+
fp8_e4m3(T f) {
|
|
6
|
+
// From PyTorch
|
|
7
|
+
// https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L148
|
|
8
|
+
uint32_t fp8_max = 543 << 21;
|
|
9
|
+
uint32_t denorm_mask = 141 << 23;
|
|
10
|
+
uint32_t f_bits = as_type<uint32_t>(static_cast<float>(f));
|
|
11
|
+
uint32_t sign = f_bits & 0x80000000;
|
|
12
|
+
f_bits ^= sign;
|
|
13
|
+
if (f_bits >= fp8_max) {
|
|
14
|
+
// Default behavior saturates to min/max
|
|
15
|
+
bits = 0x7E;
|
|
16
|
+
} else {
|
|
17
|
+
if (f_bits < (121 << 23)) {
|
|
18
|
+
f_bits = as_type<uint32_t>(
|
|
19
|
+
as_type<float>(f_bits) + as_type<float>(denorm_mask));
|
|
20
|
+
bits = static_cast<uint8_t>(f_bits - denorm_mask);
|
|
21
|
+
} else {
|
|
22
|
+
// resulting mantissa is odd
|
|
23
|
+
uint8_t mant_odd = (f_bits >> 20) & 1;
|
|
24
|
+
f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF;
|
|
25
|
+
f_bits += mant_odd;
|
|
26
|
+
bits = static_cast<uint8_t>(f_bits >> 20);
|
|
27
|
+
}
|
|
28
|
+
}
|
|
29
|
+
bits |= static_cast<uint8_t>(sign >> 24);
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
operator float() {
|
|
33
|
+
// From PyTorch:
|
|
34
|
+
// https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L46
|
|
35
|
+
uint32_t w = static_cast<uint32_t>(bits) << 24;
|
|
36
|
+
uint32_t sign = w & 0x80000000;
|
|
37
|
+
uint32_t nonsign = w & 0x7FFFFFFF;
|
|
38
|
+
|
|
39
|
+
uint32_t renorm_shift = metal::clz(nonsign);
|
|
40
|
+
renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
|
|
41
|
+
|
|
42
|
+
int32_t inf_nan_mask =
|
|
43
|
+
(static_cast<int32_t>(nonsign + 0x01000000) >> 8) & 0x7F800000;
|
|
44
|
+
int32_t zero_mask = static_cast<int32_t>(nonsign - 1) >> 31;
|
|
45
|
+
uint32_t result = sign |
|
|
46
|
+
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
|
|
47
|
+
inf_nan_mask) &
|
|
48
|
+
~zero_mask);
|
|
49
|
+
return as_type<float>(result);
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
uint8_t bits;
|
|
53
|
+
};
|
|
54
|
+
|
|
55
|
+
struct fp8_e8m0 {
|
|
56
|
+
fp8_e8m0(float x) {
|
|
57
|
+
if (!metal::isfinite(x)) {
|
|
58
|
+
bits = 0xFF;
|
|
59
|
+
return;
|
|
60
|
+
}
|
|
61
|
+
if (x < 0.0f) {
|
|
62
|
+
bits = 0x00;
|
|
63
|
+
return;
|
|
64
|
+
}
|
|
65
|
+
float le = metal::log2(x);
|
|
66
|
+
int n = int(metal::round(le));
|
|
67
|
+
|
|
68
|
+
n = n < -127 ? -127 : n;
|
|
69
|
+
n = n > 127 ? 127 : n;
|
|
70
|
+
bits = static_cast<uint8_t>(n + 127);
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
operator bfloat16_t() {
|
|
74
|
+
uint16_t out = (bits == 0 ? 0x40 : (static_cast<uint16_t>(bits) << 7));
|
|
75
|
+
return as_type<bfloat16_t>(out);
|
|
76
|
+
}
|
|
77
|
+
operator float() {
|
|
78
|
+
return static_cast<float>(this->operator bfloat16_t());
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
uint8_t bits;
|
|
82
|
+
};
|