mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,427 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#include <optional>
|
|
4
|
+
#include <variant>
|
|
5
|
+
|
|
6
|
+
#include "mlx/primitives.h"
|
|
7
|
+
|
|
8
|
+
namespace mlx::core::fast {
|
|
9
|
+
|
|
10
|
+
// Custom primitive accepts a fallback function which it uses for
|
|
11
|
+
// transformations. Transformations are virtual so that derived classes may
|
|
12
|
+
// override the default behavior.
|
|
13
|
+
class Custom : public Primitive {
|
|
14
|
+
public:
|
|
15
|
+
explicit Custom(
|
|
16
|
+
Stream stream,
|
|
17
|
+
std::function<std::vector<array>(std::vector<array>)> fallback)
|
|
18
|
+
: Primitive(stream), fallback_(std::move(fallback)) {}
|
|
19
|
+
|
|
20
|
+
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
|
|
21
|
+
const std::vector<array>& inputs,
|
|
22
|
+
const std::vector<int>& axes) override;
|
|
23
|
+
|
|
24
|
+
virtual std::vector<array> jvp(
|
|
25
|
+
const std::vector<array>& primals,
|
|
26
|
+
const std::vector<array>& tangents,
|
|
27
|
+
const std::vector<int>& argnums) override;
|
|
28
|
+
|
|
29
|
+
virtual std::vector<array> vjp(
|
|
30
|
+
const std::vector<array>& primals,
|
|
31
|
+
const std::vector<array>& cotangents,
|
|
32
|
+
const std::vector<int>& argnums,
|
|
33
|
+
const std::vector<array>& outputs) override;
|
|
34
|
+
|
|
35
|
+
protected:
|
|
36
|
+
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
|
37
|
+
};
|
|
38
|
+
|
|
39
|
+
class RMSNorm : public Custom {
|
|
40
|
+
public:
|
|
41
|
+
RMSNorm(
|
|
42
|
+
Stream stream,
|
|
43
|
+
std::function<std::vector<array>(std::vector<array>)> fallback,
|
|
44
|
+
float eps)
|
|
45
|
+
: Custom(stream, std::move(fallback)), eps_(eps) {}
|
|
46
|
+
|
|
47
|
+
static bool use_fallback(Stream stream);
|
|
48
|
+
|
|
49
|
+
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
50
|
+
override {
|
|
51
|
+
throw std::runtime_error("NYI");
|
|
52
|
+
}
|
|
53
|
+
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
54
|
+
override;
|
|
55
|
+
|
|
56
|
+
std::vector<array> vjp(
|
|
57
|
+
const std::vector<array>& primals,
|
|
58
|
+
const std::vector<array>& cotangents,
|
|
59
|
+
const std::vector<int>& argnums,
|
|
60
|
+
const std::vector<array>& outputs) override;
|
|
61
|
+
|
|
62
|
+
DEFINE_NAME(RMSNorm)
|
|
63
|
+
bool is_equivalent(const Primitive& other) const override;
|
|
64
|
+
DEFINE_INPUT_OUTPUT_SHAPE()
|
|
65
|
+
|
|
66
|
+
auto state() const {
|
|
67
|
+
return std::make_pair(nullptr, eps_);
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
private:
|
|
71
|
+
float eps_;
|
|
72
|
+
};
|
|
73
|
+
|
|
74
|
+
class RMSNormVJP : public Custom {
|
|
75
|
+
public:
|
|
76
|
+
RMSNormVJP(
|
|
77
|
+
Stream stream,
|
|
78
|
+
std::function<std::vector<array>(std::vector<array>)> fallback,
|
|
79
|
+
float eps)
|
|
80
|
+
: Custom(stream, std::move(fallback)), eps_(eps) {}
|
|
81
|
+
|
|
82
|
+
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
83
|
+
override {
|
|
84
|
+
throw std::runtime_error("NYI");
|
|
85
|
+
}
|
|
86
|
+
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
87
|
+
override;
|
|
88
|
+
|
|
89
|
+
DEFINE_NAME(RMSNormVJP)
|
|
90
|
+
bool is_equivalent(const Primitive& other) const override;
|
|
91
|
+
auto state() const {
|
|
92
|
+
return std::make_pair(nullptr, eps_);
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
private:
|
|
96
|
+
float eps_;
|
|
97
|
+
};
|
|
98
|
+
|
|
99
|
+
class LayerNorm : public Custom {
|
|
100
|
+
public:
|
|
101
|
+
LayerNorm(
|
|
102
|
+
Stream stream,
|
|
103
|
+
std::function<std::vector<array>(std::vector<array>)> fallback,
|
|
104
|
+
float eps)
|
|
105
|
+
: Custom(stream, std::move(fallback)), eps_(eps) {}
|
|
106
|
+
|
|
107
|
+
static bool use_fallback(Stream s);
|
|
108
|
+
|
|
109
|
+
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
110
|
+
override {
|
|
111
|
+
throw std::runtime_error("NYI");
|
|
112
|
+
}
|
|
113
|
+
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
114
|
+
override;
|
|
115
|
+
|
|
116
|
+
std::vector<array> vjp(
|
|
117
|
+
const std::vector<array>& primals,
|
|
118
|
+
const std::vector<array>& cotangents,
|
|
119
|
+
const std::vector<int>& argnums,
|
|
120
|
+
const std::vector<array>& outputs) override;
|
|
121
|
+
|
|
122
|
+
DEFINE_NAME(LayerNorm)
|
|
123
|
+
bool is_equivalent(const Primitive& other) const override;
|
|
124
|
+
DEFINE_INPUT_OUTPUT_SHAPE()
|
|
125
|
+
auto state() const {
|
|
126
|
+
return std::make_pair(nullptr, eps_);
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
private:
|
|
130
|
+
float eps_;
|
|
131
|
+
};
|
|
132
|
+
|
|
133
|
+
class LayerNormVJP : public Custom {
|
|
134
|
+
public:
|
|
135
|
+
LayerNormVJP(
|
|
136
|
+
Stream stream,
|
|
137
|
+
std::function<std::vector<array>(std::vector<array>)> fallback,
|
|
138
|
+
float eps)
|
|
139
|
+
: Custom(stream, std::move(fallback)), eps_(eps) {}
|
|
140
|
+
|
|
141
|
+
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
142
|
+
override {
|
|
143
|
+
throw std::runtime_error("NYI");
|
|
144
|
+
}
|
|
145
|
+
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
146
|
+
override;
|
|
147
|
+
|
|
148
|
+
DEFINE_NAME(LayerNormVJP)
|
|
149
|
+
bool is_equivalent(const Primitive& other) const override;
|
|
150
|
+
auto state() const {
|
|
151
|
+
return std::make_pair(nullptr, eps_);
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
private:
|
|
155
|
+
float eps_;
|
|
156
|
+
};
|
|
157
|
+
|
|
158
|
+
class RoPE : public Custom {
|
|
159
|
+
public:
|
|
160
|
+
RoPE(
|
|
161
|
+
Stream stream,
|
|
162
|
+
std::function<std::vector<array>(std::vector<array>)> fallback,
|
|
163
|
+
int dims,
|
|
164
|
+
bool traditional,
|
|
165
|
+
float base,
|
|
166
|
+
float scale,
|
|
167
|
+
bool forward)
|
|
168
|
+
: Custom(stream, std::move(fallback)),
|
|
169
|
+
dims_(dims),
|
|
170
|
+
traditional_(traditional),
|
|
171
|
+
base_(base),
|
|
172
|
+
scale_(scale),
|
|
173
|
+
forward_(forward) {}
|
|
174
|
+
|
|
175
|
+
static bool use_fallback(Stream s);
|
|
176
|
+
|
|
177
|
+
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
178
|
+
override {
|
|
179
|
+
throw std::runtime_error("NYI");
|
|
180
|
+
}
|
|
181
|
+
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
182
|
+
override;
|
|
183
|
+
|
|
184
|
+
std::vector<array> vjp(
|
|
185
|
+
const std::vector<array>& primals,
|
|
186
|
+
const std::vector<array>& cotangents,
|
|
187
|
+
const std::vector<int>& argnums,
|
|
188
|
+
const std::vector<array>& outputs) override;
|
|
189
|
+
|
|
190
|
+
DEFINE_NAME(RoPE)
|
|
191
|
+
bool is_equivalent(const Primitive& other) const override;
|
|
192
|
+
DEFINE_INPUT_OUTPUT_SHAPE()
|
|
193
|
+
auto state() const {
|
|
194
|
+
return std::make_tuple(
|
|
195
|
+
nullptr, dims_, traditional_, base_, scale_, forward_);
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
private:
|
|
199
|
+
int dims_;
|
|
200
|
+
bool traditional_;
|
|
201
|
+
float base_;
|
|
202
|
+
float scale_;
|
|
203
|
+
bool forward_;
|
|
204
|
+
};
|
|
205
|
+
|
|
206
|
+
class ScaledDotProductAttention : public Custom {
|
|
207
|
+
public:
|
|
208
|
+
ScaledDotProductAttention(
|
|
209
|
+
Stream stream,
|
|
210
|
+
std::function<std::vector<array>(std::vector<array>)> fallback,
|
|
211
|
+
float scale,
|
|
212
|
+
bool do_causal,
|
|
213
|
+
bool has_sinks,
|
|
214
|
+
bool output_logsumexp)
|
|
215
|
+
: Custom(stream, std::move(fallback)),
|
|
216
|
+
scale_(scale),
|
|
217
|
+
do_causal_(do_causal),
|
|
218
|
+
has_sinks_(has_sinks),
|
|
219
|
+
output_logsumexp_(output_logsumexp) {}
|
|
220
|
+
|
|
221
|
+
static bool use_fallback(
|
|
222
|
+
const array& q,
|
|
223
|
+
const array& k,
|
|
224
|
+
const array& v,
|
|
225
|
+
bool has_mask,
|
|
226
|
+
bool has_arr_mask,
|
|
227
|
+
bool do_causal,
|
|
228
|
+
bool is_training,
|
|
229
|
+
bool output_logsumexp,
|
|
230
|
+
Stream s);
|
|
231
|
+
static bool supports_bool_mask();
|
|
232
|
+
|
|
233
|
+
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
234
|
+
override {
|
|
235
|
+
throw std::runtime_error("NYI");
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
239
|
+
override;
|
|
240
|
+
|
|
241
|
+
std::vector<array> vjp(
|
|
242
|
+
const std::vector<array>& primals,
|
|
243
|
+
const std::vector<array>& cotangents,
|
|
244
|
+
const std::vector<int>& argnums,
|
|
245
|
+
const std::vector<array>& outputs) override;
|
|
246
|
+
|
|
247
|
+
bool is_equivalent(const Primitive& other) const override;
|
|
248
|
+
|
|
249
|
+
DEFINE_NAME(ScaledDotProductAttention);
|
|
250
|
+
DEFINE_INPUT_OUTPUT_SHAPE()
|
|
251
|
+
auto state() const {
|
|
252
|
+
return std::make_tuple(
|
|
253
|
+
nullptr, scale_, do_causal_, has_sinks_, output_logsumexp_);
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
private:
|
|
257
|
+
float scale_;
|
|
258
|
+
bool do_causal_;
|
|
259
|
+
bool has_sinks_;
|
|
260
|
+
bool output_logsumexp_;
|
|
261
|
+
};
|
|
262
|
+
|
|
263
|
+
class ScaledDotProductAttentionVJP : public Custom {
|
|
264
|
+
public:
|
|
265
|
+
ScaledDotProductAttentionVJP(
|
|
266
|
+
Stream stream,
|
|
267
|
+
std::function<std::vector<array>(std::vector<array>)> fallback,
|
|
268
|
+
float scale,
|
|
269
|
+
bool do_causal,
|
|
270
|
+
bool has_sinks)
|
|
271
|
+
: Custom(stream, std::move(fallback)),
|
|
272
|
+
scale_(scale),
|
|
273
|
+
do_causal_(do_causal),
|
|
274
|
+
has_sinks_(has_sinks) {}
|
|
275
|
+
|
|
276
|
+
static bool use_fallback(const array& q, Stream s);
|
|
277
|
+
|
|
278
|
+
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
279
|
+
override {
|
|
280
|
+
throw std::runtime_error("NYI");
|
|
281
|
+
}
|
|
282
|
+
|
|
283
|
+
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
284
|
+
override;
|
|
285
|
+
|
|
286
|
+
DEFINE_NAME(ScaledDotProductAttentionVJP);
|
|
287
|
+
bool is_equivalent(const Primitive& other) const override;
|
|
288
|
+
auto state() const {
|
|
289
|
+
return std::make_tuple(nullptr, scale_, do_causal_, has_sinks_);
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
private:
|
|
293
|
+
float scale_;
|
|
294
|
+
bool do_causal_;
|
|
295
|
+
bool has_sinks_;
|
|
296
|
+
};
|
|
297
|
+
|
|
298
|
+
class ConvertFP8 : public Primitive {
|
|
299
|
+
public:
|
|
300
|
+
explicit ConvertFP8(Stream stream, bool to_fp8)
|
|
301
|
+
: Primitive(stream), to_fp8_(to_fp8) {}
|
|
302
|
+
|
|
303
|
+
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
304
|
+
override;
|
|
305
|
+
|
|
306
|
+
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
307
|
+
override;
|
|
308
|
+
|
|
309
|
+
const char* name() const override {
|
|
310
|
+
if (to_fp8_) {
|
|
311
|
+
return "ToFP8";
|
|
312
|
+
} else {
|
|
313
|
+
return "FromFP8";
|
|
314
|
+
}
|
|
315
|
+
}
|
|
316
|
+
bool state() const {
|
|
317
|
+
return to_fp8_;
|
|
318
|
+
};
|
|
319
|
+
|
|
320
|
+
bool is_equivalent(const Primitive& other) const override;
|
|
321
|
+
DEFINE_INPUT_OUTPUT_SHAPE();
|
|
322
|
+
|
|
323
|
+
private:
|
|
324
|
+
bool to_fp8_;
|
|
325
|
+
};
|
|
326
|
+
|
|
327
|
+
class Quantize : public Custom {
|
|
328
|
+
public:
|
|
329
|
+
explicit Quantize(
|
|
330
|
+
Stream stream,
|
|
331
|
+
std::function<std::vector<array>(std::vector<array>)> fallback,
|
|
332
|
+
int group_size,
|
|
333
|
+
int bits,
|
|
334
|
+
QuantizationMode mode,
|
|
335
|
+
bool dequantize)
|
|
336
|
+
: Custom(stream, std::move(fallback)),
|
|
337
|
+
group_size_(group_size),
|
|
338
|
+
bits_(bits),
|
|
339
|
+
mode_(mode),
|
|
340
|
+
dequantize_(dequantize) {}
|
|
341
|
+
|
|
342
|
+
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
343
|
+
override;
|
|
344
|
+
|
|
345
|
+
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
346
|
+
override;
|
|
347
|
+
|
|
348
|
+
DEFINE_NAME(Quantize);
|
|
349
|
+
|
|
350
|
+
bool is_equivalent(const Primitive& other) const override;
|
|
351
|
+
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
|
352
|
+
auto state() const {
|
|
353
|
+
return std::make_tuple(nullptr, group_size_, bits_, mode_, dequantize_);
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
private:
|
|
357
|
+
int group_size_;
|
|
358
|
+
int bits_;
|
|
359
|
+
QuantizationMode mode_;
|
|
360
|
+
bool dequantize_;
|
|
361
|
+
};
|
|
362
|
+
|
|
363
|
+
using ScalarArg = std::variant<bool, int, float>;
|
|
364
|
+
|
|
365
|
+
class CustomKernel : public Primitive {
|
|
366
|
+
public:
|
|
367
|
+
CustomKernel(
|
|
368
|
+
Stream stream,
|
|
369
|
+
std::string name,
|
|
370
|
+
std::string source,
|
|
371
|
+
std::tuple<int, int, int> grid,
|
|
372
|
+
std::tuple<int, int, int> threadgroup,
|
|
373
|
+
std::vector<std::tuple<bool, bool, bool>> shape_infos,
|
|
374
|
+
bool ensure_row_contiguous,
|
|
375
|
+
std::optional<float> init_value,
|
|
376
|
+
std::vector<ScalarArg> scalar_arguments,
|
|
377
|
+
bool is_precompiled,
|
|
378
|
+
int shared_memory)
|
|
379
|
+
: Primitive(stream),
|
|
380
|
+
name_(std::move(name)),
|
|
381
|
+
source_(std::move(source)),
|
|
382
|
+
grid_(grid),
|
|
383
|
+
threadgroup_(threadgroup),
|
|
384
|
+
shape_infos_(std::move(shape_infos)),
|
|
385
|
+
ensure_row_contiguous_(ensure_row_contiguous),
|
|
386
|
+
init_value_(init_value),
|
|
387
|
+
scalar_arguments_(std::move(scalar_arguments)),
|
|
388
|
+
is_precompiled_(is_precompiled),
|
|
389
|
+
shared_memory_(shared_memory) {}
|
|
390
|
+
|
|
391
|
+
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
392
|
+
override {
|
|
393
|
+
throw std::runtime_error("Custom kernels only run on GPU.");
|
|
394
|
+
}
|
|
395
|
+
|
|
396
|
+
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
397
|
+
override;
|
|
398
|
+
|
|
399
|
+
DEFINE_NAME(CustomKernel);
|
|
400
|
+
auto state() const {
|
|
401
|
+
return std::make_tuple(
|
|
402
|
+
name_,
|
|
403
|
+
source_,
|
|
404
|
+
grid_,
|
|
405
|
+
threadgroup_,
|
|
406
|
+
shape_infos_,
|
|
407
|
+
ensure_row_contiguous_,
|
|
408
|
+
init_value_,
|
|
409
|
+
scalar_arguments_,
|
|
410
|
+
is_precompiled_,
|
|
411
|
+
shared_memory_);
|
|
412
|
+
}
|
|
413
|
+
|
|
414
|
+
private:
|
|
415
|
+
std::string name_;
|
|
416
|
+
std::string source_;
|
|
417
|
+
std::tuple<int, int, int> grid_;
|
|
418
|
+
std::tuple<int, int, int> threadgroup_;
|
|
419
|
+
std::vector<std::tuple<bool, bool, bool>> shape_infos_;
|
|
420
|
+
bool ensure_row_contiguous_;
|
|
421
|
+
std::optional<float> init_value_;
|
|
422
|
+
std::vector<ScalarArg> scalar_arguments_;
|
|
423
|
+
bool is_precompiled_;
|
|
424
|
+
int shared_memory_;
|
|
425
|
+
};
|
|
426
|
+
|
|
427
|
+
} // namespace mlx::core::fast
|
mlx/include/mlx/fence.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#include <vector>
|
|
4
|
+
|
|
5
|
+
#include "mlx/array.h"
|
|
6
|
+
|
|
7
|
+
namespace mlx::core {
|
|
8
|
+
|
|
9
|
+
/* A fence to be used for synchronizing work between streams.
|
|
10
|
+
*
|
|
11
|
+
* Calls to `wait` wait in the given stream until all previous calls to update
|
|
12
|
+
* are complete on their given stream.
|
|
13
|
+
*
|
|
14
|
+
* The array passed to `update` is computed and visible after the call to
|
|
15
|
+
* `wait` returns. The array passed to `wait` will not be read until all
|
|
16
|
+
* previous calls to `update` have completed.
|
|
17
|
+
*
|
|
18
|
+
* Note, calls to `update` should always be from the same thread or explicitly
|
|
19
|
+
* synchronized so that they occur in sequence. Calls to `wait` can be on any
|
|
20
|
+
* thread.
|
|
21
|
+
*
|
|
22
|
+
* For the Metal back-end the fence supports slow (default) and fast mode.
|
|
23
|
+
* Fast mode requires setting the environment variable
|
|
24
|
+
* `MLX_METAL_FAST_SYNCH=1`. Fast mode also requires Metal 3.2+ (macOS 15+,
|
|
25
|
+
* iOS 18+).
|
|
26
|
+
*/
|
|
27
|
+
class Fence {
|
|
28
|
+
public:
|
|
29
|
+
Fence() {};
|
|
30
|
+
explicit Fence(Stream stream);
|
|
31
|
+
|
|
32
|
+
void update(Stream stream, const array& x, bool cross_device);
|
|
33
|
+
void wait(Stream stream, const array& x);
|
|
34
|
+
|
|
35
|
+
private:
|
|
36
|
+
std::shared_ptr<void> fence_{nullptr};
|
|
37
|
+
};
|
|
38
|
+
|
|
39
|
+
} // namespace mlx::core
|
mlx/include/mlx/fft.h
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
// Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <variant>
|
|
6
|
+
|
|
7
|
+
#include "array.h"
|
|
8
|
+
#include "device.h"
|
|
9
|
+
#include "utils.h"
|
|
10
|
+
|
|
11
|
+
namespace mlx::core::fft {
|
|
12
|
+
|
|
13
|
+
/** Compute the n-dimensional Fourier Transform. */
|
|
14
|
+
array fftn(
|
|
15
|
+
const array& a,
|
|
16
|
+
const Shape& n,
|
|
17
|
+
const std::vector<int>& axes,
|
|
18
|
+
StreamOrDevice s = {});
|
|
19
|
+
array fftn(const array& a, const std::vector<int>& axes, StreamOrDevice s = {});
|
|
20
|
+
array fftn(const array& a, StreamOrDevice s = {});
|
|
21
|
+
|
|
22
|
+
/** Compute the n-dimensional inverse Fourier Transform. */
|
|
23
|
+
array ifftn(
|
|
24
|
+
const array& a,
|
|
25
|
+
const Shape& n,
|
|
26
|
+
const std::vector<int>& axes,
|
|
27
|
+
StreamOrDevice s = {});
|
|
28
|
+
array ifftn(
|
|
29
|
+
const array& a,
|
|
30
|
+
const std::vector<int>& axes,
|
|
31
|
+
StreamOrDevice s = {});
|
|
32
|
+
array ifftn(const array& a, StreamOrDevice s = {});
|
|
33
|
+
|
|
34
|
+
/** Compute the one-dimensional Fourier Transform. */
|
|
35
|
+
inline array fft(const array& a, int n, int axis, StreamOrDevice s = {}) {
|
|
36
|
+
return fftn(a, {n}, {axis}, s);
|
|
37
|
+
}
|
|
38
|
+
inline array fft(const array& a, int axis = -1, StreamOrDevice s = {}) {
|
|
39
|
+
return fftn(a, {axis}, s);
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
/** Compute the one-dimensional inverse Fourier Transform. */
|
|
43
|
+
inline array ifft(const array& a, int n, int axis, StreamOrDevice s = {}) {
|
|
44
|
+
return ifftn(a, {n}, {axis}, s);
|
|
45
|
+
}
|
|
46
|
+
inline array ifft(const array& a, int axis = -1, StreamOrDevice s = {}) {
|
|
47
|
+
return ifftn(a, {axis}, s);
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
/** Compute the two-dimensional Fourier Transform. */
|
|
51
|
+
inline array fft2(
|
|
52
|
+
const array& a,
|
|
53
|
+
const Shape& n,
|
|
54
|
+
const std::vector<int>& axes,
|
|
55
|
+
StreamOrDevice s = {}) {
|
|
56
|
+
return fftn(a, n, axes, s);
|
|
57
|
+
}
|
|
58
|
+
inline array fft2(
|
|
59
|
+
const array& a,
|
|
60
|
+
const std::vector<int>& axes = {-2, -1},
|
|
61
|
+
StreamOrDevice s = {}) {
|
|
62
|
+
return fftn(a, axes, s);
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
/** Compute the two-dimensional inverse Fourier Transform. */
|
|
66
|
+
inline array ifft2(
|
|
67
|
+
const array& a,
|
|
68
|
+
const Shape& n,
|
|
69
|
+
const std::vector<int>& axes,
|
|
70
|
+
StreamOrDevice s = {}) {
|
|
71
|
+
return ifftn(a, n, axes, s);
|
|
72
|
+
}
|
|
73
|
+
inline array ifft2(
|
|
74
|
+
const array& a,
|
|
75
|
+
const std::vector<int>& axes = {-2, -1},
|
|
76
|
+
StreamOrDevice s = {}) {
|
|
77
|
+
return ifftn(a, axes, s);
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
/** Compute the n-dimensional Fourier Transform on a real input. */
|
|
81
|
+
array rfftn(
|
|
82
|
+
const array& a,
|
|
83
|
+
const Shape& n,
|
|
84
|
+
const std::vector<int>& axes,
|
|
85
|
+
StreamOrDevice s = {});
|
|
86
|
+
array rfftn(
|
|
87
|
+
const array& a,
|
|
88
|
+
const std::vector<int>& axes,
|
|
89
|
+
StreamOrDevice s = {});
|
|
90
|
+
array rfftn(const array& a, StreamOrDevice s = {});
|
|
91
|
+
|
|
92
|
+
/** Compute the n-dimensional inverse of `rfftn`. */
|
|
93
|
+
array irfftn(
|
|
94
|
+
const array& a,
|
|
95
|
+
const Shape& n,
|
|
96
|
+
const std::vector<int>& axes,
|
|
97
|
+
StreamOrDevice s = {});
|
|
98
|
+
array irfftn(
|
|
99
|
+
const array& a,
|
|
100
|
+
const std::vector<int>& axes,
|
|
101
|
+
StreamOrDevice s = {});
|
|
102
|
+
array irfftn(const array& a, StreamOrDevice s = {});
|
|
103
|
+
|
|
104
|
+
/** Compute the one-dimensional Fourier Transform on a real input. */
|
|
105
|
+
inline array rfft(const array& a, int n, int axis, StreamOrDevice s = {}) {
|
|
106
|
+
return rfftn(a, {n}, {axis}, s);
|
|
107
|
+
}
|
|
108
|
+
inline array rfft(const array& a, int axis = -1, StreamOrDevice s = {}) {
|
|
109
|
+
return rfftn(a, {axis}, s);
|
|
110
|
+
}
|
|
111
|
+
/** Compute the one-dimensional inverse of `rfft`. */
|
|
112
|
+
inline array irfft(const array& a, int n, int axis, StreamOrDevice s = {}) {
|
|
113
|
+
return irfftn(a, {n}, {axis}, s);
|
|
114
|
+
}
|
|
115
|
+
inline array irfft(const array& a, int axis = -1, StreamOrDevice s = {}) {
|
|
116
|
+
return irfftn(a, {axis}, s);
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
/** Compute the two-dimensional Fourier Transform on a real input. */
|
|
120
|
+
inline array rfft2(
|
|
121
|
+
const array& a,
|
|
122
|
+
const Shape& n,
|
|
123
|
+
const std::vector<int>& axes,
|
|
124
|
+
StreamOrDevice s = {}) {
|
|
125
|
+
return rfftn(a, n, axes, s);
|
|
126
|
+
}
|
|
127
|
+
inline array rfft2(
|
|
128
|
+
const array& a,
|
|
129
|
+
const std::vector<int>& axes = {-2, -1},
|
|
130
|
+
StreamOrDevice s = {}) {
|
|
131
|
+
return rfftn(a, axes, s);
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
/** Compute the two-dimensional inverse of `rfft2`. */
|
|
135
|
+
inline array irfft2(
|
|
136
|
+
const array& a,
|
|
137
|
+
const Shape& n,
|
|
138
|
+
const std::vector<int>& axes,
|
|
139
|
+
StreamOrDevice s = {}) {
|
|
140
|
+
return irfftn(a, n, axes, s);
|
|
141
|
+
}
|
|
142
|
+
inline array irfft2(
|
|
143
|
+
const array& a,
|
|
144
|
+
const std::vector<int>& axes = {-2, -1},
|
|
145
|
+
StreamOrDevice s = {}) {
|
|
146
|
+
return irfftn(a, axes, s);
|
|
147
|
+
}
|
|
148
|
+
/** Shift the zero-frequency component to the center of the spectrum. */
|
|
149
|
+
array fftshift(const array& a, StreamOrDevice s = {});
|
|
150
|
+
|
|
151
|
+
/** Shift the zero-frequency component to the center of the spectrum along
|
|
152
|
+
* specified axes. */
|
|
153
|
+
array fftshift(
|
|
154
|
+
const array& a,
|
|
155
|
+
const std::vector<int>& axes,
|
|
156
|
+
StreamOrDevice s = {});
|
|
157
|
+
|
|
158
|
+
/** The inverse of fftshift. */
|
|
159
|
+
array ifftshift(const array& a, StreamOrDevice s = {});
|
|
160
|
+
|
|
161
|
+
/** The inverse of fftshift along specified axes. */
|
|
162
|
+
array ifftshift(
|
|
163
|
+
const array& a,
|
|
164
|
+
const std::vector<int>& axes,
|
|
165
|
+
StreamOrDevice s = {});
|
|
166
|
+
|
|
167
|
+
} // namespace mlx::core::fft
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
// Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <unordered_map>
|
|
6
|
+
|
|
7
|
+
#include "mlx/array.h"
|
|
8
|
+
|
|
9
|
+
namespace mlx::core {
|
|
10
|
+
|
|
11
|
+
struct NodeNamer {
|
|
12
|
+
std::unordered_map<std::uintptr_t, std::string> names;
|
|
13
|
+
|
|
14
|
+
const std::string& get_name(const array& x);
|
|
15
|
+
void set_name(const array& x, std::string n);
|
|
16
|
+
};
|
|
17
|
+
|
|
18
|
+
void print_graph(
|
|
19
|
+
std::ostream& os,
|
|
20
|
+
NodeNamer namer,
|
|
21
|
+
const std::vector<array>& outputs);
|
|
22
|
+
|
|
23
|
+
inline void print_graph(std::ostream& os, const std::vector<array>& outputs) {
|
|
24
|
+
print_graph(os, NodeNamer{}, outputs);
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
|
28
|
+
inline void print_graph(std::ostream& os, Arrays&&... outputs) {
|
|
29
|
+
print_graph(
|
|
30
|
+
os, NodeNamer{}, std::vector<array>{std::forward<Arrays>(outputs)...});
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
|
34
|
+
inline void
|
|
35
|
+
print_graph(std::ostream& os, NodeNamer namer, Arrays&&... outputs) {
|
|
36
|
+
print_graph(
|
|
37
|
+
os,
|
|
38
|
+
std::move(namer),
|
|
39
|
+
std::vector<array>{std::forward<Arrays>(outputs)...});
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
void export_to_dot(
|
|
43
|
+
std::ostream& os,
|
|
44
|
+
NodeNamer namer,
|
|
45
|
+
const std::vector<array>& outputs);
|
|
46
|
+
|
|
47
|
+
inline void export_to_dot(std::ostream& os, const std::vector<array>& outputs) {
|
|
48
|
+
export_to_dot(os, NodeNamer{}, outputs);
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
|
52
|
+
inline void export_to_dot(std::ostream& os, Arrays&&... outputs) {
|
|
53
|
+
export_to_dot(
|
|
54
|
+
os, NodeNamer{}, std::vector<array>{std::forward<Arrays>(outputs)...});
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
|
58
|
+
inline void
|
|
59
|
+
export_to_dot(std::ostream& os, NodeNamer namer, Arrays&&... outputs) {
|
|
60
|
+
export_to_dot(
|
|
61
|
+
os,
|
|
62
|
+
std::move(namer),
|
|
63
|
+
std::vector<array>{std::forward<Arrays>(outputs)...});
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
} // namespace mlx::core
|