mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
template <int NDIM>
|
|
6
|
+
struct MLXConvParams {
|
|
7
|
+
const int N; // Batch size
|
|
8
|
+
const int C; // In channels
|
|
9
|
+
const int O; // Out channels
|
|
10
|
+
const int iS[NDIM]; // Input spatial dim
|
|
11
|
+
const int wS[NDIM]; // Weight spatial dim
|
|
12
|
+
const int oS[NDIM]; // Output spatial dim
|
|
13
|
+
const int str[NDIM]; // Kernel strides
|
|
14
|
+
const int pad[NDIM]; // Input padding
|
|
15
|
+
const int kdil[NDIM]; // Kernel dilation
|
|
16
|
+
const int idil[NDIM]; // Input dilation
|
|
17
|
+
const int64_t in_strides[NDIM + 2]; // In strides
|
|
18
|
+
const int64_t wt_strides[NDIM + 2]; // Wt strides
|
|
19
|
+
const int64_t out_strides[NDIM + 2]; // Out strides
|
|
20
|
+
const int groups; // Input channel groups
|
|
21
|
+
const bool flip;
|
|
22
|
+
};
|
|
23
|
+
|
|
24
|
+
namespace mlx {
|
|
25
|
+
namespace steel {
|
|
26
|
+
|
|
27
|
+
struct ImplicitGemmConv2DParams {
|
|
28
|
+
const int M;
|
|
29
|
+
const int N;
|
|
30
|
+
const int K;
|
|
31
|
+
|
|
32
|
+
const int gemm_k_iterations;
|
|
33
|
+
|
|
34
|
+
const int inp_jump_w;
|
|
35
|
+
const int inp_jump_h;
|
|
36
|
+
const int inp_jump_c;
|
|
37
|
+
|
|
38
|
+
const int tiles_n;
|
|
39
|
+
const int tiles_m;
|
|
40
|
+
const int swizzle_log;
|
|
41
|
+
};
|
|
42
|
+
|
|
43
|
+
struct Conv2DGeneralJumpParams {
|
|
44
|
+
const int f_wgt_jump_h;
|
|
45
|
+
const int f_wgt_jump_w;
|
|
46
|
+
|
|
47
|
+
const int f_out_jump_h;
|
|
48
|
+
const int f_out_jump_w;
|
|
49
|
+
|
|
50
|
+
const int adj_out_h;
|
|
51
|
+
const int adj_out_w;
|
|
52
|
+
const int adj_out_hw;
|
|
53
|
+
const int adj_implicit_m;
|
|
54
|
+
};
|
|
55
|
+
|
|
56
|
+
struct Conv2DGeneralBaseInfo {
|
|
57
|
+
int weight_base;
|
|
58
|
+
int weight_size;
|
|
59
|
+
};
|
|
60
|
+
|
|
61
|
+
} // namespace steel
|
|
62
|
+
} // namespace mlx
|
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/backend/metal/kernels/steel/gemm/loader.h"
|
|
6
|
+
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
|
7
|
+
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
|
|
8
|
+
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
|
|
9
|
+
#include "mlx/backend/metal/kernels/steel/utils.h"
|
|
10
|
+
|
|
11
|
+
using namespace metal;
|
|
12
|
+
|
|
13
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
14
|
+
// GEMM kernel class
|
|
15
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
16
|
+
|
|
17
|
+
namespace mlx {
|
|
18
|
+
namespace steel {
|
|
19
|
+
|
|
20
|
+
template <bool M_aligned, bool N_aligned, bool K_aligned>
|
|
21
|
+
struct LoopAlignment {};
|
|
22
|
+
|
|
23
|
+
template <
|
|
24
|
+
typename T,
|
|
25
|
+
typename U,
|
|
26
|
+
int BM,
|
|
27
|
+
int BN,
|
|
28
|
+
int BK,
|
|
29
|
+
int WM,
|
|
30
|
+
int WN,
|
|
31
|
+
bool transpose_a,
|
|
32
|
+
bool transpose_b,
|
|
33
|
+
bool MN_aligned,
|
|
34
|
+
bool K_aligned,
|
|
35
|
+
typename AccumType = typename AccumHelper<T>::accum_type,
|
|
36
|
+
typename Epilogue = TransformNone<U, AccumType>>
|
|
37
|
+
struct GEMMKernel {
|
|
38
|
+
STEEL_CONST short tgp_padding_a = 16 / sizeof(T);
|
|
39
|
+
STEEL_CONST short tgp_padding_b = 16 / sizeof(T);
|
|
40
|
+
STEEL_CONST short tgp_mem_size_a =
|
|
41
|
+
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
|
|
42
|
+
STEEL_CONST short tgp_mem_size_b =
|
|
43
|
+
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
|
|
44
|
+
STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
|
|
45
|
+
|
|
46
|
+
STEEL_CONST short tgp_size = WM * WN * 32;
|
|
47
|
+
|
|
48
|
+
using loader_a_t = BlockLoader<
|
|
49
|
+
T,
|
|
50
|
+
transpose_a ? BK : BM,
|
|
51
|
+
transpose_a ? BM : BK,
|
|
52
|
+
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
|
|
53
|
+
!transpose_a,
|
|
54
|
+
tgp_size>;
|
|
55
|
+
using loader_b_t = BlockLoader<
|
|
56
|
+
T,
|
|
57
|
+
transpose_b ? BN : BK,
|
|
58
|
+
transpose_b ? BK : BN,
|
|
59
|
+
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
|
|
60
|
+
transpose_b,
|
|
61
|
+
tgp_size>;
|
|
62
|
+
using mma_t = BlockMMA<
|
|
63
|
+
T,
|
|
64
|
+
U,
|
|
65
|
+
BM,
|
|
66
|
+
BN,
|
|
67
|
+
BK,
|
|
68
|
+
WM,
|
|
69
|
+
WN,
|
|
70
|
+
transpose_a,
|
|
71
|
+
transpose_b,
|
|
72
|
+
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
|
|
73
|
+
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
|
|
74
|
+
AccumType,
|
|
75
|
+
Epilogue>;
|
|
76
|
+
|
|
77
|
+
/* Main kernel function */
|
|
78
|
+
template <bool M_aligned, bool N_aligned, bool K_aligned_>
|
|
79
|
+
static METAL_FUNC void gemm_loop(
|
|
80
|
+
threadgroup T* As [[threadgroup(0)]],
|
|
81
|
+
threadgroup T* Bs [[threadgroup(1)]],
|
|
82
|
+
const int gemm_k_iterations,
|
|
83
|
+
thread loader_a_t& loader_a,
|
|
84
|
+
thread loader_b_t& loader_b,
|
|
85
|
+
thread mma_t& mma_op,
|
|
86
|
+
thread const short& tgp_bm,
|
|
87
|
+
thread const short& tgp_bn,
|
|
88
|
+
thread const short& lbk,
|
|
89
|
+
LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) {
|
|
90
|
+
// Appease the compiler
|
|
91
|
+
(void)l;
|
|
92
|
+
|
|
93
|
+
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
|
|
94
|
+
|
|
95
|
+
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
|
|
96
|
+
|
|
97
|
+
for (int k = 0; k < gemm_k_iterations; k++) {
|
|
98
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
99
|
+
// Load elements into threadgroup
|
|
100
|
+
if (M_aligned) {
|
|
101
|
+
loader_a.load_unsafe();
|
|
102
|
+
} else {
|
|
103
|
+
loader_a.load_safe(tile_dims_A);
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
if (N_aligned) {
|
|
107
|
+
loader_b.load_unsafe();
|
|
108
|
+
} else {
|
|
109
|
+
loader_b.load_safe(tile_dims_B);
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
113
|
+
|
|
114
|
+
// Multiply and accumulate threadgroup elements
|
|
115
|
+
mma_op.mma(As, Bs);
|
|
116
|
+
|
|
117
|
+
// Prepare for next iteration
|
|
118
|
+
loader_a.next();
|
|
119
|
+
loader_b.next();
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
if (!K_aligned_) {
|
|
123
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
124
|
+
|
|
125
|
+
short2 tile_dims_A_last =
|
|
126
|
+
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
|
|
127
|
+
short2 tile_dims_B_last =
|
|
128
|
+
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
|
|
129
|
+
|
|
130
|
+
loader_a.load_safe(tile_dims_A_last);
|
|
131
|
+
loader_b.load_safe(tile_dims_B_last);
|
|
132
|
+
|
|
133
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
134
|
+
|
|
135
|
+
mma_op.mma(As, Bs);
|
|
136
|
+
}
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
/* Main kernel function */
|
|
140
|
+
static METAL_FUNC void run(
|
|
141
|
+
const device T* A [[buffer(0)]],
|
|
142
|
+
const device T* B [[buffer(1)]],
|
|
143
|
+
device U* D [[buffer(2)]],
|
|
144
|
+
const constant GEMMParams* params [[buffer(3)]],
|
|
145
|
+
threadgroup T* As [[threadgroup(0)]],
|
|
146
|
+
threadgroup T* Bs [[threadgroup(1)]],
|
|
147
|
+
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
148
|
+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
149
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
150
|
+
uint3 lid [[thread_position_in_threadgroup]]) {
|
|
151
|
+
// Pacifying compiler
|
|
152
|
+
(void)lid;
|
|
153
|
+
|
|
154
|
+
const int tid_y = ((tid.y) << params->swizzle_log) +
|
|
155
|
+
((tid.x) & ((1 << params->swizzle_log) - 1));
|
|
156
|
+
const int tid_x = (tid.x) >> params->swizzle_log;
|
|
157
|
+
|
|
158
|
+
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
|
159
|
+
return;
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
threadgroup_barrier(mem_flags::mem_none);
|
|
163
|
+
|
|
164
|
+
// Find block in A, B, C
|
|
165
|
+
const int c_row = tid_y * BM;
|
|
166
|
+
const int c_col = tid_x * BN;
|
|
167
|
+
const size_t c_row_long = size_t(c_row);
|
|
168
|
+
const size_t c_col_long = size_t(c_col);
|
|
169
|
+
|
|
170
|
+
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
|
171
|
+
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
|
172
|
+
D += c_row_long * params->ldd + c_col_long;
|
|
173
|
+
|
|
174
|
+
// Prepare threadgroup loading operations
|
|
175
|
+
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
|
176
|
+
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
|
177
|
+
|
|
178
|
+
// Prepare threadgroup mma operation
|
|
179
|
+
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
|
180
|
+
|
|
181
|
+
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
|
182
|
+
|
|
183
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
184
|
+
// MNK aligned loop
|
|
185
|
+
if (MN_aligned) {
|
|
186
|
+
for (int k = 0; k < gemm_k_iterations; k++) {
|
|
187
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
188
|
+
// Load elements into threadgroup
|
|
189
|
+
loader_a.load_unsafe();
|
|
190
|
+
loader_b.load_unsafe();
|
|
191
|
+
|
|
192
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
193
|
+
|
|
194
|
+
// Multiply and accumulate threadgroup elements
|
|
195
|
+
mma_op.mma(As, Bs);
|
|
196
|
+
|
|
197
|
+
// Prepare for next iteration
|
|
198
|
+
loader_a.next();
|
|
199
|
+
loader_b.next();
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
threadgroup_barrier(mem_flags::mem_none);
|
|
203
|
+
|
|
204
|
+
// Loop tail
|
|
205
|
+
if (!K_aligned) {
|
|
206
|
+
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
|
207
|
+
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
|
208
|
+
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
|
209
|
+
|
|
210
|
+
loader_a.load_safe(tile_dims_A);
|
|
211
|
+
loader_b.load_safe(tile_dims_B);
|
|
212
|
+
|
|
213
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
214
|
+
|
|
215
|
+
mma_op.mma(As, Bs);
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
// Store results to device memory
|
|
219
|
+
mma_op.store_result(D, params->ldd);
|
|
220
|
+
return;
|
|
221
|
+
|
|
222
|
+
}
|
|
223
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
224
|
+
// MN unaligned loop
|
|
225
|
+
else { // Loop over K - unaligned case
|
|
226
|
+
short tgp_bm = min(BM, params->M - c_row);
|
|
227
|
+
short tgp_bn = min(BN, params->N - c_col);
|
|
228
|
+
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
|
|
229
|
+
|
|
230
|
+
if (tgp_bm == BM && tgp_bn == BN) {
|
|
231
|
+
gemm_loop<true, true, K_aligned>(
|
|
232
|
+
As,
|
|
233
|
+
Bs,
|
|
234
|
+
gemm_k_iterations,
|
|
235
|
+
loader_a,
|
|
236
|
+
loader_b,
|
|
237
|
+
mma_op,
|
|
238
|
+
tgp_bm,
|
|
239
|
+
tgp_bn,
|
|
240
|
+
leftover_bk);
|
|
241
|
+
|
|
242
|
+
mma_op.store_result(D, params->ldd);
|
|
243
|
+
return;
|
|
244
|
+
|
|
245
|
+
} else if (tgp_bn == BN) {
|
|
246
|
+
gemm_loop<false, true, K_aligned>(
|
|
247
|
+
As,
|
|
248
|
+
Bs,
|
|
249
|
+
gemm_k_iterations,
|
|
250
|
+
loader_a,
|
|
251
|
+
loader_b,
|
|
252
|
+
mma_op,
|
|
253
|
+
tgp_bm,
|
|
254
|
+
tgp_bn,
|
|
255
|
+
leftover_bk);
|
|
256
|
+
|
|
257
|
+
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
|
258
|
+
return;
|
|
259
|
+
|
|
260
|
+
} else if (tgp_bm == BM) {
|
|
261
|
+
gemm_loop<true, false, K_aligned>(
|
|
262
|
+
As,
|
|
263
|
+
Bs,
|
|
264
|
+
gemm_k_iterations,
|
|
265
|
+
loader_a,
|
|
266
|
+
loader_b,
|
|
267
|
+
mma_op,
|
|
268
|
+
tgp_bm,
|
|
269
|
+
tgp_bn,
|
|
270
|
+
leftover_bk);
|
|
271
|
+
|
|
272
|
+
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
|
273
|
+
return;
|
|
274
|
+
|
|
275
|
+
} else {
|
|
276
|
+
gemm_loop<false, false, K_aligned>(
|
|
277
|
+
As,
|
|
278
|
+
Bs,
|
|
279
|
+
gemm_k_iterations,
|
|
280
|
+
loader_a,
|
|
281
|
+
loader_b,
|
|
282
|
+
mma_op,
|
|
283
|
+
tgp_bm,
|
|
284
|
+
tgp_bn,
|
|
285
|
+
leftover_bk);
|
|
286
|
+
|
|
287
|
+
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
|
288
|
+
return;
|
|
289
|
+
}
|
|
290
|
+
}
|
|
291
|
+
}
|
|
292
|
+
};
|
|
293
|
+
|
|
294
|
+
} // namespace steel
|
|
295
|
+
} // namespace mlx
|
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
// Copyright © 2025 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/backend/metal/kernels/steel/gemm/nax.h"
|
|
6
|
+
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
|
|
7
|
+
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
|
|
8
|
+
#include "mlx/backend/metal/kernels/steel/utils.h"
|
|
9
|
+
|
|
10
|
+
using namespace metal;
|
|
11
|
+
|
|
12
|
+
namespace mlx::steel {
|
|
13
|
+
|
|
14
|
+
template <
|
|
15
|
+
typename T,
|
|
16
|
+
short SM,
|
|
17
|
+
short SN,
|
|
18
|
+
short SK,
|
|
19
|
+
short BK,
|
|
20
|
+
bool transpose_a,
|
|
21
|
+
bool transpose_b,
|
|
22
|
+
bool kAlignedM,
|
|
23
|
+
bool kAlignedN,
|
|
24
|
+
bool kAlignedK,
|
|
25
|
+
short UM,
|
|
26
|
+
short UN,
|
|
27
|
+
short UK,
|
|
28
|
+
typename AccumType = float>
|
|
29
|
+
auto gemm_loop(
|
|
30
|
+
const device T* A,
|
|
31
|
+
const device T* B,
|
|
32
|
+
const constant GEMMParams* params [[buffer(4)]],
|
|
33
|
+
const short sgp_sm,
|
|
34
|
+
const short sgp_sn) {
|
|
35
|
+
constexpr short TM = SM / UM;
|
|
36
|
+
constexpr short TN = SN / UN;
|
|
37
|
+
constexpr short TK = SK / UK;
|
|
38
|
+
|
|
39
|
+
constexpr int RA = transpose_a ? TK : TM;
|
|
40
|
+
constexpr int CA = transpose_a ? TM : TK;
|
|
41
|
+
|
|
42
|
+
constexpr int RB = transpose_b ? TN : TK;
|
|
43
|
+
constexpr int CB = transpose_b ? TK : TN;
|
|
44
|
+
|
|
45
|
+
using DSubTile = NAXSubTile<AccumType, UM, UN>;
|
|
46
|
+
using ASubTile =
|
|
47
|
+
NAXSubTile<T, (transpose_a ? UK : UM), (transpose_a ? UM : UK)>;
|
|
48
|
+
using BSubTile =
|
|
49
|
+
NAXSubTile<T, (transpose_b ? UN : UK), (transpose_b ? UK : UN)>;
|
|
50
|
+
|
|
51
|
+
NAXTile<AccumType, TM, TN, DSubTile> Dtile;
|
|
52
|
+
Dtile.clear();
|
|
53
|
+
|
|
54
|
+
int gemm_k_iterations_ = params->gemm_k_iterations_aligned;
|
|
55
|
+
|
|
56
|
+
STEEL_PRAGMA_NO_UNROLL
|
|
57
|
+
for (int kk0 = 0; kk0 < gemm_k_iterations_; kk0++) {
|
|
58
|
+
threadgroup_barrier(mem_flags::mem_none);
|
|
59
|
+
|
|
60
|
+
STEEL_PRAGMA_NO_UNROLL
|
|
61
|
+
for (int kk1 = 0; kk1 < BK; kk1 += SK) {
|
|
62
|
+
NAXTile<T, RA, CA, ASubTile> Atile;
|
|
63
|
+
NAXTile<T, RB, CB, BSubTile> Btile;
|
|
64
|
+
const int k = kk1;
|
|
65
|
+
|
|
66
|
+
volatile int compiler_barrier;
|
|
67
|
+
|
|
68
|
+
const int A_offset = transpose_a ? k * params->lda : k;
|
|
69
|
+
const int B_offset = transpose_b ? k : k * params->ldb;
|
|
70
|
+
|
|
71
|
+
if constexpr (kAlignedM) {
|
|
72
|
+
Atile.load(A + A_offset, params->lda);
|
|
73
|
+
} else {
|
|
74
|
+
const short rmax = transpose_a ? SK : sgp_sm;
|
|
75
|
+
const short cmax = transpose_a ? sgp_sm : SK;
|
|
76
|
+
Atile.load_safe(A + A_offset, params->lda, short2(cmax, rmax));
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
if constexpr (kAlignedN) {
|
|
80
|
+
Btile.load(B + B_offset, params->ldb);
|
|
81
|
+
} else {
|
|
82
|
+
const short rmax = transpose_b ? sgp_sn : SK;
|
|
83
|
+
const short cmax = transpose_b ? SK : sgp_sn;
|
|
84
|
+
Btile.load_safe(B + B_offset, params->ldb, short2(cmax, rmax));
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
tile_matmad_nax(
|
|
88
|
+
Dtile,
|
|
89
|
+
Atile,
|
|
90
|
+
metal::bool_constant<transpose_a>{},
|
|
91
|
+
Btile,
|
|
92
|
+
metal::bool_constant<transpose_b>{});
|
|
93
|
+
|
|
94
|
+
(void)compiler_barrier;
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
A += transpose_a ? (BK * params->lda) : BK;
|
|
98
|
+
B += transpose_b ? BK : (BK * params->ldb);
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
if constexpr (!kAlignedK) {
|
|
102
|
+
simdgroup_barrier(mem_flags::mem_none);
|
|
103
|
+
|
|
104
|
+
const short rem_bk = params->K - gemm_k_iterations_ * BK;
|
|
105
|
+
|
|
106
|
+
STEEL_PRAGMA_NO_UNROLL
|
|
107
|
+
for (int kk1 = 0; kk1 < rem_bk; kk1 += SK) {
|
|
108
|
+
NAXTile<T, 1, 1, ASubTile> Atile;
|
|
109
|
+
NAXTile<T, 1, 1, BSubTile> Btile;
|
|
110
|
+
|
|
111
|
+
STEEL_PRAGMA_UNROLL
|
|
112
|
+
for (int mm = 0; mm < TM; mm++) {
|
|
113
|
+
STEEL_PRAGMA_UNROLL
|
|
114
|
+
for (int nn = 0; nn < TN; nn++) {
|
|
115
|
+
STEEL_PRAGMA_UNROLL
|
|
116
|
+
for (int kk = 0; kk < TK; kk++) {
|
|
117
|
+
const int m = mm * UM;
|
|
118
|
+
const int n = nn * UN;
|
|
119
|
+
const int k = kk1 + kk * UK;
|
|
120
|
+
const short psk = max(0, rem_bk - k);
|
|
121
|
+
|
|
122
|
+
const int A_offset =
|
|
123
|
+
transpose_a ? (m + k * params->lda) : (m * params->lda + k);
|
|
124
|
+
const int B_offset =
|
|
125
|
+
transpose_b ? (k + n * params->ldb) : (k * params->ldb + n);
|
|
126
|
+
|
|
127
|
+
{
|
|
128
|
+
const short psm = kAlignedM ? SM : max(0, sgp_sm - m);
|
|
129
|
+
const short rmax = transpose_a ? psk : psm;
|
|
130
|
+
const short cmax = transpose_a ? psm : psk;
|
|
131
|
+
Atile.load_safe(A + A_offset, params->lda, short2(cmax, rmax));
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
{
|
|
135
|
+
const short psn = kAlignedN ? SN : max(0, sgp_sn - n);
|
|
136
|
+
const short rmax = transpose_b ? psn : psk;
|
|
137
|
+
const short cmax = transpose_b ? psk : psn;
|
|
138
|
+
Btile.load_safe(B + B_offset, params->ldb, short2(cmax, rmax));
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
subtile_matmad_nax(
|
|
142
|
+
Dtile.subtile_at(mm, nn),
|
|
143
|
+
Atile.subtile_at(0, 0),
|
|
144
|
+
metal::bool_constant<transpose_a>{},
|
|
145
|
+
Btile.subtile_at(0, 0),
|
|
146
|
+
metal::bool_constant<transpose_b>{});
|
|
147
|
+
}
|
|
148
|
+
}
|
|
149
|
+
}
|
|
150
|
+
}
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
return Dtile;
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
} // namespace mlx::steel
|