mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,326 @@
|
|
|
1
|
+
// Copyright © 2023-2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <metal_integer>
|
|
6
|
+
#include <metal_math>
|
|
7
|
+
|
|
8
|
+
struct Add {
|
|
9
|
+
template <typename T>
|
|
10
|
+
T operator()(T x, T y) {
|
|
11
|
+
return x + y;
|
|
12
|
+
}
|
|
13
|
+
};
|
|
14
|
+
|
|
15
|
+
struct FloorDivide {
|
|
16
|
+
template <typename T>
|
|
17
|
+
T operator()(T x, T y) {
|
|
18
|
+
return x / y;
|
|
19
|
+
}
|
|
20
|
+
template <>
|
|
21
|
+
float operator()(float x, float y) {
|
|
22
|
+
return trunc(x / y);
|
|
23
|
+
}
|
|
24
|
+
template <>
|
|
25
|
+
half operator()(half x, half y) {
|
|
26
|
+
return trunc(x / y);
|
|
27
|
+
}
|
|
28
|
+
template <>
|
|
29
|
+
bfloat16_t operator()(bfloat16_t x, bfloat16_t y) {
|
|
30
|
+
return trunc(x / y);
|
|
31
|
+
}
|
|
32
|
+
};
|
|
33
|
+
|
|
34
|
+
struct Divide {
|
|
35
|
+
template <typename T>
|
|
36
|
+
T operator()(T x, T y) {
|
|
37
|
+
return x / y;
|
|
38
|
+
}
|
|
39
|
+
};
|
|
40
|
+
|
|
41
|
+
struct Remainder {
|
|
42
|
+
template <typename T>
|
|
43
|
+
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
|
|
44
|
+
operator()(T x, T y) {
|
|
45
|
+
return x % y;
|
|
46
|
+
}
|
|
47
|
+
template <typename T>
|
|
48
|
+
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
|
|
49
|
+
operator()(T x, T y) {
|
|
50
|
+
auto r = x % y;
|
|
51
|
+
if (r != 0 && (r < 0 != y < 0)) {
|
|
52
|
+
r += y;
|
|
53
|
+
}
|
|
54
|
+
return r;
|
|
55
|
+
}
|
|
56
|
+
template <typename T>
|
|
57
|
+
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
|
58
|
+
T r = fmod(x, y);
|
|
59
|
+
if (r != 0 && (r < 0 != y < 0)) {
|
|
60
|
+
r += y;
|
|
61
|
+
}
|
|
62
|
+
return r;
|
|
63
|
+
}
|
|
64
|
+
template <>
|
|
65
|
+
complex64_t operator()(complex64_t x, complex64_t y) {
|
|
66
|
+
return x % y;
|
|
67
|
+
}
|
|
68
|
+
};
|
|
69
|
+
|
|
70
|
+
struct Equal {
|
|
71
|
+
template <typename T>
|
|
72
|
+
bool operator()(T x, T y) {
|
|
73
|
+
return x == y;
|
|
74
|
+
}
|
|
75
|
+
};
|
|
76
|
+
|
|
77
|
+
struct NaNEqual {
|
|
78
|
+
template <typename T>
|
|
79
|
+
bool operator()(T x, T y) {
|
|
80
|
+
return x == y || (metal::isnan(x) && metal::isnan(y));
|
|
81
|
+
}
|
|
82
|
+
template <>
|
|
83
|
+
bool operator()(complex64_t x, complex64_t y) {
|
|
84
|
+
return x == y ||
|
|
85
|
+
(metal::isnan(x.real) && metal::isnan(y.real) && metal::isnan(x.imag) &&
|
|
86
|
+
metal::isnan(y.imag)) ||
|
|
87
|
+
(x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) ||
|
|
88
|
+
(metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag);
|
|
89
|
+
}
|
|
90
|
+
};
|
|
91
|
+
|
|
92
|
+
struct Greater {
|
|
93
|
+
template <typename T>
|
|
94
|
+
bool operator()(T x, T y) {
|
|
95
|
+
return x > y;
|
|
96
|
+
}
|
|
97
|
+
};
|
|
98
|
+
|
|
99
|
+
struct GreaterEqual {
|
|
100
|
+
template <typename T>
|
|
101
|
+
bool operator()(T x, T y) {
|
|
102
|
+
return x >= y;
|
|
103
|
+
}
|
|
104
|
+
};
|
|
105
|
+
|
|
106
|
+
struct Less {
|
|
107
|
+
template <typename T>
|
|
108
|
+
bool operator()(T x, T y) {
|
|
109
|
+
return x < y;
|
|
110
|
+
}
|
|
111
|
+
};
|
|
112
|
+
|
|
113
|
+
struct LessEqual {
|
|
114
|
+
template <typename T>
|
|
115
|
+
bool operator()(T x, T y) {
|
|
116
|
+
return x <= y;
|
|
117
|
+
}
|
|
118
|
+
};
|
|
119
|
+
|
|
120
|
+
struct LogAddExp {
|
|
121
|
+
template <typename T>
|
|
122
|
+
T operator()(T x, T y) {
|
|
123
|
+
if (metal::isnan(x) || metal::isnan(y)) {
|
|
124
|
+
return metal::numeric_limits<T>::quiet_NaN();
|
|
125
|
+
}
|
|
126
|
+
constexpr T inf = metal::numeric_limits<T>::infinity();
|
|
127
|
+
T maxval = metal::max(x, y);
|
|
128
|
+
T minval = metal::min(x, y);
|
|
129
|
+
return (minval == -inf || maxval == inf)
|
|
130
|
+
? maxval
|
|
131
|
+
: (maxval + log1p(metal::exp(minval - maxval)));
|
|
132
|
+
};
|
|
133
|
+
|
|
134
|
+
complex64_t operator()(complex64_t x, complex64_t y) {
|
|
135
|
+
if (metal::isnan(x.real) || metal::isnan(x.imag) || metal::isnan(y.real) ||
|
|
136
|
+
metal::isnan(y.imag)) {
|
|
137
|
+
return metal::numeric_limits<float>::quiet_NaN();
|
|
138
|
+
}
|
|
139
|
+
constexpr float inf = metal::numeric_limits<float>::infinity();
|
|
140
|
+
complex64_t maxval = x > y ? x : y;
|
|
141
|
+
complex64_t minval = x < y ? x : y;
|
|
142
|
+
if (minval.real == -inf || maxval.real == inf)
|
|
143
|
+
return maxval;
|
|
144
|
+
float m = metal::exp(minval.real - maxval.real);
|
|
145
|
+
complex64_t dexp{
|
|
146
|
+
m * metal::cos(minval.imag - maxval.imag),
|
|
147
|
+
m * metal::sin(minval.imag - maxval.imag),
|
|
148
|
+
};
|
|
149
|
+
return maxval + log1p(dexp);
|
|
150
|
+
}
|
|
151
|
+
};
|
|
152
|
+
|
|
153
|
+
struct Maximum {
|
|
154
|
+
template <typename T>
|
|
155
|
+
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
|
|
156
|
+
return metal::max(x, y);
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
template <typename T>
|
|
160
|
+
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
|
161
|
+
if (metal::isnan(x)) {
|
|
162
|
+
return x;
|
|
163
|
+
}
|
|
164
|
+
return x > y ? x : y;
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
template <>
|
|
168
|
+
complex64_t operator()(complex64_t x, complex64_t y) {
|
|
169
|
+
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
|
|
170
|
+
return x;
|
|
171
|
+
}
|
|
172
|
+
return x > y ? x : y;
|
|
173
|
+
}
|
|
174
|
+
};
|
|
175
|
+
|
|
176
|
+
struct Minimum {
|
|
177
|
+
template <typename T>
|
|
178
|
+
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
|
|
179
|
+
return metal::min(x, y);
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
template <typename T>
|
|
183
|
+
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
|
184
|
+
if (metal::isnan(x)) {
|
|
185
|
+
return x;
|
|
186
|
+
}
|
|
187
|
+
return x < y ? x : y;
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
template <>
|
|
191
|
+
complex64_t operator()(complex64_t x, complex64_t y) {
|
|
192
|
+
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
|
|
193
|
+
return x;
|
|
194
|
+
}
|
|
195
|
+
return x < y ? x : y;
|
|
196
|
+
}
|
|
197
|
+
};
|
|
198
|
+
|
|
199
|
+
struct Multiply {
|
|
200
|
+
template <typename T>
|
|
201
|
+
T operator()(T x, T y) {
|
|
202
|
+
return x * y;
|
|
203
|
+
}
|
|
204
|
+
};
|
|
205
|
+
|
|
206
|
+
struct NotEqual {
|
|
207
|
+
template <typename T>
|
|
208
|
+
bool operator()(T x, T y) {
|
|
209
|
+
return x != y;
|
|
210
|
+
}
|
|
211
|
+
template <>
|
|
212
|
+
bool operator()(complex64_t x, complex64_t y) {
|
|
213
|
+
return x.real != y.real || x.imag != y.imag;
|
|
214
|
+
}
|
|
215
|
+
};
|
|
216
|
+
|
|
217
|
+
struct Power {
|
|
218
|
+
template <typename T>
|
|
219
|
+
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
|
220
|
+
return metal::pow(base, exp);
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
template <typename T>
|
|
224
|
+
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
|
225
|
+
T res = 1;
|
|
226
|
+
// Undefined to raise integer to negative power
|
|
227
|
+
if (exp < 0) {
|
|
228
|
+
return 0;
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
while (exp) {
|
|
232
|
+
if (exp & 1) {
|
|
233
|
+
res *= base;
|
|
234
|
+
}
|
|
235
|
+
exp >>= 1;
|
|
236
|
+
base *= base;
|
|
237
|
+
}
|
|
238
|
+
return res;
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
template <>
|
|
242
|
+
complex64_t operator()(complex64_t x, complex64_t y) {
|
|
243
|
+
if (x.real == 0 && x.imag == 0) {
|
|
244
|
+
if (metal::isnan(y.real) || metal::isnan(y.imag)) {
|
|
245
|
+
auto nan = metal::numeric_limits<float>::quiet_NaN();
|
|
246
|
+
return {nan, nan};
|
|
247
|
+
}
|
|
248
|
+
return {0.0, 0.0};
|
|
249
|
+
}
|
|
250
|
+
auto x_theta = metal::atan2(x.imag, x.real);
|
|
251
|
+
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
|
|
252
|
+
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
|
|
253
|
+
auto phase = y.imag * x_ln_r + y.real * x_theta;
|
|
254
|
+
return {mag * metal::cos(phase), mag * metal::sin(phase)};
|
|
255
|
+
}
|
|
256
|
+
};
|
|
257
|
+
|
|
258
|
+
struct Subtract {
|
|
259
|
+
template <typename T>
|
|
260
|
+
T operator()(T x, T y) {
|
|
261
|
+
return x - y;
|
|
262
|
+
}
|
|
263
|
+
};
|
|
264
|
+
|
|
265
|
+
struct LogicalAnd {
|
|
266
|
+
template <typename T>
|
|
267
|
+
T operator()(T x, T y) {
|
|
268
|
+
return x && y;
|
|
269
|
+
};
|
|
270
|
+
};
|
|
271
|
+
|
|
272
|
+
struct LogicalOr {
|
|
273
|
+
template <typename T>
|
|
274
|
+
T operator()(T x, T y) {
|
|
275
|
+
return x || y;
|
|
276
|
+
};
|
|
277
|
+
};
|
|
278
|
+
|
|
279
|
+
struct BitwiseAnd {
|
|
280
|
+
template <typename T>
|
|
281
|
+
T operator()(T x, T y) {
|
|
282
|
+
return x & y;
|
|
283
|
+
};
|
|
284
|
+
};
|
|
285
|
+
|
|
286
|
+
struct BitwiseOr {
|
|
287
|
+
template <typename T>
|
|
288
|
+
T operator()(T x, T y) {
|
|
289
|
+
return x | y;
|
|
290
|
+
};
|
|
291
|
+
};
|
|
292
|
+
|
|
293
|
+
struct BitwiseXor {
|
|
294
|
+
template <typename T>
|
|
295
|
+
T operator()(T x, T y) {
|
|
296
|
+
return x ^ y;
|
|
297
|
+
};
|
|
298
|
+
};
|
|
299
|
+
|
|
300
|
+
struct LeftShift {
|
|
301
|
+
template <typename T>
|
|
302
|
+
T operator()(T x, T y) {
|
|
303
|
+
return x << y;
|
|
304
|
+
};
|
|
305
|
+
};
|
|
306
|
+
|
|
307
|
+
struct RightShift {
|
|
308
|
+
template <typename T>
|
|
309
|
+
T operator()(T x, T y) {
|
|
310
|
+
return x >> y;
|
|
311
|
+
};
|
|
312
|
+
};
|
|
313
|
+
|
|
314
|
+
struct ArcTan2 {
|
|
315
|
+
template <typename T>
|
|
316
|
+
T operator()(T y, T x) {
|
|
317
|
+
return metal::precise::atan2(y, x);
|
|
318
|
+
}
|
|
319
|
+
};
|
|
320
|
+
|
|
321
|
+
struct DivMod {
|
|
322
|
+
template <typename T>
|
|
323
|
+
metal::array<T, 2> operator()(T x, T y) {
|
|
324
|
+
return {FloorDivide{}(x, y), Remainder{}(x, y)};
|
|
325
|
+
};
|
|
326
|
+
};
|
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
template <typename T, typename U, typename Op>
|
|
4
|
+
[[kernel]] void binary_ss(
|
|
5
|
+
device const T* a,
|
|
6
|
+
device const T* b,
|
|
7
|
+
device U* c,
|
|
8
|
+
device U* d,
|
|
9
|
+
uint index [[thread_position_in_grid]]) {
|
|
10
|
+
auto out = Op()(a[0], b[0]);
|
|
11
|
+
c[index] = out[0];
|
|
12
|
+
d[index] = out[1];
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
|
16
|
+
[[kernel]] void binary_sv(
|
|
17
|
+
device const T* a,
|
|
18
|
+
device const T* b,
|
|
19
|
+
device U* c,
|
|
20
|
+
device U* d,
|
|
21
|
+
constant uint& size,
|
|
22
|
+
uint index [[thread_position_in_grid]]) {
|
|
23
|
+
index *= N;
|
|
24
|
+
if (N > 1 && index + N > size) {
|
|
25
|
+
for (int i = 0; index + i < size; ++i) {
|
|
26
|
+
auto out = Op()(a[0], b[index + i]);
|
|
27
|
+
c[index + i] = out[0];
|
|
28
|
+
d[index + i] = out[1];
|
|
29
|
+
}
|
|
30
|
+
} else {
|
|
31
|
+
for (int i = 0; i < N; ++i) {
|
|
32
|
+
auto out = Op()(a[0], b[index + i]);
|
|
33
|
+
c[index + i] = out[0];
|
|
34
|
+
d[index + i] = out[1];
|
|
35
|
+
}
|
|
36
|
+
}
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
|
40
|
+
[[kernel]] void binary_vs(
|
|
41
|
+
device const T* a,
|
|
42
|
+
device const T* b,
|
|
43
|
+
device U* c,
|
|
44
|
+
device U* d,
|
|
45
|
+
constant uint& size,
|
|
46
|
+
uint index [[thread_position_in_grid]]) {
|
|
47
|
+
index *= N;
|
|
48
|
+
if (N > 1 && index + N > size) {
|
|
49
|
+
for (int i = 0; index + i < size; ++i) {
|
|
50
|
+
auto out = Op()(a[index + i], b[0]);
|
|
51
|
+
c[index + i] = out[0];
|
|
52
|
+
d[index + i] = out[1];
|
|
53
|
+
}
|
|
54
|
+
} else {
|
|
55
|
+
for (int i = 0; i < N; ++i) {
|
|
56
|
+
auto out = Op()(a[index + i], b[0]);
|
|
57
|
+
c[index + i] = out[0];
|
|
58
|
+
d[index + i] = out[1];
|
|
59
|
+
}
|
|
60
|
+
}
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
|
64
|
+
[[kernel]] void binary_vv(
|
|
65
|
+
device const T* a,
|
|
66
|
+
device const T* b,
|
|
67
|
+
device U* c,
|
|
68
|
+
device U* d,
|
|
69
|
+
constant uint& size,
|
|
70
|
+
uint index [[thread_position_in_grid]]) {
|
|
71
|
+
index *= N;
|
|
72
|
+
if (N > 1 && index + N > size) {
|
|
73
|
+
for (int i = 0; index + i < size; ++i) {
|
|
74
|
+
auto out = Op()(a[index + i], b[index + i]);
|
|
75
|
+
c[index + i] = out[0];
|
|
76
|
+
d[index + i] = out[1];
|
|
77
|
+
}
|
|
78
|
+
} else {
|
|
79
|
+
for (int i = 0; i < N; ++i) {
|
|
80
|
+
auto out = Op()(a[index + i], b[index + i]);
|
|
81
|
+
c[index + i] = out[0];
|
|
82
|
+
d[index + i] = out[1];
|
|
83
|
+
}
|
|
84
|
+
}
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
|
88
|
+
[[kernel]] void binary_sv2(
|
|
89
|
+
device const T* a,
|
|
90
|
+
device const T* b,
|
|
91
|
+
device U* c,
|
|
92
|
+
device U* d,
|
|
93
|
+
constant int64_t& size,
|
|
94
|
+
uint2 index [[thread_position_in_grid]],
|
|
95
|
+
uint2 grid_dim [[threads_per_grid]]) {
|
|
96
|
+
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
|
97
|
+
if (N > 1 && offset + N > size) {
|
|
98
|
+
for (int i = 0; offset + i < size; ++i) {
|
|
99
|
+
auto out = Op()(a[0], b[offset + i]);
|
|
100
|
+
c[offset + i] = out[0];
|
|
101
|
+
d[offset + i] = out[1];
|
|
102
|
+
}
|
|
103
|
+
} else {
|
|
104
|
+
for (int i = 0; i < N; ++i) {
|
|
105
|
+
auto out = Op()(a[0], b[offset + i]);
|
|
106
|
+
c[offset + i] = out[0];
|
|
107
|
+
d[offset + i] = out[1];
|
|
108
|
+
}
|
|
109
|
+
}
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
|
113
|
+
[[kernel]] void binary_vs2(
|
|
114
|
+
device const T* a,
|
|
115
|
+
device const T* b,
|
|
116
|
+
device U* c,
|
|
117
|
+
device U* d,
|
|
118
|
+
constant int64_t& size,
|
|
119
|
+
uint2 index [[thread_position_in_grid]],
|
|
120
|
+
uint2 grid_dim [[threads_per_grid]]) {
|
|
121
|
+
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
|
122
|
+
if (N > 1 && offset + N > size) {
|
|
123
|
+
for (int i = 0; offset + i < size; ++i) {
|
|
124
|
+
auto out = Op()(a[offset + i], b[0]);
|
|
125
|
+
c[offset + i] = out[0];
|
|
126
|
+
d[offset + i] = out[1];
|
|
127
|
+
}
|
|
128
|
+
} else {
|
|
129
|
+
for (int i = 0; i < N; ++i) {
|
|
130
|
+
auto out = Op()(a[offset + i], b[0]);
|
|
131
|
+
c[offset + i] = out[0];
|
|
132
|
+
d[offset + i] = out[1];
|
|
133
|
+
}
|
|
134
|
+
}
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
|
138
|
+
[[kernel]] void binary_vv2(
|
|
139
|
+
device const T* a,
|
|
140
|
+
device const T* b,
|
|
141
|
+
device U* c,
|
|
142
|
+
device U* d,
|
|
143
|
+
constant int64_t& size,
|
|
144
|
+
uint2 index [[thread_position_in_grid]],
|
|
145
|
+
uint2 grid_dim [[threads_per_grid]]) {
|
|
146
|
+
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
|
147
|
+
if (N > 1 && offset + N > size) {
|
|
148
|
+
for (int i = 0; offset + i < size; ++i) {
|
|
149
|
+
auto out = Op()(a[offset + i], b[offset + i]);
|
|
150
|
+
c[offset + i] = out[0];
|
|
151
|
+
d[offset + i] = out[1];
|
|
152
|
+
}
|
|
153
|
+
} else {
|
|
154
|
+
for (int i = 0; i < N; ++i) {
|
|
155
|
+
auto out = Op()(a[offset + i], b[offset + i]);
|
|
156
|
+
c[offset + i] = out[0];
|
|
157
|
+
d[offset + i] = out[1];
|
|
158
|
+
}
|
|
159
|
+
}
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
template <typename T, typename U, typename Op, typename IdxT = int64_t>
|
|
163
|
+
[[kernel]] void binary_g_nd1(
|
|
164
|
+
device const T* a,
|
|
165
|
+
device const T* b,
|
|
166
|
+
device U* c,
|
|
167
|
+
device U* d,
|
|
168
|
+
constant const int64_t& a_stride,
|
|
169
|
+
constant const int64_t& b_stride,
|
|
170
|
+
uint index [[thread_position_in_grid]]) {
|
|
171
|
+
auto a_idx = elem_to_loc_1<IdxT>(index, a_stride);
|
|
172
|
+
auto b_idx = elem_to_loc_1<IdxT>(index, b_stride);
|
|
173
|
+
auto out = Op()(a[a_idx], b[b_idx]);
|
|
174
|
+
c[index] = out[0];
|
|
175
|
+
d[index] = out[1];
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
template <typename T, typename U, typename Op, typename IdxT = int64_t>
|
|
179
|
+
[[kernel]] void binary_g_nd2(
|
|
180
|
+
device const T* a,
|
|
181
|
+
device const T* b,
|
|
182
|
+
device U* c,
|
|
183
|
+
device U* d,
|
|
184
|
+
constant const int64_t a_strides[2],
|
|
185
|
+
constant const int64_t b_strides[2],
|
|
186
|
+
uint2 index [[thread_position_in_grid]],
|
|
187
|
+
uint2 grid_dim [[threads_per_grid]]) {
|
|
188
|
+
auto a_idx = elem_to_loc_2<IdxT>(index, a_strides);
|
|
189
|
+
auto b_idx = elem_to_loc_2<IdxT>(index, b_strides);
|
|
190
|
+
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
|
|
191
|
+
auto out = Op()(a[a_idx], b[b_idx]);
|
|
192
|
+
c[out_idx] = out[0];
|
|
193
|
+
d[out_idx] = out[1];
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
template <typename T, typename U, typename Op, typename IdxT = int64_t>
|
|
197
|
+
[[kernel]] void binary_g_nd3(
|
|
198
|
+
device const T* a,
|
|
199
|
+
device const T* b,
|
|
200
|
+
device U* c,
|
|
201
|
+
device U* d,
|
|
202
|
+
constant const int64_t a_strides[3],
|
|
203
|
+
constant const int64_t b_strides[3],
|
|
204
|
+
uint3 index [[thread_position_in_grid]],
|
|
205
|
+
uint3 grid_dim [[threads_per_grid]]) {
|
|
206
|
+
auto a_idx = elem_to_loc_3<IdxT>(index, a_strides);
|
|
207
|
+
auto b_idx = elem_to_loc_3<IdxT>(index, b_strides);
|
|
208
|
+
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
|
|
209
|
+
auto out = Op()(a[a_idx], b[b_idx]);
|
|
210
|
+
c[out_idx] = out[0];
|
|
211
|
+
d[out_idx] = out[1];
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
template <
|
|
215
|
+
typename T,
|
|
216
|
+
typename U,
|
|
217
|
+
typename Op,
|
|
218
|
+
int N = 1,
|
|
219
|
+
typename IdxT = int64_t>
|
|
220
|
+
[[kernel]] void binary_g(
|
|
221
|
+
device const T* a,
|
|
222
|
+
device const T* b,
|
|
223
|
+
device U* c,
|
|
224
|
+
device U* d,
|
|
225
|
+
constant const int* shape,
|
|
226
|
+
constant const int64_t* a_strides,
|
|
227
|
+
constant const int64_t* b_strides,
|
|
228
|
+
constant const int& ndim,
|
|
229
|
+
uint3 index [[thread_position_in_grid]],
|
|
230
|
+
uint3 grid_dim [[threads_per_grid]]) {
|
|
231
|
+
auto idx = elem_to_loc_2_nd<IdxT>(
|
|
232
|
+
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
|
|
233
|
+
auto xshape = shape[ndim - 1];
|
|
234
|
+
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
|
|
235
|
+
IdxT a_xstride = a_strides[ndim - 1];
|
|
236
|
+
IdxT b_xstride = b_strides[ndim - 1];
|
|
237
|
+
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
|
238
|
+
auto out = Op()(a[idx.x], b[idx.y]);
|
|
239
|
+
c[out_idx] = out[0];
|
|
240
|
+
d[out_idx++] = out[1];
|
|
241
|
+
idx.x += a_xstride;
|
|
242
|
+
idx.y += b_xstride;
|
|
243
|
+
}
|
|
244
|
+
}
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
// Copyright © 2025 Apple Inc.
|
|
2
|
+
// Copyright © 2008-2013 NVIDIA Corporation
|
|
3
|
+
// Copyright © 2013 Filipe RNC Maia
|
|
4
|
+
//
|
|
5
|
+
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
// you may not use this file except in compliance with the License.
|
|
7
|
+
// You may obtain a copy of the License at
|
|
8
|
+
//
|
|
9
|
+
// http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
//
|
|
11
|
+
// Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
// See the License for the specific language governing permissions and
|
|
15
|
+
// limitations under the License.
|
|
16
|
+
//
|
|
17
|
+
// Forked from
|
|
18
|
+
// https://github.com/NVIDIA/cccl/blob/main/thrust/thrust/detail/complex/cexpf.h
|
|
19
|
+
|
|
20
|
+
// TODO: We should use thrust::exp but the thrust header in old CUDA versions
|
|
21
|
+
// can not be used in JIT.
|
|
22
|
+
|
|
23
|
+
#pragma once
|
|
24
|
+
|
|
25
|
+
#include <metal_math>
|
|
26
|
+
|
|
27
|
+
using ieee_float_shape_type = union {
|
|
28
|
+
float value;
|
|
29
|
+
uint32_t word;
|
|
30
|
+
};
|
|
31
|
+
|
|
32
|
+
inline void get_float_word(thread uint32_t& i, float d) {
|
|
33
|
+
ieee_float_shape_type gf_u;
|
|
34
|
+
gf_u.value = (d);
|
|
35
|
+
(i) = gf_u.word;
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
inline void get_float_word(thread int32_t& i, float d) {
|
|
39
|
+
ieee_float_shape_type gf_u;
|
|
40
|
+
gf_u.value = (d);
|
|
41
|
+
(i) = gf_u.word;
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
inline void set_float_word(thread float& d, uint32_t i) {
|
|
45
|
+
ieee_float_shape_type sf_u;
|
|
46
|
+
sf_u.word = (i);
|
|
47
|
+
(d) = sf_u.value;
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
inline float frexp_expf(float x, thread int* expt) {
|
|
51
|
+
const uint32_t k = 235;
|
|
52
|
+
const float kln2 = 162.88958740F;
|
|
53
|
+
|
|
54
|
+
float exp_x;
|
|
55
|
+
uint32_t hx;
|
|
56
|
+
|
|
57
|
+
exp_x = metal::exp(x - kln2);
|
|
58
|
+
get_float_word(hx, exp_x);
|
|
59
|
+
*expt = (hx >> 23) - (0x7f + 127) + k;
|
|
60
|
+
set_float_word(exp_x, (hx & 0x7fffff) | ((0x7f + 127) << 23));
|
|
61
|
+
return exp_x;
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
inline complex64_t ldexp_cexpf(complex64_t z, int expt) {
|
|
65
|
+
float x, y, exp_x, scale1, scale2;
|
|
66
|
+
int ex_expt, half_expt;
|
|
67
|
+
|
|
68
|
+
x = z.real;
|
|
69
|
+
y = z.imag;
|
|
70
|
+
exp_x = frexp_expf(x, &ex_expt);
|
|
71
|
+
expt += ex_expt;
|
|
72
|
+
|
|
73
|
+
half_expt = expt / 2;
|
|
74
|
+
set_float_word(scale1, (0x7f + half_expt) << 23);
|
|
75
|
+
half_expt = expt - half_expt;
|
|
76
|
+
set_float_word(scale2, (0x7f + half_expt) << 23);
|
|
77
|
+
|
|
78
|
+
return complex64_t{
|
|
79
|
+
metal::cos(y) * exp_x * scale1 * scale2,
|
|
80
|
+
metal::sin(y) * exp_x * scale1 * scale2};
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
inline complex64_t cexpf(const thread complex64_t& z) {
|
|
84
|
+
float x, y, exp_x;
|
|
85
|
+
uint32_t hx, hy;
|
|
86
|
+
|
|
87
|
+
const uint32_t exp_ovfl = 0x42b17218, cexp_ovfl = 0x43400074;
|
|
88
|
+
|
|
89
|
+
x = z.real;
|
|
90
|
+
y = z.imag;
|
|
91
|
+
|
|
92
|
+
get_float_word(hy, y);
|
|
93
|
+
hy &= 0x7fffffff;
|
|
94
|
+
|
|
95
|
+
/* cexp(x + I 0) = exp(x) + I 0 */
|
|
96
|
+
if (hy == 0) {
|
|
97
|
+
return complex64_t{metal::exp(x), y};
|
|
98
|
+
}
|
|
99
|
+
get_float_word(hx, x);
|
|
100
|
+
/* cexp(0 + I y) = cos(y) + I sin(y) */
|
|
101
|
+
if ((hx & 0x7fffffff) == 0) {
|
|
102
|
+
return complex64_t{metal::cos(y), metal::sin(y)};
|
|
103
|
+
}
|
|
104
|
+
if (hy >= 0x7f800000) {
|
|
105
|
+
if ((hx & 0x7fffffff) != 0x7f800000) {
|
|
106
|
+
/* cexp(finite|NaN +- I Inf|NaN) = NaN + I NaN */
|
|
107
|
+
return complex64_t{y - y, y - y};
|
|
108
|
+
} else if (hx & 0x80000000) {
|
|
109
|
+
/* cexp(-Inf +- I Inf|NaN) = 0 + I 0 */
|
|
110
|
+
return complex64_t{0.0, 0.0};
|
|
111
|
+
} else {
|
|
112
|
+
/* cexp(+Inf +- I Inf|NaN) = Inf + I NaN */
|
|
113
|
+
return complex64_t{x, y - y};
|
|
114
|
+
}
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
if (hx >= exp_ovfl && hx <= cexp_ovfl) {
|
|
118
|
+
/*
|
|
119
|
+
* x is between 88.7 and 192, so we must scale to avoid
|
|
120
|
+
* overflow in expf(x).
|
|
121
|
+
*/
|
|
122
|
+
return ldexp_cexpf(z, 0);
|
|
123
|
+
} else {
|
|
124
|
+
/*
|
|
125
|
+
* Cases covered here:
|
|
126
|
+
* - x < exp_ovfl and exp(x) won't overflow (common case)
|
|
127
|
+
* - x > cexp_ovfl, so exp(x) * s overflows for all s > 0
|
|
128
|
+
* - x = +-Inf (generated by exp())
|
|
129
|
+
* - x = NaN (spurious inexact exception from y)
|
|
130
|
+
*/
|
|
131
|
+
exp_x = metal::exp(x);
|
|
132
|
+
return complex64_t{exp_x * metal::cos(y), exp_x * metal::sin(y)};
|
|
133
|
+
}
|
|
134
|
+
}
|