mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
// Copyright © 2025 Apple Inc.
|
|
2
|
+
|
|
3
|
+
// This file include utilities that are used by C++ code (i.e. .cpp files).
|
|
4
|
+
|
|
5
|
+
#pragma once
|
|
6
|
+
|
|
7
|
+
#include "mlx/array.h"
|
|
8
|
+
#include "mlx/backend/cuda/allocator.h"
|
|
9
|
+
#include "mlx/backend/cuda/cuda_utils.h"
|
|
10
|
+
|
|
11
|
+
namespace mlx::core {
|
|
12
|
+
|
|
13
|
+
template <typename T>
|
|
14
|
+
inline uint max_occupancy_block_dim(T kernel) {
|
|
15
|
+
int _, block_dim;
|
|
16
|
+
if constexpr (std::is_same_v<T, CUfunction>) {
|
|
17
|
+
CHECK_CUDA_ERROR(
|
|
18
|
+
cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0));
|
|
19
|
+
} else {
|
|
20
|
+
CHECK_CUDA_ERROR(
|
|
21
|
+
cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel));
|
|
22
|
+
}
|
|
23
|
+
return block_dim;
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
template <typename T>
|
|
27
|
+
inline T* gpu_ptr(array& arr) {
|
|
28
|
+
return reinterpret_cast<T*>(
|
|
29
|
+
static_cast<char*>(
|
|
30
|
+
static_cast<cu::CudaBuffer*>(arr.buffer().ptr())->data) +
|
|
31
|
+
arr.offset());
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
// For const array, keep constness in pointer unless it is untyped.
|
|
35
|
+
template <typename T>
|
|
36
|
+
inline std::conditional_t<std::is_same_v<T, void>, void*, const T*> gpu_ptr(
|
|
37
|
+
const array& arr) {
|
|
38
|
+
return gpu_ptr<T>(const_cast<array&>(arr));
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
struct Dtype;
|
|
42
|
+
|
|
43
|
+
// Convert Dtype to CUDA C++ types.
|
|
44
|
+
const char* dtype_to_cuda_type(const Dtype& dtype);
|
|
45
|
+
|
|
46
|
+
} // namespace mlx::core
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
// Copyright © 2025 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/backend/cuda/event.h"
|
|
6
|
+
|
|
7
|
+
#include <condition_variable>
|
|
8
|
+
#include <functional>
|
|
9
|
+
#include <map>
|
|
10
|
+
#include <mutex>
|
|
11
|
+
#include <thread>
|
|
12
|
+
|
|
13
|
+
namespace mlx::core::cu {
|
|
14
|
+
|
|
15
|
+
// Run tasks in worker thread, synchronized with cuda stream.
|
|
16
|
+
class Worker {
|
|
17
|
+
public:
|
|
18
|
+
explicit Worker(Device& d);
|
|
19
|
+
~Worker();
|
|
20
|
+
|
|
21
|
+
Worker(const Worker&) = delete;
|
|
22
|
+
Worker& operator=(const Worker&) = delete;
|
|
23
|
+
|
|
24
|
+
// Add a pending |task| that will run when consumed or commited.
|
|
25
|
+
void add_task(std::function<void()> task);
|
|
26
|
+
|
|
27
|
+
// Inform worker thread to run current batches after kernels in |stream|
|
|
28
|
+
// finish running.
|
|
29
|
+
void commit(cudaStream_t stream);
|
|
30
|
+
|
|
31
|
+
private:
|
|
32
|
+
static void signal(void*);
|
|
33
|
+
|
|
34
|
+
void thread_fn();
|
|
35
|
+
std::mutex mtx_;
|
|
36
|
+
std::condition_variable cond_;
|
|
37
|
+
|
|
38
|
+
uint64_t committed_batch_{0};
|
|
39
|
+
uint64_t signaled_batch_{0};
|
|
40
|
+
|
|
41
|
+
// Cuda stream and event for signaling kernel completion.
|
|
42
|
+
CudaStream signal_stream_;
|
|
43
|
+
CudaEvent signal_event_;
|
|
44
|
+
|
|
45
|
+
bool stop_{false};
|
|
46
|
+
|
|
47
|
+
// Tasks are put in |pending_tasks_| first, and then moved to
|
|
48
|
+
// |worker_tasks_| when end_batch() is called.
|
|
49
|
+
using Tasks = std::vector<std::function<void()>>;
|
|
50
|
+
Tasks pending_tasks_;
|
|
51
|
+
std::map<uint64_t, Tasks> worker_tasks_;
|
|
52
|
+
std::thread worker_;
|
|
53
|
+
};
|
|
54
|
+
|
|
55
|
+
} // namespace mlx::core::cu
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
// Copyright © 2023-2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/backend/common/copy.h"
|
|
6
|
+
#include "mlx/stream.h"
|
|
7
|
+
|
|
8
|
+
#include <optional>
|
|
9
|
+
|
|
10
|
+
namespace mlx::core {
|
|
11
|
+
|
|
12
|
+
// Generic copy inplace
|
|
13
|
+
void copy_gpu_inplace(
|
|
14
|
+
const array& in,
|
|
15
|
+
array& out,
|
|
16
|
+
const Shape& data_shape,
|
|
17
|
+
const Strides& i_strides,
|
|
18
|
+
const Strides& o_strides,
|
|
19
|
+
int64_t i_offset,
|
|
20
|
+
int64_t o_offset,
|
|
21
|
+
CopyType ctype,
|
|
22
|
+
const Stream& s,
|
|
23
|
+
std::optional<array> dynamic_i_offset = std::nullopt,
|
|
24
|
+
std::optional<array> dynamic_o_offset = std::nullopt);
|
|
25
|
+
|
|
26
|
+
void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s);
|
|
27
|
+
void copy_gpu(const array& src, array& out, CopyType ctype);
|
|
28
|
+
|
|
29
|
+
void copy_gpu_inplace(
|
|
30
|
+
const array& in,
|
|
31
|
+
array& out,
|
|
32
|
+
CopyType ctype,
|
|
33
|
+
const Stream& s);
|
|
34
|
+
|
|
35
|
+
void copy_gpu_inplace(
|
|
36
|
+
const array& in,
|
|
37
|
+
array& out,
|
|
38
|
+
const Strides& i_strides,
|
|
39
|
+
int64_t i_offset,
|
|
40
|
+
CopyType ctype,
|
|
41
|
+
const Stream& s);
|
|
42
|
+
|
|
43
|
+
// Fill the output with the scalar val
|
|
44
|
+
void fill_gpu(const array& val, array& out, const Stream& s);
|
|
45
|
+
|
|
46
|
+
// Return a contiguous array with same shape that copies the data of |arr|.
|
|
47
|
+
array contiguous_copy_gpu(const array& arr, const Stream& s);
|
|
48
|
+
|
|
49
|
+
// Copy data from |in| and transpose to |out|'s shape.
|
|
50
|
+
void reshape_gpu(const array& in, array& out, Stream s);
|
|
51
|
+
|
|
52
|
+
// Like the normal ops but safe to call in eval_gpu.
|
|
53
|
+
array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s);
|
|
54
|
+
array reshape_in_eval(const array& x, Shape shape, Stream s);
|
|
55
|
+
array swapaxes_in_eval(const array& x, int axis1, int axis2);
|
|
56
|
+
|
|
57
|
+
} // namespace mlx::core
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
// Copyright © 2023-2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <future>
|
|
6
|
+
#include <memory>
|
|
7
|
+
|
|
8
|
+
#include "mlx/array.h"
|
|
9
|
+
#include "mlx/stream.h"
|
|
10
|
+
|
|
11
|
+
namespace mlx::core::gpu {
|
|
12
|
+
|
|
13
|
+
void new_stream(Stream stream);
|
|
14
|
+
void eval(array& arr);
|
|
15
|
+
void finalize(Stream s);
|
|
16
|
+
void synchronize(Stream s);
|
|
17
|
+
|
|
18
|
+
} // namespace mlx::core::gpu
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/array.h"
|
|
6
|
+
|
|
7
|
+
namespace mlx::core {
|
|
8
|
+
|
|
9
|
+
void slice_gpu(
|
|
10
|
+
const array& in,
|
|
11
|
+
array& out,
|
|
12
|
+
const Shape& start_indices,
|
|
13
|
+
const Shape& strides,
|
|
14
|
+
const Stream& s);
|
|
15
|
+
|
|
16
|
+
void concatenate_gpu(
|
|
17
|
+
const std::vector<array>& inputs,
|
|
18
|
+
array& out,
|
|
19
|
+
int axis,
|
|
20
|
+
const Stream& s);
|
|
21
|
+
|
|
22
|
+
void pad_gpu(
|
|
23
|
+
const array& in,
|
|
24
|
+
const array& val,
|
|
25
|
+
array& out,
|
|
26
|
+
const std::vector<int>& axes,
|
|
27
|
+
const Shape& low_pad_size,
|
|
28
|
+
const Stream& s);
|
|
29
|
+
|
|
30
|
+
array compute_dynamic_offset(
|
|
31
|
+
const array& indices,
|
|
32
|
+
const Strides& strides,
|
|
33
|
+
const std::vector<int>& axes,
|
|
34
|
+
const Stream& s);
|
|
35
|
+
|
|
36
|
+
} // namespace mlx::core
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
// Copyright © 2023-2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <map>
|
|
6
|
+
#include <mutex>
|
|
7
|
+
#include <vector>
|
|
8
|
+
|
|
9
|
+
#include "mlx/allocator.h"
|
|
10
|
+
#include "mlx/backend/common/buffer_cache.h"
|
|
11
|
+
#include "mlx/backend/metal/device.h"
|
|
12
|
+
#include "mlx/backend/metal/resident.h"
|
|
13
|
+
|
|
14
|
+
namespace mlx::core::metal {
|
|
15
|
+
|
|
16
|
+
using allocator::Buffer;
|
|
17
|
+
|
|
18
|
+
class MetalAllocator : public allocator::Allocator {
|
|
19
|
+
/** Allocator for Metal GPUs. */
|
|
20
|
+
public:
|
|
21
|
+
virtual Buffer malloc(size_t size) override;
|
|
22
|
+
virtual void free(Buffer buffer) override;
|
|
23
|
+
virtual size_t size(Buffer buffer) const override;
|
|
24
|
+
virtual Buffer make_buffer(void* ptr, size_t size) override;
|
|
25
|
+
virtual void release(Buffer buffer) override;
|
|
26
|
+
|
|
27
|
+
size_t get_active_memory() {
|
|
28
|
+
return active_memory_;
|
|
29
|
+
};
|
|
30
|
+
size_t get_peak_memory() {
|
|
31
|
+
return peak_memory_;
|
|
32
|
+
};
|
|
33
|
+
void reset_peak_memory() {
|
|
34
|
+
std::unique_lock lk(mutex_);
|
|
35
|
+
peak_memory_ = 0;
|
|
36
|
+
};
|
|
37
|
+
size_t get_cache_memory() {
|
|
38
|
+
return buffer_cache_.cache_size();
|
|
39
|
+
};
|
|
40
|
+
size_t set_cache_limit(size_t limit);
|
|
41
|
+
size_t set_memory_limit(size_t limit);
|
|
42
|
+
size_t get_memory_limit();
|
|
43
|
+
size_t set_wired_limit(size_t limit);
|
|
44
|
+
void clear_cache();
|
|
45
|
+
|
|
46
|
+
private:
|
|
47
|
+
MTL::Device* device_;
|
|
48
|
+
|
|
49
|
+
// The size of allocations which go on the heap until it is full. This size
|
|
50
|
+
// is chosen because it is the actual minimum size of a buffer allocated from
|
|
51
|
+
// the heap, a heap can have at most heap.size() / 256 buffers.
|
|
52
|
+
static constexpr int small_size_ = 256;
|
|
53
|
+
static constexpr int heap_size_ = 1 << 20;
|
|
54
|
+
MTL::Heap* heap_;
|
|
55
|
+
MetalAllocator();
|
|
56
|
+
~MetalAllocator();
|
|
57
|
+
friend MetalAllocator& allocator();
|
|
58
|
+
|
|
59
|
+
// Caching allocator
|
|
60
|
+
BufferCache<MTL::Buffer> buffer_cache_;
|
|
61
|
+
|
|
62
|
+
ResidencySet residency_set_;
|
|
63
|
+
|
|
64
|
+
// Allocation stats
|
|
65
|
+
size_t block_limit_;
|
|
66
|
+
size_t gc_limit_;
|
|
67
|
+
size_t active_memory_{0};
|
|
68
|
+
size_t peak_memory_{0};
|
|
69
|
+
size_t max_pool_size_;
|
|
70
|
+
size_t wired_limit_{0};
|
|
71
|
+
size_t num_resources_{0};
|
|
72
|
+
size_t resource_limit_{0};
|
|
73
|
+
|
|
74
|
+
std::mutex mutex_;
|
|
75
|
+
};
|
|
76
|
+
|
|
77
|
+
MetalAllocator& allocator();
|
|
78
|
+
|
|
79
|
+
} // namespace mlx::core::metal
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/array.h"
|
|
6
|
+
|
|
7
|
+
namespace mlx::core {
|
|
8
|
+
|
|
9
|
+
void binary_op_gpu(
|
|
10
|
+
const std::vector<array>& inputs,
|
|
11
|
+
std::vector<array>& outputs,
|
|
12
|
+
const char* op,
|
|
13
|
+
const Stream& s);
|
|
14
|
+
|
|
15
|
+
void binary_op_gpu(
|
|
16
|
+
const std::vector<array>& inputs,
|
|
17
|
+
array& out,
|
|
18
|
+
const char* op,
|
|
19
|
+
const Stream& s);
|
|
20
|
+
|
|
21
|
+
void binary_op_gpu_inplace(
|
|
22
|
+
const std::vector<array>& inputs,
|
|
23
|
+
std::vector<array>& outputs,
|
|
24
|
+
const char* op,
|
|
25
|
+
const Stream& s);
|
|
26
|
+
|
|
27
|
+
void binary_op_gpu_inplace(
|
|
28
|
+
const std::vector<array>& inputs,
|
|
29
|
+
array& out,
|
|
30
|
+
const char* op,
|
|
31
|
+
const Stream& s);
|
|
32
|
+
|
|
33
|
+
} // namespace mlx::core
|
|
@@ -0,0 +1,283 @@
|
|
|
1
|
+
// Copyright © 2023-2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <Metal/Metal.hpp>
|
|
6
|
+
#include <functional>
|
|
7
|
+
#include <mutex>
|
|
8
|
+
#include <shared_mutex>
|
|
9
|
+
#include <string>
|
|
10
|
+
#include <unordered_map>
|
|
11
|
+
#include <unordered_set>
|
|
12
|
+
|
|
13
|
+
#include "mlx/array.h"
|
|
14
|
+
#include "mlx/device.h"
|
|
15
|
+
|
|
16
|
+
namespace mlx::core::metal {
|
|
17
|
+
|
|
18
|
+
using MTLFCList =
|
|
19
|
+
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
|
|
20
|
+
|
|
21
|
+
struct DeviceStream;
|
|
22
|
+
|
|
23
|
+
struct CommandEncoder {
|
|
24
|
+
explicit CommandEncoder(DeviceStream& stream);
|
|
25
|
+
CommandEncoder(const CommandEncoder&) = delete;
|
|
26
|
+
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
|
27
|
+
|
|
28
|
+
struct ConcurrentContext {
|
|
29
|
+
ConcurrentContext(CommandEncoder& enc) : enc(enc) {
|
|
30
|
+
enc.concurrent_ = true;
|
|
31
|
+
}
|
|
32
|
+
~ConcurrentContext() {
|
|
33
|
+
enc.concurrent_ = false;
|
|
34
|
+
enc.prev_outputs_.insert(
|
|
35
|
+
enc.concurrent_outputs_.begin(), enc.concurrent_outputs_.end());
|
|
36
|
+
enc.concurrent_outputs_.clear();
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
private:
|
|
40
|
+
CommandEncoder& enc;
|
|
41
|
+
};
|
|
42
|
+
|
|
43
|
+
void set_input_array(const array& a, int idx, int64_t offset = 0);
|
|
44
|
+
void set_output_array(array& a, int idx, int64_t offset = 0);
|
|
45
|
+
void register_output_array(const array& a);
|
|
46
|
+
void dispatch_threadgroups(MTL::Size grid_dims, MTL::Size group_dims);
|
|
47
|
+
void dispatch_threads(MTL::Size grid_dims, MTL::Size group_dims);
|
|
48
|
+
void maybeInsertBarrier();
|
|
49
|
+
void set_buffer(const MTL::Buffer* buf, int idx, int64_t offset = 0);
|
|
50
|
+
|
|
51
|
+
void set_compute_pipeline_state(MTL::ComputePipelineState* kernel) {
|
|
52
|
+
enc_->setComputePipelineState(kernel);
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
void wait_for_fence(MTL::Fence* fence) {
|
|
56
|
+
enc_->waitForFence(fence);
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
void update_fence(MTL::Fence* fence) {
|
|
60
|
+
enc_->updateFence(fence);
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
template <typename Vec, typename = std::enable_if_t<is_vector_v<Vec>>>
|
|
64
|
+
void set_vector_bytes(const Vec& vec, size_t nelems, int idx) {
|
|
65
|
+
enc_->setBytes(vec.data(), nelems * sizeof(typename Vec::value_type), idx);
|
|
66
|
+
}
|
|
67
|
+
template <typename Vec, typename = std::enable_if_t<is_vector_v<Vec>>>
|
|
68
|
+
void set_vector_bytes(const Vec& vec, int idx) {
|
|
69
|
+
return set_vector_bytes(vec, vec.size(), idx);
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
template <typename T>
|
|
73
|
+
void set_bytes(const T* v, int n, int idx) {
|
|
74
|
+
return enc_->setBytes(v, n * sizeof(T), idx);
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
template <typename T>
|
|
78
|
+
void set_bytes(const T& v, int idx) {
|
|
79
|
+
return enc_->setBytes(&v, sizeof(T), idx);
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
void set_threadgroup_memory_length(size_t length, int idx) {
|
|
83
|
+
enc_->setThreadgroupMemoryLength(length, idx);
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
ConcurrentContext start_concurrent() {
|
|
87
|
+
return ConcurrentContext(*this);
|
|
88
|
+
}
|
|
89
|
+
~CommandEncoder();
|
|
90
|
+
|
|
91
|
+
// Inputs to all kernels in the encoder including temporaries
|
|
92
|
+
std::unordered_set<const void*>& inputs() {
|
|
93
|
+
return all_inputs_;
|
|
94
|
+
};
|
|
95
|
+
|
|
96
|
+
// Outputs of all kernels in the encoder including temporaries
|
|
97
|
+
std::unordered_set<const void*>& outputs() {
|
|
98
|
+
return all_outputs_;
|
|
99
|
+
};
|
|
100
|
+
|
|
101
|
+
void barrier();
|
|
102
|
+
|
|
103
|
+
private:
|
|
104
|
+
DeviceStream& stream_;
|
|
105
|
+
MTL::ComputeCommandEncoder* enc_;
|
|
106
|
+
bool needs_barrier_{false};
|
|
107
|
+
bool concurrent_{false};
|
|
108
|
+
std::unordered_set<MTL::Resource*> prev_outputs_;
|
|
109
|
+
std::unordered_set<MTL::Resource*> next_outputs_;
|
|
110
|
+
std::unordered_set<MTL::Resource*> concurrent_outputs_;
|
|
111
|
+
std::unordered_set<const void*> all_inputs_;
|
|
112
|
+
std::unordered_set<const void*> all_outputs_;
|
|
113
|
+
};
|
|
114
|
+
|
|
115
|
+
struct Fence {
|
|
116
|
+
Fence(MTL::Fence* fence) : fence(fence) {}
|
|
117
|
+
~Fence() {
|
|
118
|
+
fence->release();
|
|
119
|
+
}
|
|
120
|
+
MTL::Fence* fence;
|
|
121
|
+
};
|
|
122
|
+
|
|
123
|
+
struct DeviceStream {
|
|
124
|
+
DeviceStream(MTL::CommandQueue* queue) : queue(queue) {};
|
|
125
|
+
~DeviceStream() {
|
|
126
|
+
queue->release();
|
|
127
|
+
if (buffer != nullptr) {
|
|
128
|
+
buffer->release();
|
|
129
|
+
}
|
|
130
|
+
};
|
|
131
|
+
MTL::CommandQueue* queue;
|
|
132
|
+
// A map of prior command encoder outputs to their corresponding fence
|
|
133
|
+
std::unordered_map<const void*, std::shared_ptr<Fence>> outputs;
|
|
134
|
+
// Used to allow thread-safe access to the outputs map
|
|
135
|
+
std::mutex fence_mtx;
|
|
136
|
+
|
|
137
|
+
// Data updated between command buffers
|
|
138
|
+
MTL::CommandBuffer* buffer{nullptr};
|
|
139
|
+
int buffer_ops{0};
|
|
140
|
+
size_t buffer_sizes{0};
|
|
141
|
+
|
|
142
|
+
// The command encoder, fence, and temporaries are updated between command
|
|
143
|
+
// encoders
|
|
144
|
+
std::unique_ptr<CommandEncoder> encoder{nullptr};
|
|
145
|
+
std::shared_ptr<Fence> fence;
|
|
146
|
+
std::vector<array> temporaries;
|
|
147
|
+
};
|
|
148
|
+
|
|
149
|
+
class Device {
|
|
150
|
+
public:
|
|
151
|
+
Device();
|
|
152
|
+
Device(const Device&) = delete;
|
|
153
|
+
Device& operator=(const Device&) = delete;
|
|
154
|
+
~Device();
|
|
155
|
+
|
|
156
|
+
MTL::Device* mtl_device() {
|
|
157
|
+
return device_;
|
|
158
|
+
};
|
|
159
|
+
|
|
160
|
+
const std::string& get_architecture() {
|
|
161
|
+
return arch_;
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
int get_architecture_gen() const {
|
|
165
|
+
return arch_gen_;
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
void new_queue(int index);
|
|
169
|
+
|
|
170
|
+
MTL::CommandQueue* get_queue(Stream stream);
|
|
171
|
+
|
|
172
|
+
MTL::CommandBuffer* get_command_buffer(int index);
|
|
173
|
+
bool command_buffer_needs_commit(int index);
|
|
174
|
+
void commit_command_buffer(int index);
|
|
175
|
+
CommandEncoder& get_command_encoder(int index);
|
|
176
|
+
void end_encoding(int index);
|
|
177
|
+
|
|
178
|
+
MTL::Library* get_library(
|
|
179
|
+
const std::string& name,
|
|
180
|
+
const std::string& path = "");
|
|
181
|
+
|
|
182
|
+
MTL::Library* get_library(
|
|
183
|
+
const std::string& name,
|
|
184
|
+
const std::function<std::string(void)>& builder);
|
|
185
|
+
|
|
186
|
+
void clear_library(const std::string& name);
|
|
187
|
+
|
|
188
|
+
MTL::ComputePipelineState* get_kernel(
|
|
189
|
+
const std::string& base_name,
|
|
190
|
+
MTL::Library* mtl_lib,
|
|
191
|
+
const std::string& hash_name = "",
|
|
192
|
+
const MTLFCList& func_consts = {},
|
|
193
|
+
const std::vector<MTL::Function*>& linked_functions = {});
|
|
194
|
+
|
|
195
|
+
MTL::ComputePipelineState* get_kernel(
|
|
196
|
+
const std::string& base_name,
|
|
197
|
+
const std::string& hash_name = "",
|
|
198
|
+
const MTLFCList& func_consts = {},
|
|
199
|
+
const std::vector<MTL::Function*>& linked_functions = {});
|
|
200
|
+
|
|
201
|
+
MTL::ArgumentEncoder* argument_encoder(
|
|
202
|
+
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
|
|
203
|
+
|
|
204
|
+
// Record temporary arrays for the given stream index
|
|
205
|
+
void add_temporary(array arr, int index);
|
|
206
|
+
void add_temporaries(std::vector<array> arrays, int index);
|
|
207
|
+
|
|
208
|
+
void set_residency_set(const MTL::ResidencySet* residency_set);
|
|
209
|
+
|
|
210
|
+
private:
|
|
211
|
+
DeviceStream& get_stream_(int index) {
|
|
212
|
+
return stream_map_.find(index)->second;
|
|
213
|
+
}
|
|
214
|
+
MTL::Library* get_library_cache_(const std::string& name);
|
|
215
|
+
|
|
216
|
+
MTL::Library* get_library_(const std::string& name);
|
|
217
|
+
MTL::Library* build_library_(const std::string& source_string);
|
|
218
|
+
|
|
219
|
+
MTL::Function* get_function_(const std::string& name, MTL::Library* mtl_lib);
|
|
220
|
+
|
|
221
|
+
MTL::Function* get_function_(
|
|
222
|
+
const std::string& name,
|
|
223
|
+
const std::string& specialized_name,
|
|
224
|
+
const MTLFCList& func_consts,
|
|
225
|
+
MTL::Library* mtl_lib);
|
|
226
|
+
|
|
227
|
+
MTL::LinkedFunctions* get_linked_functions_(
|
|
228
|
+
const std::vector<MTL::Function*>& funcs);
|
|
229
|
+
|
|
230
|
+
MTL::ComputePipelineState* get_kernel_(
|
|
231
|
+
const std::string& name,
|
|
232
|
+
const MTL::Function* mtl_function);
|
|
233
|
+
|
|
234
|
+
MTL::ComputePipelineState* get_kernel_(
|
|
235
|
+
const std::string& name,
|
|
236
|
+
const MTL::Function* mtl_function,
|
|
237
|
+
const MTL::LinkedFunctions* linked_functions);
|
|
238
|
+
|
|
239
|
+
MTL::ComputePipelineState* get_kernel_(
|
|
240
|
+
const std::string& base_name,
|
|
241
|
+
MTL::Library* mtl_lib,
|
|
242
|
+
const std::string& hash_name,
|
|
243
|
+
const MTLFCList& func_consts = {},
|
|
244
|
+
const std::vector<MTL::Function*>& linked_functions = {});
|
|
245
|
+
|
|
246
|
+
MTL::Device* device_;
|
|
247
|
+
std::unordered_map<int32_t, DeviceStream> stream_map_;
|
|
248
|
+
|
|
249
|
+
std::shared_mutex kernel_mtx_;
|
|
250
|
+
std::shared_mutex library_mtx_;
|
|
251
|
+
std::unordered_map<std::string, MTL::Library*> library_map_;
|
|
252
|
+
MTL::Library* default_library_;
|
|
253
|
+
std::unordered_map<
|
|
254
|
+
MTL::Library*,
|
|
255
|
+
std::unordered_map<std::string, MTL::ComputePipelineState*>>
|
|
256
|
+
library_kernels_;
|
|
257
|
+
const MTL::ResidencySet* residency_set_{nullptr};
|
|
258
|
+
std::string arch_;
|
|
259
|
+
int arch_gen_;
|
|
260
|
+
int max_ops_per_buffer_;
|
|
261
|
+
int max_mb_per_buffer_;
|
|
262
|
+
};
|
|
263
|
+
|
|
264
|
+
Device& device(mlx::core::Device);
|
|
265
|
+
|
|
266
|
+
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
|
|
267
|
+
|
|
268
|
+
inline bool is_nax_available() {
|
|
269
|
+
auto _check_nax = []() {
|
|
270
|
+
bool can_use_nax = false;
|
|
271
|
+
if (__builtin_available(
|
|
272
|
+
macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
|
273
|
+
can_use_nax = true;
|
|
274
|
+
}
|
|
275
|
+
can_use_nax &=
|
|
276
|
+
metal::device(mlx::core::Device::gpu).get_architecture_gen() >= 17;
|
|
277
|
+
return can_use_nax;
|
|
278
|
+
};
|
|
279
|
+
static bool is_nax_available_ = _check_nax();
|
|
280
|
+
return is_nax_available_;
|
|
281
|
+
}
|
|
282
|
+
|
|
283
|
+
} // namespace mlx::core::metal
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
// Copyright © 2023-2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
namespace mlx::core::metal {
|
|
6
|
+
|
|
7
|
+
const char* utils();
|
|
8
|
+
const char* binary_ops();
|
|
9
|
+
const char* unary_ops();
|
|
10
|
+
const char* ternary_ops();
|
|
11
|
+
const char* reduce_utils();
|
|
12
|
+
const char* gather();
|
|
13
|
+
const char* scatter();
|
|
14
|
+
const char* masked_scatter();
|
|
15
|
+
|
|
16
|
+
const char* arange();
|
|
17
|
+
const char* unary();
|
|
18
|
+
const char* binary();
|
|
19
|
+
const char* binary_two();
|
|
20
|
+
const char* copy();
|
|
21
|
+
const char* fft();
|
|
22
|
+
const char* gather_axis();
|
|
23
|
+
const char* gather_front();
|
|
24
|
+
const char* hadamard();
|
|
25
|
+
const char* logsumexp();
|
|
26
|
+
const char* quantized_utils();
|
|
27
|
+
const char* quantized();
|
|
28
|
+
const char* fp_quantized();
|
|
29
|
+
const char* ternary();
|
|
30
|
+
const char* scan();
|
|
31
|
+
const char* scatter_axis();
|
|
32
|
+
const char* softmax();
|
|
33
|
+
const char* sort();
|
|
34
|
+
const char* reduce();
|
|
35
|
+
|
|
36
|
+
const char* gemm();
|
|
37
|
+
const char* steel_gemm_fused();
|
|
38
|
+
const char* steel_gemm_masked();
|
|
39
|
+
const char* steel_gemm_splitk();
|
|
40
|
+
const char* steel_gemm_gather();
|
|
41
|
+
const char* steel_gemm_segmented();
|
|
42
|
+
const char* conv();
|
|
43
|
+
const char* steel_conv();
|
|
44
|
+
const char* steel_conv_general();
|
|
45
|
+
const char* gemv_masked();
|
|
46
|
+
const char* steel_attention();
|
|
47
|
+
|
|
48
|
+
const char* gemm_nax();
|
|
49
|
+
const char* steel_gemm_fused_nax();
|
|
50
|
+
const char* steel_gemm_gather_nax();
|
|
51
|
+
|
|
52
|
+
const char* quantized_nax();
|
|
53
|
+
const char* fp_quantized_nax();
|
|
54
|
+
|
|
55
|
+
const char* steel_attention_nax();
|
|
56
|
+
|
|
57
|
+
} // namespace mlx::core::metal
|