mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
// Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/allocator.h"
|
|
6
|
+
#include "mlx/array.h"
|
|
7
|
+
#include "mlx/backend/common/utils.h"
|
|
8
|
+
|
|
9
|
+
namespace mlx::core {
|
|
10
|
+
|
|
11
|
+
enum class BinaryOpType {
|
|
12
|
+
ScalarScalar,
|
|
13
|
+
ScalarVector,
|
|
14
|
+
VectorScalar,
|
|
15
|
+
VectorVector,
|
|
16
|
+
General,
|
|
17
|
+
};
|
|
18
|
+
|
|
19
|
+
inline BinaryOpType get_binary_op_type(const array& a, const array& b) {
|
|
20
|
+
BinaryOpType bopt;
|
|
21
|
+
if (a.data_size() == 1 && b.data_size() == 1) {
|
|
22
|
+
bopt = BinaryOpType::ScalarScalar;
|
|
23
|
+
} else if (a.data_size() == 1 && b.flags().contiguous) {
|
|
24
|
+
bopt = BinaryOpType::ScalarVector;
|
|
25
|
+
} else if (b.data_size() == 1 && a.flags().contiguous) {
|
|
26
|
+
bopt = BinaryOpType::VectorScalar;
|
|
27
|
+
} else if (
|
|
28
|
+
(a.flags().row_contiguous && b.flags().row_contiguous) ||
|
|
29
|
+
(a.flags().col_contiguous && b.flags().col_contiguous)) {
|
|
30
|
+
bopt = BinaryOpType::VectorVector;
|
|
31
|
+
} else {
|
|
32
|
+
bopt = BinaryOpType::General;
|
|
33
|
+
}
|
|
34
|
+
return bopt;
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
inline void set_binary_op_output_data(
|
|
38
|
+
const array& a,
|
|
39
|
+
const array& b,
|
|
40
|
+
array& out,
|
|
41
|
+
BinaryOpType bopt,
|
|
42
|
+
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
|
|
43
|
+
bool b_donatable = is_donatable(b, out);
|
|
44
|
+
bool a_donatable = is_donatable(a, out);
|
|
45
|
+
switch (bopt) {
|
|
46
|
+
case BinaryOpType::ScalarScalar:
|
|
47
|
+
out.set_data(mallocfn(out.itemsize()), 1, a.strides(), a.flags());
|
|
48
|
+
break;
|
|
49
|
+
case BinaryOpType::ScalarVector:
|
|
50
|
+
if (b_donatable) {
|
|
51
|
+
out.copy_shared_buffer(b);
|
|
52
|
+
} else {
|
|
53
|
+
out.set_data(
|
|
54
|
+
mallocfn(b.data_size() * out.itemsize()),
|
|
55
|
+
b.data_size(),
|
|
56
|
+
b.strides(),
|
|
57
|
+
b.flags());
|
|
58
|
+
}
|
|
59
|
+
break;
|
|
60
|
+
case BinaryOpType::VectorScalar:
|
|
61
|
+
if (a_donatable) {
|
|
62
|
+
out.copy_shared_buffer(a);
|
|
63
|
+
} else {
|
|
64
|
+
out.set_data(
|
|
65
|
+
mallocfn(a.data_size() * out.itemsize()),
|
|
66
|
+
a.data_size(),
|
|
67
|
+
a.strides(),
|
|
68
|
+
a.flags());
|
|
69
|
+
}
|
|
70
|
+
break;
|
|
71
|
+
case BinaryOpType::VectorVector:
|
|
72
|
+
if (a_donatable) {
|
|
73
|
+
out.copy_shared_buffer(a);
|
|
74
|
+
} else if (b_donatable) {
|
|
75
|
+
out.copy_shared_buffer(b);
|
|
76
|
+
} else {
|
|
77
|
+
out.set_data(
|
|
78
|
+
mallocfn(a.data_size() * out.itemsize()),
|
|
79
|
+
a.data_size(),
|
|
80
|
+
a.strides(),
|
|
81
|
+
a.flags());
|
|
82
|
+
}
|
|
83
|
+
break;
|
|
84
|
+
case BinaryOpType::General:
|
|
85
|
+
if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) {
|
|
86
|
+
out.copy_shared_buffer(a);
|
|
87
|
+
} else if (
|
|
88
|
+
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
|
|
89
|
+
out.copy_shared_buffer(b);
|
|
90
|
+
} else {
|
|
91
|
+
out.set_data(mallocfn(out.nbytes()));
|
|
92
|
+
}
|
|
93
|
+
break;
|
|
94
|
+
}
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
} // namespace mlx::core
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
// Copyright © 2025 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <cassert>
|
|
6
|
+
#include <functional>
|
|
7
|
+
#include <map>
|
|
8
|
+
|
|
9
|
+
namespace mlx::core {
|
|
10
|
+
|
|
11
|
+
template <typename T>
|
|
12
|
+
class BufferCache {
|
|
13
|
+
public:
|
|
14
|
+
BufferCache(
|
|
15
|
+
size_t page_size,
|
|
16
|
+
std::function<size_t(T*)> get_size,
|
|
17
|
+
std::function<void(T*)> free)
|
|
18
|
+
: page_size_(page_size),
|
|
19
|
+
get_size_(std::move(get_size)),
|
|
20
|
+
free_(std::move(free)) {}
|
|
21
|
+
|
|
22
|
+
~BufferCache() {
|
|
23
|
+
clear();
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
BufferCache(const BufferCache&) = delete;
|
|
27
|
+
BufferCache& operator=(const BufferCache&) = delete;
|
|
28
|
+
|
|
29
|
+
T* reuse_from_cache(size_t size) {
|
|
30
|
+
// Find the closest buffer in pool.
|
|
31
|
+
auto it = buffer_pool_.lower_bound(size);
|
|
32
|
+
if (it == buffer_pool_.end() ||
|
|
33
|
+
it->first >= std::min(2 * size, size + 2 * page_size_)) {
|
|
34
|
+
return nullptr;
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
// Collect from the cache.
|
|
38
|
+
T* buf = it->second->buf;
|
|
39
|
+
pool_size_ -= it->first;
|
|
40
|
+
|
|
41
|
+
// Remove from record.
|
|
42
|
+
remove_from_list(it->second);
|
|
43
|
+
buffer_pool_.erase(it);
|
|
44
|
+
return buf;
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
void recycle_to_cache(T* buf) {
|
|
48
|
+
assert(buf);
|
|
49
|
+
// Add to cache.
|
|
50
|
+
BufferHolder* bh = new BufferHolder(buf);
|
|
51
|
+
add_at_head(bh);
|
|
52
|
+
size_t size = get_size_(buf);
|
|
53
|
+
pool_size_ += size;
|
|
54
|
+
buffer_pool_.emplace(size, bh);
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
int release_cached_buffers(size_t min_bytes_to_free) {
|
|
58
|
+
if (min_bytes_to_free >= 0.9 * pool_size_) {
|
|
59
|
+
return clear();
|
|
60
|
+
} else {
|
|
61
|
+
int n_release = 0;
|
|
62
|
+
size_t total_bytes_freed = 0;
|
|
63
|
+
|
|
64
|
+
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
|
65
|
+
// Release buffer.
|
|
66
|
+
size_t size = get_size_(tail_->buf);
|
|
67
|
+
total_bytes_freed += size;
|
|
68
|
+
free_(tail_->buf);
|
|
69
|
+
n_release++;
|
|
70
|
+
|
|
71
|
+
// Remove from record.
|
|
72
|
+
auto its = buffer_pool_.equal_range(size);
|
|
73
|
+
auto it = std::find_if(its.first, its.second, [this](const auto& el) {
|
|
74
|
+
return el.second == tail_;
|
|
75
|
+
});
|
|
76
|
+
assert(it != buffer_pool_.end());
|
|
77
|
+
buffer_pool_.erase(it);
|
|
78
|
+
remove_from_list(tail_);
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
pool_size_ -= total_bytes_freed;
|
|
82
|
+
return n_release;
|
|
83
|
+
}
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
int clear() {
|
|
87
|
+
int n_release = 0;
|
|
88
|
+
for (auto& [size, holder] : buffer_pool_) {
|
|
89
|
+
free_(holder->buf);
|
|
90
|
+
n_release++;
|
|
91
|
+
delete holder;
|
|
92
|
+
}
|
|
93
|
+
buffer_pool_.clear();
|
|
94
|
+
pool_size_ = 0;
|
|
95
|
+
head_ = nullptr;
|
|
96
|
+
tail_ = nullptr;
|
|
97
|
+
return n_release;
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
size_t cache_size() const {
|
|
101
|
+
return pool_size_;
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
size_t page_size() const {
|
|
105
|
+
return page_size_;
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
private:
|
|
109
|
+
struct BufferHolder {
|
|
110
|
+
public:
|
|
111
|
+
explicit BufferHolder(T* buf_) : buf(buf_) {}
|
|
112
|
+
|
|
113
|
+
BufferHolder* prev{nullptr};
|
|
114
|
+
BufferHolder* next{nullptr};
|
|
115
|
+
T* buf;
|
|
116
|
+
};
|
|
117
|
+
|
|
118
|
+
void add_at_head(BufferHolder* to_add) {
|
|
119
|
+
if (!head_) {
|
|
120
|
+
head_ = to_add;
|
|
121
|
+
tail_ = to_add;
|
|
122
|
+
} else {
|
|
123
|
+
head_->prev = to_add;
|
|
124
|
+
to_add->next = head_;
|
|
125
|
+
head_ = to_add;
|
|
126
|
+
}
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
void remove_from_list(BufferHolder* to_remove) {
|
|
130
|
+
if (to_remove->prev && to_remove->next) { // if middle
|
|
131
|
+
to_remove->prev->next = to_remove->next;
|
|
132
|
+
to_remove->next->prev = to_remove->prev;
|
|
133
|
+
} else if (to_remove->prev && to_remove == tail_) { // if tail
|
|
134
|
+
tail_ = to_remove->prev;
|
|
135
|
+
tail_->next = nullptr;
|
|
136
|
+
} else if (to_remove == head_ && to_remove->next) { // if head
|
|
137
|
+
head_ = to_remove->next;
|
|
138
|
+
head_->prev = nullptr;
|
|
139
|
+
} else if (to_remove == head_ && to_remove == tail_) { // if only element
|
|
140
|
+
head_ = nullptr;
|
|
141
|
+
tail_ = nullptr;
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
delete to_remove;
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
std::multimap<size_t, BufferHolder*> buffer_pool_;
|
|
148
|
+
BufferHolder* head_{nullptr};
|
|
149
|
+
BufferHolder* tail_{nullptr};
|
|
150
|
+
size_t pool_size_{0};
|
|
151
|
+
|
|
152
|
+
const size_t page_size_;
|
|
153
|
+
std::function<size_t(T*)> get_size_;
|
|
154
|
+
std::function<void(T*)> free_;
|
|
155
|
+
};
|
|
156
|
+
|
|
157
|
+
} // namespace mlx::core
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
// Copyright © 2023-2024 Apple Inc.
|
|
2
|
+
#pragma once
|
|
3
|
+
|
|
4
|
+
#include <functional>
|
|
5
|
+
#include <iomanip>
|
|
6
|
+
|
|
7
|
+
#include "mlx/array.h"
|
|
8
|
+
#include "mlx/primitives.h"
|
|
9
|
+
|
|
10
|
+
namespace mlx::core {
|
|
11
|
+
|
|
12
|
+
inline bool is_static_cast(const Primitive& p) {
|
|
13
|
+
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
std::string get_type_string(Dtype d);
|
|
17
|
+
|
|
18
|
+
template <typename T>
|
|
19
|
+
void print_float_constant(std::ostream& os, const array& x) {
|
|
20
|
+
auto old_precision = os.precision();
|
|
21
|
+
if constexpr (std::is_same_v<T, double>) {
|
|
22
|
+
os << std::setprecision(std::numeric_limits<double>::digits10 + 1);
|
|
23
|
+
} else {
|
|
24
|
+
os << std::setprecision(std::numeric_limits<float>::digits10 + 1);
|
|
25
|
+
}
|
|
26
|
+
os << x.item<T>() << std::setprecision(old_precision);
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
template <typename T>
|
|
30
|
+
void print_int_constant(std::ostream& os, const array& x) {
|
|
31
|
+
os << x.item<T>();
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
template <typename T>
|
|
35
|
+
void print_complex_constant(std::ostream& os, const array& x) {
|
|
36
|
+
auto old_precision = os.precision();
|
|
37
|
+
T constant = x.item<T>();
|
|
38
|
+
|
|
39
|
+
os << get_type_string(x.dtype()) << "("
|
|
40
|
+
<< std::setprecision(std::numeric_limits<float>::digits10 + 1)
|
|
41
|
+
<< constant.real() << ", " << constant.imag() << ")"
|
|
42
|
+
<< std::setprecision(old_precision);
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
void print_constant(std::ostream& os, const array& x);
|
|
46
|
+
|
|
47
|
+
inline bool is_scalar(const array& x) {
|
|
48
|
+
return x.ndim() == 0;
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
// Check if we can use a contiguous operation given inputs and the output shape
|
|
52
|
+
bool compiled_check_contiguity(
|
|
53
|
+
const std::vector<array>& inputs,
|
|
54
|
+
const Shape& shape);
|
|
55
|
+
|
|
56
|
+
// Allocate space for the outputs possibly with input donation
|
|
57
|
+
void compiled_allocate_outputs(
|
|
58
|
+
const std::vector<array>& inputs,
|
|
59
|
+
std::vector<array>& outputs,
|
|
60
|
+
const std::function<bool(size_t)>& is_constant,
|
|
61
|
+
bool contiguous,
|
|
62
|
+
const std::function<allocator::Buffer(size_t)>& mallocfn =
|
|
63
|
+
allocator::malloc);
|
|
64
|
+
|
|
65
|
+
// Collapse contiguous dims ignoring scalars and constants.
|
|
66
|
+
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
|
67
|
+
const std::vector<array>& inputs,
|
|
68
|
+
const array& out,
|
|
69
|
+
const std::function<bool(size_t)>& is_constant);
|
|
70
|
+
|
|
71
|
+
// Return whether the kernel should use large index.
|
|
72
|
+
bool compiled_use_large_index(
|
|
73
|
+
const std::vector<array>& inputs,
|
|
74
|
+
const std::vector<array>& outputs,
|
|
75
|
+
bool contiguous);
|
|
76
|
+
|
|
77
|
+
} // namespace mlx::core
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
// Copyright © 2023-2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/backend/common/utils.h"
|
|
6
|
+
|
|
7
|
+
namespace mlx::core {
|
|
8
|
+
|
|
9
|
+
enum class CopyType {
|
|
10
|
+
// Copy a raw scalar input into the full contiguous output
|
|
11
|
+
Scalar,
|
|
12
|
+
|
|
13
|
+
// Copy the raw input buffer contiguously into a raw output buffer of the same
|
|
14
|
+
// size
|
|
15
|
+
Vector,
|
|
16
|
+
|
|
17
|
+
// Copy the full virtual input to the full contiguous output
|
|
18
|
+
General,
|
|
19
|
+
|
|
20
|
+
// Copy the full virtual input to the full virtual output. We assume the
|
|
21
|
+
// input and output have the same shape.
|
|
22
|
+
GeneralGeneral
|
|
23
|
+
};
|
|
24
|
+
|
|
25
|
+
inline bool set_copy_output_data(
|
|
26
|
+
const array& in,
|
|
27
|
+
array& out,
|
|
28
|
+
CopyType ctype,
|
|
29
|
+
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
|
|
30
|
+
if (ctype == CopyType::Vector) {
|
|
31
|
+
// If the input is donateable, we are doing a vector copy and the types
|
|
32
|
+
// have the same size, then the input buffer can hold the output.
|
|
33
|
+
if (is_donatable(in, out)) {
|
|
34
|
+
out.copy_shared_buffer(in);
|
|
35
|
+
return true;
|
|
36
|
+
} else {
|
|
37
|
+
out.set_data(
|
|
38
|
+
mallocfn(in.data_size() * out.itemsize()),
|
|
39
|
+
in.data_size(),
|
|
40
|
+
in.strides(),
|
|
41
|
+
in.flags());
|
|
42
|
+
return false;
|
|
43
|
+
}
|
|
44
|
+
} else {
|
|
45
|
+
out.set_data(mallocfn(out.nbytes()));
|
|
46
|
+
return false;
|
|
47
|
+
}
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
} // namespace mlx::core
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <map>
|
|
6
|
+
|
|
7
|
+
#include "mlx/utils.h"
|
|
8
|
+
|
|
9
|
+
namespace mlx::core {
|
|
10
|
+
|
|
11
|
+
// From http://neilsloane.com/hadamard/
|
|
12
|
+
constexpr std::string_view h12 = R"(
|
|
13
|
+
+-++++++++++
|
|
14
|
+
--+-+-+-+-+-
|
|
15
|
+
+++-++----++
|
|
16
|
+
+---+--+-++-
|
|
17
|
+
+++++-++----
|
|
18
|
+
+-+---+--+-+
|
|
19
|
+
++--+++-++--
|
|
20
|
+
+--++---+--+
|
|
21
|
+
++----+++-++
|
|
22
|
+
+--+-++---+-
|
|
23
|
+
++++----+++-
|
|
24
|
+
+-+--+-++---
|
|
25
|
+
)";
|
|
26
|
+
|
|
27
|
+
constexpr std::string_view h20 = R"(
|
|
28
|
+
+----+----++--++-++-
|
|
29
|
+
-+----+---+++---+-++
|
|
30
|
+
--+----+---+++-+-+-+
|
|
31
|
+
---+----+---+++++-+-
|
|
32
|
+
----+----++--++-++-+
|
|
33
|
+
-+++++-----+--+++--+
|
|
34
|
+
+-+++-+---+-+--+++--
|
|
35
|
+
++-++--+---+-+--+++-
|
|
36
|
+
+++-+---+---+-+--+++
|
|
37
|
+
++++-----++--+-+--++
|
|
38
|
+
--++-+-++-+-----++++
|
|
39
|
+
---++-+-++-+---+-+++
|
|
40
|
+
+---++-+-+--+--++-++
|
|
41
|
+
++---++-+----+-+++-+
|
|
42
|
+
-++---++-+----+++++-
|
|
43
|
+
-+--+--++-+----+----
|
|
44
|
+
+-+-----++-+----+---
|
|
45
|
+
-+-+-+---+--+----+--
|
|
46
|
+
--+-+++------+----+-
|
|
47
|
+
+--+--++------+----+
|
|
48
|
+
)";
|
|
49
|
+
|
|
50
|
+
constexpr std::string_view h28 = R"(
|
|
51
|
+
+------++----++-+--+-+--++--
|
|
52
|
+
-+-----+++-----+-+--+-+--++-
|
|
53
|
+
--+-----+++---+-+-+----+--++
|
|
54
|
+
---+-----+++---+-+-+-+--+--+
|
|
55
|
+
----+-----+++---+-+-+++--+--
|
|
56
|
+
-----+-----++++--+-+--++--+-
|
|
57
|
+
------++----++-+--+-+--++--+
|
|
58
|
+
--++++-+-------++--+++-+--+-
|
|
59
|
+
---++++-+-----+-++--+-+-+--+
|
|
60
|
+
+---+++--+----++-++--+-+-+--
|
|
61
|
+
++---++---+----++-++--+-+-+-
|
|
62
|
+
+++---+----+----++-++--+-+-+
|
|
63
|
+
++++--------+-+--++-++--+-+-
|
|
64
|
+
-++++--------+++--++--+--+-+
|
|
65
|
+
-+-++-++--++--+--------++++-
|
|
66
|
+
+-+-++--+--++--+--------++++
|
|
67
|
+
-+-+-++--+--++--+----+---+++
|
|
68
|
+
+-+-+-++--+--+---+---++---++
|
|
69
|
+
++-+-+-++--+------+--+++---+
|
|
70
|
+
-++-+-+-++--+------+-++++---
|
|
71
|
+
+-++-+---++--+------+-++++--
|
|
72
|
+
-++--++-+-++-+++----++------
|
|
73
|
+
+-++--++-+-++-+++-----+-----
|
|
74
|
+
++-++---+-+-++-+++-----+----
|
|
75
|
+
-++-++-+-+-+-+--+++-----+---
|
|
76
|
+
--++-++++-+-+----+++-----+--
|
|
77
|
+
+--++-+-++-+-+----+++-----+-
|
|
78
|
+
++--++-+-++-+-+----++------+
|
|
79
|
+
)";
|
|
80
|
+
|
|
81
|
+
inline const std::map<int, std::string_view> hadamard_matrices() {
|
|
82
|
+
return {{12, h12}, {20, h20}, {28, h28}};
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
inline std::pair<int, int> decompose_hadamard(int n) {
|
|
86
|
+
// n = m*2^k
|
|
87
|
+
int m = 1;
|
|
88
|
+
if (!is_power_of_2(n)) {
|
|
89
|
+
auto h_matrices = hadamard_matrices();
|
|
90
|
+
for (auto [factor, _] : h_matrices) {
|
|
91
|
+
if (n % factor == 0) {
|
|
92
|
+
m = factor;
|
|
93
|
+
n /= factor;
|
|
94
|
+
break;
|
|
95
|
+
}
|
|
96
|
+
}
|
|
97
|
+
if (m == 1) {
|
|
98
|
+
throw std::invalid_argument(
|
|
99
|
+
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
|
|
100
|
+
}
|
|
101
|
+
}
|
|
102
|
+
if (n > (1 << 26)) {
|
|
103
|
+
throw std::invalid_argument(
|
|
104
|
+
"[hadamard] Only supports n = m*2^k where k <= 26");
|
|
105
|
+
}
|
|
106
|
+
return {n, m};
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
} // namespace mlx::core
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
// Copyright © 2025 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/backend/common/utils.h"
|
|
6
|
+
#include "mlx/utils.h"
|
|
7
|
+
|
|
8
|
+
#include <sstream>
|
|
9
|
+
|
|
10
|
+
namespace mlx::core {
|
|
11
|
+
|
|
12
|
+
inline std::tuple<Shape, Strides, Strides> collapse_batches(
|
|
13
|
+
const array& a,
|
|
14
|
+
const array& b) {
|
|
15
|
+
if (a.ndim() == 2) {
|
|
16
|
+
return {Shape{1}, Strides{0}, Strides{0}};
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
|
20
|
+
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
|
|
21
|
+
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
|
|
22
|
+
|
|
23
|
+
auto [batch_shape, batch_strides] =
|
|
24
|
+
collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride});
|
|
25
|
+
|
|
26
|
+
auto a_batch_strides = batch_strides[0];
|
|
27
|
+
auto b_batch_strides = batch_strides[1];
|
|
28
|
+
|
|
29
|
+
if (batch_shape.empty()) {
|
|
30
|
+
batch_shape.push_back(1);
|
|
31
|
+
a_batch_strides.push_back(0);
|
|
32
|
+
b_batch_strides.push_back(0);
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
return std::make_tuple(batch_shape, a_batch_strides, b_batch_strides);
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
inline std::tuple<Shape, Strides, Strides, Strides>
|
|
39
|
+
collapse_batches(const array& a, const array& b, const array& c) {
|
|
40
|
+
if (a.ndim() == 2) {
|
|
41
|
+
return {Shape{1}, Strides{0}, Strides{0}, Strides{0}};
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
|
45
|
+
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
|
|
46
|
+
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
|
|
47
|
+
Strides C_bstride{c.strides().begin(), c.strides().end() - 2};
|
|
48
|
+
|
|
49
|
+
auto [batch_shape, batch_strides] = collapse_contiguous_dims(
|
|
50
|
+
A_bshape, std::vector{A_bstride, B_bstride, C_bstride});
|
|
51
|
+
|
|
52
|
+
auto A_batch_stride = batch_strides[0];
|
|
53
|
+
auto B_batch_stride = batch_strides[1];
|
|
54
|
+
auto C_batch_stride = batch_strides[2];
|
|
55
|
+
|
|
56
|
+
if (batch_shape.empty()) {
|
|
57
|
+
batch_shape.push_back(1);
|
|
58
|
+
A_batch_stride.push_back(0);
|
|
59
|
+
B_batch_stride.push_back(0);
|
|
60
|
+
C_batch_stride.push_back(0);
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
return std::make_tuple(
|
|
64
|
+
batch_shape, A_batch_stride, B_batch_stride, C_batch_stride);
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
} // namespace mlx::core
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
// Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/backend/common/utils.h"
|
|
6
|
+
|
|
7
|
+
namespace mlx::core {
|
|
8
|
+
|
|
9
|
+
enum ReductionOpType {
|
|
10
|
+
// Self-explanatory. Read everything and produce 1 output.
|
|
11
|
+
ContiguousAllReduce,
|
|
12
|
+
|
|
13
|
+
// The input is contiguous and the last axis is reduced
|
|
14
|
+
// N1xR1xN2xR2x...xNnxRn
|
|
15
|
+
ContiguousReduce,
|
|
16
|
+
|
|
17
|
+
// The input is contiguous and the last axis is not reduced
|
|
18
|
+
// R1xN1xR2xN2x...xRnxNn
|
|
19
|
+
ContiguousStridedReduce,
|
|
20
|
+
|
|
21
|
+
// The input is not contiguous but the last axis is and it is reduced so we
|
|
22
|
+
// need to figure out the offsets but we can call the contiguous reduce after
|
|
23
|
+
// that.
|
|
24
|
+
// N3xR1xN1xR4x...xRn
|
|
25
|
+
GeneralContiguousReduce,
|
|
26
|
+
|
|
27
|
+
// The input is not contiguous but the last reduction axis and the last axis
|
|
28
|
+
// are so we need to figure out the offset but we can call the strided reduce
|
|
29
|
+
// after that.
|
|
30
|
+
GeneralStridedReduce,
|
|
31
|
+
|
|
32
|
+
// The input is not contiguous after the reduction axis and it may contain
|
|
33
|
+
// 0-stride axes or transpositions. We could copy the strides and produce a
|
|
34
|
+
// transposed outcome or we can read the input out of order and write the
|
|
35
|
+
// output in order.
|
|
36
|
+
GeneralReduce
|
|
37
|
+
};
|
|
38
|
+
|
|
39
|
+
struct ReductionPlan {
|
|
40
|
+
ReductionOpType type;
|
|
41
|
+
Shape shape;
|
|
42
|
+
Strides strides;
|
|
43
|
+
|
|
44
|
+
ReductionPlan(ReductionOpType type_, Shape shape_, Strides strides_)
|
|
45
|
+
: type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
|
|
46
|
+
ReductionPlan(ReductionOpType type_) : type(type_) {}
|
|
47
|
+
};
|
|
48
|
+
|
|
49
|
+
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
|
|
50
|
+
|
|
51
|
+
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
|
52
|
+
const array& x,
|
|
53
|
+
const std::vector<int>& axes);
|
|
54
|
+
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
|
55
|
+
Shape shape,
|
|
56
|
+
Strides strides,
|
|
57
|
+
const std::vector<int>& axes);
|
|
58
|
+
|
|
59
|
+
} // namespace mlx::core
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/array.h"
|
|
6
|
+
|
|
7
|
+
namespace mlx::core {
|
|
8
|
+
|
|
9
|
+
std::tuple<int64_t, Strides> prepare_slice(
|
|
10
|
+
const array& in,
|
|
11
|
+
const Shape& start_indices,
|
|
12
|
+
const Shape& strides);
|
|
13
|
+
|
|
14
|
+
void slice(
|
|
15
|
+
const array& in,
|
|
16
|
+
array& out,
|
|
17
|
+
const Shape& start_indices,
|
|
18
|
+
const Shape& strides);
|
|
19
|
+
|
|
20
|
+
} // namespace mlx::core
|