mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
// Copyright © 2025 Apple Inc.
|
|
2
|
+
#pragma once
|
|
3
|
+
|
|
4
|
+
#include <cublasLt.h>
|
|
5
|
+
#include "mlx/array.h"
|
|
6
|
+
#include "mlx/backend/cuda/device.h"
|
|
7
|
+
#include "mlx/dtype_utils.h"
|
|
8
|
+
|
|
9
|
+
namespace mlx::core {
|
|
10
|
+
namespace cublas_utils {
|
|
11
|
+
|
|
12
|
+
// Get the shared cublas preference for a device
|
|
13
|
+
cublasLtMatmulPreference_t get_preference(cu::Device& device);
|
|
14
|
+
|
|
15
|
+
void* allocate_workspace(cu::CommandEncoder& encoder, size_t workspace_size);
|
|
16
|
+
|
|
17
|
+
cublasLtMatrixLayout_t create_matrix_layout(
|
|
18
|
+
cudaDataType_t type,
|
|
19
|
+
uint64_t rows,
|
|
20
|
+
uint64_t cols,
|
|
21
|
+
bool transposed,
|
|
22
|
+
int64_t ld,
|
|
23
|
+
int32_t batch_count,
|
|
24
|
+
int64_t batch_stride);
|
|
25
|
+
|
|
26
|
+
inline cudaDataType_t dtype_to_cublas_type(Dtype dtype, std::string_view tag) {
|
|
27
|
+
switch (dtype) {
|
|
28
|
+
case float16:
|
|
29
|
+
return CUDA_R_16F;
|
|
30
|
+
case bfloat16:
|
|
31
|
+
return CUDA_R_16BF;
|
|
32
|
+
case float32:
|
|
33
|
+
return CUDA_R_32F;
|
|
34
|
+
case float64:
|
|
35
|
+
return CUDA_R_64F;
|
|
36
|
+
case complex64:
|
|
37
|
+
return CUDA_C_32F;
|
|
38
|
+
default:
|
|
39
|
+
throw std::runtime_error(fmt::format(
|
|
40
|
+
"Unsupported dtype in {}: {}.", tag, dtype_to_string(dtype)));
|
|
41
|
+
}
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
} // namespace cublas_utils
|
|
45
|
+
|
|
46
|
+
class CublasMatmulBase {
|
|
47
|
+
public:
|
|
48
|
+
virtual ~CublasMatmulBase();
|
|
49
|
+
|
|
50
|
+
void set_bias(cu::CommandEncoder& encoder, const array& bias);
|
|
51
|
+
|
|
52
|
+
protected:
|
|
53
|
+
CublasMatmulBase() = default;
|
|
54
|
+
|
|
55
|
+
// Common member variables shared by all matmul types
|
|
56
|
+
uint64_t M_;
|
|
57
|
+
uint64_t N_;
|
|
58
|
+
cudaDataType_t scale_type_;
|
|
59
|
+
cublasLtMatmulPreference_t pref_{nullptr};
|
|
60
|
+
cublasLtHandle_t handle_{nullptr};
|
|
61
|
+
cublasLtMatmulDesc_t matmul_desc_{nullptr};
|
|
62
|
+
cublasLtMatrixLayout_t a_desc_{nullptr};
|
|
63
|
+
cublasLtMatrixLayout_t b_desc_{nullptr};
|
|
64
|
+
cublasLtMatrixLayout_t c_desc_{nullptr};
|
|
65
|
+
cublasLtMatrixLayout_t out_desc_{nullptr};
|
|
66
|
+
cublasLtMatmulHeuristicResult_t heuristic_;
|
|
67
|
+
|
|
68
|
+
void init_base(
|
|
69
|
+
cu::Device& device,
|
|
70
|
+
cudaDataType_t scale_type,
|
|
71
|
+
cublasComputeType_t compute_type,
|
|
72
|
+
cudaDataType_t data_type,
|
|
73
|
+
cudaDataType_t output_type,
|
|
74
|
+
bool a_transposed,
|
|
75
|
+
uint64_t a_rows,
|
|
76
|
+
uint64_t a_cols,
|
|
77
|
+
int64_t lda,
|
|
78
|
+
bool b_transposed,
|
|
79
|
+
uint64_t b_rows,
|
|
80
|
+
uint64_t b_cols,
|
|
81
|
+
int64_t ldb,
|
|
82
|
+
int32_t batch_count,
|
|
83
|
+
int64_t a_batch_stride,
|
|
84
|
+
int64_t b_batch_stride);
|
|
85
|
+
|
|
86
|
+
void execute_matmul(
|
|
87
|
+
cu::CommandEncoder& encoder,
|
|
88
|
+
void* out,
|
|
89
|
+
const void* a,
|
|
90
|
+
const void* b,
|
|
91
|
+
const void* c,
|
|
92
|
+
const void* alpha_ptr,
|
|
93
|
+
const void* beta_ptr);
|
|
94
|
+
};
|
|
95
|
+
|
|
96
|
+
} // namespace mlx::core
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
// Copyright © 2025 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <cublasLt.h>
|
|
6
|
+
#include <cuda.h>
|
|
7
|
+
#include <cuda_runtime.h>
|
|
8
|
+
#include <cudnn.h>
|
|
9
|
+
|
|
10
|
+
namespace mlx::core {
|
|
11
|
+
|
|
12
|
+
// Throw exception if the cuda API does not succeed.
|
|
13
|
+
void check_cublas_error(const char* name, cublasStatus_t err);
|
|
14
|
+
void check_cuda_error(const char* name, cudaError_t err);
|
|
15
|
+
void check_cuda_error(const char* name, CUresult err);
|
|
16
|
+
void check_cudnn_error(const char* name, cudnnStatus_t err);
|
|
17
|
+
|
|
18
|
+
// The macro version that prints the command that failed.
|
|
19
|
+
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
|
|
20
|
+
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
|
|
21
|
+
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
|
|
22
|
+
|
|
23
|
+
// Base class for RAII managed CUDA resources.
|
|
24
|
+
template <typename Handle, cudaError_t (*Destroy)(Handle)>
|
|
25
|
+
class CudaHandle {
|
|
26
|
+
public:
|
|
27
|
+
CudaHandle(Handle handle = nullptr) : handle_(handle) {}
|
|
28
|
+
|
|
29
|
+
CudaHandle(CudaHandle&& other) : handle_(other.handle_) {
|
|
30
|
+
assert(this != &other);
|
|
31
|
+
other.handle_ = nullptr;
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
~CudaHandle() {
|
|
35
|
+
// Skip if there was an error to avoid throwing in the destructors
|
|
36
|
+
if (cudaPeekAtLastError() != cudaSuccess) {
|
|
37
|
+
return;
|
|
38
|
+
}
|
|
39
|
+
reset();
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
CudaHandle(const CudaHandle&) = delete;
|
|
43
|
+
CudaHandle& operator=(const CudaHandle&) = delete;
|
|
44
|
+
|
|
45
|
+
CudaHandle& operator=(CudaHandle&& other) {
|
|
46
|
+
assert(this != &other);
|
|
47
|
+
reset();
|
|
48
|
+
std::swap(handle_, other.handle_);
|
|
49
|
+
return *this;
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
void reset() {
|
|
53
|
+
if (handle_ != nullptr) {
|
|
54
|
+
CHECK_CUDA_ERROR(Destroy(handle_));
|
|
55
|
+
handle_ = nullptr;
|
|
56
|
+
}
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
operator Handle() const {
|
|
60
|
+
return handle_;
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
protected:
|
|
64
|
+
Handle handle_;
|
|
65
|
+
};
|
|
66
|
+
|
|
67
|
+
namespace cu {
|
|
68
|
+
class Device;
|
|
69
|
+
}; // namespace cu
|
|
70
|
+
|
|
71
|
+
// Wrappers of CUDA resources.
|
|
72
|
+
class CudaGraph : public CudaHandle<cudaGraph_t, cudaGraphDestroy> {
|
|
73
|
+
public:
|
|
74
|
+
using CudaHandle::CudaHandle;
|
|
75
|
+
explicit CudaGraph(cu::Device& device);
|
|
76
|
+
void end_capture(cudaStream_t stream);
|
|
77
|
+
};
|
|
78
|
+
|
|
79
|
+
class CudaGraphExec : public CudaHandle<cudaGraphExec_t, cudaGraphExecDestroy> {
|
|
80
|
+
public:
|
|
81
|
+
void instantiate(cudaGraph_t graph);
|
|
82
|
+
};
|
|
83
|
+
|
|
84
|
+
class CudaStream : public CudaHandle<cudaStream_t, cudaStreamDestroy> {
|
|
85
|
+
public:
|
|
86
|
+
explicit CudaStream(cu::Device& device);
|
|
87
|
+
};
|
|
88
|
+
|
|
89
|
+
} // namespace mlx::core
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
// Copyright © 2025 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/backend/cuda/device/config.h"
|
|
6
|
+
#include "mlx/backend/cuda/utils.h"
|
|
7
|
+
#include "mlx/dtype_utils.h"
|
|
8
|
+
|
|
9
|
+
#include <cudnn_frontend.h>
|
|
10
|
+
#include <fmt/format.h>
|
|
11
|
+
|
|
12
|
+
namespace mlx::core {
|
|
13
|
+
|
|
14
|
+
namespace cu {
|
|
15
|
+
class CommandEncoder;
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
namespace fe = cudnn_frontend;
|
|
19
|
+
|
|
20
|
+
#define CHECK_CUDNN_FE_ERROR(cmd) \
|
|
21
|
+
do { \
|
|
22
|
+
auto error = cmd; \
|
|
23
|
+
if (!error.is_good()) { \
|
|
24
|
+
throw std::runtime_error( \
|
|
25
|
+
fmt::format("{} failed: {}.", #cmd, error.get_message())); \
|
|
26
|
+
} \
|
|
27
|
+
} while (0)
|
|
28
|
+
|
|
29
|
+
// Return pointer alignment of |x|'s data.
|
|
30
|
+
inline uint8_t get_alignment(const array& x) {
|
|
31
|
+
uint8_t alignment = 1;
|
|
32
|
+
uintptr_t address = reinterpret_cast<uintptr_t>(gpu_ptr<void>(x));
|
|
33
|
+
for (; alignment < 32; alignment *= 2) {
|
|
34
|
+
if (address % (alignment * 2)) {
|
|
35
|
+
return alignment;
|
|
36
|
+
}
|
|
37
|
+
}
|
|
38
|
+
return alignment;
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
// Convert the type of elements in |vec| to |T|.
|
|
42
|
+
template <typename T, typename Vec>
|
|
43
|
+
inline std::vector<T> convert_vector(const Vec& vec) {
|
|
44
|
+
return std::vector<T>(vec.begin(), vec.end());
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
// Map dtype to cudnn data type.
|
|
48
|
+
inline fe::DataType_t dtype_to_cudnn_type(Dtype dtype) {
|
|
49
|
+
switch (dtype) {
|
|
50
|
+
case int8:
|
|
51
|
+
return fe::DataType_t::INT8;
|
|
52
|
+
case int32:
|
|
53
|
+
return fe::DataType_t::INT32;
|
|
54
|
+
case uint8:
|
|
55
|
+
return fe::DataType_t::UINT8;
|
|
56
|
+
case float16:
|
|
57
|
+
return fe::DataType_t::HALF;
|
|
58
|
+
case bfloat16:
|
|
59
|
+
return fe::DataType_t::BFLOAT16;
|
|
60
|
+
case float32:
|
|
61
|
+
return fe::DataType_t::FLOAT;
|
|
62
|
+
case float64:
|
|
63
|
+
return fe::DataType_t::DOUBLE;
|
|
64
|
+
default:
|
|
65
|
+
throw std::runtime_error(fmt::format(
|
|
66
|
+
"Unsupported dtype in cuDNN: {}.", dtype_to_string(dtype)));
|
|
67
|
+
}
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
// Return an array that can be used as map key for |vec| with size <= MAX_NDIM.
|
|
71
|
+
//
|
|
72
|
+
// There are 2 differences from the const_param util from kernel_utils.cuh:
|
|
73
|
+
// 1. The rest of array is filled with 0.
|
|
74
|
+
// 2. This util can be used in .cpp files.
|
|
75
|
+
template <int NDIM = MAX_NDIM, typename T, template <typename U> class Vec>
|
|
76
|
+
inline std::array<T, NDIM> vector_key(const Vec<T>& vec) {
|
|
77
|
+
if (vec.size() > NDIM) {
|
|
78
|
+
throw std::runtime_error(
|
|
79
|
+
fmt::format("ndim can not be larger than {}.", NDIM));
|
|
80
|
+
}
|
|
81
|
+
std::array<T, NDIM> result = {};
|
|
82
|
+
std::copy_n(vec.begin(), vec.size(), result.begin());
|
|
83
|
+
return result;
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
// Extends cuDNN graph with helpers.
|
|
87
|
+
class DnnGraph : public fe::graph::Graph {
|
|
88
|
+
public:
|
|
89
|
+
DnnGraph(cudnnHandle_t handle, Dtype io_dtype, Dtype compute_dtype = float32)
|
|
90
|
+
: handle_(handle) {
|
|
91
|
+
set_io_data_type(dtype_to_cudnn_type(io_dtype));
|
|
92
|
+
set_intermediate_data_type(dtype_to_cudnn_type(compute_dtype));
|
|
93
|
+
set_compute_data_type(dtype_to_cudnn_type(compute_dtype));
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
// Create a cuDNN tensor description from MLX array |x|.
|
|
97
|
+
auto& tensor(
|
|
98
|
+
std::shared_ptr<fe::graph::Tensor_attributes>& attrs,
|
|
99
|
+
int64_t uid,
|
|
100
|
+
const array& x) {
|
|
101
|
+
set_tensor_attrs(attrs, uid, x);
|
|
102
|
+
return attrs;
|
|
103
|
+
}
|
|
104
|
+
auto tensor(const char* name, int64_t uid, const array& x) {
|
|
105
|
+
auto attrs = Graph::tensor(fe::graph::Tensor_attributes().set_name(name));
|
|
106
|
+
tensor(attrs, uid, x);
|
|
107
|
+
return attrs;
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
// Create a cuDNN tensor description from MLX array |x|, and transpose it from
|
|
111
|
+
// NHWC layout to NCHW.
|
|
112
|
+
auto& tensor_nchw(
|
|
113
|
+
std::shared_ptr<fe::graph::Tensor_attributes>& attrs,
|
|
114
|
+
int64_t uid,
|
|
115
|
+
const array& x) {
|
|
116
|
+
set_tensor_attrs_nchw(attrs, uid, x);
|
|
117
|
+
return attrs;
|
|
118
|
+
}
|
|
119
|
+
auto tensor_nchw(const char* name, int64_t uid, const array& x) {
|
|
120
|
+
auto attrs = Graph::tensor(fe::graph::Tensor_attributes().set_name(name));
|
|
121
|
+
tensor_nchw(attrs, uid, x);
|
|
122
|
+
return attrs;
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
// Create a cuDNN tensor for scalar.
|
|
126
|
+
auto scalar(const char* name, int64_t uid, Dtype dtype) {
|
|
127
|
+
return Graph::tensor(fe::graph::Tensor_attributes()
|
|
128
|
+
.set_name(name)
|
|
129
|
+
.set_uid(uid)
|
|
130
|
+
.set_dim({1, 1, 1, 1})
|
|
131
|
+
.set_stride({1, 1, 1, 1})
|
|
132
|
+
.set_is_pass_by_value(true)
|
|
133
|
+
.set_data_type(dtype_to_cudnn_type(dtype)));
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
// Call this before setting notes.
|
|
137
|
+
fe::error_t prepare();
|
|
138
|
+
// Call this after setting notes.
|
|
139
|
+
fe::error_t build();
|
|
140
|
+
|
|
141
|
+
// Add cuDNN graph to CUDA graph, using native CUDA graph API.
|
|
142
|
+
fe::error_t encode_graph(
|
|
143
|
+
cu::CommandEncoder& encoder,
|
|
144
|
+
std::unordered_map<int64_t, void*> variant_pack);
|
|
145
|
+
// Add cuDNN graph to CUDA graph, using stream capture.
|
|
146
|
+
fe::error_t encode_capturing(
|
|
147
|
+
cu::CommandEncoder& encoder,
|
|
148
|
+
std::unordered_map<int64_t, void*> variant_pack);
|
|
149
|
+
|
|
150
|
+
private:
|
|
151
|
+
void* prepare_workspace(cu::CommandEncoder& encoder);
|
|
152
|
+
|
|
153
|
+
void set_tensor_attrs(
|
|
154
|
+
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
|
|
155
|
+
int64_t uid,
|
|
156
|
+
const array& x,
|
|
157
|
+
const std::vector<int64_t>& shape,
|
|
158
|
+
const std::vector<int64_t>& strides);
|
|
159
|
+
void set_tensor_attrs(
|
|
160
|
+
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
|
|
161
|
+
int64_t uid,
|
|
162
|
+
const array& x);
|
|
163
|
+
void set_tensor_attrs_nchw(
|
|
164
|
+
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
|
|
165
|
+
int64_t uid,
|
|
166
|
+
const array& x);
|
|
167
|
+
|
|
168
|
+
cudnnHandle_t handle_;
|
|
169
|
+
};
|
|
170
|
+
|
|
171
|
+
} // namespace mlx::core
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
// Copyright © 2025 Apple Inc.
|
|
2
|
+
|
|
3
|
+
// This file is used by both CUDA kernel code and host-only C++ code.
|
|
4
|
+
|
|
5
|
+
#pragma once
|
|
6
|
+
|
|
7
|
+
// The maximum dimensions of shape/strides passed as kernel parameters.
|
|
8
|
+
#define MAX_NDIM 10
|
|
9
|
+
|
|
10
|
+
// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in
|
|
11
|
+
// warpSize variable exists, using it would prevent compile-time optimizations.
|
|
12
|
+
#define WARP_SIZE 32
|
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
// Copyright © 2025 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/array.h"
|
|
6
|
+
#include "mlx/backend/cuda/allocator.h"
|
|
7
|
+
#include "mlx/backend/cuda/lru_cache.h"
|
|
8
|
+
#include "mlx/backend/cuda/worker.h"
|
|
9
|
+
#include "mlx/stream.h"
|
|
10
|
+
|
|
11
|
+
#include <cublasLt.h>
|
|
12
|
+
#include <cuda.h>
|
|
13
|
+
#include <cudnn.h>
|
|
14
|
+
#include <thrust/execution_policy.h>
|
|
15
|
+
|
|
16
|
+
#include <unordered_map>
|
|
17
|
+
|
|
18
|
+
namespace mlx::core::cu {
|
|
19
|
+
|
|
20
|
+
class CommandEncoder {
|
|
21
|
+
public:
|
|
22
|
+
struct CaptureContext {
|
|
23
|
+
CaptureContext(CommandEncoder& enc);
|
|
24
|
+
~CaptureContext();
|
|
25
|
+
CudaGraph graph;
|
|
26
|
+
CommandEncoder& enc;
|
|
27
|
+
bool discard{false};
|
|
28
|
+
};
|
|
29
|
+
struct ConcurrentContext {
|
|
30
|
+
ConcurrentContext(CommandEncoder& enc);
|
|
31
|
+
~ConcurrentContext();
|
|
32
|
+
CommandEncoder& enc;
|
|
33
|
+
};
|
|
34
|
+
|
|
35
|
+
explicit CommandEncoder(Device& d);
|
|
36
|
+
|
|
37
|
+
CommandEncoder(const CommandEncoder&) = delete;
|
|
38
|
+
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
|
39
|
+
|
|
40
|
+
CaptureContext capture_context() {
|
|
41
|
+
return CaptureContext{*this};
|
|
42
|
+
}
|
|
43
|
+
ConcurrentContext concurrent_context() {
|
|
44
|
+
return ConcurrentContext{*this};
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
void set_input_array(const array& arr);
|
|
48
|
+
void set_output_array(const array& arr);
|
|
49
|
+
|
|
50
|
+
template <typename F, typename... Params>
|
|
51
|
+
void add_kernel_node(
|
|
52
|
+
F* func,
|
|
53
|
+
dim3 grid_dim,
|
|
54
|
+
dim3 block_dim,
|
|
55
|
+
uint32_t smem_bytes,
|
|
56
|
+
Params&&... params) {
|
|
57
|
+
constexpr size_t num = sizeof...(Params);
|
|
58
|
+
void* ptrs[num];
|
|
59
|
+
size_t i = 0;
|
|
60
|
+
([&](auto&& p) { ptrs[i++] = static_cast<void*>(&p); }(
|
|
61
|
+
std::forward<Params>(params)),
|
|
62
|
+
...);
|
|
63
|
+
add_kernel_node((void*)func, grid_dim, block_dim, smem_bytes, ptrs);
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
void add_kernel_node(
|
|
67
|
+
CUfunction func,
|
|
68
|
+
dim3 grid_dim,
|
|
69
|
+
dim3 block_dim,
|
|
70
|
+
uint32_t smem_bytes,
|
|
71
|
+
void** params);
|
|
72
|
+
|
|
73
|
+
void add_kernel_node(
|
|
74
|
+
void* func,
|
|
75
|
+
dim3 grid_dim,
|
|
76
|
+
dim3 block_dim,
|
|
77
|
+
uint32_t smem_bytes,
|
|
78
|
+
void** params);
|
|
79
|
+
|
|
80
|
+
void add_graph_node(cudaGraph_t child);
|
|
81
|
+
|
|
82
|
+
void add_temporary(const array& arr) {
|
|
83
|
+
temporaries_.push_back(arr.data_shared_ptr());
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
void add_completed_handler(std::function<void()> task);
|
|
87
|
+
bool needs_commit();
|
|
88
|
+
void commit();
|
|
89
|
+
|
|
90
|
+
Device& device() {
|
|
91
|
+
return device_;
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
CudaStream& stream() {
|
|
95
|
+
return stream_;
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
// Wait until kernels and completion handlers are finished
|
|
99
|
+
void synchronize();
|
|
100
|
+
|
|
101
|
+
private:
|
|
102
|
+
void add_kernel_node(const cudaKernelNodeParams& params);
|
|
103
|
+
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
|
|
104
|
+
|
|
105
|
+
struct GraphNode {
|
|
106
|
+
cudaGraphNode_t node;
|
|
107
|
+
// K = kernel
|
|
108
|
+
// E = empty
|
|
109
|
+
// () = subgraph (with metadata)
|
|
110
|
+
// Symbols ':', '-' are reserved as separators
|
|
111
|
+
std::string node_type;
|
|
112
|
+
std::string id;
|
|
113
|
+
};
|
|
114
|
+
|
|
115
|
+
void insert_graph_dependencies(GraphNode node);
|
|
116
|
+
void insert_graph_dependencies(std::vector<GraphNode> nodes);
|
|
117
|
+
|
|
118
|
+
Device& device_;
|
|
119
|
+
CudaStream stream_;
|
|
120
|
+
CudaGraph graph_;
|
|
121
|
+
Worker worker_;
|
|
122
|
+
char node_count_{0};
|
|
123
|
+
bool in_concurrent_{false};
|
|
124
|
+
std::vector<cudaGraphNode_t> from_nodes_;
|
|
125
|
+
std::vector<cudaGraphNode_t> to_nodes_;
|
|
126
|
+
std::string graph_nodes_key_;
|
|
127
|
+
std::string graph_deps_key_;
|
|
128
|
+
std::vector<GraphNode> concurrent_nodes_;
|
|
129
|
+
std::vector<std::shared_ptr<array::Data>> temporaries_;
|
|
130
|
+
LRUCache<std::string, CudaGraphExec> graph_cache_;
|
|
131
|
+
std::vector<std::uintptr_t> active_deps_;
|
|
132
|
+
std::vector<std::uintptr_t> active_outputs_;
|
|
133
|
+
std::unordered_map<std::uintptr_t, GraphNode> node_map_;
|
|
134
|
+
size_t bytes_in_graph_{0};
|
|
135
|
+
bool is_graph_updatable_{true};
|
|
136
|
+
int max_ops_per_graph_;
|
|
137
|
+
int max_mb_per_graph_;
|
|
138
|
+
};
|
|
139
|
+
|
|
140
|
+
class Device {
|
|
141
|
+
public:
|
|
142
|
+
explicit Device(int device);
|
|
143
|
+
~Device();
|
|
144
|
+
|
|
145
|
+
Device(const Device&) = delete;
|
|
146
|
+
Device& operator=(const Device&) = delete;
|
|
147
|
+
|
|
148
|
+
// Make this device the current cuda device, this method is thread-safe.
|
|
149
|
+
void make_current();
|
|
150
|
+
|
|
151
|
+
CommandEncoder& get_command_encoder(Stream s);
|
|
152
|
+
|
|
153
|
+
int cuda_device() const {
|
|
154
|
+
return device_;
|
|
155
|
+
}
|
|
156
|
+
int compute_capability_major() const {
|
|
157
|
+
return compute_capability_major_;
|
|
158
|
+
}
|
|
159
|
+
int compute_capability_minor() const {
|
|
160
|
+
return compute_capability_minor_;
|
|
161
|
+
}
|
|
162
|
+
cublasLtHandle_t lt_handle() const {
|
|
163
|
+
return lt_;
|
|
164
|
+
}
|
|
165
|
+
cudnnHandle_t cudnn_handle() const {
|
|
166
|
+
return cudnn_;
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
private:
|
|
170
|
+
int device_;
|
|
171
|
+
int compute_capability_major_;
|
|
172
|
+
int compute_capability_minor_;
|
|
173
|
+
std::string device_name_;
|
|
174
|
+
cublasLtHandle_t lt_;
|
|
175
|
+
cudnnHandle_t cudnn_;
|
|
176
|
+
std::unordered_map<int, CommandEncoder> encoders_;
|
|
177
|
+
};
|
|
178
|
+
|
|
179
|
+
Device& device(mlx::core::Device device);
|
|
180
|
+
CommandEncoder& get_command_encoder(Stream s);
|
|
181
|
+
|
|
182
|
+
// Return an execution policy that does not sync for result.
|
|
183
|
+
// Note that not all thrust APIs support async policy, confirm before using.
|
|
184
|
+
inline auto thrust_policy(cudaStream_t stream) {
|
|
185
|
+
// TODO: Connect thrust's custom allocator with mlx's allocator.
|
|
186
|
+
return thrust::cuda::par_nosync.on(stream);
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
} // namespace mlx::core::cu
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
// Copyright © 2025 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/allocator.h"
|
|
6
|
+
#include "mlx/backend/cuda/utils.h"
|
|
7
|
+
#include "mlx/stream.h"
|
|
8
|
+
|
|
9
|
+
#include <memory>
|
|
10
|
+
|
|
11
|
+
#include <cuda_runtime.h>
|
|
12
|
+
#include <cuda/atomic>
|
|
13
|
+
|
|
14
|
+
namespace mlx::core::cu {
|
|
15
|
+
|
|
16
|
+
class Device;
|
|
17
|
+
|
|
18
|
+
// RAII-managed move-only wrapper of cudaEvent_t.
|
|
19
|
+
struct CudaEventHandle : public CudaHandle<cudaEvent_t, cudaEventDestroy> {
|
|
20
|
+
CudaEventHandle(Device& d, int flags);
|
|
21
|
+
Device& device;
|
|
22
|
+
int flags;
|
|
23
|
+
};
|
|
24
|
+
|
|
25
|
+
// Wrapper of native cuda event. It can synchronize between GPU streams, or wait
|
|
26
|
+
// on GPU stream in CPU stream, but can not wait on CPU stream.
|
|
27
|
+
class CudaEvent {
|
|
28
|
+
public:
|
|
29
|
+
CudaEvent(Device& d, int flags);
|
|
30
|
+
~CudaEvent();
|
|
31
|
+
|
|
32
|
+
CudaEvent(CudaEvent&&) = default;
|
|
33
|
+
CudaEvent& operator=(CudaEvent&&) = default;
|
|
34
|
+
|
|
35
|
+
CudaEvent(const CudaEvent&) = delete;
|
|
36
|
+
CudaEvent& operator=(const CudaEvent&) = delete;
|
|
37
|
+
|
|
38
|
+
void wait();
|
|
39
|
+
void wait(cudaStream_t stream);
|
|
40
|
+
void record(cudaStream_t stream);
|
|
41
|
+
|
|
42
|
+
// Return whether the recorded kernels have completed. Note that this method
|
|
43
|
+
// returns true if record() has not been called.
|
|
44
|
+
bool completed() const;
|
|
45
|
+
|
|
46
|
+
// Internal: make sure event pool is initialized.
|
|
47
|
+
static void init_pool();
|
|
48
|
+
|
|
49
|
+
private:
|
|
50
|
+
CudaEventHandle event_;
|
|
51
|
+
};
|
|
52
|
+
|
|
53
|
+
// Event that can synchronize between CPU and GPU. It is much slower than
|
|
54
|
+
// CudaEvent so the latter should always be preferred when possible.
|
|
55
|
+
class AtomicEvent {
|
|
56
|
+
public:
|
|
57
|
+
using Atomic = cuda::atomic<uint64_t>;
|
|
58
|
+
|
|
59
|
+
AtomicEvent();
|
|
60
|
+
|
|
61
|
+
void wait(uint64_t value);
|
|
62
|
+
void wait(cudaStream_t stream, uint64_t value);
|
|
63
|
+
void wait(Stream s, uint64_t value);
|
|
64
|
+
void signal(uint64_t value);
|
|
65
|
+
void signal(cudaStream_t stream, uint64_t value);
|
|
66
|
+
void signal(Stream s, uint64_t value);
|
|
67
|
+
bool is_signaled(uint64_t value) const;
|
|
68
|
+
uint64_t value() const;
|
|
69
|
+
|
|
70
|
+
private:
|
|
71
|
+
Atomic* atomic() const {
|
|
72
|
+
return static_cast<AtomicEvent::Atomic*>(buf_->raw_ptr());
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
std::shared_ptr<allocator::Buffer> buf_;
|
|
76
|
+
};
|
|
77
|
+
|
|
78
|
+
} // namespace mlx::core::cu
|