mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,173 @@
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <metal_stdlib>
6
+
7
+ using namespace metal;
8
+
9
+ struct complex64_t;
10
+
11
+ template <typename T>
12
+ static constexpr constant bool can_convert_to_complex64 =
13
+ !is_same_v<T, complex64_t> && is_convertible_v<T, float>;
14
+
15
+ template <typename T>
16
+ static constexpr constant bool can_convert_from_complex64 =
17
+ !is_same_v<T, complex64_t> &&
18
+ (is_convertible_v<float, T> || is_convertible_v<bfloat16_t, T>);
19
+
20
+ struct complex64_t {
21
+ float real;
22
+ float imag;
23
+
24
+ // Constructors
25
+ constexpr complex64_t(float real, float imag) : real(real), imag(imag) {};
26
+ constexpr complex64_t() : real(0), imag(0) {};
27
+ constexpr complex64_t() threadgroup : real(0), imag(0) {};
28
+
29
+ // Conversions to complex64_t
30
+ template <
31
+ typename T,
32
+ typename = typename enable_if<can_convert_to_complex64<T>>::type>
33
+ constexpr complex64_t(T x) thread : real(x), imag(0) {}
34
+
35
+ template <
36
+ typename T,
37
+ typename = typename enable_if<can_convert_to_complex64<T>>::type>
38
+ constexpr complex64_t(T x) threadgroup : real(x), imag(0) {}
39
+
40
+ template <
41
+ typename T,
42
+ typename = typename enable_if<can_convert_to_complex64<T>>::type>
43
+ constexpr complex64_t(T x) device : real(x), imag(0) {}
44
+
45
+ template <
46
+ typename T,
47
+ typename = typename enable_if<can_convert_to_complex64<T>>::type>
48
+ constexpr complex64_t(T x) constant : real(x), imag(0) {}
49
+
50
+ // Conversions from complex64_t
51
+ template <
52
+ typename T,
53
+ typename = typename enable_if<can_convert_from_complex64<T>>::type>
54
+ constexpr operator T() const thread {
55
+ return static_cast<T>(real);
56
+ }
57
+
58
+ template <
59
+ typename T,
60
+ typename = typename enable_if<can_convert_from_complex64<T>>::type>
61
+ constexpr operator T() const threadgroup {
62
+ return static_cast<T>(real);
63
+ }
64
+
65
+ template <
66
+ typename T,
67
+ typename = typename enable_if<can_convert_from_complex64<T>>::type>
68
+ constexpr operator T() const device {
69
+ return static_cast<T>(real);
70
+ }
71
+
72
+ template <
73
+ typename T,
74
+ typename = typename enable_if<can_convert_from_complex64<T>>::type>
75
+ constexpr operator T() const constant {
76
+ return static_cast<T>(real);
77
+ }
78
+ };
79
+
80
+ constexpr complex64_t operator-(complex64_t x) {
81
+ return {-x.real, -x.imag};
82
+ }
83
+
84
+ constexpr bool operator>=(complex64_t a, complex64_t b) {
85
+ return (a.real > b.real) || (a.real == b.real && a.imag >= b.imag);
86
+ }
87
+
88
+ constexpr bool operator>(complex64_t a, complex64_t b) {
89
+ return (a.real > b.real) || (a.real == b.real && a.imag > b.imag);
90
+ }
91
+
92
+ constexpr bool operator<=(complex64_t a, complex64_t b) {
93
+ return operator>=(b, a);
94
+ }
95
+
96
+ constexpr bool operator<(complex64_t a, complex64_t b) {
97
+ return operator>(b, a);
98
+ }
99
+
100
+ constexpr bool operator==(complex64_t a, complex64_t b) {
101
+ return a.real == b.real && a.imag == b.imag;
102
+ }
103
+
104
+ constexpr complex64_t operator+(complex64_t a, complex64_t b) {
105
+ return {a.real + b.real, a.imag + b.imag};
106
+ }
107
+
108
+ constexpr thread complex64_t& operator+=(thread complex64_t& a, complex64_t b) {
109
+ a.real += b.real;
110
+ a.imag += b.imag;
111
+ return a;
112
+ }
113
+
114
+ constexpr threadgroup complex64_t& operator+=(
115
+ threadgroup complex64_t& a,
116
+ complex64_t b) {
117
+ a.real += b.real;
118
+ a.imag += b.imag;
119
+ return a;
120
+ }
121
+
122
+ constexpr device complex64_t& operator+=(device complex64_t& a, complex64_t b) {
123
+ a.real += b.real;
124
+ a.imag += b.imag;
125
+ return a;
126
+ }
127
+
128
+ constexpr complex64_t operator+(float a, complex64_t b) {
129
+ return {a + b.real, b.imag};
130
+ }
131
+ constexpr complex64_t operator+(complex64_t a, float b) {
132
+ return {a.real + b, a.imag};
133
+ }
134
+
135
+ constexpr complex64_t operator-(complex64_t a, complex64_t b) {
136
+ return {a.real - b.real, a.imag - b.imag};
137
+ }
138
+ constexpr complex64_t operator-(float a, complex64_t b) {
139
+ return {a - b.real, -b.imag};
140
+ }
141
+ constexpr complex64_t operator-(complex64_t a, float b) {
142
+ return {a.real - b, a.imag};
143
+ }
144
+
145
+ constexpr complex64_t operator*(complex64_t a, complex64_t b) {
146
+ return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
147
+ }
148
+
149
+ constexpr complex64_t operator/(complex64_t a, complex64_t b) {
150
+ auto denom = b.real * b.real + b.imag * b.imag;
151
+ auto x = a.real * b.real + a.imag * b.imag;
152
+ auto y = a.imag * b.real - a.real * b.imag;
153
+ return {x / denom, y / denom};
154
+ }
155
+
156
+ constexpr complex64_t operator/(float a, complex64_t b) {
157
+ auto denom = b.real * b.real + b.imag * b.imag;
158
+ auto x = a * b.real;
159
+ auto y = -a * b.imag;
160
+ return {x / denom, y / denom};
161
+ }
162
+
163
+ constexpr complex64_t operator%(complex64_t a, complex64_t b) {
164
+ auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
165
+ auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));
166
+ if (real != 0 && (real < 0 != b.real < 0)) {
167
+ real += b.real;
168
+ }
169
+ if (imag != 0 && (imag < 0 != b.imag < 0)) {
170
+ imag += b.imag;
171
+ }
172
+ return {real, imag};
173
+ }
@@ -0,0 +1,276 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ template <typename T, typename U, int N = WorkPerThread<U>::n>
4
+ [[kernel]] void copy_s(
5
+ device const T* src [[buffer(0)]],
6
+ device U* dst [[buffer(1)]],
7
+ constant uint& size,
8
+ uint index [[thread_position_in_grid]]) {
9
+ index *= N;
10
+ if (N > 1 && index + N > size) {
11
+ for (int i = 0; index + i < size; ++i) {
12
+ dst[index + i] = static_cast<U>(src[0]);
13
+ }
14
+ } else {
15
+ for (int i = 0; i < N; ++i) {
16
+ dst[index + i] = static_cast<U>(src[0]);
17
+ }
18
+ }
19
+ }
20
+
21
+ template <typename T, typename U, int N = WorkPerThread<U>::n>
22
+ [[kernel]] void copy_v(
23
+ device const T* src [[buffer(0)]],
24
+ device U* dst [[buffer(1)]],
25
+ constant uint& size,
26
+ uint index [[thread_position_in_grid]]) {
27
+ index *= N;
28
+ if (N > 1 && index + N > size) {
29
+ for (int i = 0; index + i < size; ++i) {
30
+ dst[index + i] = static_cast<U>(src[index + i]);
31
+ }
32
+ } else {
33
+ for (int i = 0; i < N; ++i) {
34
+ dst[index + i] = static_cast<U>(src[index + i]);
35
+ }
36
+ }
37
+ }
38
+
39
+ template <typename T, typename U, int N = WorkPerThread<U>::n>
40
+ [[kernel]] void copy_s2(
41
+ device const T* src [[buffer(0)]],
42
+ device U* dst [[buffer(1)]],
43
+ constant int64_t& size,
44
+ uint2 index [[thread_position_in_grid]],
45
+ uint2 grid_dim [[threads_per_grid]]) {
46
+ int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
47
+ if (N > 1 && offset + N > size) {
48
+ for (int i = 0; offset + i < size; ++i) {
49
+ dst[offset + i] = static_cast<U>(src[0]);
50
+ }
51
+ } else {
52
+ for (int i = 0; i < N; ++i) {
53
+ dst[offset + i] = static_cast<U>(src[0]);
54
+ }
55
+ }
56
+ }
57
+
58
+ template <typename T, typename U, int N = WorkPerThread<U>::n>
59
+ [[kernel]] void copy_v2(
60
+ device const T* src [[buffer(0)]],
61
+ device U* dst [[buffer(1)]],
62
+ constant int64_t& size,
63
+ uint2 index [[thread_position_in_grid]],
64
+ uint2 grid_dim [[threads_per_grid]]) {
65
+ int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
66
+ if (N > 1 && offset + N > size) {
67
+ for (int i = 0; offset + i < size; ++i) {
68
+ dst[offset + i] = static_cast<U>(src[offset + i]);
69
+ }
70
+ } else {
71
+ for (int i = 0; i < N; ++i) {
72
+ dst[offset + i] = static_cast<U>(src[offset + i]);
73
+ }
74
+ }
75
+ }
76
+
77
+ template <typename T, typename U, typename IdxT = int64_t>
78
+ [[kernel]] void copy_g_nd1(
79
+ device const T* src [[buffer(0)]],
80
+ device U* dst [[buffer(1)]],
81
+ constant const int64_t& src_stride [[buffer(3)]],
82
+ uint index [[thread_position_in_grid]]) {
83
+ auto src_idx = elem_to_loc_1<IdxT>(index, src_stride);
84
+ dst[index] = static_cast<U>(src[src_idx]);
85
+ }
86
+
87
+ template <typename T, typename U, typename IdxT = int64_t>
88
+ [[kernel]] void copy_g_nd2(
89
+ device const T* src [[buffer(0)]],
90
+ device U* dst [[buffer(1)]],
91
+ constant const int64_t* src_strides [[buffer(3)]],
92
+ uint2 index [[thread_position_in_grid]],
93
+ uint2 grid_dim [[threads_per_grid]]) {
94
+ auto src_idx = elem_to_loc_2<IdxT>(index, src_strides);
95
+ IdxT dst_idx = index.x + IdxT(grid_dim.x) * index.y;
96
+ dst[dst_idx] = static_cast<U>(src[src_idx]);
97
+ }
98
+
99
+ template <typename T, typename U, typename IdxT = int64_t>
100
+ [[kernel]] void copy_g_nd3(
101
+ device const T* src [[buffer(0)]],
102
+ device U* dst [[buffer(1)]],
103
+ constant const int64_t* src_strides [[buffer(3)]],
104
+ uint3 index [[thread_position_in_grid]],
105
+ uint3 grid_dim [[threads_per_grid]]) {
106
+ auto src_idx = elem_to_loc_3<IdxT>(index, src_strides);
107
+ IdxT dst_idx =
108
+ index.x + IdxT(grid_dim.x) * (index.y + IdxT(grid_dim.y) * index.z);
109
+ dst[dst_idx] = static_cast<U>(src[src_idx]);
110
+ }
111
+
112
+ template <typename T, typename U, int N = 1, typename IdxT = int64_t>
113
+ [[kernel]] void copy_g(
114
+ device const T* src [[buffer(0)]],
115
+ device U* dst [[buffer(1)]],
116
+ constant const int* src_shape [[buffer(2)]],
117
+ constant const int64_t* src_strides [[buffer(3)]],
118
+ constant const int& ndim [[buffer(5)]],
119
+ uint3 index [[thread_position_in_grid]],
120
+ uint3 grid_dim [[threads_per_grid]]) {
121
+ auto src_idx = elem_to_loc<IdxT>(
122
+ {N * index.x, index.y, index.z}, src_shape, src_strides, ndim);
123
+ if (N == 1) {
124
+ IdxT dst_idx =
125
+ index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
126
+ dst[dst_idx] = static_cast<U>(src[src_idx]);
127
+ return;
128
+ }
129
+ auto xshape = src_shape[ndim - 1];
130
+ IdxT dst_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
131
+ auto src_xstride = src_strides[ndim - 1];
132
+ for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
133
+ dst[dst_idx + i] = static_cast<U>(src[src_idx]);
134
+ src_idx += src_xstride;
135
+ }
136
+ }
137
+
138
+ template <typename T, typename U, typename IdxT = int64_t>
139
+ [[kernel]] void copy_gg_nd1(
140
+ device const T* src [[buffer(0)]],
141
+ device U* dst [[buffer(1)]],
142
+ constant const int64_t& src_stride [[buffer(3)]],
143
+ constant const int64_t& dst_stride [[buffer(4)]],
144
+ uint index [[thread_position_in_grid]]) {
145
+ auto src_idx = elem_to_loc_1<IdxT>(index, src_stride);
146
+ auto dst_idx = elem_to_loc_1<IdxT>(index, dst_stride);
147
+ dst[dst_idx] = static_cast<U>(src[src_idx]);
148
+ }
149
+
150
+ template <typename T, typename U, typename IdxT = int64_t>
151
+ [[kernel]] void copy_gg_nd2(
152
+ device const T* src [[buffer(0)]],
153
+ device U* dst [[buffer(1)]],
154
+ constant const int64_t* src_strides [[buffer(3)]],
155
+ constant const int64_t* dst_strides [[buffer(4)]],
156
+ uint2 index [[thread_position_in_grid]]) {
157
+ auto src_idx = elem_to_loc_2<IdxT>(index, src_strides);
158
+ auto dst_idx = elem_to_loc_2<IdxT>(index, dst_strides);
159
+ dst[dst_idx] = static_cast<U>(src[src_idx]);
160
+ }
161
+
162
+ template <typename T, typename U, typename IdxT = int64_t>
163
+ [[kernel]] void copy_gg_nd3(
164
+ device const T* src [[buffer(0)]],
165
+ device U* dst [[buffer(1)]],
166
+ constant const int64_t* src_strides [[buffer(3)]],
167
+ constant const int64_t* dst_strides [[buffer(4)]],
168
+ uint3 index [[thread_position_in_grid]]) {
169
+ auto src_idx = elem_to_loc_3<IdxT>(index, src_strides);
170
+ auto dst_idx = elem_to_loc_3<IdxT>(index, dst_strides);
171
+ dst[dst_idx] = static_cast<U>(src[src_idx]);
172
+ }
173
+
174
+ template <typename T, typename U, int N = 1, typename IdxT = int64_t>
175
+ [[kernel]] void copy_gg(
176
+ device const T* src [[buffer(0)]],
177
+ device U* dst [[buffer(1)]],
178
+ constant const int* src_shape [[buffer(2)]],
179
+ constant const int64_t* src_strides [[buffer(3)]],
180
+ constant const int64_t* dst_strides [[buffer(4)]],
181
+ constant const int& ndim [[buffer(5)]],
182
+ uint3 index [[thread_position_in_grid]]) {
183
+ auto idx = elem_to_loc_2_nd<IdxT>(
184
+ {N * index.x, index.y, index.z},
185
+ src_shape,
186
+ src_strides,
187
+ dst_strides,
188
+ ndim);
189
+ if (N == 1) {
190
+ dst[idx.y] = static_cast<U>(src[idx.x]);
191
+ return;
192
+ }
193
+ IdxT src_xstride = src_strides[ndim - 1];
194
+ IdxT dst_xstride = dst_strides[ndim - 1];
195
+ auto xshape = src_shape[ndim - 1];
196
+ for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
197
+ dst[idx.y] = static_cast<U>(src[idx.x]);
198
+ idx.x += src_xstride;
199
+ idx.y += dst_xstride;
200
+ }
201
+ }
202
+
203
+ template <typename T, typename U, typename IdxT = int64_t>
204
+ [[kernel]] void copy_gg_dynamic_nd1(
205
+ device const T* src [[buffer(0)]],
206
+ device U* dst [[buffer(1)]],
207
+ constant const int64_t& src_stride [[buffer(3)]],
208
+ constant const int64_t& dst_stride [[buffer(4)]],
209
+ constant const int64_t& src_offset [[buffer(6)]],
210
+ constant const int64_t& dst_offset [[buffer(7)]],
211
+ uint index [[thread_position_in_grid]]) {
212
+ auto src_idx = elem_to_loc_1<IdxT>(index, src_stride);
213
+ auto dst_idx = elem_to_loc_1<IdxT>(index, dst_stride);
214
+ dst[dst_idx + dst_offset] = src[src_idx + src_offset];
215
+ }
216
+
217
+ template <typename T, typename U, typename IdxT = int64_t>
218
+ [[kernel]] void copy_gg_dynamic_nd2(
219
+ device const T* src [[buffer(0)]],
220
+ device U* dst [[buffer(1)]],
221
+ constant const int64_t* src_strides [[buffer(3)]],
222
+ constant const int64_t* dst_strides [[buffer(4)]],
223
+ constant const int64_t& src_offset [[buffer(6)]],
224
+ constant const int64_t& dst_offset [[buffer(7)]],
225
+ uint2 index [[thread_position_in_grid]]) {
226
+ auto src_idx = elem_to_loc_2<IdxT>(index, src_strides);
227
+ auto dst_idx = elem_to_loc_2<IdxT>(index, dst_strides);
228
+ dst[dst_idx + dst_offset] = src[src_idx + src_offset];
229
+ }
230
+
231
+ template <typename T, typename U, typename IdxT = int64_t>
232
+ [[kernel]] void copy_gg_dynamic_nd3(
233
+ device const T* src [[buffer(0)]],
234
+ device U* dst [[buffer(1)]],
235
+ constant const int64_t* src_strides [[buffer(3)]],
236
+ constant const int64_t* dst_strides [[buffer(4)]],
237
+ constant const int64_t& src_offset [[buffer(6)]],
238
+ constant const int64_t& dst_offset [[buffer(7)]],
239
+ uint3 index [[thread_position_in_grid]]) {
240
+ auto src_idx = elem_to_loc_3<IdxT>(index, src_strides);
241
+ auto dst_idx = elem_to_loc_3<IdxT>(index, dst_strides);
242
+ dst[dst_idx + dst_offset] = src[src_idx + src_offset];
243
+ }
244
+
245
+ template <typename T, typename U, int N = 1, typename IdxT = int64_t>
246
+ [[kernel]] void copy_gg_dynamic(
247
+ device const T* src [[buffer(0)]],
248
+ device U* dst [[buffer(1)]],
249
+ constant const int* src_shape [[buffer(2)]],
250
+ constant const int64_t* src_strides [[buffer(3)]],
251
+ constant const int64_t* dst_strides [[buffer(4)]],
252
+ constant const int& ndim [[buffer(5)]],
253
+ constant const int64_t& src_offset [[buffer(6)]],
254
+ constant const int64_t& dst_offset [[buffer(7)]],
255
+ uint3 index [[thread_position_in_grid]]) {
256
+ src += src_offset;
257
+ dst += dst_offset;
258
+ auto idx = elem_to_loc_2_nd<IdxT>(
259
+ {N * index.x, index.y, index.z},
260
+ src_shape,
261
+ src_strides,
262
+ dst_strides,
263
+ ndim);
264
+ if (N == 1) {
265
+ dst[idx.y] = src[idx.x];
266
+ return;
267
+ }
268
+ IdxT src_xstride = src_strides[ndim - 1];
269
+ IdxT dst_xstride = dst_strides[ndim - 1];
270
+ auto xshape = src_shape[ndim - 1];
271
+ for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
272
+ dst[idx.y] = src[idx.x];
273
+ idx.x += src_xstride;
274
+ idx.y += dst_xstride;
275
+ }
276
+ }
@@ -0,0 +1,24 @@
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #if defined __METAL__ || defined MLX_METAL_JIT
6
+ #define MTL_CONST constant
7
+ #else
8
+ #define MTL_CONST
9
+ #endif
10
+
11
+ static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
12
+ static MTL_CONST constexpr int REDUCE_N_READS = 4;
13
+ static MTL_CONST constexpr int REDUCE_N_WRITES = 4;
14
+ static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
15
+ static MTL_CONST constexpr int RMS_N_READS = 4;
16
+ static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096;
17
+
18
+ // Instantiate a templated kernel.
19
+ // Extra args are used as template parameters:
20
+ // e.g. instantiate_kernel(binary_int, binary, a, b) ->
21
+ // [[host_name(binary_int)]] [kernel] binary<a, b>
22
+ #define instantiate_kernel(name, func, ...) \
23
+ template [[host_name( \
24
+ name)]] [[kernel]] decltype(func<__VA_ARGS__>) func<__VA_ARGS__>;
@@ -0,0 +1,69 @@
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+ #include <metal_math>
5
+
6
+ /*
7
+ * Approximation to the error function.
8
+ * Based on code from:
9
+ * https://stackoverflow.com/questions/35148198/efficient-faithfully-rounded-implementation-of-error-function-erff#answer-35148199
10
+ */
11
+ float erf(float a) {
12
+ float r, s, t, u;
13
+ t = metal::abs(a);
14
+ s = a * a;
15
+ if (t > 0.927734375f) {
16
+ // maximum error 0.99527 ulp
17
+ r = metal::fma(
18
+ -1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12
19
+ u = metal::fma(
20
+ -3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6
21
+ r = metal::fma(r, s, u);
22
+ r = metal::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4
23
+ r = metal::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1
24
+ r = metal::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3
25
+ r = metal::fma(r, t, -t);
26
+ // TODO, replace with expm1 when implemented
27
+ r = 1.0f - metal::exp(r);
28
+ r = metal::copysign(r, a);
29
+ } else {
30
+ // maximum error 0.98929 ulp
31
+ r = -5.96761703e-4f; // -0x1.38e000p-11
32
+ r = metal::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8
33
+ r = metal::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6
34
+ r = metal::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4
35
+ r = metal::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2
36
+ r = metal::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3
37
+ r = metal::fma(r, a, a);
38
+ }
39
+ return r;
40
+ }
41
+
42
+ float erfinv(float a) {
43
+ auto t = metal::fma(a, 0.0f - a, 1.0f);
44
+ t = metal::log(t);
45
+ float p;
46
+ if (metal::abs(t) > 6.125f) { // maximum ulp error = 2.35793
47
+ p = 3.03697567e-10f; // 0x1.4deb44p-32
48
+ p = metal::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
49
+ p = metal::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
50
+ p = metal::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
51
+ p = metal::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
52
+ p = metal::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
53
+ p = metal::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
54
+ p = metal::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
55
+ p = metal::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
56
+ } else { // maximum ulp error = 2.35002
57
+ p = 5.43877832e-9f; // 0x1.75c000p-28
58
+ p = metal::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
59
+ p = metal::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
60
+ p = metal::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
61
+ p = metal::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
62
+ p = metal::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
63
+ p = metal::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
64
+ p = metal::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
65
+ p = metal::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
66
+ p = metal::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
67
+ }
68
+ return a * p;
69
+ }
@@ -0,0 +1,90 @@
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <metal_math>
6
+
7
+ // Original license copied below:
8
+ // Copyright (c) 2015-2023 Norbert Juffa
9
+ // All rights reserved.
10
+ //
11
+ // Redistribution and use in source and binary forms, with or without
12
+ // modification, are permitted provided that the following conditions
13
+ // are met:
14
+ //
15
+ // 1. Redistributions of source code must retain the above copyright
16
+ // notice, this list of conditions and the following disclaimer.
17
+ //
18
+ // 2. Redistributions in binary form must reproduce the above copyright
19
+ // notice, this list of conditions and the following disclaimer in the
20
+ // documentation and/or other materials provided with the distribution.
21
+ //
22
+ // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23
+ // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24
+ // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
25
+ // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
26
+ // HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
27
+ // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
28
+ // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
29
+ // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
30
+ // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31
+ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32
+ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33
+
34
+ /* Compute exponential base e minus 1. Maximum ulp error = 0.997458
35
+
36
+ i = rint(a/log(2)), f = a-i*log(2). Then expm1(a) = 2**i * (expm1(f)+1) - 1.
37
+ Compute r = expm1(f). Then expm1(a)= 2 * (0.5 * 2**i * r + 0.5 * 2**i - 0.5).
38
+ With t = 0.5*2**i, expm1(a) = 2*(r * t + t-0.5). However, for best accuracy,
39
+ when i == 1, expm1(a)= 2*(r + 0.5), and when i == 0, expm1(a) = r.
40
+
41
+ NOTE: Scale factor b is only applied if i < 0 or i > 1 (should be power of 2)
42
+ */
43
+ float expm1f_scaled_unchecked(float a, float b) {
44
+ float f, j, r, s, t, u, v, x, y;
45
+ int i;
46
+
47
+ // exp(a) = 2**i * exp(f); i = rintf (a / log(2))
48
+ j = fma(1.442695f, a, 12582912.f); // 0x1.715476p0, 0x1.8p23
49
+ j = j - 12582912.0f; // 0x1.8p23
50
+ i = (int)j;
51
+ f = fma(j, -6.93145752e-1f, a);
52
+
53
+ // approximate r = exp(f)-1 on interval [-log(2)/2, +log(2)/2]
54
+ s = f * f;
55
+ if (a == 0.0f)
56
+ s = a; // ensure -0 is passed through
57
+ // err = 0.997458 ulp1 = 11081805
58
+ r = 1.97350979e-4f; // 0x1.9de000p-13
59
+ r = fma(r, f, 1.39309070e-3f); // 0x1.6d30bcp-10
60
+ r = fma(r, f, 8.33343994e-3f); // 0x1.1111f6p-7
61
+ r = fma(r, f, 4.16668020e-2f); // 0x1.55559ep-5
62
+ r = fma(r, f, 1.66666716e-1f); // 0x1.55555cp-3
63
+ r = fma(r, f, 4.99999970e-1f); // 0x1.fffffep-2
64
+ u = (j == 1) ? (f + 0.5f) : f;
65
+ v = fma(r, s, u);
66
+ s = 0.5f * b;
67
+ t = ldexp(s, i);
68
+ y = t - s;
69
+ x = (t - y) - s; // double-float canonicalization of difference
70
+ r = fma(v, t, x) + y;
71
+ r = r + r;
72
+ if (j == 0)
73
+ r = v;
74
+ if (j == 1)
75
+ r = v + v;
76
+ return r;
77
+ }
78
+
79
+ /* Compute exponential base e minus 1. max ulp err = 0.99746 */
80
+ float expm1f(float a) {
81
+ float r;
82
+
83
+ r = expm1f_scaled_unchecked(a, 1.0f);
84
+ /* handle severe overflow and underflow */
85
+ if (abs(a - 1.0f) > 88.0f) {
86
+ r = pow(2, a);
87
+ r = fma(r, r, -1.0f);
88
+ }
89
+ return r;
90
+ }