mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
// Copyright © 2023-2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <optional>
|
|
6
|
+
|
|
7
|
+
#include "mlx/array.h"
|
|
8
|
+
|
|
9
|
+
namespace mlx::core {
|
|
10
|
+
|
|
11
|
+
void async_eval(std::vector<array> outputs);
|
|
12
|
+
|
|
13
|
+
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
|
14
|
+
void async_eval(Arrays&&... outputs) {
|
|
15
|
+
async_eval(std::vector<array>{std::forward<Arrays>(outputs)...});
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
void eval(std::vector<array> outputs);
|
|
19
|
+
|
|
20
|
+
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
|
21
|
+
void eval(Arrays&&... outputs) {
|
|
22
|
+
eval(std::vector<array>{std::forward<Arrays>(outputs)...});
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
/**
|
|
26
|
+
* Computes the output and vector-Jacobian product (VJP) of a function.
|
|
27
|
+
*
|
|
28
|
+
* Computes the vector-Jacobian product of the vector of cotangents with the
|
|
29
|
+
* Jacobian of the function evaluated at the primals. Returns a pair of
|
|
30
|
+
* vectors of output arrays and VJP arrays.
|
|
31
|
+
**/
|
|
32
|
+
std::pair<std::vector<array>, std::vector<array>> vjp(
|
|
33
|
+
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
|
34
|
+
const std::vector<array>& primals,
|
|
35
|
+
const std::vector<array>& cotangents);
|
|
36
|
+
|
|
37
|
+
/**
|
|
38
|
+
* Computes the output and vector-Jacobian product (VJP) of a unary function.
|
|
39
|
+
*/
|
|
40
|
+
std::pair<array, array> vjp(
|
|
41
|
+
const std::function<array(const array&)>& fun,
|
|
42
|
+
const array& primal,
|
|
43
|
+
const array& cotangent);
|
|
44
|
+
|
|
45
|
+
/**
|
|
46
|
+
* Computes the output and Jacobian-vector product (JVP) of a function.
|
|
47
|
+
*
|
|
48
|
+
* Computes the Jacobian-vector product of the Jacobian of the function
|
|
49
|
+
* evaluated at the primals with the vector of tangents. Returns a pair of
|
|
50
|
+
* vectors of output arrays and JVP arrays.
|
|
51
|
+
**/
|
|
52
|
+
std::pair<std::vector<array>, std::vector<array>> jvp(
|
|
53
|
+
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
|
54
|
+
const std::vector<array>& primals,
|
|
55
|
+
const std::vector<array>& tangents);
|
|
56
|
+
|
|
57
|
+
/**
|
|
58
|
+
* Computes the output and Jacobian-vector product (JVP) of a unary function.
|
|
59
|
+
*/
|
|
60
|
+
std::pair<array, array> jvp(
|
|
61
|
+
const std::function<array(const array&)>& fun,
|
|
62
|
+
const array& primal,
|
|
63
|
+
const array& tangent);
|
|
64
|
+
|
|
65
|
+
// Return type of general value_and_grad: a function which takes an input
|
|
66
|
+
// vector of arrays and returns a pair of vectors of arrays one for the
|
|
67
|
+
// values and one for the gradients wrt the first value.
|
|
68
|
+
using ValueAndGradFn =
|
|
69
|
+
std::function<std::pair<std::vector<array>, std::vector<array>>(
|
|
70
|
+
const std::vector<array>&)>;
|
|
71
|
+
using SimpleValueAndGradFn = std::function<std::pair<array, std::vector<array>>(
|
|
72
|
+
const std::vector<array>&)>;
|
|
73
|
+
|
|
74
|
+
/**
|
|
75
|
+
* Returns a function which computes the value and gradient of the input
|
|
76
|
+
* function with respect to a vector of input arrays.
|
|
77
|
+
**/
|
|
78
|
+
ValueAndGradFn value_and_grad(
|
|
79
|
+
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
|
80
|
+
const std::vector<int>& argnums);
|
|
81
|
+
|
|
82
|
+
/**
|
|
83
|
+
* Returns a function which computes the value and gradient of the input
|
|
84
|
+
* function with respect to a single input array.
|
|
85
|
+
**/
|
|
86
|
+
ValueAndGradFn inline value_and_grad(
|
|
87
|
+
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
|
88
|
+
int argnum = 0) {
|
|
89
|
+
return value_and_grad(fun, std::vector<int>{argnum});
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
/**
|
|
93
|
+
* Returns a function which computes the value and gradient of the unary
|
|
94
|
+
* input function.
|
|
95
|
+
**/
|
|
96
|
+
std::function<std::pair<array, array>(const array&)> inline value_and_grad(
|
|
97
|
+
const std::function<array(const array&)>& fun) {
|
|
98
|
+
return [fun](auto inputs) { return vjp(fun, inputs, array(1.0f)); };
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
SimpleValueAndGradFn inline value_and_grad(
|
|
102
|
+
const std::function<array(const std::vector<array>&)>& fun,
|
|
103
|
+
const std::vector<int>& argnums) {
|
|
104
|
+
return [fun, argnums](auto inputs) {
|
|
105
|
+
auto result = value_and_grad(
|
|
106
|
+
[fun](auto inputs) { return std::vector<array>{fun(inputs)}; },
|
|
107
|
+
argnums)(inputs);
|
|
108
|
+
|
|
109
|
+
return std::make_pair(result.first[0], result.second);
|
|
110
|
+
};
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
SimpleValueAndGradFn inline value_and_grad(
|
|
114
|
+
const std::function<array(const std::vector<array>&)>& fun,
|
|
115
|
+
int argnum = 0) {
|
|
116
|
+
return value_and_grad(fun, std::vector<int>{argnum});
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
/**
|
|
120
|
+
* Returns a function which computes the gradient of the input function with
|
|
121
|
+
* respect to a vector of input arrays.
|
|
122
|
+
*
|
|
123
|
+
* The function being differentiated takes a vector of arrays and returns an
|
|
124
|
+
* array. The vector of `argnums` specifies which the arguments to compute
|
|
125
|
+
* the gradient with respect to. At least one argument must be specified.
|
|
126
|
+
**/
|
|
127
|
+
std::function<std::vector<array>(const std::vector<array>&)> inline grad(
|
|
128
|
+
const std::function<array(const std::vector<array>&)>& fun,
|
|
129
|
+
const std::vector<int>& argnums) {
|
|
130
|
+
auto fn = value_and_grad(fun, argnums);
|
|
131
|
+
return [fn](const std::vector<array>& inputs) { return fn(inputs).second; };
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
/**
|
|
135
|
+
* Returns a function which computes the gradient of the input function with
|
|
136
|
+
* respect to a single input array.
|
|
137
|
+
*
|
|
138
|
+
* The function being differentiated takes a vector of arrays and returns an
|
|
139
|
+
* array. The optional `argnum` index specifies which the argument to compute
|
|
140
|
+
* the gradient with respect to and defaults to 0.
|
|
141
|
+
**/
|
|
142
|
+
std::function<std::vector<array>(const std::vector<array>&)> inline grad(
|
|
143
|
+
const std::function<array(const std::vector<array>&)>& fun,
|
|
144
|
+
int argnum = 0) {
|
|
145
|
+
return grad(fun, std::vector<int>{argnum});
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
/**
|
|
149
|
+
* Returns a function which computes the gradient of the unary input function.
|
|
150
|
+
**/
|
|
151
|
+
std::function<array(const array&)> inline grad(
|
|
152
|
+
const std::function<array(const array&)>& fun) {
|
|
153
|
+
auto fn = value_and_grad(fun);
|
|
154
|
+
return [fn](const array& input) { return fn(input).second; };
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
/**
|
|
158
|
+
* Automatically vectorize a unary function over the requested axes.
|
|
159
|
+
*/
|
|
160
|
+
std::function<array(const array&)> vmap(
|
|
161
|
+
const std::function<array(const array&)>& fun,
|
|
162
|
+
int in_axis = 0,
|
|
163
|
+
int out_axis = 0);
|
|
164
|
+
|
|
165
|
+
/**
|
|
166
|
+
* Automatically vectorize a binary function over the requested axes.
|
|
167
|
+
*/
|
|
168
|
+
std::function<array(const array&, const array&)> vmap(
|
|
169
|
+
const std::function<array(const array&, const array&)>& fun,
|
|
170
|
+
int in_axis_a = 0,
|
|
171
|
+
int in_axis_b = 0,
|
|
172
|
+
int out_axis = 0);
|
|
173
|
+
|
|
174
|
+
/**
|
|
175
|
+
* Automatically vectorize a function over the requested axes.
|
|
176
|
+
*
|
|
177
|
+
* The input function to `vmap` takes as an argument a vector of arrays and
|
|
178
|
+
* returns a vector of arrays. Optionally specify the axes to vectorize over
|
|
179
|
+
* with `in_axes` and `out_axes`, otherwise a default of 0 is used.
|
|
180
|
+
* Returns a vectorized function with the same signature as the input
|
|
181
|
+
* function.
|
|
182
|
+
*/
|
|
183
|
+
std::function<std::vector<array>(const std::vector<array>&)> vmap(
|
|
184
|
+
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
|
185
|
+
const std::vector<int>& in_axes = {},
|
|
186
|
+
const std::vector<int>& out_axes = {});
|
|
187
|
+
|
|
188
|
+
/**
|
|
189
|
+
* Redefine the transformations of `fun` according to the provided functions.
|
|
190
|
+
*
|
|
191
|
+
* Namely when calling the vjp of `fun` then `fun_vjp` will be called,
|
|
192
|
+
* `fun_jvp` for the jvp and `fun_vmap` for vmap.
|
|
193
|
+
*
|
|
194
|
+
* If any transformation is not provided, then a default one is created by
|
|
195
|
+
* calling `vjp`, `jvp` and `vmap` on the function directly.
|
|
196
|
+
*/
|
|
197
|
+
std::function<std::vector<array>(const std::vector<array>&)> custom_function(
|
|
198
|
+
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
|
199
|
+
std::optional<std::function<std::vector<array>(
|
|
200
|
+
const std::vector<array>&,
|
|
201
|
+
const std::vector<array>&,
|
|
202
|
+
const std::vector<array>&)>> fun_vjp = std::nullopt,
|
|
203
|
+
std::optional<std::function<std::vector<array>(
|
|
204
|
+
const std::vector<array>&,
|
|
205
|
+
const std::vector<array>&,
|
|
206
|
+
const std::vector<int>&)>> fun_jvp = std::nullopt,
|
|
207
|
+
std::optional<std::function<std::pair<std::vector<array>, std::vector<int>>(
|
|
208
|
+
const std::vector<array>&,
|
|
209
|
+
const std::vector<int>&)>> fun_vmap = std::nullopt);
|
|
210
|
+
|
|
211
|
+
/**
|
|
212
|
+
* Return a function that behaves exactly like `fun` but if the vjp of the
|
|
213
|
+
* results is computed `fun_vjp` will be used instead of `vjp(fun, ...)` .
|
|
214
|
+
*/
|
|
215
|
+
std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
|
|
216
|
+
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
|
217
|
+
std::function<std::vector<array>(
|
|
218
|
+
const std::vector<array>&,
|
|
219
|
+
const std::vector<array>&,
|
|
220
|
+
const std::vector<array>&)> fun_vjp);
|
|
221
|
+
|
|
222
|
+
/**
|
|
223
|
+
* Checkpoint the gradient of a function. Namely, discard all intermediate
|
|
224
|
+
* state and recalculate it when we need to compute the gradient.
|
|
225
|
+
*/
|
|
226
|
+
std::function<std::vector<array>(const std::vector<array>&)> checkpoint(
|
|
227
|
+
std::function<std::vector<array>(const std::vector<array>&)> fun);
|
|
228
|
+
|
|
229
|
+
} // namespace mlx::core
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
// Copyright © 2023-2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
namespace mlx::core::detail {
|
|
6
|
+
|
|
7
|
+
std::pair<std::vector<array>, std::vector<array>> vmap_trace(
|
|
8
|
+
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
|
9
|
+
const std::vector<array>& inputs,
|
|
10
|
+
const std::vector<int>& in_axes);
|
|
11
|
+
|
|
12
|
+
std::vector<array> vmap_replace(
|
|
13
|
+
const std::vector<array>& inputs,
|
|
14
|
+
const std::vector<array>& s_inputs,
|
|
15
|
+
const std::vector<array>& s_outputs,
|
|
16
|
+
const std::vector<int>& in_axes,
|
|
17
|
+
const std::vector<int>& out_axes);
|
|
18
|
+
|
|
19
|
+
// Create an InTracing object during tracing operations to signify to the rest
|
|
20
|
+
// of the codebase that we are during tracing so evals should not throw away
|
|
21
|
+
// the graph.
|
|
22
|
+
struct InTracing {
|
|
23
|
+
explicit InTracing(bool dynamic = false, bool grad = false) {
|
|
24
|
+
grad_counter += grad;
|
|
25
|
+
trace_stack().push_back({dynamic, grad});
|
|
26
|
+
}
|
|
27
|
+
~InTracing() {
|
|
28
|
+
grad_counter -= trace_stack().back().second;
|
|
29
|
+
trace_stack().pop_back();
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
static bool in_tracing() {
|
|
33
|
+
return !trace_stack().empty();
|
|
34
|
+
}
|
|
35
|
+
static bool in_dynamic_tracing() {
|
|
36
|
+
// compile is always and only the outer-most transform
|
|
37
|
+
return in_tracing() && trace_stack().front().first;
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
static bool in_grad_tracing() {
|
|
41
|
+
return grad_counter > 0;
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
private:
|
|
45
|
+
static int grad_counter;
|
|
46
|
+
static std::vector<std::pair<char, char>>& trace_stack();
|
|
47
|
+
};
|
|
48
|
+
|
|
49
|
+
struct RetainGraph {
|
|
50
|
+
RetainGraph() {
|
|
51
|
+
tracing_counter++;
|
|
52
|
+
}
|
|
53
|
+
~RetainGraph() {
|
|
54
|
+
tracing_counter--;
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
static bool retain_graph() {
|
|
58
|
+
return tracing_counter > 0;
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
private:
|
|
62
|
+
static int tracing_counter;
|
|
63
|
+
};
|
|
64
|
+
|
|
65
|
+
/** Return true if we are currently performing a function transformation in
|
|
66
|
+
* order to keep the graph when evaluating tracer arrays. */
|
|
67
|
+
inline bool in_tracing() {
|
|
68
|
+
return detail::InTracing::in_tracing();
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
/** Return true if we are in a dynamic (shapeless) trace used for compiling or
|
|
72
|
+
* exporting graphs with dynamic shapes. */
|
|
73
|
+
inline bool in_dynamic_tracing() {
|
|
74
|
+
return detail::InTracing::in_dynamic_tracing();
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
/** Return true if we are in a gradient trace (vjp, jvp, etc). */
|
|
78
|
+
inline bool in_grad_tracing() {
|
|
79
|
+
return detail::InTracing::in_grad_tracing();
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
inline bool retain_graph() {
|
|
83
|
+
return detail::RetainGraph::retain_graph();
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
} // namespace mlx::core::detail
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
// Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <algorithm>
|
|
6
|
+
#include <cmath>
|
|
7
|
+
#include <cstdint>
|
|
8
|
+
#include <vector>
|
|
9
|
+
|
|
10
|
+
#define __MLX_BFLOAT_NAN__ 0x7FC0
|
|
11
|
+
|
|
12
|
+
namespace mlx::core {
|
|
13
|
+
|
|
14
|
+
namespace {
|
|
15
|
+
union float_bits_bf16 {
|
|
16
|
+
float f;
|
|
17
|
+
uint32_t u;
|
|
18
|
+
};
|
|
19
|
+
} // namespace
|
|
20
|
+
|
|
21
|
+
struct _MLX_BFloat16 {
|
|
22
|
+
uint16_t bits_;
|
|
23
|
+
|
|
24
|
+
// Default constructor
|
|
25
|
+
_MLX_BFloat16() = default;
|
|
26
|
+
|
|
27
|
+
// Default copy constructor
|
|
28
|
+
_MLX_BFloat16(_MLX_BFloat16 const&) = default;
|
|
29
|
+
|
|
30
|
+
// Appease std::vector<bool> for being special
|
|
31
|
+
_MLX_BFloat16& operator=(std::vector<bool>::reference x) {
|
|
32
|
+
bits_ = x;
|
|
33
|
+
return *this;
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
_MLX_BFloat16& operator=(const float& x) {
|
|
37
|
+
return (*this = _MLX_BFloat16(x));
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
// From float32
|
|
41
|
+
_MLX_BFloat16(const float& x) {
|
|
42
|
+
if (std::isnan(x)) {
|
|
43
|
+
bits_ = __MLX_BFLOAT_NAN__;
|
|
44
|
+
} else {
|
|
45
|
+
// Union
|
|
46
|
+
float_bits_bf16 in;
|
|
47
|
+
|
|
48
|
+
// Take bits
|
|
49
|
+
in.f = x;
|
|
50
|
+
|
|
51
|
+
// Round to nearest even
|
|
52
|
+
in.u += (in.u >> 16 & 1) + uint32_t(0x7FFF);
|
|
53
|
+
|
|
54
|
+
// Take upper 16 bits
|
|
55
|
+
bits_ = in.u >> 16;
|
|
56
|
+
}
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
// To float32
|
|
60
|
+
operator float() const {
|
|
61
|
+
// Union
|
|
62
|
+
float_bits_bf16 out;
|
|
63
|
+
|
|
64
|
+
// Upper 16 bits are the data and lower 16 bits are 0s
|
|
65
|
+
out.u = ((uint32_t)bits_) << 16;
|
|
66
|
+
|
|
67
|
+
return out.f;
|
|
68
|
+
}
|
|
69
|
+
};
|
|
70
|
+
|
|
71
|
+
#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
|
|
72
|
+
inline otype __operator__(atype lhs, btype rhs) { \
|
|
73
|
+
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \
|
|
77
|
+
inline otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \
|
|
78
|
+
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
|
79
|
+
} \
|
|
80
|
+
inline otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \
|
|
81
|
+
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
// Operators
|
|
85
|
+
#define bfloat_binop(_op_, _operator_) \
|
|
86
|
+
bfloat_binop_base( \
|
|
87
|
+
_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \
|
|
88
|
+
bfloat_binop_helper(_op_, _operator_, float, float, float); \
|
|
89
|
+
bfloat_binop_helper(_op_, _operator_, double, double, double); \
|
|
90
|
+
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, bool, float); \
|
|
91
|
+
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \
|
|
92
|
+
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \
|
|
93
|
+
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \
|
|
94
|
+
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);
|
|
95
|
+
|
|
96
|
+
bfloat_binop(+, operator+);
|
|
97
|
+
bfloat_binop(-, operator-);
|
|
98
|
+
bfloat_binop(*, operator*);
|
|
99
|
+
bfloat_binop(/, operator/);
|
|
100
|
+
|
|
101
|
+
#undef bfloat_binop
|
|
102
|
+
|
|
103
|
+
// Comparison ops
|
|
104
|
+
#define bfloat_compop(__op__, __operator__) \
|
|
105
|
+
bfloat_binop_base( \
|
|
106
|
+
__op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \
|
|
107
|
+
bfloat_binop_helper(__op__, __operator__, bool, float, float); \
|
|
108
|
+
bfloat_binop_helper(__op__, __operator__, bool, double, double); \
|
|
109
|
+
bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \
|
|
110
|
+
bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \
|
|
111
|
+
bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \
|
|
112
|
+
bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float);
|
|
113
|
+
|
|
114
|
+
bfloat_compop(>, operator>);
|
|
115
|
+
bfloat_compop(<, operator<);
|
|
116
|
+
bfloat_compop(>=, operator>=);
|
|
117
|
+
bfloat_compop(<=, operator<=);
|
|
118
|
+
bfloat_compop(==, operator==);
|
|
119
|
+
bfloat_compop(!=, operator!=);
|
|
120
|
+
|
|
121
|
+
#undef bfloat_compop
|
|
122
|
+
|
|
123
|
+
// Negative
|
|
124
|
+
inline _MLX_BFloat16 operator-(_MLX_BFloat16 lhs) {
|
|
125
|
+
return -static_cast<float>(lhs);
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
// Inplace ops
|
|
129
|
+
#define bfloat_inplace_op(__op__, __operator__) \
|
|
130
|
+
inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, const float& rhs) { \
|
|
131
|
+
lhs = lhs __op__ rhs; \
|
|
132
|
+
return lhs; \
|
|
133
|
+
} \
|
|
134
|
+
inline float& __operator__(float& lhs, _MLX_BFloat16 rhs) { \
|
|
135
|
+
lhs = lhs __op__ rhs; \
|
|
136
|
+
return lhs; \
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
bfloat_inplace_op(+, operator+=);
|
|
140
|
+
bfloat_inplace_op(-, operator-=);
|
|
141
|
+
bfloat_inplace_op(*, operator*=);
|
|
142
|
+
bfloat_inplace_op(/, operator/=);
|
|
143
|
+
|
|
144
|
+
#undef bfloat_inplace_op
|
|
145
|
+
|
|
146
|
+
// Bitwise ops
|
|
147
|
+
|
|
148
|
+
#define bfloat_bitop(__op__, __operator__) \
|
|
149
|
+
inline _MLX_BFloat16 __operator__(_MLX_BFloat16 lhs, _MLX_BFloat16 rhs) { \
|
|
150
|
+
_MLX_BFloat16 out; \
|
|
151
|
+
out.bits_ = lhs.bits_ __op__ rhs.bits_; \
|
|
152
|
+
return out; \
|
|
153
|
+
} \
|
|
154
|
+
inline _MLX_BFloat16 __operator__(_MLX_BFloat16 lhs, uint16_t rhs) { \
|
|
155
|
+
_MLX_BFloat16 out; \
|
|
156
|
+
out.bits_ = lhs.bits_ __op__ rhs; \
|
|
157
|
+
return out; \
|
|
158
|
+
} \
|
|
159
|
+
inline _MLX_BFloat16 __operator__(uint16_t lhs, _MLX_BFloat16 rhs) { \
|
|
160
|
+
_MLX_BFloat16 out; \
|
|
161
|
+
out.bits_ = lhs __op__ rhs.bits_; \
|
|
162
|
+
return out; \
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
bfloat_bitop(|, operator|);
|
|
166
|
+
bfloat_bitop(&, operator&);
|
|
167
|
+
bfloat_bitop(^, operator^);
|
|
168
|
+
|
|
169
|
+
#undef bfloat_bitop
|
|
170
|
+
|
|
171
|
+
#define bfloat_inplace_bitop(__op__, __operator__) \
|
|
172
|
+
inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \
|
|
173
|
+
lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \
|
|
174
|
+
return lhs; \
|
|
175
|
+
} \
|
|
176
|
+
inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, uint16_t rhs) { \
|
|
177
|
+
lhs.bits_ = lhs.bits_ __op__ rhs; \
|
|
178
|
+
return lhs; \
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
bfloat_inplace_bitop(|, operator|=);
|
|
182
|
+
bfloat_inplace_bitop(&, operator&=);
|
|
183
|
+
bfloat_inplace_bitop(^, operator^=);
|
|
184
|
+
|
|
185
|
+
#undef bfloat_inplace_bitop
|
|
186
|
+
|
|
187
|
+
} // namespace mlx::core
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
// Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
#include <complex>
|
|
5
|
+
#include "mlx/types/half_types.h"
|
|
6
|
+
|
|
7
|
+
namespace mlx::core {
|
|
8
|
+
|
|
9
|
+
struct complex64_t;
|
|
10
|
+
struct complex128_t;
|
|
11
|
+
|
|
12
|
+
template <typename T>
|
|
13
|
+
inline constexpr bool can_convert_to_complex128 =
|
|
14
|
+
!std::is_same_v<T, complex128_t> && std::is_convertible_v<T, double>;
|
|
15
|
+
|
|
16
|
+
struct complex128_t : public std::complex<double> {
|
|
17
|
+
complex128_t() : std::complex<double>() {};
|
|
18
|
+
complex128_t(double v, double u) : std::complex<double>(v, u) {};
|
|
19
|
+
complex128_t(std::complex<double> v) : std::complex<double>(v) {};
|
|
20
|
+
|
|
21
|
+
template <
|
|
22
|
+
typename T,
|
|
23
|
+
typename = typename std::enable_if<can_convert_to_complex128<T>>::type>
|
|
24
|
+
complex128_t(T x) : std::complex<double>(x){};
|
|
25
|
+
|
|
26
|
+
operator float() const {
|
|
27
|
+
return real();
|
|
28
|
+
};
|
|
29
|
+
};
|
|
30
|
+
|
|
31
|
+
template <typename T>
|
|
32
|
+
inline constexpr bool can_convert_to_complex64 =
|
|
33
|
+
!std::is_same_v<T, complex64_t> && std::is_convertible_v<T, float>;
|
|
34
|
+
|
|
35
|
+
struct complex64_t : public std::complex<float> {
|
|
36
|
+
complex64_t() : std::complex<float>() {};
|
|
37
|
+
complex64_t(float v, float u) : std::complex<float>(v, u) {};
|
|
38
|
+
complex64_t(std::complex<float> v) : std::complex<float>(v) {};
|
|
39
|
+
|
|
40
|
+
template <
|
|
41
|
+
typename T,
|
|
42
|
+
typename = typename std::enable_if<can_convert_to_complex64<T>>::type>
|
|
43
|
+
complex64_t(T x) : std::complex<float>(x){};
|
|
44
|
+
|
|
45
|
+
operator float() const {
|
|
46
|
+
return real();
|
|
47
|
+
};
|
|
48
|
+
};
|
|
49
|
+
|
|
50
|
+
inline bool operator>=(const complex64_t& a, const complex64_t& b) {
|
|
51
|
+
return (a.real() > b.real()) ||
|
|
52
|
+
(a.real() == b.real() && a.imag() >= b.imag());
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
inline bool operator>(const complex64_t& a, const complex64_t& b) {
|
|
56
|
+
return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag());
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
inline complex64_t operator%(complex64_t a, complex64_t b) {
|
|
60
|
+
auto real = a.real() - (b.real() * static_cast<int64_t>(a.real() / b.real()));
|
|
61
|
+
auto imag = a.imag() - (b.imag() * static_cast<int64_t>(a.imag() / b.imag()));
|
|
62
|
+
if (real != 0 && ((real < 0) != (b.real() < 0)))
|
|
63
|
+
real += b.real();
|
|
64
|
+
if (imag != 0 && ((imag < 0) != (b.imag() < 0)))
|
|
65
|
+
imag += b.imag();
|
|
66
|
+
return {real, imag};
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
inline bool operator<=(const complex64_t& a, const complex64_t& b) {
|
|
70
|
+
return operator>=(b, a);
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
inline bool operator<(const complex64_t& a, const complex64_t& b) {
|
|
74
|
+
return operator>(b, a);
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
inline complex64_t operator-(const complex64_t& v) {
|
|
78
|
+
return -static_cast<std::complex<float>>(v);
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
// clang-format off
|
|
82
|
+
#define complex_binop_helper(_op_, _operator_, itype) \
|
|
83
|
+
inline complex64_t _operator_(itype x, const complex64_t& y) { \
|
|
84
|
+
return static_cast<complex64_t>(x) _op_ y; \
|
|
85
|
+
} \
|
|
86
|
+
inline complex64_t _operator_(const complex64_t& x, itype y) { \
|
|
87
|
+
return x _op_ static_cast<complex64_t>(y); \
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
#define complex_binop(_op_, _operator_) \
|
|
91
|
+
inline complex64_t _operator_(const std::complex<float>& x, const complex64_t& y) { \
|
|
92
|
+
return x _op_ static_cast<std::complex<float>>(y); \
|
|
93
|
+
} \
|
|
94
|
+
inline complex64_t _operator_(const complex64_t& x, const std::complex<float>& y) { \
|
|
95
|
+
return static_cast<std::complex<float>>(x) _op_ y; \
|
|
96
|
+
} \
|
|
97
|
+
inline complex64_t _operator_(const complex64_t& x, const complex64_t& y) { \
|
|
98
|
+
return static_cast<std::complex<float>>(x) \
|
|
99
|
+
_op_ static_cast<std::complex<float>>(y); \
|
|
100
|
+
} \
|
|
101
|
+
complex_binop_helper(_op_, _operator_, bool) \
|
|
102
|
+
complex_binop_helper(_op_, _operator_, uint32_t) \
|
|
103
|
+
complex_binop_helper(_op_, _operator_, uint64_t) \
|
|
104
|
+
complex_binop_helper(_op_, _operator_, int32_t) \
|
|
105
|
+
complex_binop_helper(_op_, _operator_, int64_t) \
|
|
106
|
+
complex_binop_helper(_op_, _operator_, float16_t) \
|
|
107
|
+
complex_binop_helper(_op_, _operator_, bfloat16_t) \
|
|
108
|
+
complex_binop_helper(_op_, _operator_, float)
|
|
109
|
+
// clang-format on
|
|
110
|
+
|
|
111
|
+
complex_binop(+, operator+)
|
|
112
|
+
|
|
113
|
+
} // namespace mlx::core
|