mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
6
|
+
// Attn param classes
|
|
7
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
8
|
+
|
|
9
|
+
namespace mlx {
|
|
10
|
+
namespace steel {
|
|
11
|
+
|
|
12
|
+
struct AttnParams {
|
|
13
|
+
int B; ///< Batch Size
|
|
14
|
+
int H; ///< Heads
|
|
15
|
+
int D; ///< Head Dim
|
|
16
|
+
|
|
17
|
+
int qL; ///< Query Sequence Length
|
|
18
|
+
int kL; ///< Key Sequence Length
|
|
19
|
+
|
|
20
|
+
int gqa_factor; ///< Group Query factor
|
|
21
|
+
float scale; ///< Attention scale
|
|
22
|
+
|
|
23
|
+
int NQ; ///< Number of query blocks
|
|
24
|
+
int NK; ///< Number of key/value blocks
|
|
25
|
+
|
|
26
|
+
int NQ_aligned; ///< Number of full query blocks
|
|
27
|
+
int NK_aligned; ///< Number of full key/value blocks
|
|
28
|
+
|
|
29
|
+
int qL_rem; ///< Remainder in last query block
|
|
30
|
+
int kL_rem; ///< Remainder in last key/value block
|
|
31
|
+
int qL_off; ///< Offset in query sequence start
|
|
32
|
+
|
|
33
|
+
int64_t Q_strides[3]; ///< Query strides (B, H, L, D = 1)
|
|
34
|
+
int64_t K_strides[3]; ///< Key strides (B, H, L, D = 1)
|
|
35
|
+
int64_t V_strides[3]; ///< Value strides (B, H, L, D = 1)
|
|
36
|
+
int64_t O_strides[3]; ///< Output strides (B, H, L, D = 1)
|
|
37
|
+
};
|
|
38
|
+
|
|
39
|
+
struct AttnMaskParams {
|
|
40
|
+
int64_t M_strides[3]; ///< Mask strides (B, H, qL, kL = 1)
|
|
41
|
+
};
|
|
42
|
+
|
|
43
|
+
} // namespace steel
|
|
44
|
+
} // namespace mlx
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/backend/metal/kernels/steel/utils.h"
|
|
6
|
+
|
|
7
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
8
|
+
// Transforms and Epilogues
|
|
9
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
10
|
+
|
|
11
|
+
namespace mlx {
|
|
12
|
+
namespace steel {
|
|
13
|
+
|
|
14
|
+
template <typename OutT, typename InT>
|
|
15
|
+
struct TransformNone {
|
|
16
|
+
static METAL_FUNC OutT apply(InT x) {
|
|
17
|
+
return static_cast<OutT>(x);
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
static METAL_FUNC OutT apply(InT x, OutT) {
|
|
21
|
+
return static_cast<OutT>(x);
|
|
22
|
+
}
|
|
23
|
+
};
|
|
24
|
+
|
|
25
|
+
template <typename OutT, typename InT>
|
|
26
|
+
struct TransformAdd {
|
|
27
|
+
TransformAdd(const float, const float) {}
|
|
28
|
+
|
|
29
|
+
static METAL_FUNC OutT apply(InT x) {
|
|
30
|
+
return static_cast<OutT>(x);
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
static METAL_FUNC OutT apply(InT x, OutT c) {
|
|
34
|
+
return static_cast<OutT>(x) + c;
|
|
35
|
+
}
|
|
36
|
+
};
|
|
37
|
+
|
|
38
|
+
template <typename OutT, typename InT>
|
|
39
|
+
struct TransformAxpby {
|
|
40
|
+
const float alpha;
|
|
41
|
+
const float beta;
|
|
42
|
+
|
|
43
|
+
TransformAxpby(const float alpha_, const float beta_)
|
|
44
|
+
: alpha(alpha_), beta(beta_) {}
|
|
45
|
+
|
|
46
|
+
static METAL_FUNC OutT apply(InT x) {
|
|
47
|
+
return static_cast<OutT>(x);
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
METAL_FUNC OutT apply(InT x, OutT c) const {
|
|
51
|
+
return static_cast<OutT>(x * alpha + (beta * c));
|
|
52
|
+
}
|
|
53
|
+
};
|
|
54
|
+
|
|
55
|
+
template <typename T>
|
|
56
|
+
struct AccumHelper {
|
|
57
|
+
typedef float accum_type;
|
|
58
|
+
};
|
|
59
|
+
|
|
60
|
+
struct BlockSwizzle {
|
|
61
|
+
static METAL_FUNC int2
|
|
62
|
+
swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) {
|
|
63
|
+
const int tid_x = (tid.x) >> swizzle_log;
|
|
64
|
+
const int tid_y =
|
|
65
|
+
((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1));
|
|
66
|
+
return int2(tid_x, tid_y);
|
|
67
|
+
}
|
|
68
|
+
};
|
|
69
|
+
|
|
70
|
+
} // namespace steel
|
|
71
|
+
} // namespace mlx
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/backend/metal/kernels/steel/defines.h"
|
|
6
|
+
#include "mlx/backend/metal/kernels/steel/utils.h"
|
|
7
|
+
|
|
8
|
+
#include "mlx/backend/metal/kernels/steel/conv/loader.h"
|
|
9
|
+
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
|
10
|
+
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
|
11
|
+
|
|
12
|
+
using namespace metal;
|
|
13
|
+
using namespace mlx::steel;
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#include <metal_stdlib>
|
|
4
|
+
|
|
5
|
+
using namespace metal;
|
|
6
|
+
|
|
7
|
+
template <
|
|
8
|
+
typename T,
|
|
9
|
+
int BM,
|
|
10
|
+
int BN,
|
|
11
|
+
int BK,
|
|
12
|
+
int WM,
|
|
13
|
+
int WN,
|
|
14
|
+
int N_CHANNELS = 0,
|
|
15
|
+
bool SMALL_FILTER = false>
|
|
16
|
+
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
|
17
|
+
implicit_gemm_conv_2d(
|
|
18
|
+
const device T* A [[buffer(0)]],
|
|
19
|
+
const device T* B [[buffer(1)]],
|
|
20
|
+
device T* C [[buffer(2)]],
|
|
21
|
+
const constant MLXConvParams<2>* params [[buffer(3)]],
|
|
22
|
+
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
|
|
23
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
24
|
+
uint3 lid [[thread_position_in_threadgroup]],
|
|
25
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
26
|
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
27
|
+
using namespace mlx::steel;
|
|
28
|
+
|
|
29
|
+
(void)lid;
|
|
30
|
+
|
|
31
|
+
constexpr bool transpose_a = false;
|
|
32
|
+
constexpr bool transpose_b = true;
|
|
33
|
+
constexpr short tgp_padding_a = 16 / sizeof(T);
|
|
34
|
+
constexpr short tgp_padding_b = 16 / sizeof(T);
|
|
35
|
+
|
|
36
|
+
constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;
|
|
37
|
+
constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;
|
|
38
|
+
constexpr short shape_a_rows = (transpose_a ? BK : BM);
|
|
39
|
+
constexpr short shape_b_rows = (transpose_b ? BN : BK);
|
|
40
|
+
constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;
|
|
41
|
+
constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
|
|
42
|
+
|
|
43
|
+
constexpr short tgp_size = WM * WN * 32;
|
|
44
|
+
|
|
45
|
+
// Input loader
|
|
46
|
+
|
|
47
|
+
using loader_a_t = typename metal::conditional_t<
|
|
48
|
+
// Check for small channel specialization
|
|
49
|
+
N_CHANNELS != 0 && N_CHANNELS <= 4,
|
|
50
|
+
|
|
51
|
+
// Go to small channel specialization
|
|
52
|
+
Conv2DInputBlockLoaderSmallChannels<
|
|
53
|
+
T,
|
|
54
|
+
BM,
|
|
55
|
+
BN,
|
|
56
|
+
BK,
|
|
57
|
+
tgp_size,
|
|
58
|
+
N_CHANNELS,
|
|
59
|
+
tgp_padding_a>,
|
|
60
|
+
|
|
61
|
+
// Else go to general loader
|
|
62
|
+
typename metal::conditional_t<
|
|
63
|
+
// Check if filter size is small enough
|
|
64
|
+
SMALL_FILTER,
|
|
65
|
+
|
|
66
|
+
// Go to small filter specialization
|
|
67
|
+
Conv2DInputBlockLoaderSmallFilter<
|
|
68
|
+
T,
|
|
69
|
+
BM,
|
|
70
|
+
BN,
|
|
71
|
+
BK,
|
|
72
|
+
tgp_size,
|
|
73
|
+
tgp_padding_a>,
|
|
74
|
+
|
|
75
|
+
// Else go to large filter generalization
|
|
76
|
+
Conv2DInputBlockLoaderLargeFilter<
|
|
77
|
+
T,
|
|
78
|
+
BM,
|
|
79
|
+
BN,
|
|
80
|
+
BK,
|
|
81
|
+
tgp_size,
|
|
82
|
+
tgp_padding_a>>>;
|
|
83
|
+
|
|
84
|
+
// Weight loader
|
|
85
|
+
using loader_b_t = typename metal::conditional_t<
|
|
86
|
+
// Check for small channel specialization
|
|
87
|
+
N_CHANNELS != 0 && N_CHANNELS <= 4,
|
|
88
|
+
|
|
89
|
+
// Go to small channel specialization
|
|
90
|
+
Conv2DWeightBlockLoaderSmallChannels<
|
|
91
|
+
T,
|
|
92
|
+
BM,
|
|
93
|
+
BN,
|
|
94
|
+
BK,
|
|
95
|
+
tgp_size,
|
|
96
|
+
N_CHANNELS,
|
|
97
|
+
tgp_padding_b>,
|
|
98
|
+
|
|
99
|
+
// Else go to general loader
|
|
100
|
+
Conv2DWeightBlockLoader<T, BM, BN, BK, tgp_size, tgp_padding_b>>;
|
|
101
|
+
|
|
102
|
+
using mma_t = BlockMMA<
|
|
103
|
+
T,
|
|
104
|
+
T,
|
|
105
|
+
BM,
|
|
106
|
+
BN,
|
|
107
|
+
BK,
|
|
108
|
+
WM,
|
|
109
|
+
WN,
|
|
110
|
+
transpose_a,
|
|
111
|
+
transpose_b,
|
|
112
|
+
shape_a_cols,
|
|
113
|
+
shape_b_cols>;
|
|
114
|
+
|
|
115
|
+
threadgroup T As[tgp_mem_size_a];
|
|
116
|
+
threadgroup T Bs[tgp_mem_size_b];
|
|
117
|
+
|
|
118
|
+
const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
|
|
119
|
+
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
|
|
120
|
+
const int tid_x = (tid.x) >> gemm_params->swizzle_log;
|
|
121
|
+
|
|
122
|
+
if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
|
|
123
|
+
return;
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
const int c_row = tid_y * BM;
|
|
127
|
+
const int c_col = tid_x * BN;
|
|
128
|
+
const int K = gemm_params->K;
|
|
129
|
+
const int N = gemm_params->N;
|
|
130
|
+
const int C_per_group = params->C / params->groups;
|
|
131
|
+
|
|
132
|
+
// Groups
|
|
133
|
+
A += tid.z * C_per_group;
|
|
134
|
+
B += tid.z * N * K;
|
|
135
|
+
C += tid.z * N;
|
|
136
|
+
|
|
137
|
+
B += c_col * K;
|
|
138
|
+
C += c_row * (N * params->groups) + c_col;
|
|
139
|
+
|
|
140
|
+
const int2 offsets_a(0, c_row);
|
|
141
|
+
const int2 offsets_b(0, c_col);
|
|
142
|
+
|
|
143
|
+
// Prepare threadgroup loading operations
|
|
144
|
+
loader_a_t loader_a(
|
|
145
|
+
A, As, offsets_a, params, gemm_params, simd_gid, simd_lid);
|
|
146
|
+
loader_b_t loader_b(
|
|
147
|
+
B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid);
|
|
148
|
+
|
|
149
|
+
// Prepare threadgroup mma operation
|
|
150
|
+
mma_t mma_op(simd_gid, simd_lid);
|
|
151
|
+
|
|
152
|
+
int gemm_k_iterations = gemm_params->gemm_k_iterations;
|
|
153
|
+
for (int k = 0; k < gemm_k_iterations; k++) {
|
|
154
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
155
|
+
// Load elements into threadgroup
|
|
156
|
+
loader_a.load_unsafe();
|
|
157
|
+
loader_b.load_unsafe();
|
|
158
|
+
|
|
159
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
160
|
+
|
|
161
|
+
// Multiply and accumulate threadgroup elements
|
|
162
|
+
mma_op.mma(As, Bs);
|
|
163
|
+
|
|
164
|
+
// Prepare for next iteration
|
|
165
|
+
loader_a.next();
|
|
166
|
+
loader_b.next();
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
threadgroup_barrier(mem_flags::mem_none);
|
|
170
|
+
|
|
171
|
+
// Store results to device memory
|
|
172
|
+
short tgp_bm = min(BM, gemm_params->M - c_row);
|
|
173
|
+
short tgp_bn = min(BN, gemm_params->N - c_col);
|
|
174
|
+
const int ldc = N * params->groups;
|
|
175
|
+
mma_op.store_result_safe(C, ldc, short2(tgp_bn, tgp_bm));
|
|
176
|
+
}
|
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h"
|
|
4
|
+
|
|
5
|
+
constant bool align_C [[function_constant(200)]];
|
|
6
|
+
|
|
7
|
+
template <
|
|
8
|
+
typename T,
|
|
9
|
+
int BM,
|
|
10
|
+
int BN,
|
|
11
|
+
int BK,
|
|
12
|
+
int WM,
|
|
13
|
+
int WN,
|
|
14
|
+
typename AccumType = float,
|
|
15
|
+
typename Epilogue = TransformNone<T, AccumType>>
|
|
16
|
+
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
|
17
|
+
implicit_gemm_conv_2d_general(
|
|
18
|
+
const device T* A [[buffer(0)]],
|
|
19
|
+
const device T* B [[buffer(1)]],
|
|
20
|
+
device T* C [[buffer(2)]],
|
|
21
|
+
const constant MLXConvParams<2>* params [[buffer(3)]],
|
|
22
|
+
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
|
|
23
|
+
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]],
|
|
24
|
+
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]],
|
|
25
|
+
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]],
|
|
26
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
27
|
+
uint3 lid [[thread_position_in_threadgroup]],
|
|
28
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
29
|
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
30
|
+
(void)lid;
|
|
31
|
+
|
|
32
|
+
constexpr bool transpose_a = false;
|
|
33
|
+
constexpr bool transpose_b = true;
|
|
34
|
+
constexpr short tgp_padding_a = 16 / sizeof(T);
|
|
35
|
+
constexpr short tgp_padding_b = 16 / sizeof(T);
|
|
36
|
+
|
|
37
|
+
constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;
|
|
38
|
+
constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;
|
|
39
|
+
constexpr short shape_a_rows = (transpose_a ? BK : BM);
|
|
40
|
+
constexpr short shape_b_rows = (transpose_b ? BN : BK);
|
|
41
|
+
constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;
|
|
42
|
+
constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
|
|
43
|
+
|
|
44
|
+
constexpr short tgp_size = WM * WN * 32;
|
|
45
|
+
|
|
46
|
+
// Input loader
|
|
47
|
+
using loader_a_t =
|
|
48
|
+
Conv2DInputBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_a>;
|
|
49
|
+
|
|
50
|
+
// Weight loader
|
|
51
|
+
using loader_b_t =
|
|
52
|
+
Conv2DWeightBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_b>;
|
|
53
|
+
|
|
54
|
+
using mma_t = BlockMMA<
|
|
55
|
+
T,
|
|
56
|
+
T,
|
|
57
|
+
BM,
|
|
58
|
+
BN,
|
|
59
|
+
BK,
|
|
60
|
+
WM,
|
|
61
|
+
WN,
|
|
62
|
+
transpose_a,
|
|
63
|
+
transpose_b,
|
|
64
|
+
shape_a_cols,
|
|
65
|
+
shape_b_cols>;
|
|
66
|
+
|
|
67
|
+
threadgroup T As[tgp_mem_size_a];
|
|
68
|
+
threadgroup T Bs[tgp_mem_size_b];
|
|
69
|
+
|
|
70
|
+
const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
|
|
71
|
+
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
|
|
72
|
+
const int tid_x = (tid.x) >> gemm_params->swizzle_log;
|
|
73
|
+
|
|
74
|
+
if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
|
|
75
|
+
return;
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
const int tid_z = tid.z;
|
|
79
|
+
|
|
80
|
+
const int base_oh = tid_z / jump_params->f_out_jump_w;
|
|
81
|
+
const int base_ow = tid_z % jump_params->f_out_jump_w;
|
|
82
|
+
|
|
83
|
+
const int base_wh = base_h[base_oh].weight_base;
|
|
84
|
+
const int base_ww = base_w[base_ow].weight_base;
|
|
85
|
+
|
|
86
|
+
const int base_wh_size = base_h[base_oh].weight_size;
|
|
87
|
+
const int base_ww_size = base_w[base_ow].weight_size;
|
|
88
|
+
|
|
89
|
+
const int c_row = tid_y * BM;
|
|
90
|
+
const int c_col = tid_x * BN;
|
|
91
|
+
const int K = gemm_params->K;
|
|
92
|
+
|
|
93
|
+
B += c_col * K;
|
|
94
|
+
|
|
95
|
+
const int4 offsets_a(0, c_row, base_oh, base_ow);
|
|
96
|
+
const int2 offsets_b(0, c_col);
|
|
97
|
+
|
|
98
|
+
// Prepare threadgroup loading operations
|
|
99
|
+
loader_a_t loader_a(
|
|
100
|
+
A,
|
|
101
|
+
As,
|
|
102
|
+
offsets_a,
|
|
103
|
+
params,
|
|
104
|
+
jump_params,
|
|
105
|
+
base_wh,
|
|
106
|
+
base_ww,
|
|
107
|
+
simd_gid,
|
|
108
|
+
simd_lid);
|
|
109
|
+
loader_b_t loader_b(
|
|
110
|
+
B,
|
|
111
|
+
Bs,
|
|
112
|
+
offsets_b,
|
|
113
|
+
params,
|
|
114
|
+
jump_params,
|
|
115
|
+
base_wh,
|
|
116
|
+
base_ww,
|
|
117
|
+
simd_gid,
|
|
118
|
+
simd_lid);
|
|
119
|
+
|
|
120
|
+
// Prepare threadgroup mma operation
|
|
121
|
+
mma_t mma_op(simd_gid, simd_lid);
|
|
122
|
+
|
|
123
|
+
if (align_C) {
|
|
124
|
+
int gemm_k_iterations =
|
|
125
|
+
base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
|
|
126
|
+
|
|
127
|
+
for (int k = 0; k < gemm_k_iterations; k++) {
|
|
128
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
129
|
+
// Load elements into threadgroup
|
|
130
|
+
loader_a.load_unsafe();
|
|
131
|
+
loader_b.load_unsafe();
|
|
132
|
+
|
|
133
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
134
|
+
|
|
135
|
+
// Multiply and accumulate threadgroup elements
|
|
136
|
+
mma_op.mma(As, Bs);
|
|
137
|
+
|
|
138
|
+
// Prepare for next iteration
|
|
139
|
+
loader_a.next();
|
|
140
|
+
loader_b.next();
|
|
141
|
+
}
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
else {
|
|
145
|
+
for (int k = 1; k < gemm_params->gemm_k_iterations; k++) {
|
|
146
|
+
for (int j = 0; j < base_wh_size * base_ww_size; j++) {
|
|
147
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
148
|
+
// Load elements into threadgroup
|
|
149
|
+
loader_a.load_unsafe();
|
|
150
|
+
loader_b.load_unsafe();
|
|
151
|
+
|
|
152
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
153
|
+
|
|
154
|
+
// Multiply and accumulate threadgroup elements
|
|
155
|
+
mma_op.mma(As, Bs);
|
|
156
|
+
|
|
157
|
+
// Prepare for next iteration
|
|
158
|
+
loader_a.next();
|
|
159
|
+
loader_b.next();
|
|
160
|
+
}
|
|
161
|
+
}
|
|
162
|
+
const short remaining_k = params->C % BK;
|
|
163
|
+
for (int j = 0; j < base_wh_size * base_ww_size; j++) {
|
|
164
|
+
// Load elements into threadgroup
|
|
165
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
166
|
+
loader_a.load_safe(remaining_k);
|
|
167
|
+
loader_b.load_safe(remaining_k);
|
|
168
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
169
|
+
// Multiply and accumulate threadgroup elements
|
|
170
|
+
mma_op.mma(As, Bs);
|
|
171
|
+
// Prepare for next iteration
|
|
172
|
+
loader_a.next();
|
|
173
|
+
loader_b.next();
|
|
174
|
+
}
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
threadgroup_barrier(mem_flags::mem_none);
|
|
178
|
+
|
|
179
|
+
// Store results to device memory
|
|
180
|
+
{
|
|
181
|
+
// Adjust for simdgroup and thread location
|
|
182
|
+
int offset_m = c_row + mma_op.sm;
|
|
183
|
+
int offset_n = c_col + mma_op.sn;
|
|
184
|
+
C += offset_n;
|
|
185
|
+
|
|
186
|
+
if (offset_n >= gemm_params->N)
|
|
187
|
+
return;
|
|
188
|
+
|
|
189
|
+
short diff = gemm_params->N - offset_n;
|
|
190
|
+
|
|
191
|
+
STEEL_PRAGMA_UNROLL
|
|
192
|
+
for (int i = 0; i < mma_t::TM; i++) {
|
|
193
|
+
int cm = offset_m + i * mma_t::TM_stride;
|
|
194
|
+
|
|
195
|
+
int n = cm / jump_params->adj_out_hw;
|
|
196
|
+
int hw = cm % jump_params->adj_out_hw;
|
|
197
|
+
int oh =
|
|
198
|
+
(hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh;
|
|
199
|
+
int ow =
|
|
200
|
+
(hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow;
|
|
201
|
+
|
|
202
|
+
if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) {
|
|
203
|
+
int offset_cm = n * params->out_strides[0] +
|
|
204
|
+
oh * params->out_strides[1] + ow * params->out_strides[2];
|
|
205
|
+
|
|
206
|
+
STEEL_PRAGMA_UNROLL
|
|
207
|
+
for (int j = 0; j < mma_t::TN; j++) {
|
|
208
|
+
// Get accumulated result and associated offset in C
|
|
209
|
+
thread const auto& accum = mma_op.Ctile.frag_at(i, j);
|
|
210
|
+
int offset = offset_cm + (j * mma_t::TN_stride);
|
|
211
|
+
|
|
212
|
+
constexpr short kelems = decltype(mma_op.Ctile)::kElemsPerFrag;
|
|
213
|
+
|
|
214
|
+
// Apply epilogue and output C
|
|
215
|
+
STEEL_PRAGMA_UNROLL
|
|
216
|
+
for (short k = 0; k < kelems; k++) {
|
|
217
|
+
if ((j * mma_t::TN_stride + k) < diff) {
|
|
218
|
+
C[offset + k] = Epilogue::apply(accum[k]);
|
|
219
|
+
}
|
|
220
|
+
}
|
|
221
|
+
}
|
|
222
|
+
}
|
|
223
|
+
}
|
|
224
|
+
}
|
|
225
|
+
}
|