mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,444 @@
|
|
|
1
|
+
// Copyright © 2023-2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <metal_math>
|
|
6
|
+
|
|
7
|
+
#include "mlx/backend/metal/kernels/bf16.h"
|
|
8
|
+
#include "mlx/backend/metal/kernels/bf16_math.h"
|
|
9
|
+
#include "mlx/backend/metal/kernels/complex.h"
|
|
10
|
+
#include "mlx/backend/metal/kernels/defines.h"
|
|
11
|
+
|
|
12
|
+
typedef half float16_t;
|
|
13
|
+
|
|
14
|
+
// Work per thread values for different types. The values here are expected to
|
|
15
|
+
// match get_work_per_thread in mlx/backend/metal/utils.h
|
|
16
|
+
template <typename U>
|
|
17
|
+
struct WorkPerThread {
|
|
18
|
+
static_assert(sizeof(U) <= 8, "Type too large");
|
|
19
|
+
static constexpr int constant n = 8 / sizeof(U);
|
|
20
|
+
};
|
|
21
|
+
|
|
22
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
23
|
+
// Type limits utils
|
|
24
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
25
|
+
|
|
26
|
+
template <typename U>
|
|
27
|
+
struct Limits {
|
|
28
|
+
static const constant U max = metal::numeric_limits<U>::max();
|
|
29
|
+
static const constant U min = metal::numeric_limits<U>::min();
|
|
30
|
+
static const constant U finite_max = metal::numeric_limits<U>::max();
|
|
31
|
+
static const constant U finite_min = metal::numeric_limits<U>::min();
|
|
32
|
+
};
|
|
33
|
+
|
|
34
|
+
#define instantiate_default_limit(type) \
|
|
35
|
+
template <> \
|
|
36
|
+
struct Limits<type> { \
|
|
37
|
+
static constexpr constant type max = metal::numeric_limits<type>::max(); \
|
|
38
|
+
static constexpr constant type min = metal::numeric_limits<type>::min(); \
|
|
39
|
+
static constexpr constant type finite_max = \
|
|
40
|
+
metal::numeric_limits<type>::max(); \
|
|
41
|
+
static constexpr constant type finite_min = \
|
|
42
|
+
metal::numeric_limits<type>::min(); \
|
|
43
|
+
};
|
|
44
|
+
|
|
45
|
+
instantiate_default_limit(uint8_t);
|
|
46
|
+
instantiate_default_limit(uint16_t);
|
|
47
|
+
instantiate_default_limit(uint32_t);
|
|
48
|
+
instantiate_default_limit(uint64_t);
|
|
49
|
+
instantiate_default_limit(int8_t);
|
|
50
|
+
instantiate_default_limit(int16_t);
|
|
51
|
+
instantiate_default_limit(int32_t);
|
|
52
|
+
instantiate_default_limit(int64_t);
|
|
53
|
+
|
|
54
|
+
#define instantiate_float_limit(type) \
|
|
55
|
+
template <> \
|
|
56
|
+
struct Limits<type> { \
|
|
57
|
+
static constexpr constant type max = \
|
|
58
|
+
metal::numeric_limits<type>::infinity(); \
|
|
59
|
+
static constexpr constant type min = \
|
|
60
|
+
-metal::numeric_limits<type>::infinity(); \
|
|
61
|
+
static constexpr constant type finite_max = \
|
|
62
|
+
metal::numeric_limits<type>::max(); \
|
|
63
|
+
static constexpr constant type finite_min = \
|
|
64
|
+
-metal::numeric_limits<type>::max(); \
|
|
65
|
+
};
|
|
66
|
+
|
|
67
|
+
instantiate_float_limit(half);
|
|
68
|
+
instantiate_float_limit(float);
|
|
69
|
+
instantiate_float_limit(bfloat16_t);
|
|
70
|
+
|
|
71
|
+
template <>
|
|
72
|
+
struct Limits<bool> {
|
|
73
|
+
static constexpr constant bool max = true;
|
|
74
|
+
static constexpr constant bool min = false;
|
|
75
|
+
};
|
|
76
|
+
|
|
77
|
+
template <>
|
|
78
|
+
struct Limits<complex64_t> {
|
|
79
|
+
static constexpr constant complex64_t max = complex64_t(
|
|
80
|
+
metal::numeric_limits<float>::infinity(),
|
|
81
|
+
metal::numeric_limits<float>::infinity());
|
|
82
|
+
static constexpr constant complex64_t min = complex64_t(
|
|
83
|
+
-metal::numeric_limits<float>::infinity(),
|
|
84
|
+
-metal::numeric_limits<float>::infinity());
|
|
85
|
+
};
|
|
86
|
+
|
|
87
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
88
|
+
// Indexing utils
|
|
89
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
90
|
+
|
|
91
|
+
#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
|
|
92
|
+
|
|
93
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
94
|
+
// Single Array with generic dims
|
|
95
|
+
|
|
96
|
+
template <typename IdxT = int64_t>
|
|
97
|
+
METAL_FUNC IdxT elem_to_loc(
|
|
98
|
+
IdxT elem,
|
|
99
|
+
constant const int* shape,
|
|
100
|
+
constant const int64_t* strides,
|
|
101
|
+
int ndim) {
|
|
102
|
+
IdxT loc = 0;
|
|
103
|
+
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
|
104
|
+
loc += (elem % shape[i]) * IdxT(strides[i]);
|
|
105
|
+
elem /= shape[i];
|
|
106
|
+
}
|
|
107
|
+
return loc;
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
// Non templated version to handle arbitrary dims
|
|
111
|
+
template <typename IdxT = int64_t>
|
|
112
|
+
METAL_FUNC IdxT elem_to_loc(
|
|
113
|
+
uint3 elem,
|
|
114
|
+
constant const int* shape,
|
|
115
|
+
constant const int64_t* strides,
|
|
116
|
+
int ndim) {
|
|
117
|
+
IdxT loc =
|
|
118
|
+
elem.x * IdxT(strides[ndim - 1]) + elem.y * IdxT(strides[ndim - 2]);
|
|
119
|
+
for (int d = ndim - 3; d >= 0; --d) {
|
|
120
|
+
loc += (elem.z % shape[d]) * IdxT(strides[d]);
|
|
121
|
+
elem.z /= shape[d];
|
|
122
|
+
}
|
|
123
|
+
return loc;
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
127
|
+
// Single Array with fixed N dims
|
|
128
|
+
|
|
129
|
+
template <typename IdxT = int64_t>
|
|
130
|
+
METAL_FUNC IdxT elem_to_loc_1(uint elem, constant const int64_t& stride) {
|
|
131
|
+
return elem * IdxT(stride);
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
template <typename IdxT = int64_t>
|
|
135
|
+
METAL_FUNC IdxT elem_to_loc_2(uint2 elem, constant const int64_t strides[2]) {
|
|
136
|
+
return elem.x * IdxT(strides[1]) + elem.y * IdxT(strides[0]);
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
template <typename IdxT = int64_t>
|
|
140
|
+
METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const int64_t strides[3]) {
|
|
141
|
+
return elem.x * IdxT(strides[2]) + elem.y * IdxT(strides[1]) +
|
|
142
|
+
elem.z * IdxT(strides[0]);
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
146
|
+
// Multiple Arrays with generic dims
|
|
147
|
+
|
|
148
|
+
template <typename IdxT = int64_t>
|
|
149
|
+
METAL_FUNC vec<IdxT, 2> elem_to_loc_2_nd(
|
|
150
|
+
uint3 elem,
|
|
151
|
+
constant const int* shape,
|
|
152
|
+
constant const int64_t* a_strides,
|
|
153
|
+
constant const int64_t* b_strides,
|
|
154
|
+
int ndim) {
|
|
155
|
+
vec<IdxT, 2> loc = {
|
|
156
|
+
IdxT(
|
|
157
|
+
elem.x * IdxT(a_strides[ndim - 1]) +
|
|
158
|
+
IdxT(elem.y) * IdxT(a_strides[ndim - 2])),
|
|
159
|
+
IdxT(
|
|
160
|
+
elem.x * IdxT(b_strides[ndim - 1]) +
|
|
161
|
+
elem.y * IdxT(b_strides[ndim - 2]))};
|
|
162
|
+
for (int d = ndim - 3; d >= 0; --d) {
|
|
163
|
+
uint l = elem.z % shape[d];
|
|
164
|
+
loc.x += l * IdxT(a_strides[d]);
|
|
165
|
+
loc.y += l * IdxT(b_strides[d]);
|
|
166
|
+
elem.z /= shape[d];
|
|
167
|
+
}
|
|
168
|
+
return loc;
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
template <typename IdxT = int64_t>
|
|
172
|
+
METAL_FUNC vec<IdxT, 3> elem_to_loc_3_nd(
|
|
173
|
+
uint3 elem,
|
|
174
|
+
constant const int* shape,
|
|
175
|
+
constant const int64_t* a_strides,
|
|
176
|
+
constant const int64_t* b_strides,
|
|
177
|
+
constant const int64_t* c_strides,
|
|
178
|
+
int ndim) {
|
|
179
|
+
vec<IdxT, 3> loc = {
|
|
180
|
+
IdxT(elem.x * IdxT(a_strides[ndim - 1])) +
|
|
181
|
+
IdxT(elem.y * IdxT(a_strides[ndim - 2])),
|
|
182
|
+
IdxT(elem.x * IdxT(b_strides[ndim - 1])) +
|
|
183
|
+
IdxT(elem.y * IdxT(b_strides[ndim - 2])),
|
|
184
|
+
IdxT(elem.x * IdxT(c_strides[ndim - 1])) +
|
|
185
|
+
IdxT(elem.y * IdxT(c_strides[ndim - 2]))};
|
|
186
|
+
for (int d = ndim - 3; d >= 0; --d) {
|
|
187
|
+
uint l = elem.z % shape[d];
|
|
188
|
+
loc.x += l * IdxT(a_strides[d]);
|
|
189
|
+
loc.y += l * IdxT(b_strides[d]);
|
|
190
|
+
loc.z += l * IdxT(c_strides[d]);
|
|
191
|
+
elem.z /= shape[d];
|
|
192
|
+
}
|
|
193
|
+
return loc;
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
197
|
+
// Elem to loc in a loop utils
|
|
198
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
199
|
+
|
|
200
|
+
template <int DIM, typename OffsetT = size_t, bool General = true>
|
|
201
|
+
struct LoopedElemToLoc {
|
|
202
|
+
int dim;
|
|
203
|
+
LoopedElemToLoc<DIM - 1, OffsetT, General> inner_looper;
|
|
204
|
+
OffsetT offset{0};
|
|
205
|
+
int index{0};
|
|
206
|
+
|
|
207
|
+
LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {}
|
|
208
|
+
|
|
209
|
+
void next(const constant int* shape, const constant int64_t* strides) {
|
|
210
|
+
if (dim == 0) {
|
|
211
|
+
return;
|
|
212
|
+
}
|
|
213
|
+
index++;
|
|
214
|
+
offset += OffsetT(strides[dim - 1]);
|
|
215
|
+
if (index >= shape[dim - 1]) {
|
|
216
|
+
index = 0;
|
|
217
|
+
inner_looper.next(shape, strides);
|
|
218
|
+
offset = inner_looper.offset;
|
|
219
|
+
}
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
void next(int n, const constant int* shape, const constant int64_t* strides) {
|
|
223
|
+
if (dim == 0) {
|
|
224
|
+
return;
|
|
225
|
+
}
|
|
226
|
+
index += n;
|
|
227
|
+
offset += n * OffsetT(strides[dim - 1]);
|
|
228
|
+
|
|
229
|
+
if (index >= shape[dim - 1]) {
|
|
230
|
+
int extra = index - shape[dim - 1];
|
|
231
|
+
if (extra >= shape[dim - 1]) {
|
|
232
|
+
inner_looper.next(1 + extra / shape[dim - 1], shape, strides);
|
|
233
|
+
extra = extra % shape[dim - 1];
|
|
234
|
+
} else {
|
|
235
|
+
inner_looper.next(shape, strides);
|
|
236
|
+
}
|
|
237
|
+
index = 0;
|
|
238
|
+
offset = inner_looper.offset;
|
|
239
|
+
if (extra > 0) {
|
|
240
|
+
next(extra, shape, strides);
|
|
241
|
+
}
|
|
242
|
+
}
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
OffsetT location() {
|
|
246
|
+
return offset;
|
|
247
|
+
}
|
|
248
|
+
};
|
|
249
|
+
|
|
250
|
+
template <typename OffsetT>
|
|
251
|
+
struct LoopedElemToLoc<1, OffsetT, true> {
|
|
252
|
+
int dim;
|
|
253
|
+
OffsetT offset{0};
|
|
254
|
+
uint index{0};
|
|
255
|
+
|
|
256
|
+
LoopedElemToLoc(int dim) : dim(dim) {}
|
|
257
|
+
|
|
258
|
+
void next(const constant int* shape, const constant int64_t* strides) {
|
|
259
|
+
index++;
|
|
260
|
+
if (dim > 1) {
|
|
261
|
+
offset = elem_to_loc<OffsetT>(index, shape, strides, dim);
|
|
262
|
+
} else {
|
|
263
|
+
offset += OffsetT(strides[0]);
|
|
264
|
+
}
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
void next(int n, const constant int* shape, const constant int64_t* strides) {
|
|
268
|
+
index += n;
|
|
269
|
+
if (dim > 1) {
|
|
270
|
+
offset = elem_to_loc<OffsetT>(index, shape, strides, dim);
|
|
271
|
+
} else {
|
|
272
|
+
offset = index * OffsetT(strides[0]);
|
|
273
|
+
}
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
OffsetT location() {
|
|
277
|
+
return offset;
|
|
278
|
+
}
|
|
279
|
+
};
|
|
280
|
+
|
|
281
|
+
template <typename OffsetT>
|
|
282
|
+
struct LoopedElemToLoc<1, OffsetT, false> {
|
|
283
|
+
OffsetT offset{0};
|
|
284
|
+
|
|
285
|
+
LoopedElemToLoc(int) {}
|
|
286
|
+
|
|
287
|
+
void next(const constant int*, const constant int64_t* strides) {
|
|
288
|
+
offset += OffsetT(strides[0]);
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
void next(int n, const constant int*, const constant int64_t* strides) {
|
|
292
|
+
offset += n * OffsetT(strides[0]);
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
OffsetT location() {
|
|
296
|
+
return offset;
|
|
297
|
+
}
|
|
298
|
+
};
|
|
299
|
+
|
|
300
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
301
|
+
// Calculation utils
|
|
302
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
303
|
+
|
|
304
|
+
/** Compute ceil((float)N/(float)M) */
|
|
305
|
+
template <typename T, typename U>
|
|
306
|
+
inline T ceildiv(T N, U M) {
|
|
307
|
+
return (N + M - 1) / M;
|
|
308
|
+
}
|
|
309
|
+
|
|
310
|
+
// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
|
|
311
|
+
inline float log1p(float x) {
|
|
312
|
+
float xp1 = 1.0f + x;
|
|
313
|
+
if (xp1 == Limits<float>::max) {
|
|
314
|
+
return Limits<float>::max;
|
|
315
|
+
}
|
|
316
|
+
if (xp1 == 1.0f) {
|
|
317
|
+
return x;
|
|
318
|
+
}
|
|
319
|
+
|
|
320
|
+
return x * (metal::log(xp1) / (xp1 - 1.0f));
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
inline bfloat16_t log1p(bfloat16_t x) {
|
|
324
|
+
float xp1 = 1.0f + static_cast<float>(x);
|
|
325
|
+
if (xp1 == Limits<float>::max) {
|
|
326
|
+
return Limits<bfloat16_t>::max;
|
|
327
|
+
}
|
|
328
|
+
if (xp1 == 1.0f) {
|
|
329
|
+
return x;
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
|
|
333
|
+
}
|
|
334
|
+
|
|
335
|
+
inline complex64_t log1p(complex64_t in) {
|
|
336
|
+
float x = in.real;
|
|
337
|
+
float y = in.imag;
|
|
338
|
+
float zabs = metal::precise::sqrt(x * x + y * y);
|
|
339
|
+
float theta = metal::atan2(y, x + 1);
|
|
340
|
+
if (zabs < 0.5f) {
|
|
341
|
+
float r = x * (2 + x) + y * y;
|
|
342
|
+
if (r == 0) { // handle underflow
|
|
343
|
+
return {x, theta};
|
|
344
|
+
}
|
|
345
|
+
return {0.5f * log1p(r), theta};
|
|
346
|
+
} else {
|
|
347
|
+
auto z0 = metal::sqrt((x + 1) * (x + 1) + y * y);
|
|
348
|
+
return {metal::log(z0), theta};
|
|
349
|
+
}
|
|
350
|
+
}
|
|
351
|
+
|
|
352
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
353
|
+
// SIMD shuffle ops
|
|
354
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
355
|
+
|
|
356
|
+
inline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
|
|
357
|
+
return as_type<uint64_t>(
|
|
358
|
+
metal::simd_shuffle_down(as_type<uint2>(data), delta));
|
|
359
|
+
}
|
|
360
|
+
|
|
361
|
+
inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
|
|
362
|
+
return as_type<int64_t>(
|
|
363
|
+
metal::simd_shuffle_down(as_type<uint2>(data), delta));
|
|
364
|
+
}
|
|
365
|
+
|
|
366
|
+
inline bool simd_shuffle_down(bool data, uint16_t delta) {
|
|
367
|
+
return simd_shuffle_down(static_cast<uint32_t>(data), delta);
|
|
368
|
+
}
|
|
369
|
+
|
|
370
|
+
inline complex64_t simd_shuffle_down(complex64_t data, uint16_t delta) {
|
|
371
|
+
return complex64_t(
|
|
372
|
+
simd_shuffle_down(data.real, delta), simd_shuffle_down(data.imag, delta));
|
|
373
|
+
}
|
|
374
|
+
|
|
375
|
+
inline uint64_t simd_shuffle_up(uint64_t data, uint16_t delta) {
|
|
376
|
+
return as_type<uint64_t>(metal::simd_shuffle_up(as_type<uint2>(data), delta));
|
|
377
|
+
}
|
|
378
|
+
|
|
379
|
+
inline int64_t simd_shuffle_up(int64_t data, uint16_t delta) {
|
|
380
|
+
return as_type<int64_t>(metal::simd_shuffle_up(as_type<uint2>(data), delta));
|
|
381
|
+
}
|
|
382
|
+
|
|
383
|
+
inline bool simd_shuffle_up(bool data, uint16_t delta) {
|
|
384
|
+
return simd_shuffle_up(static_cast<uint32_t>(data), delta);
|
|
385
|
+
}
|
|
386
|
+
|
|
387
|
+
inline complex64_t simd_shuffle_up(complex64_t data, uint16_t delta) {
|
|
388
|
+
return complex64_t(
|
|
389
|
+
simd_shuffle_up(data.real, delta), simd_shuffle_up(data.imag, delta));
|
|
390
|
+
}
|
|
391
|
+
|
|
392
|
+
inline uint64_t
|
|
393
|
+
simd_shuffle_and_fill_up(uint64_t data, uint64_t filling, uint16_t delta) {
|
|
394
|
+
return as_type<uint64_t>(metal::simd_shuffle_and_fill_up(
|
|
395
|
+
as_type<uint2>(data), as_type<uint2>(filling), delta));
|
|
396
|
+
}
|
|
397
|
+
|
|
398
|
+
inline int64_t
|
|
399
|
+
simd_shuffle_and_fill_up(int64_t data, int64_t filling, uint16_t delta) {
|
|
400
|
+
return as_type<int64_t>(metal::simd_shuffle_and_fill_up(
|
|
401
|
+
as_type<uint2>(data), as_type<uint2>(filling), delta));
|
|
402
|
+
}
|
|
403
|
+
|
|
404
|
+
inline bool simd_shuffle_and_fill_up(bool data, bool filling, uint16_t delta) {
|
|
405
|
+
return simd_shuffle_and_fill_up(
|
|
406
|
+
static_cast<uint32_t>(data), static_cast<uint32_t>(filling), delta);
|
|
407
|
+
}
|
|
408
|
+
|
|
409
|
+
inline complex64_t simd_shuffle_and_fill_up(
|
|
410
|
+
complex64_t data,
|
|
411
|
+
complex64_t filling,
|
|
412
|
+
uint16_t delta) {
|
|
413
|
+
return complex64_t(
|
|
414
|
+
simd_shuffle_and_fill_up(data.real, filling.real, delta),
|
|
415
|
+
simd_shuffle_and_fill_up(data.imag, filling.imag, delta));
|
|
416
|
+
}
|
|
417
|
+
|
|
418
|
+
inline uint64_t simd_shuffle(uint64_t data, uint16_t lane) {
|
|
419
|
+
return as_type<uint64_t>(metal::simd_shuffle(as_type<uint2>(data), lane));
|
|
420
|
+
}
|
|
421
|
+
|
|
422
|
+
inline int64_t simd_shuffle(int64_t data, uint16_t lane) {
|
|
423
|
+
return as_type<int64_t>(metal::simd_shuffle(as_type<uint2>(data), lane));
|
|
424
|
+
}
|
|
425
|
+
|
|
426
|
+
inline bool simd_shuffle(bool data, uint16_t lane) {
|
|
427
|
+
return simd_shuffle(static_cast<uint32_t>(data), lane);
|
|
428
|
+
}
|
|
429
|
+
|
|
430
|
+
inline complex64_t simd_shuffle(complex64_t data, uint16_t lane) {
|
|
431
|
+
return complex64_t(
|
|
432
|
+
simd_shuffle(data.real, lane), simd_shuffle(data.imag, lane));
|
|
433
|
+
}
|
|
434
|
+
|
|
435
|
+
// std::conditional is not included with Metal
|
|
436
|
+
template <bool condition, typename T, typename U>
|
|
437
|
+
struct ConditionalType {
|
|
438
|
+
using type = U;
|
|
439
|
+
};
|
|
440
|
+
|
|
441
|
+
template <typename T, typename U>
|
|
442
|
+
struct ConditionalType<true, T, U> {
|
|
443
|
+
using type = T;
|
|
444
|
+
};
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
// Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/backend/metal/device.h"
|
|
6
|
+
|
|
7
|
+
namespace mlx::core {
|
|
8
|
+
|
|
9
|
+
template <bool CHECK_AB = true>
|
|
10
|
+
void steel_matmul_regular_axpby(
|
|
11
|
+
const Stream& s,
|
|
12
|
+
metal::Device& d,
|
|
13
|
+
const array& a,
|
|
14
|
+
const array& b,
|
|
15
|
+
const array& c,
|
|
16
|
+
array& out,
|
|
17
|
+
int M,
|
|
18
|
+
int N,
|
|
19
|
+
int K,
|
|
20
|
+
int batch_size_out,
|
|
21
|
+
int lda,
|
|
22
|
+
int ldb,
|
|
23
|
+
int ldd,
|
|
24
|
+
bool transpose_a,
|
|
25
|
+
bool transpose_b,
|
|
26
|
+
std::vector<array>& copies,
|
|
27
|
+
Shape batch_shape,
|
|
28
|
+
Strides batch_strides,
|
|
29
|
+
int64_t A_batch_stride,
|
|
30
|
+
int64_t B_batch_stride,
|
|
31
|
+
int64_t matrix_stride_out,
|
|
32
|
+
int64_t C_batch_stride = 0,
|
|
33
|
+
float alpha = 1.0f,
|
|
34
|
+
float beta = 0.0f);
|
|
35
|
+
|
|
36
|
+
inline void steel_matmul_regular(
|
|
37
|
+
const Stream& s,
|
|
38
|
+
metal::Device& d,
|
|
39
|
+
const array& a,
|
|
40
|
+
const array& b,
|
|
41
|
+
array& out,
|
|
42
|
+
int M,
|
|
43
|
+
int N,
|
|
44
|
+
int K,
|
|
45
|
+
int batch_size_out,
|
|
46
|
+
int lda,
|
|
47
|
+
int ldb,
|
|
48
|
+
int ldd,
|
|
49
|
+
bool transpose_a,
|
|
50
|
+
bool transpose_b,
|
|
51
|
+
std::vector<array>& copies,
|
|
52
|
+
Shape batch_shape,
|
|
53
|
+
Strides batch_strides,
|
|
54
|
+
int64_t A_batch_stride,
|
|
55
|
+
int64_t B_batch_stride,
|
|
56
|
+
int64_t matrix_stride_out) {
|
|
57
|
+
return steel_matmul_regular_axpby<false>(
|
|
58
|
+
/* const Stream& s = */ s,
|
|
59
|
+
/* metal::Device& d = */ d,
|
|
60
|
+
/* const array& a = */ a,
|
|
61
|
+
/* const array& b = */ b,
|
|
62
|
+
/* const array& c = */ b,
|
|
63
|
+
/* array& out = */ out,
|
|
64
|
+
/* int M = */ M,
|
|
65
|
+
/* int N = */ N,
|
|
66
|
+
/* int K = */ K,
|
|
67
|
+
/* int batch_size_out = */ batch_size_out,
|
|
68
|
+
/* int lda = */ lda,
|
|
69
|
+
/* int ldb = */ ldb,
|
|
70
|
+
/* int ldd = */ ldd,
|
|
71
|
+
/* bool transpose_a = */ transpose_a,
|
|
72
|
+
/* bool transpose_b = */ transpose_b,
|
|
73
|
+
/* std::vector<array>& copies = */ copies,
|
|
74
|
+
/* Shape batch_shape = */ batch_shape,
|
|
75
|
+
/* Strides batch_strides = */ batch_strides,
|
|
76
|
+
/* int64_t A_batch_stride = */ A_batch_stride,
|
|
77
|
+
/* int64_t B_batch_stride = */ B_batch_stride,
|
|
78
|
+
/* int64_t matrix_stride_out = */ matrix_stride_out);
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
template <bool CHECK_AB = true>
|
|
82
|
+
void steel_matmul_axpby(
|
|
83
|
+
const Stream& s,
|
|
84
|
+
metal::Device& d,
|
|
85
|
+
const array& a,
|
|
86
|
+
const array& b,
|
|
87
|
+
const array& c,
|
|
88
|
+
array& out,
|
|
89
|
+
int M,
|
|
90
|
+
int N,
|
|
91
|
+
int K,
|
|
92
|
+
int batch_size_out,
|
|
93
|
+
int lda,
|
|
94
|
+
int ldb,
|
|
95
|
+
bool transpose_a,
|
|
96
|
+
bool transpose_b,
|
|
97
|
+
std::vector<array>& copies,
|
|
98
|
+
Shape batch_shape = {},
|
|
99
|
+
Strides A_batch_stride = {},
|
|
100
|
+
Strides B_batch_stride = {},
|
|
101
|
+
Strides C_batch_stride = {},
|
|
102
|
+
float alpha = 1.0f,
|
|
103
|
+
float beta = 0.0f);
|
|
104
|
+
|
|
105
|
+
inline void steel_matmul(
|
|
106
|
+
const Stream& s,
|
|
107
|
+
metal::Device& d,
|
|
108
|
+
const array& a,
|
|
109
|
+
const array& b,
|
|
110
|
+
array& out,
|
|
111
|
+
int M,
|
|
112
|
+
int N,
|
|
113
|
+
int K,
|
|
114
|
+
int batch_size_out,
|
|
115
|
+
int lda,
|
|
116
|
+
int ldb,
|
|
117
|
+
bool transpose_a,
|
|
118
|
+
bool transpose_b,
|
|
119
|
+
std::vector<array>& copies,
|
|
120
|
+
Shape batch_shape = {},
|
|
121
|
+
Strides A_batch_stride = {},
|
|
122
|
+
Strides B_batch_stride = {}) {
|
|
123
|
+
return steel_matmul_axpby<false>(
|
|
124
|
+
/* const Stream& s = */ s,
|
|
125
|
+
/* metal::Device& d = */ d,
|
|
126
|
+
/* const array& a = */ a,
|
|
127
|
+
/* const array& b = */ b,
|
|
128
|
+
/* const array& c = */ b,
|
|
129
|
+
/* array& out = */ out,
|
|
130
|
+
/* int M = */ M,
|
|
131
|
+
/* int N = */ N,
|
|
132
|
+
/* int K = */ K,
|
|
133
|
+
/* int batch_size_out = */ batch_size_out,
|
|
134
|
+
/* int lda = */ lda,
|
|
135
|
+
/* int ldb = */ ldb,
|
|
136
|
+
/* bool transpose_a = */ transpose_a,
|
|
137
|
+
/* bool transpose_b = */ transpose_b,
|
|
138
|
+
/* std::vector<array>& copies = */ copies,
|
|
139
|
+
/* Shape batch_shape = */ batch_shape,
|
|
140
|
+
/* Strides A_batch_stride = */ A_batch_stride,
|
|
141
|
+
/* Strides B_batch_stride = */ B_batch_stride);
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
} // namespace mlx::core
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
// Copyright © 2023-2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <string>
|
|
6
|
+
#include <unordered_map>
|
|
7
|
+
#include <variant>
|
|
8
|
+
|
|
9
|
+
namespace mlx::core::metal {
|
|
10
|
+
|
|
11
|
+
/* Check if the Metal backend is available. */
|
|
12
|
+
bool is_available();
|
|
13
|
+
|
|
14
|
+
/** Capture a GPU trace, saving it to an absolute file `path` */
|
|
15
|
+
void start_capture(std::string path = "");
|
|
16
|
+
void stop_capture();
|
|
17
|
+
|
|
18
|
+
/** Get information about the GPU and system settings. */
|
|
19
|
+
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
|
|
20
|
+
device_info();
|
|
21
|
+
|
|
22
|
+
} // namespace mlx::core::metal
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
// Copyright @ 2023 - 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/backend/common/reduce.h"
|
|
6
|
+
#include "mlx/backend/metal/device.h"
|
|
7
|
+
#include "mlx/stream.h"
|
|
8
|
+
|
|
9
|
+
namespace mlx::core {
|
|
10
|
+
|
|
11
|
+
using metal::CommandEncoder;
|
|
12
|
+
|
|
13
|
+
void all_reduce_dispatch(
|
|
14
|
+
const array& in,
|
|
15
|
+
array& out,
|
|
16
|
+
const std::string& op_name,
|
|
17
|
+
CommandEncoder& compute_encoder,
|
|
18
|
+
metal::Device& d,
|
|
19
|
+
const Stream& s);
|
|
20
|
+
|
|
21
|
+
void row_reduce_general_dispatch(
|
|
22
|
+
const array& in,
|
|
23
|
+
array& out,
|
|
24
|
+
const std::string& op_name,
|
|
25
|
+
const ReductionPlan& plan,
|
|
26
|
+
const std::vector<int>& axes,
|
|
27
|
+
CommandEncoder& compute_encoder,
|
|
28
|
+
metal::Device& d,
|
|
29
|
+
const Stream& s);
|
|
30
|
+
|
|
31
|
+
void strided_reduce_general_dispatch(
|
|
32
|
+
const array& in,
|
|
33
|
+
array& out,
|
|
34
|
+
const std::string& op_name,
|
|
35
|
+
const ReductionPlan& plan,
|
|
36
|
+
const std::vector<int>& axes,
|
|
37
|
+
CommandEncoder& compute_encoder,
|
|
38
|
+
metal::Device& d,
|
|
39
|
+
const Stream& s);
|
|
40
|
+
|
|
41
|
+
} // namespace mlx::core
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/backend/metal/device.h"
|
|
6
|
+
|
|
7
|
+
namespace mlx::core::metal {
|
|
8
|
+
|
|
9
|
+
class ResidencySet {
|
|
10
|
+
public:
|
|
11
|
+
ResidencySet(MTL::Device* d);
|
|
12
|
+
~ResidencySet();
|
|
13
|
+
|
|
14
|
+
ResidencySet(const ResidencySet&) = delete;
|
|
15
|
+
ResidencySet& operator=(const ResidencySet&) = delete;
|
|
16
|
+
|
|
17
|
+
const MTL::ResidencySet* mtl_residency_set() {
|
|
18
|
+
return wired_set_;
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
void insert(MTL::Allocation* buf);
|
|
22
|
+
void erase(MTL::Allocation* buf);
|
|
23
|
+
|
|
24
|
+
void resize(size_t size);
|
|
25
|
+
|
|
26
|
+
private:
|
|
27
|
+
MTL::ResidencySet* wired_set_{nullptr};
|
|
28
|
+
std::unordered_set<const MTL::Allocation*> unwired_set_;
|
|
29
|
+
size_t capacity_{0};
|
|
30
|
+
};
|
|
31
|
+
|
|
32
|
+
} // namespace mlx::core::metal
|