mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
// Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <metal_stdlib>
|
|
6
|
+
|
|
7
|
+
using namespace metal;
|
|
8
|
+
|
|
9
|
+
struct complex64_t;
|
|
10
|
+
|
|
11
|
+
template <typename T>
|
|
12
|
+
static constexpr constant bool can_convert_to_complex64 =
|
|
13
|
+
!is_same_v<T, complex64_t> && is_convertible_v<T, float>;
|
|
14
|
+
|
|
15
|
+
template <typename T>
|
|
16
|
+
static constexpr constant bool can_convert_from_complex64 =
|
|
17
|
+
!is_same_v<T, complex64_t> &&
|
|
18
|
+
(is_convertible_v<float, T> || is_convertible_v<bfloat16_t, T>);
|
|
19
|
+
|
|
20
|
+
struct complex64_t {
|
|
21
|
+
float real;
|
|
22
|
+
float imag;
|
|
23
|
+
|
|
24
|
+
// Constructors
|
|
25
|
+
constexpr complex64_t(float real, float imag) : real(real), imag(imag) {};
|
|
26
|
+
constexpr complex64_t() : real(0), imag(0) {};
|
|
27
|
+
constexpr complex64_t() threadgroup : real(0), imag(0) {};
|
|
28
|
+
|
|
29
|
+
// Conversions to complex64_t
|
|
30
|
+
template <
|
|
31
|
+
typename T,
|
|
32
|
+
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
|
33
|
+
constexpr complex64_t(T x) thread : real(x), imag(0) {}
|
|
34
|
+
|
|
35
|
+
template <
|
|
36
|
+
typename T,
|
|
37
|
+
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
|
38
|
+
constexpr complex64_t(T x) threadgroup : real(x), imag(0) {}
|
|
39
|
+
|
|
40
|
+
template <
|
|
41
|
+
typename T,
|
|
42
|
+
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
|
43
|
+
constexpr complex64_t(T x) device : real(x), imag(0) {}
|
|
44
|
+
|
|
45
|
+
template <
|
|
46
|
+
typename T,
|
|
47
|
+
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
|
48
|
+
constexpr complex64_t(T x) constant : real(x), imag(0) {}
|
|
49
|
+
|
|
50
|
+
// Conversions from complex64_t
|
|
51
|
+
template <
|
|
52
|
+
typename T,
|
|
53
|
+
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
|
54
|
+
constexpr operator T() const thread {
|
|
55
|
+
return static_cast<T>(real);
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
template <
|
|
59
|
+
typename T,
|
|
60
|
+
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
|
61
|
+
constexpr operator T() const threadgroup {
|
|
62
|
+
return static_cast<T>(real);
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
template <
|
|
66
|
+
typename T,
|
|
67
|
+
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
|
68
|
+
constexpr operator T() const device {
|
|
69
|
+
return static_cast<T>(real);
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
template <
|
|
73
|
+
typename T,
|
|
74
|
+
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
|
75
|
+
constexpr operator T() const constant {
|
|
76
|
+
return static_cast<T>(real);
|
|
77
|
+
}
|
|
78
|
+
};
|
|
79
|
+
|
|
80
|
+
constexpr complex64_t operator-(complex64_t x) {
|
|
81
|
+
return {-x.real, -x.imag};
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
constexpr bool operator>=(complex64_t a, complex64_t b) {
|
|
85
|
+
return (a.real > b.real) || (a.real == b.real && a.imag >= b.imag);
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
constexpr bool operator>(complex64_t a, complex64_t b) {
|
|
89
|
+
return (a.real > b.real) || (a.real == b.real && a.imag > b.imag);
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
constexpr bool operator<=(complex64_t a, complex64_t b) {
|
|
93
|
+
return operator>=(b, a);
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
constexpr bool operator<(complex64_t a, complex64_t b) {
|
|
97
|
+
return operator>(b, a);
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
constexpr bool operator==(complex64_t a, complex64_t b) {
|
|
101
|
+
return a.real == b.real && a.imag == b.imag;
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
constexpr complex64_t operator+(complex64_t a, complex64_t b) {
|
|
105
|
+
return {a.real + b.real, a.imag + b.imag};
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
constexpr thread complex64_t& operator+=(thread complex64_t& a, complex64_t b) {
|
|
109
|
+
a.real += b.real;
|
|
110
|
+
a.imag += b.imag;
|
|
111
|
+
return a;
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
constexpr threadgroup complex64_t& operator+=(
|
|
115
|
+
threadgroup complex64_t& a,
|
|
116
|
+
complex64_t b) {
|
|
117
|
+
a.real += b.real;
|
|
118
|
+
a.imag += b.imag;
|
|
119
|
+
return a;
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
constexpr device complex64_t& operator+=(device complex64_t& a, complex64_t b) {
|
|
123
|
+
a.real += b.real;
|
|
124
|
+
a.imag += b.imag;
|
|
125
|
+
return a;
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
constexpr complex64_t operator+(float a, complex64_t b) {
|
|
129
|
+
return {a + b.real, b.imag};
|
|
130
|
+
}
|
|
131
|
+
constexpr complex64_t operator+(complex64_t a, float b) {
|
|
132
|
+
return {a.real + b, a.imag};
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
constexpr complex64_t operator-(complex64_t a, complex64_t b) {
|
|
136
|
+
return {a.real - b.real, a.imag - b.imag};
|
|
137
|
+
}
|
|
138
|
+
constexpr complex64_t operator-(float a, complex64_t b) {
|
|
139
|
+
return {a - b.real, -b.imag};
|
|
140
|
+
}
|
|
141
|
+
constexpr complex64_t operator-(complex64_t a, float b) {
|
|
142
|
+
return {a.real - b, a.imag};
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
constexpr complex64_t operator*(complex64_t a, complex64_t b) {
|
|
146
|
+
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
constexpr complex64_t operator/(complex64_t a, complex64_t b) {
|
|
150
|
+
auto denom = b.real * b.real + b.imag * b.imag;
|
|
151
|
+
auto x = a.real * b.real + a.imag * b.imag;
|
|
152
|
+
auto y = a.imag * b.real - a.real * b.imag;
|
|
153
|
+
return {x / denom, y / denom};
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
constexpr complex64_t operator/(float a, complex64_t b) {
|
|
157
|
+
auto denom = b.real * b.real + b.imag * b.imag;
|
|
158
|
+
auto x = a * b.real;
|
|
159
|
+
auto y = -a * b.imag;
|
|
160
|
+
return {x / denom, y / denom};
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
constexpr complex64_t operator%(complex64_t a, complex64_t b) {
|
|
164
|
+
auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
|
|
165
|
+
auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));
|
|
166
|
+
if (real != 0 && (real < 0 != b.real < 0)) {
|
|
167
|
+
real += b.real;
|
|
168
|
+
}
|
|
169
|
+
if (imag != 0 && (imag < 0 != b.imag < 0)) {
|
|
170
|
+
imag += b.imag;
|
|
171
|
+
}
|
|
172
|
+
return {real, imag};
|
|
173
|
+
}
|
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
template <typename T, typename U, int N = WorkPerThread<U>::n>
|
|
4
|
+
[[kernel]] void copy_s(
|
|
5
|
+
device const T* src [[buffer(0)]],
|
|
6
|
+
device U* dst [[buffer(1)]],
|
|
7
|
+
constant uint& size,
|
|
8
|
+
uint index [[thread_position_in_grid]]) {
|
|
9
|
+
index *= N;
|
|
10
|
+
if (N > 1 && index + N > size) {
|
|
11
|
+
for (int i = 0; index + i < size; ++i) {
|
|
12
|
+
dst[index + i] = static_cast<U>(src[0]);
|
|
13
|
+
}
|
|
14
|
+
} else {
|
|
15
|
+
for (int i = 0; i < N; ++i) {
|
|
16
|
+
dst[index + i] = static_cast<U>(src[0]);
|
|
17
|
+
}
|
|
18
|
+
}
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
template <typename T, typename U, int N = WorkPerThread<U>::n>
|
|
22
|
+
[[kernel]] void copy_v(
|
|
23
|
+
device const T* src [[buffer(0)]],
|
|
24
|
+
device U* dst [[buffer(1)]],
|
|
25
|
+
constant uint& size,
|
|
26
|
+
uint index [[thread_position_in_grid]]) {
|
|
27
|
+
index *= N;
|
|
28
|
+
if (N > 1 && index + N > size) {
|
|
29
|
+
for (int i = 0; index + i < size; ++i) {
|
|
30
|
+
dst[index + i] = static_cast<U>(src[index + i]);
|
|
31
|
+
}
|
|
32
|
+
} else {
|
|
33
|
+
for (int i = 0; i < N; ++i) {
|
|
34
|
+
dst[index + i] = static_cast<U>(src[index + i]);
|
|
35
|
+
}
|
|
36
|
+
}
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
template <typename T, typename U, int N = WorkPerThread<U>::n>
|
|
40
|
+
[[kernel]] void copy_s2(
|
|
41
|
+
device const T* src [[buffer(0)]],
|
|
42
|
+
device U* dst [[buffer(1)]],
|
|
43
|
+
constant int64_t& size,
|
|
44
|
+
uint2 index [[thread_position_in_grid]],
|
|
45
|
+
uint2 grid_dim [[threads_per_grid]]) {
|
|
46
|
+
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
|
47
|
+
if (N > 1 && offset + N > size) {
|
|
48
|
+
for (int i = 0; offset + i < size; ++i) {
|
|
49
|
+
dst[offset + i] = static_cast<U>(src[0]);
|
|
50
|
+
}
|
|
51
|
+
} else {
|
|
52
|
+
for (int i = 0; i < N; ++i) {
|
|
53
|
+
dst[offset + i] = static_cast<U>(src[0]);
|
|
54
|
+
}
|
|
55
|
+
}
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
template <typename T, typename U, int N = WorkPerThread<U>::n>
|
|
59
|
+
[[kernel]] void copy_v2(
|
|
60
|
+
device const T* src [[buffer(0)]],
|
|
61
|
+
device U* dst [[buffer(1)]],
|
|
62
|
+
constant int64_t& size,
|
|
63
|
+
uint2 index [[thread_position_in_grid]],
|
|
64
|
+
uint2 grid_dim [[threads_per_grid]]) {
|
|
65
|
+
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
|
66
|
+
if (N > 1 && offset + N > size) {
|
|
67
|
+
for (int i = 0; offset + i < size; ++i) {
|
|
68
|
+
dst[offset + i] = static_cast<U>(src[offset + i]);
|
|
69
|
+
}
|
|
70
|
+
} else {
|
|
71
|
+
for (int i = 0; i < N; ++i) {
|
|
72
|
+
dst[offset + i] = static_cast<U>(src[offset + i]);
|
|
73
|
+
}
|
|
74
|
+
}
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
template <typename T, typename U, typename IdxT = int64_t>
|
|
78
|
+
[[kernel]] void copy_g_nd1(
|
|
79
|
+
device const T* src [[buffer(0)]],
|
|
80
|
+
device U* dst [[buffer(1)]],
|
|
81
|
+
constant const int64_t& src_stride [[buffer(3)]],
|
|
82
|
+
uint index [[thread_position_in_grid]]) {
|
|
83
|
+
auto src_idx = elem_to_loc_1<IdxT>(index, src_stride);
|
|
84
|
+
dst[index] = static_cast<U>(src[src_idx]);
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
template <typename T, typename U, typename IdxT = int64_t>
|
|
88
|
+
[[kernel]] void copy_g_nd2(
|
|
89
|
+
device const T* src [[buffer(0)]],
|
|
90
|
+
device U* dst [[buffer(1)]],
|
|
91
|
+
constant const int64_t* src_strides [[buffer(3)]],
|
|
92
|
+
uint2 index [[thread_position_in_grid]],
|
|
93
|
+
uint2 grid_dim [[threads_per_grid]]) {
|
|
94
|
+
auto src_idx = elem_to_loc_2<IdxT>(index, src_strides);
|
|
95
|
+
IdxT dst_idx = index.x + IdxT(grid_dim.x) * index.y;
|
|
96
|
+
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
template <typename T, typename U, typename IdxT = int64_t>
|
|
100
|
+
[[kernel]] void copy_g_nd3(
|
|
101
|
+
device const T* src [[buffer(0)]],
|
|
102
|
+
device U* dst [[buffer(1)]],
|
|
103
|
+
constant const int64_t* src_strides [[buffer(3)]],
|
|
104
|
+
uint3 index [[thread_position_in_grid]],
|
|
105
|
+
uint3 grid_dim [[threads_per_grid]]) {
|
|
106
|
+
auto src_idx = elem_to_loc_3<IdxT>(index, src_strides);
|
|
107
|
+
IdxT dst_idx =
|
|
108
|
+
index.x + IdxT(grid_dim.x) * (index.y + IdxT(grid_dim.y) * index.z);
|
|
109
|
+
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
template <typename T, typename U, int N = 1, typename IdxT = int64_t>
|
|
113
|
+
[[kernel]] void copy_g(
|
|
114
|
+
device const T* src [[buffer(0)]],
|
|
115
|
+
device U* dst [[buffer(1)]],
|
|
116
|
+
constant const int* src_shape [[buffer(2)]],
|
|
117
|
+
constant const int64_t* src_strides [[buffer(3)]],
|
|
118
|
+
constant const int& ndim [[buffer(5)]],
|
|
119
|
+
uint3 index [[thread_position_in_grid]],
|
|
120
|
+
uint3 grid_dim [[threads_per_grid]]) {
|
|
121
|
+
auto src_idx = elem_to_loc<IdxT>(
|
|
122
|
+
{N * index.x, index.y, index.z}, src_shape, src_strides, ndim);
|
|
123
|
+
if (N == 1) {
|
|
124
|
+
IdxT dst_idx =
|
|
125
|
+
index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
|
|
126
|
+
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
|
127
|
+
return;
|
|
128
|
+
}
|
|
129
|
+
auto xshape = src_shape[ndim - 1];
|
|
130
|
+
IdxT dst_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
|
|
131
|
+
auto src_xstride = src_strides[ndim - 1];
|
|
132
|
+
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
|
133
|
+
dst[dst_idx + i] = static_cast<U>(src[src_idx]);
|
|
134
|
+
src_idx += src_xstride;
|
|
135
|
+
}
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
template <typename T, typename U, typename IdxT = int64_t>
|
|
139
|
+
[[kernel]] void copy_gg_nd1(
|
|
140
|
+
device const T* src [[buffer(0)]],
|
|
141
|
+
device U* dst [[buffer(1)]],
|
|
142
|
+
constant const int64_t& src_stride [[buffer(3)]],
|
|
143
|
+
constant const int64_t& dst_stride [[buffer(4)]],
|
|
144
|
+
uint index [[thread_position_in_grid]]) {
|
|
145
|
+
auto src_idx = elem_to_loc_1<IdxT>(index, src_stride);
|
|
146
|
+
auto dst_idx = elem_to_loc_1<IdxT>(index, dst_stride);
|
|
147
|
+
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
template <typename T, typename U, typename IdxT = int64_t>
|
|
151
|
+
[[kernel]] void copy_gg_nd2(
|
|
152
|
+
device const T* src [[buffer(0)]],
|
|
153
|
+
device U* dst [[buffer(1)]],
|
|
154
|
+
constant const int64_t* src_strides [[buffer(3)]],
|
|
155
|
+
constant const int64_t* dst_strides [[buffer(4)]],
|
|
156
|
+
uint2 index [[thread_position_in_grid]]) {
|
|
157
|
+
auto src_idx = elem_to_loc_2<IdxT>(index, src_strides);
|
|
158
|
+
auto dst_idx = elem_to_loc_2<IdxT>(index, dst_strides);
|
|
159
|
+
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
template <typename T, typename U, typename IdxT = int64_t>
|
|
163
|
+
[[kernel]] void copy_gg_nd3(
|
|
164
|
+
device const T* src [[buffer(0)]],
|
|
165
|
+
device U* dst [[buffer(1)]],
|
|
166
|
+
constant const int64_t* src_strides [[buffer(3)]],
|
|
167
|
+
constant const int64_t* dst_strides [[buffer(4)]],
|
|
168
|
+
uint3 index [[thread_position_in_grid]]) {
|
|
169
|
+
auto src_idx = elem_to_loc_3<IdxT>(index, src_strides);
|
|
170
|
+
auto dst_idx = elem_to_loc_3<IdxT>(index, dst_strides);
|
|
171
|
+
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
template <typename T, typename U, int N = 1, typename IdxT = int64_t>
|
|
175
|
+
[[kernel]] void copy_gg(
|
|
176
|
+
device const T* src [[buffer(0)]],
|
|
177
|
+
device U* dst [[buffer(1)]],
|
|
178
|
+
constant const int* src_shape [[buffer(2)]],
|
|
179
|
+
constant const int64_t* src_strides [[buffer(3)]],
|
|
180
|
+
constant const int64_t* dst_strides [[buffer(4)]],
|
|
181
|
+
constant const int& ndim [[buffer(5)]],
|
|
182
|
+
uint3 index [[thread_position_in_grid]]) {
|
|
183
|
+
auto idx = elem_to_loc_2_nd<IdxT>(
|
|
184
|
+
{N * index.x, index.y, index.z},
|
|
185
|
+
src_shape,
|
|
186
|
+
src_strides,
|
|
187
|
+
dst_strides,
|
|
188
|
+
ndim);
|
|
189
|
+
if (N == 1) {
|
|
190
|
+
dst[idx.y] = static_cast<U>(src[idx.x]);
|
|
191
|
+
return;
|
|
192
|
+
}
|
|
193
|
+
IdxT src_xstride = src_strides[ndim - 1];
|
|
194
|
+
IdxT dst_xstride = dst_strides[ndim - 1];
|
|
195
|
+
auto xshape = src_shape[ndim - 1];
|
|
196
|
+
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
|
197
|
+
dst[idx.y] = static_cast<U>(src[idx.x]);
|
|
198
|
+
idx.x += src_xstride;
|
|
199
|
+
idx.y += dst_xstride;
|
|
200
|
+
}
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
template <typename T, typename U, typename IdxT = int64_t>
|
|
204
|
+
[[kernel]] void copy_gg_dynamic_nd1(
|
|
205
|
+
device const T* src [[buffer(0)]],
|
|
206
|
+
device U* dst [[buffer(1)]],
|
|
207
|
+
constant const int64_t& src_stride [[buffer(3)]],
|
|
208
|
+
constant const int64_t& dst_stride [[buffer(4)]],
|
|
209
|
+
constant const int64_t& src_offset [[buffer(6)]],
|
|
210
|
+
constant const int64_t& dst_offset [[buffer(7)]],
|
|
211
|
+
uint index [[thread_position_in_grid]]) {
|
|
212
|
+
auto src_idx = elem_to_loc_1<IdxT>(index, src_stride);
|
|
213
|
+
auto dst_idx = elem_to_loc_1<IdxT>(index, dst_stride);
|
|
214
|
+
dst[dst_idx + dst_offset] = src[src_idx + src_offset];
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
template <typename T, typename U, typename IdxT = int64_t>
|
|
218
|
+
[[kernel]] void copy_gg_dynamic_nd2(
|
|
219
|
+
device const T* src [[buffer(0)]],
|
|
220
|
+
device U* dst [[buffer(1)]],
|
|
221
|
+
constant const int64_t* src_strides [[buffer(3)]],
|
|
222
|
+
constant const int64_t* dst_strides [[buffer(4)]],
|
|
223
|
+
constant const int64_t& src_offset [[buffer(6)]],
|
|
224
|
+
constant const int64_t& dst_offset [[buffer(7)]],
|
|
225
|
+
uint2 index [[thread_position_in_grid]]) {
|
|
226
|
+
auto src_idx = elem_to_loc_2<IdxT>(index, src_strides);
|
|
227
|
+
auto dst_idx = elem_to_loc_2<IdxT>(index, dst_strides);
|
|
228
|
+
dst[dst_idx + dst_offset] = src[src_idx + src_offset];
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
template <typename T, typename U, typename IdxT = int64_t>
|
|
232
|
+
[[kernel]] void copy_gg_dynamic_nd3(
|
|
233
|
+
device const T* src [[buffer(0)]],
|
|
234
|
+
device U* dst [[buffer(1)]],
|
|
235
|
+
constant const int64_t* src_strides [[buffer(3)]],
|
|
236
|
+
constant const int64_t* dst_strides [[buffer(4)]],
|
|
237
|
+
constant const int64_t& src_offset [[buffer(6)]],
|
|
238
|
+
constant const int64_t& dst_offset [[buffer(7)]],
|
|
239
|
+
uint3 index [[thread_position_in_grid]]) {
|
|
240
|
+
auto src_idx = elem_to_loc_3<IdxT>(index, src_strides);
|
|
241
|
+
auto dst_idx = elem_to_loc_3<IdxT>(index, dst_strides);
|
|
242
|
+
dst[dst_idx + dst_offset] = src[src_idx + src_offset];
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
template <typename T, typename U, int N = 1, typename IdxT = int64_t>
|
|
246
|
+
[[kernel]] void copy_gg_dynamic(
|
|
247
|
+
device const T* src [[buffer(0)]],
|
|
248
|
+
device U* dst [[buffer(1)]],
|
|
249
|
+
constant const int* src_shape [[buffer(2)]],
|
|
250
|
+
constant const int64_t* src_strides [[buffer(3)]],
|
|
251
|
+
constant const int64_t* dst_strides [[buffer(4)]],
|
|
252
|
+
constant const int& ndim [[buffer(5)]],
|
|
253
|
+
constant const int64_t& src_offset [[buffer(6)]],
|
|
254
|
+
constant const int64_t& dst_offset [[buffer(7)]],
|
|
255
|
+
uint3 index [[thread_position_in_grid]]) {
|
|
256
|
+
src += src_offset;
|
|
257
|
+
dst += dst_offset;
|
|
258
|
+
auto idx = elem_to_loc_2_nd<IdxT>(
|
|
259
|
+
{N * index.x, index.y, index.z},
|
|
260
|
+
src_shape,
|
|
261
|
+
src_strides,
|
|
262
|
+
dst_strides,
|
|
263
|
+
ndim);
|
|
264
|
+
if (N == 1) {
|
|
265
|
+
dst[idx.y] = src[idx.x];
|
|
266
|
+
return;
|
|
267
|
+
}
|
|
268
|
+
IdxT src_xstride = src_strides[ndim - 1];
|
|
269
|
+
IdxT dst_xstride = dst_strides[ndim - 1];
|
|
270
|
+
auto xshape = src_shape[ndim - 1];
|
|
271
|
+
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
|
272
|
+
dst[idx.y] = src[idx.x];
|
|
273
|
+
idx.x += src_xstride;
|
|
274
|
+
idx.y += dst_xstride;
|
|
275
|
+
}
|
|
276
|
+
}
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
// Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#if defined __METAL__ || defined MLX_METAL_JIT
|
|
6
|
+
#define MTL_CONST constant
|
|
7
|
+
#else
|
|
8
|
+
#define MTL_CONST
|
|
9
|
+
#endif
|
|
10
|
+
|
|
11
|
+
static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
|
|
12
|
+
static MTL_CONST constexpr int REDUCE_N_READS = 4;
|
|
13
|
+
static MTL_CONST constexpr int REDUCE_N_WRITES = 4;
|
|
14
|
+
static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
|
|
15
|
+
static MTL_CONST constexpr int RMS_N_READS = 4;
|
|
16
|
+
static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096;
|
|
17
|
+
|
|
18
|
+
// Instantiate a templated kernel.
|
|
19
|
+
// Extra args are used as template parameters:
|
|
20
|
+
// e.g. instantiate_kernel(binary_int, binary, a, b) ->
|
|
21
|
+
// [[host_name(binary_int)]] [kernel] binary<a, b>
|
|
22
|
+
#define instantiate_kernel(name, func, ...) \
|
|
23
|
+
template [[host_name( \
|
|
24
|
+
name)]] [[kernel]] decltype(func<__VA_ARGS__>) func<__VA_ARGS__>;
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
// Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
#include <metal_math>
|
|
5
|
+
|
|
6
|
+
/*
|
|
7
|
+
* Approximation to the error function.
|
|
8
|
+
* Based on code from:
|
|
9
|
+
* https://stackoverflow.com/questions/35148198/efficient-faithfully-rounded-implementation-of-error-function-erff#answer-35148199
|
|
10
|
+
*/
|
|
11
|
+
float erf(float a) {
|
|
12
|
+
float r, s, t, u;
|
|
13
|
+
t = metal::abs(a);
|
|
14
|
+
s = a * a;
|
|
15
|
+
if (t > 0.927734375f) {
|
|
16
|
+
// maximum error 0.99527 ulp
|
|
17
|
+
r = metal::fma(
|
|
18
|
+
-1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12
|
|
19
|
+
u = metal::fma(
|
|
20
|
+
-3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6
|
|
21
|
+
r = metal::fma(r, s, u);
|
|
22
|
+
r = metal::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4
|
|
23
|
+
r = metal::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1
|
|
24
|
+
r = metal::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3
|
|
25
|
+
r = metal::fma(r, t, -t);
|
|
26
|
+
// TODO, replace with expm1 when implemented
|
|
27
|
+
r = 1.0f - metal::exp(r);
|
|
28
|
+
r = metal::copysign(r, a);
|
|
29
|
+
} else {
|
|
30
|
+
// maximum error 0.98929 ulp
|
|
31
|
+
r = -5.96761703e-4f; // -0x1.38e000p-11
|
|
32
|
+
r = metal::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8
|
|
33
|
+
r = metal::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6
|
|
34
|
+
r = metal::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4
|
|
35
|
+
r = metal::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2
|
|
36
|
+
r = metal::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3
|
|
37
|
+
r = metal::fma(r, a, a);
|
|
38
|
+
}
|
|
39
|
+
return r;
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
float erfinv(float a) {
|
|
43
|
+
auto t = metal::fma(a, 0.0f - a, 1.0f);
|
|
44
|
+
t = metal::log(t);
|
|
45
|
+
float p;
|
|
46
|
+
if (metal::abs(t) > 6.125f) { // maximum ulp error = 2.35793
|
|
47
|
+
p = 3.03697567e-10f; // 0x1.4deb44p-32
|
|
48
|
+
p = metal::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
|
|
49
|
+
p = metal::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
|
|
50
|
+
p = metal::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
|
|
51
|
+
p = metal::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
|
|
52
|
+
p = metal::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
|
|
53
|
+
p = metal::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
|
|
54
|
+
p = metal::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
|
|
55
|
+
p = metal::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
|
|
56
|
+
} else { // maximum ulp error = 2.35002
|
|
57
|
+
p = 5.43877832e-9f; // 0x1.75c000p-28
|
|
58
|
+
p = metal::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
|
|
59
|
+
p = metal::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
|
|
60
|
+
p = metal::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
|
|
61
|
+
p = metal::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
|
|
62
|
+
p = metal::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
|
|
63
|
+
p = metal::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
|
|
64
|
+
p = metal::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
|
|
65
|
+
p = metal::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
|
|
66
|
+
p = metal::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
|
|
67
|
+
}
|
|
68
|
+
return a * p;
|
|
69
|
+
}
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
// Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <metal_math>
|
|
6
|
+
|
|
7
|
+
// Original license copied below:
|
|
8
|
+
// Copyright (c) 2015-2023 Norbert Juffa
|
|
9
|
+
// All rights reserved.
|
|
10
|
+
//
|
|
11
|
+
// Redistribution and use in source and binary forms, with or without
|
|
12
|
+
// modification, are permitted provided that the following conditions
|
|
13
|
+
// are met:
|
|
14
|
+
//
|
|
15
|
+
// 1. Redistributions of source code must retain the above copyright
|
|
16
|
+
// notice, this list of conditions and the following disclaimer.
|
|
17
|
+
//
|
|
18
|
+
// 2. Redistributions in binary form must reproduce the above copyright
|
|
19
|
+
// notice, this list of conditions and the following disclaimer in the
|
|
20
|
+
// documentation and/or other materials provided with the distribution.
|
|
21
|
+
//
|
|
22
|
+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
|
23
|
+
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
|
24
|
+
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
|
25
|
+
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
|
26
|
+
// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
|
27
|
+
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
|
28
|
+
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
|
29
|
+
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
|
30
|
+
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
31
|
+
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
32
|
+
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
33
|
+
|
|
34
|
+
/* Compute exponential base e minus 1. Maximum ulp error = 0.997458
|
|
35
|
+
|
|
36
|
+
i = rint(a/log(2)), f = a-i*log(2). Then expm1(a) = 2**i * (expm1(f)+1) - 1.
|
|
37
|
+
Compute r = expm1(f). Then expm1(a)= 2 * (0.5 * 2**i * r + 0.5 * 2**i - 0.5).
|
|
38
|
+
With t = 0.5*2**i, expm1(a) = 2*(r * t + t-0.5). However, for best accuracy,
|
|
39
|
+
when i == 1, expm1(a)= 2*(r + 0.5), and when i == 0, expm1(a) = r.
|
|
40
|
+
|
|
41
|
+
NOTE: Scale factor b is only applied if i < 0 or i > 1 (should be power of 2)
|
|
42
|
+
*/
|
|
43
|
+
float expm1f_scaled_unchecked(float a, float b) {
|
|
44
|
+
float f, j, r, s, t, u, v, x, y;
|
|
45
|
+
int i;
|
|
46
|
+
|
|
47
|
+
// exp(a) = 2**i * exp(f); i = rintf (a / log(2))
|
|
48
|
+
j = fma(1.442695f, a, 12582912.f); // 0x1.715476p0, 0x1.8p23
|
|
49
|
+
j = j - 12582912.0f; // 0x1.8p23
|
|
50
|
+
i = (int)j;
|
|
51
|
+
f = fma(j, -6.93145752e-1f, a);
|
|
52
|
+
|
|
53
|
+
// approximate r = exp(f)-1 on interval [-log(2)/2, +log(2)/2]
|
|
54
|
+
s = f * f;
|
|
55
|
+
if (a == 0.0f)
|
|
56
|
+
s = a; // ensure -0 is passed through
|
|
57
|
+
// err = 0.997458 ulp1 = 11081805
|
|
58
|
+
r = 1.97350979e-4f; // 0x1.9de000p-13
|
|
59
|
+
r = fma(r, f, 1.39309070e-3f); // 0x1.6d30bcp-10
|
|
60
|
+
r = fma(r, f, 8.33343994e-3f); // 0x1.1111f6p-7
|
|
61
|
+
r = fma(r, f, 4.16668020e-2f); // 0x1.55559ep-5
|
|
62
|
+
r = fma(r, f, 1.66666716e-1f); // 0x1.55555cp-3
|
|
63
|
+
r = fma(r, f, 4.99999970e-1f); // 0x1.fffffep-2
|
|
64
|
+
u = (j == 1) ? (f + 0.5f) : f;
|
|
65
|
+
v = fma(r, s, u);
|
|
66
|
+
s = 0.5f * b;
|
|
67
|
+
t = ldexp(s, i);
|
|
68
|
+
y = t - s;
|
|
69
|
+
x = (t - y) - s; // double-float canonicalization of difference
|
|
70
|
+
r = fma(v, t, x) + y;
|
|
71
|
+
r = r + r;
|
|
72
|
+
if (j == 0)
|
|
73
|
+
r = v;
|
|
74
|
+
if (j == 1)
|
|
75
|
+
r = v + v;
|
|
76
|
+
return r;
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
/* Compute exponential base e minus 1. max ulp err = 0.99746 */
|
|
80
|
+
float expm1f(float a) {
|
|
81
|
+
float r;
|
|
82
|
+
|
|
83
|
+
r = expm1f_scaled_unchecked(a, 1.0f);
|
|
84
|
+
/* handle severe overflow and underflow */
|
|
85
|
+
if (abs(a - 1.0f) > 88.0f) {
|
|
86
|
+
r = pow(2, a);
|
|
87
|
+
r = fma(r, r, -1.0f);
|
|
88
|
+
}
|
|
89
|
+
return r;
|
|
90
|
+
}
|