mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/backend/cpu/simd/type.h"
|
|
6
|
+
|
|
7
|
+
namespace mlx::core::simd {
|
|
8
|
+
|
|
9
|
+
constexpr float inf = std::numeric_limits<float>::infinity();
|
|
10
|
+
|
|
11
|
+
/**
|
|
12
|
+
* Compute exp(x) in an optimizer friendly way as follows:
|
|
13
|
+
*
|
|
14
|
+
* First change the problem to computing 2**y where y = x / ln(2).
|
|
15
|
+
*
|
|
16
|
+
* Now we will compute 2**y as 2**y1 * 2**y2 where y1 is the integer part
|
|
17
|
+
* `ipart` and y2 is fractional part. For the integer part we perform bit
|
|
18
|
+
* shifting and for the fractional part we use a polynomial approximation.
|
|
19
|
+
*
|
|
20
|
+
* The algorithm and constants of the polynomial taken from
|
|
21
|
+
* https://github.com/akohlmey/fastermath/blob/master/src/exp.c which took them
|
|
22
|
+
* from Cephes math library.
|
|
23
|
+
*
|
|
24
|
+
* Note: The implementation below is a general fast exp. There could be faster
|
|
25
|
+
* implementations for numbers strictly < 0.
|
|
26
|
+
*/
|
|
27
|
+
template <typename T, int N>
|
|
28
|
+
Simd<T, N> exp(Simd<T, N> in) {
|
|
29
|
+
if constexpr (is_complex<T>) {
|
|
30
|
+
return Simd<T, 1>{std::exp(in.value)};
|
|
31
|
+
} else {
|
|
32
|
+
Simd<float, N> x_init = in;
|
|
33
|
+
auto x = x_init * 1.442695f; // multiply with log_2(e)
|
|
34
|
+
Simd<float, N> ipart, fpart;
|
|
35
|
+
ipart = floor(x + 0.5);
|
|
36
|
+
fpart = x - ipart;
|
|
37
|
+
|
|
38
|
+
x = 1.535336188319500e-4f;
|
|
39
|
+
x = fma(x, fpart, 1.339887440266574e-3f);
|
|
40
|
+
x = fma(x, fpart, 9.618437357674640e-3f);
|
|
41
|
+
x = fma(x, fpart, 5.550332471162809e-2f);
|
|
42
|
+
x = fma(x, fpart, 2.402264791363012e-1f);
|
|
43
|
+
x = fma(x, fpart, 6.931472028550421e-1f);
|
|
44
|
+
x = fma(x, fpart, 1.000000000000000f);
|
|
45
|
+
|
|
46
|
+
// generate 2**ipart in the floating point representation using integer
|
|
47
|
+
// bitshifting
|
|
48
|
+
Simd<int, N> epart = (Simd<int, N>(ipart) + 127) << 23;
|
|
49
|
+
|
|
50
|
+
// Deal with NaN and Inf
|
|
51
|
+
auto result = select(isnan(x_init), x_init, (*(Simd<float, N>*)&epart) * x);
|
|
52
|
+
result = select(x_init > 88.0f, Simd<float, N>(inf), result);
|
|
53
|
+
result = select(x_init < -88.0f, Simd<float, N>(0), result);
|
|
54
|
+
return Simd<T, N>(result);
|
|
55
|
+
}
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
/* Implementation from:
|
|
59
|
+
* https://github.com/JishinMaster/simd_utils/blob/3c1433a86fb38edcc9b02039f3c9a65b16640976/neon_mathfun.h#L357
|
|
60
|
+
* which originally came from the Cephes math library.
|
|
61
|
+
*/
|
|
62
|
+
template <bool Sine, typename T, int N>
|
|
63
|
+
Simd<T, N> sincos(Simd<T, N> in) {
|
|
64
|
+
auto sign_mask_sin = in < 0;
|
|
65
|
+
in = abs(in);
|
|
66
|
+
Simd<float, N> x = in;
|
|
67
|
+
|
|
68
|
+
// scale by 4/Pi
|
|
69
|
+
auto y = x * 1.27323954473516f;
|
|
70
|
+
|
|
71
|
+
// store the integer part of y in mm0
|
|
72
|
+
Simd<uint32_t, N> emm2 = y;
|
|
73
|
+
|
|
74
|
+
// j=(j+1) & (~1) (see the cephes sources)
|
|
75
|
+
emm2 = emm2 + 1;
|
|
76
|
+
emm2 = emm2 & ~1;
|
|
77
|
+
|
|
78
|
+
y = emm2;
|
|
79
|
+
|
|
80
|
+
// Get the polynom selection mask. There is one polynom for 0 <= x <= Pi/4
|
|
81
|
+
// and another one for Pi/4<x<=Pi/2. Both branches will be computed.
|
|
82
|
+
auto poly_mask = (emm2 & 2) != 0;
|
|
83
|
+
|
|
84
|
+
// The magic pass: "Extended precision modular arithmetic"
|
|
85
|
+
// x = ((x - y * DP1) - y * DP2) - y * DP3
|
|
86
|
+
x = fma(y, Simd<float, N>(-0.78515625f), x);
|
|
87
|
+
x = fma(y, Simd<float, N>(-2.4187564849853515625e-4f), x);
|
|
88
|
+
x = fma(y, Simd<float, N>(-3.77489497744594108e-8f), x);
|
|
89
|
+
|
|
90
|
+
sign_mask_sin = sign_mask_sin ^ ((emm2 & 4) != 0);
|
|
91
|
+
auto sign_mask_cos = ((emm2 - 2) & 4) != 0;
|
|
92
|
+
|
|
93
|
+
// Evaluate the first polynom (0 <= x <= Pi/4) in y1,
|
|
94
|
+
// and the second polynom (Pi/4 <= x <= 0) in y2
|
|
95
|
+
auto z = x * x;
|
|
96
|
+
|
|
97
|
+
auto y1 =
|
|
98
|
+
fma(z, Simd<float, N>(2.443315711809948e-5f), -1.388731625493765e-3f);
|
|
99
|
+
auto y2 = fma(z, Simd<float, N>(-1.9515295891e-4f), 8.3321608736e-3f);
|
|
100
|
+
y1 = fma(y1, z, 4.166664568298827e-2f);
|
|
101
|
+
y2 = fma(y2, z, -1.6666654611e-1f);
|
|
102
|
+
y1 = y1 * z;
|
|
103
|
+
y2 = y2 * z;
|
|
104
|
+
y1 = y1 * z;
|
|
105
|
+
y2 = fma(x, y2, x);
|
|
106
|
+
y1 = fma(z, Simd<float, N>(-0.5f), y1);
|
|
107
|
+
y1 = y1 + 1.0f;
|
|
108
|
+
|
|
109
|
+
if constexpr (Sine) {
|
|
110
|
+
auto ys = select(poly_mask, y1, y2);
|
|
111
|
+
return select(sign_mask_sin, -ys, ys);
|
|
112
|
+
} else {
|
|
113
|
+
auto yc = select(poly_mask, y2, y1);
|
|
114
|
+
return select(sign_mask_cos, yc, -yc);
|
|
115
|
+
}
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
template <typename T, int N>
|
|
119
|
+
Simd<T, N> sin(Simd<T, N> x) {
|
|
120
|
+
if constexpr (is_complex<T>) {
|
|
121
|
+
return std::sin(x.value);
|
|
122
|
+
} else {
|
|
123
|
+
return sincos<true>(x);
|
|
124
|
+
}
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
template <typename T, int N>
|
|
128
|
+
Simd<T, N> cos(Simd<T, N> x) {
|
|
129
|
+
if constexpr (is_complex<T>) {
|
|
130
|
+
return std::cos(x.value);
|
|
131
|
+
} else {
|
|
132
|
+
return sincos<false>(x);
|
|
133
|
+
}
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
template <typename T, int N>
|
|
137
|
+
Simd<T, N> erf(Simd<T, N> x) {
|
|
138
|
+
// https://github.com/pytorch/pytorch/blob/abf28982a8cb43342e7669d859de9543fd804cc9/aten/src/ATen/cpu/vec/vec256/vec256_float.h#L175
|
|
139
|
+
Simd<float, N> v = x;
|
|
140
|
+
auto t = recip(fma(Simd<float, N>(0.3275911f), abs(v), 1.0f));
|
|
141
|
+
auto r = fma(Simd<float, N>(1.061405429f), t, -1.453152027f);
|
|
142
|
+
r = fma(r, t, 1.421413741f);
|
|
143
|
+
r = fma(r, t, -0.284496736f);
|
|
144
|
+
r = fma(r, t, 0.254829592f);
|
|
145
|
+
auto e = -exp(-v * v);
|
|
146
|
+
auto result = Simd<T, N>(fma(e * t, r, 1.0f));
|
|
147
|
+
return select(x > 0, result, -result);
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
template <typename T, int N>
|
|
151
|
+
Simd<T, N> erfinv(Simd<T, N> a_) {
|
|
152
|
+
Simd<float, N> a = a_;
|
|
153
|
+
auto t = fma(a, 0.0f - a, 1.0f);
|
|
154
|
+
t = log(t);
|
|
155
|
+
auto lhs = [](auto t) {
|
|
156
|
+
Simd<float, N> p;
|
|
157
|
+
p = 3.03697567e-10f; // 0x1.4deb44p-32
|
|
158
|
+
p = fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
|
|
159
|
+
p = fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
|
|
160
|
+
p = fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
|
|
161
|
+
p = fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
|
|
162
|
+
p = fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
|
|
163
|
+
p = fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
|
|
164
|
+
p = fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
|
|
165
|
+
return fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
|
|
166
|
+
};
|
|
167
|
+
auto rhs = [](auto t) {
|
|
168
|
+
Simd<float, N> p;
|
|
169
|
+
p = 5.43877832e-9f; // 0x1.75c000p-28
|
|
170
|
+
p = fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
|
|
171
|
+
p = fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
|
|
172
|
+
p = fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
|
|
173
|
+
p = fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
|
|
174
|
+
p = fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
|
|
175
|
+
p = fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
|
|
176
|
+
p = fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
|
|
177
|
+
p = fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
|
|
178
|
+
return fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
|
|
179
|
+
};
|
|
180
|
+
auto thresh = 6.125f;
|
|
181
|
+
// Compute both branches and select if N > 1
|
|
182
|
+
if constexpr (N == 1) {
|
|
183
|
+
if ((abs(t) > thresh).value) { // maximum ulp error = 2.35793
|
|
184
|
+
return a * lhs(t);
|
|
185
|
+
} else { // maximum ulp error = 2.35002
|
|
186
|
+
return a * rhs(t);
|
|
187
|
+
}
|
|
188
|
+
} else {
|
|
189
|
+
return a * select(abs(t) > thresh, lhs(t), rhs(t));
|
|
190
|
+
}
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
} // namespace mlx::core::simd
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include <arm_neon.h>
|
|
4
|
+
|
|
5
|
+
#include "mlx/backend/cpu/simd/base_simd.h"
|
|
6
|
+
|
|
7
|
+
namespace mlx::core::simd {
|
|
8
|
+
|
|
9
|
+
constexpr int N = 8;
|
|
10
|
+
|
|
11
|
+
template <>
|
|
12
|
+
struct Simd<float16_t, N> {
|
|
13
|
+
static constexpr int size = N;
|
|
14
|
+
using scalar_t = float16_t;
|
|
15
|
+
|
|
16
|
+
Simd<float16_t, N>() {}
|
|
17
|
+
|
|
18
|
+
template <typename U>
|
|
19
|
+
Simd<float16_t, N>(U v) : value(vdupq_n_f16(v)){};
|
|
20
|
+
|
|
21
|
+
Simd<float16_t, N>(float16x8_t v) : value(v){};
|
|
22
|
+
|
|
23
|
+
Simd<float16_t, N>(Simd<float, N> other) {
|
|
24
|
+
auto f32x4_a = *(float32x4_t*)(&other);
|
|
25
|
+
auto f32x4_b = *((float32x4_t*)(&other) + 1);
|
|
26
|
+
value = vcvt_high_f16_f32(vcvt_f16_f32(f32x4_a), f32x4_b);
|
|
27
|
+
};
|
|
28
|
+
|
|
29
|
+
Simd<float16_t, N>(Simd<uint16_t, N> other) {
|
|
30
|
+
value = vcvtq_f16_u16(*(uint16x8_t*)(&other.value));
|
|
31
|
+
};
|
|
32
|
+
|
|
33
|
+
operator Simd<int16_t, N>() {
|
|
34
|
+
auto v = vcvtq_s16_f16(value);
|
|
35
|
+
return load<int16_t, N>((int16_t*)&v);
|
|
36
|
+
};
|
|
37
|
+
|
|
38
|
+
operator Simd<float, N>() {
|
|
39
|
+
float32x4x2_t v;
|
|
40
|
+
v.val[0] = vcvt_f32_f16(*(float16x4_t*)(&value));
|
|
41
|
+
v.val[1] = vcvt_high_f32_f16(value);
|
|
42
|
+
return load<float, N>((float*)&v);
|
|
43
|
+
}
|
|
44
|
+
float16_t operator[](int idx) const {
|
|
45
|
+
return reinterpret_cast<const float16_t*>(&value)[idx];
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
float16_t& operator[](int idx) {
|
|
49
|
+
return reinterpret_cast<float16_t*>(&value)[idx];
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
float16x8_t value;
|
|
53
|
+
};
|
|
54
|
+
|
|
55
|
+
#define DEFINE_NEON_UNARY_OP(name, op) \
|
|
56
|
+
inline Simd<float16_t, N> name(Simd<float16_t, N> a) { \
|
|
57
|
+
return Simd<float16_t, N>{op(a.value)}; \
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
DEFINE_NEON_UNARY_OP(abs, vabsq_f16)
|
|
61
|
+
DEFINE_NEON_UNARY_OP(ceil, vrndpq_f16)
|
|
62
|
+
DEFINE_NEON_UNARY_OP(floor, vrndmq_f16)
|
|
63
|
+
DEFINE_NEON_UNARY_OP(sqrt, vsqrtq_f16)
|
|
64
|
+
DEFINE_NEON_UNARY_OP(rsqrt, vrsqrteq_f16)
|
|
65
|
+
DEFINE_NEON_UNARY_OP(recip, vrecpeq_f16)
|
|
66
|
+
DEFINE_NEON_UNARY_OP(rint, vrndnq_f16)
|
|
67
|
+
|
|
68
|
+
#define DEFINE_NEON_BINARY_OP(name, op) \
|
|
69
|
+
inline Simd<float16_t, N> name(Simd<float16_t, N> a, Simd<float16_t, N> b) { \
|
|
70
|
+
return op(a.value, b.value); \
|
|
71
|
+
} \
|
|
72
|
+
template <typename T> \
|
|
73
|
+
Simd<float16_t, N> name(Simd<float16_t, N> a, T b) { \
|
|
74
|
+
return op(a.value, Simd<float16_t, N>(b).value); \
|
|
75
|
+
} \
|
|
76
|
+
template <typename T> \
|
|
77
|
+
Simd<float16_t, N> name(T a, Simd<float16_t, N> b) { \
|
|
78
|
+
return op(Simd<float16_t, N>(a).value, b.value); \
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
inline Simd<float16_t, N> operator!(Simd<float16_t, N> v) {
|
|
82
|
+
auto out = vceqzq_f16(v.value);
|
|
83
|
+
return Simd<uint16_t, N>(*(uint16_t*)&out);
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
inline Simd<float16_t, N> operator-(Simd<float16_t, N> v) {
|
|
87
|
+
return vnegq_f16(v.value);
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
DEFINE_NEON_BINARY_OP(maximum, vmaxq_f16)
|
|
91
|
+
DEFINE_NEON_BINARY_OP(minimum, vminq_f16)
|
|
92
|
+
DEFINE_NEON_BINARY_OP(operator+, vaddq_f16)
|
|
93
|
+
DEFINE_NEON_BINARY_OP(operator-, vsubq_f16)
|
|
94
|
+
DEFINE_NEON_BINARY_OP(operator*, vmulq_f16)
|
|
95
|
+
DEFINE_NEON_BINARY_OP(operator/, vdivq_f16)
|
|
96
|
+
|
|
97
|
+
#define DEFINE_NEON_COMPARISON(Op, op) \
|
|
98
|
+
template <typename T> \
|
|
99
|
+
Simd<bool, N> operator Op(Simd<float16_t, N> a, T b) { \
|
|
100
|
+
auto out = op(a.value, Simd<float16_t, N>(b).value); \
|
|
101
|
+
return Simd<uint16_t, N>(*(uint16_t*)(&out)); \
|
|
102
|
+
} \
|
|
103
|
+
template <typename T> \
|
|
104
|
+
Simd<bool, N> operator Op(T a, Simd<float16_t, N> b) { \
|
|
105
|
+
auto out = op(Simd<float16_t, N>(a).value, b.value); \
|
|
106
|
+
return Simd<uint16_t, N>(*(uint16_t*)(&out)); \
|
|
107
|
+
} \
|
|
108
|
+
inline Simd<bool, N> operator Op( \
|
|
109
|
+
Simd<float16_t, N> a, Simd<float16_t, N> b) { \
|
|
110
|
+
auto out = op(a.value, b.value); \
|
|
111
|
+
return Simd<uint16_t, N>(*(uint16_t*)(&out)); \
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
DEFINE_NEON_COMPARISON(==, vceqq_f16)
|
|
115
|
+
DEFINE_NEON_COMPARISON(>=, vcgeq_f16)
|
|
116
|
+
DEFINE_NEON_COMPARISON(<=, vcleq_f16)
|
|
117
|
+
DEFINE_NEON_COMPARISON(>, vcgtq_f16)
|
|
118
|
+
DEFINE_NEON_COMPARISON(<, vcltq_f16)
|
|
119
|
+
|
|
120
|
+
template <typename T>
|
|
121
|
+
Simd<bool, N> operator!=(Simd<float16_t, N> a, T b) {
|
|
122
|
+
return !(a == b);
|
|
123
|
+
}
|
|
124
|
+
template <typename T>
|
|
125
|
+
Simd<bool, N> operator!=(T a, Simd<float16_t, N> b) {
|
|
126
|
+
return !(a == b);
|
|
127
|
+
}
|
|
128
|
+
inline Simd<bool, N> operator!=(Simd<float16_t, N> a, Simd<float16_t, N> b) {
|
|
129
|
+
return !(a == b);
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
inline Simd<float16_t, N> operator||(
|
|
133
|
+
Simd<float16_t, N> a,
|
|
134
|
+
Simd<float16_t, N> b) {
|
|
135
|
+
return Simd<uint16_t, N>((a != 0) || (b != 0));
|
|
136
|
+
}
|
|
137
|
+
template <typename T>
|
|
138
|
+
Simd<float16_t, N> operator||(Simd<float16_t, N> a, T b) {
|
|
139
|
+
return Simd<uint16_t, N>((a != 0) || (b != 0));
|
|
140
|
+
}
|
|
141
|
+
template <typename T>
|
|
142
|
+
Simd<float16_t, N> operator||(T a, Simd<float16_t, N> b) {
|
|
143
|
+
return Simd<uint16_t, N>((a != 0) || (b != 0));
|
|
144
|
+
}
|
|
145
|
+
inline Simd<float16_t, N> operator&&(
|
|
146
|
+
Simd<float16_t, N> a,
|
|
147
|
+
Simd<float16_t, N> b) {
|
|
148
|
+
return Simd<uint16_t, N>((a != 0) && (b != 0));
|
|
149
|
+
}
|
|
150
|
+
template <typename T>
|
|
151
|
+
Simd<float16_t, N> operator&&(Simd<float16_t, N> a, T b) {
|
|
152
|
+
return Simd<uint16_t, N>((a != 0) && (b != 0));
|
|
153
|
+
}
|
|
154
|
+
template <typename T>
|
|
155
|
+
Simd<float16_t, N> operator&&(T a, Simd<float16_t, N> b) {
|
|
156
|
+
return Simd<uint16_t, N>((a != 0) && (b != 0));
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
template <>
|
|
160
|
+
inline Simd<bool, N> isnan(Simd<float16_t, N> v) {
|
|
161
|
+
return v != v;
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
template <>
|
|
165
|
+
inline Simd<float16_t, N>
|
|
166
|
+
clamp(Simd<float16_t, N> v, Simd<float16_t, N> min, Simd<float16_t, N> max) {
|
|
167
|
+
return minimum(maximum(v, min), max);
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
template <typename T>
|
|
171
|
+
Simd<float16_t, N> fma(Simd<float16_t, N> x, Simd<float16_t, N> y, T z) {
|
|
172
|
+
return vfmaq_f16(x.value, y.value, Simd<float16_t, N>(z).value);
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
template <typename MaskT>
|
|
176
|
+
Simd<float16_t, N>
|
|
177
|
+
select(Simd<MaskT, N> mask, Simd<float16_t, N> x, Simd<float16_t, N> y) {
|
|
178
|
+
return vbslq_f16(Simd<uint16_t, N>(mask).value, x.value, y.value);
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
// Reductions
|
|
182
|
+
inline float16_t max(Simd<float16_t, N> x) {
|
|
183
|
+
float16x4_t y;
|
|
184
|
+
y = vpmax_f16(vget_low_f16(x.value), vget_high_f16(x.value));
|
|
185
|
+
y = vpmax_f16(y, y);
|
|
186
|
+
y = vpmax_f16(y, y);
|
|
187
|
+
return vget_lane_f16(y, 0);
|
|
188
|
+
}
|
|
189
|
+
inline float16_t min(Simd<float16_t, N> x) {
|
|
190
|
+
float16x4_t y;
|
|
191
|
+
y = vpmin_f16(vget_low_f16(x.value), vget_high_f16(x.value));
|
|
192
|
+
y = vpmin_f16(y, y);
|
|
193
|
+
y = vpmin_f16(y, y);
|
|
194
|
+
return vget_lane_f16(y, 0);
|
|
195
|
+
}
|
|
196
|
+
inline float16_t sum(Simd<float16_t, N> x) {
|
|
197
|
+
float16x4_t y;
|
|
198
|
+
y = vpadd_f16(vget_low_f16(x.value), vget_high_f16(x.value));
|
|
199
|
+
y = vpadd_f16(y, y);
|
|
200
|
+
y = vpadd_f16(y, y);
|
|
201
|
+
return vget_lane_f16(y, 0);
|
|
202
|
+
}
|
|
203
|
+
inline float16_t prod(Simd<float16_t, N> x) {
|
|
204
|
+
auto hx = vmul_f16(vget_low_f16(x.value), vget_high_f16(x.value));
|
|
205
|
+
auto out = hx[0];
|
|
206
|
+
hx[0] *= hx[1];
|
|
207
|
+
hx[0] *= hx[2];
|
|
208
|
+
hx[0] *= hx[3];
|
|
209
|
+
return hx[0];
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
} // namespace mlx::core::simd
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include "mlx/backend/cpu/simd/base_simd.h"
|
|
4
|
+
|
|
5
|
+
#ifdef MLX_USE_ACCELERATE
|
|
6
|
+
#if defined(__x86_64__)
|
|
7
|
+
// the accelerate_simd implementation require neon -- use base implementation
|
|
8
|
+
#else
|
|
9
|
+
#include "mlx/backend/cpu/simd/accelerate_simd.h"
|
|
10
|
+
#endif
|
|
11
|
+
#endif
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/array.h"
|
|
6
|
+
|
|
7
|
+
namespace mlx::core {
|
|
8
|
+
|
|
9
|
+
std::tuple<int64_t, Strides> prepare_slice(
|
|
10
|
+
const array& in,
|
|
11
|
+
const Shape& start_indices,
|
|
12
|
+
const Shape& strides);
|
|
13
|
+
|
|
14
|
+
void shared_buffer_slice(
|
|
15
|
+
const array& in,
|
|
16
|
+
const Strides& out_strides,
|
|
17
|
+
size_t data_offset,
|
|
18
|
+
size_t data_size,
|
|
19
|
+
array& out);
|
|
20
|
+
|
|
21
|
+
} // namespace mlx::core
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
// Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
#include "mlx/array.h"
|
|
5
|
+
#include "mlx/backend/common/ternary.h"
|
|
6
|
+
#include "mlx/backend/common/utils.h"
|
|
7
|
+
#include "mlx/backend/cpu/encoder.h"
|
|
8
|
+
|
|
9
|
+
namespace mlx::core {
|
|
10
|
+
|
|
11
|
+
template <typename T1, typename T2, typename T3, typename U, typename Op, int D>
|
|
12
|
+
void ternary_op_dims(
|
|
13
|
+
const T1* a,
|
|
14
|
+
const T2* b,
|
|
15
|
+
const T3* c,
|
|
16
|
+
U* out,
|
|
17
|
+
Op op,
|
|
18
|
+
const Shape& shape,
|
|
19
|
+
const Strides& a_strides,
|
|
20
|
+
const Strides& b_strides,
|
|
21
|
+
const Strides& c_strides,
|
|
22
|
+
const Strides& out_strides,
|
|
23
|
+
int axis) {
|
|
24
|
+
auto stride_a = a_strides[axis];
|
|
25
|
+
auto stride_b = b_strides[axis];
|
|
26
|
+
auto stride_c = c_strides[axis];
|
|
27
|
+
auto stride_out = out_strides[axis];
|
|
28
|
+
auto N = shape[axis];
|
|
29
|
+
|
|
30
|
+
for (int i = 0; i < N; i++) {
|
|
31
|
+
if constexpr (D > 1) {
|
|
32
|
+
ternary_op_dims<T1, T2, T3, U, Op, D - 1>(
|
|
33
|
+
a,
|
|
34
|
+
b,
|
|
35
|
+
c,
|
|
36
|
+
out,
|
|
37
|
+
op,
|
|
38
|
+
shape,
|
|
39
|
+
a_strides,
|
|
40
|
+
b_strides,
|
|
41
|
+
c_strides,
|
|
42
|
+
out_strides,
|
|
43
|
+
axis + 1);
|
|
44
|
+
} else {
|
|
45
|
+
*out = op(*a, *b, *c);
|
|
46
|
+
}
|
|
47
|
+
a += stride_a;
|
|
48
|
+
b += stride_b;
|
|
49
|
+
c += stride_c;
|
|
50
|
+
out += stride_out;
|
|
51
|
+
}
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
template <typename T1, typename T2, typename T3, typename U, typename Op>
|
|
55
|
+
void ternary_op_dispatch_dims(
|
|
56
|
+
const T1* a_ptr,
|
|
57
|
+
const T2* b_ptr,
|
|
58
|
+
const T3* c_ptr,
|
|
59
|
+
U* out_ptr,
|
|
60
|
+
Op op,
|
|
61
|
+
size_t size,
|
|
62
|
+
Shape& shape,
|
|
63
|
+
std::vector<Strides>& strides) {
|
|
64
|
+
const auto& a_strides = strides[0];
|
|
65
|
+
const auto& b_strides = strides[1];
|
|
66
|
+
const auto& c_strides = strides[2];
|
|
67
|
+
const auto& out_strides = strides[3];
|
|
68
|
+
int ndim = shape.size();
|
|
69
|
+
switch (ndim) {
|
|
70
|
+
case 1:
|
|
71
|
+
ternary_op_dims<T1, T2, T3, U, Op, 1>(
|
|
72
|
+
a_ptr,
|
|
73
|
+
b_ptr,
|
|
74
|
+
c_ptr,
|
|
75
|
+
out_ptr,
|
|
76
|
+
op,
|
|
77
|
+
shape,
|
|
78
|
+
a_strides,
|
|
79
|
+
b_strides,
|
|
80
|
+
c_strides,
|
|
81
|
+
out_strides,
|
|
82
|
+
0);
|
|
83
|
+
return;
|
|
84
|
+
case 2:
|
|
85
|
+
ternary_op_dims<T1, T2, T3, U, Op, 2>(
|
|
86
|
+
a_ptr,
|
|
87
|
+
b_ptr,
|
|
88
|
+
c_ptr,
|
|
89
|
+
out_ptr,
|
|
90
|
+
op,
|
|
91
|
+
shape,
|
|
92
|
+
a_strides,
|
|
93
|
+
b_strides,
|
|
94
|
+
c_strides,
|
|
95
|
+
out_strides,
|
|
96
|
+
0);
|
|
97
|
+
return;
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
ContiguousIterator a_it(shape, a_strides, ndim - 2);
|
|
101
|
+
ContiguousIterator b_it(shape, b_strides, ndim - 2);
|
|
102
|
+
ContiguousIterator c_it(shape, c_strides, ndim - 2);
|
|
103
|
+
auto stride = out_strides[ndim - 3];
|
|
104
|
+
for (size_t elem = 0; elem < size; elem += stride) {
|
|
105
|
+
ternary_op_dims<T1, T2, T3, U, Op, 2>(
|
|
106
|
+
a_ptr + a_it.loc,
|
|
107
|
+
b_ptr + b_it.loc,
|
|
108
|
+
c_ptr + c_it.loc,
|
|
109
|
+
out_ptr + elem,
|
|
110
|
+
op,
|
|
111
|
+
shape,
|
|
112
|
+
a_strides,
|
|
113
|
+
b_strides,
|
|
114
|
+
c_strides,
|
|
115
|
+
out_strides,
|
|
116
|
+
ndim - 2);
|
|
117
|
+
a_it.step();
|
|
118
|
+
b_it.step();
|
|
119
|
+
c_it.step();
|
|
120
|
+
}
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
template <typename T1, typename T2, typename T3, typename U, typename Op>
|
|
124
|
+
void ternary_op(
|
|
125
|
+
const array& a,
|
|
126
|
+
const array& b,
|
|
127
|
+
const array& c,
|
|
128
|
+
array& out,
|
|
129
|
+
Op op,
|
|
130
|
+
TernaryOpType topt) {
|
|
131
|
+
const T1* a_ptr = a.data<T1>();
|
|
132
|
+
const T2* b_ptr = b.data<T2>();
|
|
133
|
+
const T3* c_ptr = c.data<T3>();
|
|
134
|
+
U* out_ptr = out.data<U>();
|
|
135
|
+
|
|
136
|
+
if (topt == TernaryOpType::ScalarScalarScalar) {
|
|
137
|
+
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
|
|
138
|
+
} else if (topt == TernaryOpType::VectorVectorVector) {
|
|
139
|
+
for (size_t i = 0; i < out.size(); ++i) {
|
|
140
|
+
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
|
|
141
|
+
a_ptr++;
|
|
142
|
+
b_ptr++;
|
|
143
|
+
c_ptr++;
|
|
144
|
+
out_ptr++;
|
|
145
|
+
}
|
|
146
|
+
} else {
|
|
147
|
+
auto [shape, strides] = collapse_contiguous_dims(
|
|
148
|
+
a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()});
|
|
149
|
+
ternary_op_dispatch_dims<T1, T2, T3, U>(
|
|
150
|
+
a_ptr, b_ptr, c_ptr, out_ptr, op, out.size(), shape, strides);
|
|
151
|
+
}
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
} // namespace mlx::core
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
// Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <cstdint>
|
|
6
|
+
#include <utility>
|
|
7
|
+
|
|
8
|
+
namespace mlx::core::random {
|
|
9
|
+
|
|
10
|
+
/** Applies the Threefry 2x32 hash function.
|
|
11
|
+
* This code is based on the Jax counter-based and splittable PRNG
|
|
12
|
+
* https://github.com/google/jax/blob/main/docs/jep/263-prng.md
|
|
13
|
+
*
|
|
14
|
+
* Original Threefry reference:
|
|
15
|
+
* http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
|
|
16
|
+
*/
|
|
17
|
+
std::pair<uint32_t, uint32_t> threefry2x32_hash(
|
|
18
|
+
const std::pair<uint32_t, uint32_t>& key,
|
|
19
|
+
std::pair<uint32_t, uint32_t> count);
|
|
20
|
+
|
|
21
|
+
} // namespace mlx::core::random
|