mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,346 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
using namespace mlx::steel;
|
|
4
|
+
|
|
5
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
6
|
+
// GEMM kernels
|
|
7
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
8
|
+
|
|
9
|
+
constant bool has_batch [[function_constant(10)]];
|
|
10
|
+
|
|
11
|
+
constant bool use_out_source [[function_constant(100)]];
|
|
12
|
+
constant bool do_axpby [[function_constant(110)]];
|
|
13
|
+
|
|
14
|
+
constant bool align_M [[function_constant(200)]];
|
|
15
|
+
constant bool align_N [[function_constant(201)]];
|
|
16
|
+
constant bool align_K [[function_constant(202)]];
|
|
17
|
+
|
|
18
|
+
// clang-format off
|
|
19
|
+
template <
|
|
20
|
+
typename T,
|
|
21
|
+
int BM,
|
|
22
|
+
int BN,
|
|
23
|
+
int BK,
|
|
24
|
+
int WM,
|
|
25
|
+
int WN,
|
|
26
|
+
bool transpose_a,
|
|
27
|
+
bool transpose_b,
|
|
28
|
+
typename AccumType = float>
|
|
29
|
+
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm(
|
|
30
|
+
const device T* A [[buffer(0)]],
|
|
31
|
+
const device T* B [[buffer(1)]],
|
|
32
|
+
const device T* C [[buffer(2), function_constant(use_out_source)]],
|
|
33
|
+
device T* D [[buffer(3)]],
|
|
34
|
+
const constant GEMMParams* params [[buffer(4)]],
|
|
35
|
+
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
|
36
|
+
const constant int* batch_shape [[buffer(6), function_constant(has_batch)]],
|
|
37
|
+
const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]],
|
|
38
|
+
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
39
|
+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
40
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
41
|
+
uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
|
|
42
|
+
// Pacifying compiler
|
|
43
|
+
(void)lid;
|
|
44
|
+
|
|
45
|
+
using gemm_kernel = GEMMKernel<
|
|
46
|
+
T,
|
|
47
|
+
T,
|
|
48
|
+
BM,
|
|
49
|
+
BN,
|
|
50
|
+
BK,
|
|
51
|
+
WM,
|
|
52
|
+
WN,
|
|
53
|
+
transpose_a,
|
|
54
|
+
transpose_b,
|
|
55
|
+
true,
|
|
56
|
+
true,
|
|
57
|
+
AccumType>;
|
|
58
|
+
|
|
59
|
+
using loader_a_t = typename gemm_kernel::loader_a_t;
|
|
60
|
+
using loader_b_t = typename gemm_kernel::loader_b_t;
|
|
61
|
+
using mma_t = typename gemm_kernel::mma_t;
|
|
62
|
+
|
|
63
|
+
// Find block
|
|
64
|
+
const int tid_y = ((tid.y) << params->swizzle_log) +
|
|
65
|
+
((tid.x) & ((1 << params->swizzle_log) - 1));
|
|
66
|
+
const int tid_x = (tid.x) >> params->swizzle_log;
|
|
67
|
+
|
|
68
|
+
// Exit early if out of bounds
|
|
69
|
+
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
|
70
|
+
return;
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
// Adjust for batch
|
|
74
|
+
if (has_batch) {
|
|
75
|
+
const constant auto* A_bstrides = batch_strides;
|
|
76
|
+
const constant auto* B_bstrides = batch_strides + params->batch_ndim;
|
|
77
|
+
|
|
78
|
+
ulong2 batch_offsets = elem_to_loc_broadcast(
|
|
79
|
+
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
|
80
|
+
|
|
81
|
+
A += batch_offsets.x;
|
|
82
|
+
B += batch_offsets.y;
|
|
83
|
+
|
|
84
|
+
if (use_out_source) {
|
|
85
|
+
const constant auto* C_bstrides = B_bstrides + params->batch_ndim;
|
|
86
|
+
C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
|
|
87
|
+
}
|
|
88
|
+
} else {
|
|
89
|
+
A += params->batch_stride_a * tid.z;
|
|
90
|
+
B += params->batch_stride_b * tid.z;
|
|
91
|
+
|
|
92
|
+
if (use_out_source) {
|
|
93
|
+
C += addmm_params->batch_stride_c * tid.z;
|
|
94
|
+
}
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
D += params->batch_stride_d * tid.z;
|
|
98
|
+
|
|
99
|
+
// Prepare threadgroup memory
|
|
100
|
+
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
|
101
|
+
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
|
102
|
+
|
|
103
|
+
threadgroup_barrier(mem_flags::mem_none);
|
|
104
|
+
|
|
105
|
+
// Find block in A, B, C
|
|
106
|
+
const int c_row = tid_y * BM;
|
|
107
|
+
const int c_col = tid_x * BN;
|
|
108
|
+
const size_t c_row_long = size_t(c_row);
|
|
109
|
+
const size_t c_col_long = size_t(c_col);
|
|
110
|
+
|
|
111
|
+
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
|
112
|
+
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
|
113
|
+
D += c_row_long * params->ldd + c_col_long;
|
|
114
|
+
|
|
115
|
+
if (use_out_source) {
|
|
116
|
+
C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc;
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
// Prepare threadgroup mma operation
|
|
120
|
+
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
|
121
|
+
|
|
122
|
+
// Prepare threadgroup loading operations
|
|
123
|
+
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
|
124
|
+
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
|
125
|
+
|
|
126
|
+
// Prepare threadgroup bounds
|
|
127
|
+
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
|
|
128
|
+
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
|
|
129
|
+
|
|
130
|
+
// Prepare iterations
|
|
131
|
+
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
|
132
|
+
|
|
133
|
+
// Do unaligned K iterations first
|
|
134
|
+
if (!align_K) {
|
|
135
|
+
const int k_last = params->gemm_k_iterations_aligned * BK;
|
|
136
|
+
const int k_remain = params->K - k_last;
|
|
137
|
+
const size_t k_jump_a =
|
|
138
|
+
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
|
|
139
|
+
const size_t k_jump_b =
|
|
140
|
+
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
|
|
141
|
+
|
|
142
|
+
// Move loader source ahead to end
|
|
143
|
+
loader_a.src += k_jump_a;
|
|
144
|
+
loader_b.src += k_jump_b;
|
|
145
|
+
|
|
146
|
+
// Load tile
|
|
147
|
+
const short2 tile_dims_A =
|
|
148
|
+
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
|
149
|
+
const short2 tile_dims_B =
|
|
150
|
+
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
|
151
|
+
|
|
152
|
+
loader_a.load_safe(tile_dims_A);
|
|
153
|
+
loader_b.load_safe(tile_dims_B);
|
|
154
|
+
|
|
155
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
156
|
+
|
|
157
|
+
// Do matmul
|
|
158
|
+
mma_op.mma(As, Bs);
|
|
159
|
+
|
|
160
|
+
// Reset source back to start
|
|
161
|
+
loader_a.src -= k_jump_a;
|
|
162
|
+
loader_b.src -= k_jump_b;
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
const TransformAdd<AccumType, AccumType> epilogue_op_add(
|
|
166
|
+
addmm_params->alpha, addmm_params->beta);
|
|
167
|
+
const TransformAxpby<AccumType, AccumType> epilogue_op_axpby(
|
|
168
|
+
addmm_params->alpha, addmm_params->beta);
|
|
169
|
+
|
|
170
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
171
|
+
// MNK aligned loop
|
|
172
|
+
if (align_M && align_N) {
|
|
173
|
+
// Do gemm
|
|
174
|
+
for (int k = 0; k < gemm_k_iterations; k++) {
|
|
175
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
176
|
+
// Load elements into threadgroup
|
|
177
|
+
loader_a.load_unsafe();
|
|
178
|
+
loader_b.load_unsafe();
|
|
179
|
+
|
|
180
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
181
|
+
|
|
182
|
+
// Multiply and accumulate threadgroup elements
|
|
183
|
+
mma_op.mma(As, Bs);
|
|
184
|
+
|
|
185
|
+
// Prepare for next iteration
|
|
186
|
+
loader_a.next();
|
|
187
|
+
loader_b.next();
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
threadgroup_barrier(mem_flags::mem_none);
|
|
191
|
+
|
|
192
|
+
// Do epilogue
|
|
193
|
+
if (use_out_source) {
|
|
194
|
+
if (do_axpby) {
|
|
195
|
+
mma_op.apply_epilogue(
|
|
196
|
+
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);
|
|
197
|
+
} else {
|
|
198
|
+
mma_op.apply_epilogue(
|
|
199
|
+
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);
|
|
200
|
+
}
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
// Store results to device memory
|
|
204
|
+
return mma_op.store_result(D, params->ldd);
|
|
205
|
+
|
|
206
|
+
}
|
|
207
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
208
|
+
// MN unaligned loop
|
|
209
|
+
else { // Loop over K - unaligned case
|
|
210
|
+
const int leftover_bk = 0;
|
|
211
|
+
|
|
212
|
+
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
|
|
213
|
+
// Do gemm
|
|
214
|
+
gemm_kernel::gemm_loop(
|
|
215
|
+
As,
|
|
216
|
+
Bs,
|
|
217
|
+
gemm_k_iterations,
|
|
218
|
+
loader_a,
|
|
219
|
+
loader_b,
|
|
220
|
+
mma_op,
|
|
221
|
+
tgp_bm,
|
|
222
|
+
tgp_bn,
|
|
223
|
+
leftover_bk,
|
|
224
|
+
LoopAlignment<true, true, true>{});
|
|
225
|
+
|
|
226
|
+
// Do epilogue
|
|
227
|
+
if (use_out_source) {
|
|
228
|
+
if (do_axpby) {
|
|
229
|
+
mma_op.apply_epilogue(
|
|
230
|
+
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);
|
|
231
|
+
} else {
|
|
232
|
+
mma_op.apply_epilogue(
|
|
233
|
+
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);
|
|
234
|
+
}
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
// Store results to device memory
|
|
238
|
+
return mma_op.store_result(D, params->ldd);
|
|
239
|
+
|
|
240
|
+
} else if (align_N || tgp_bn == BN) {
|
|
241
|
+
gemm_kernel::gemm_loop(
|
|
242
|
+
As,
|
|
243
|
+
Bs,
|
|
244
|
+
gemm_k_iterations,
|
|
245
|
+
loader_a,
|
|
246
|
+
loader_b,
|
|
247
|
+
mma_op,
|
|
248
|
+
tgp_bm,
|
|
249
|
+
tgp_bn,
|
|
250
|
+
leftover_bk,
|
|
251
|
+
LoopAlignment<false, true, true>{});
|
|
252
|
+
|
|
253
|
+
// Do epilogue
|
|
254
|
+
if (use_out_source) {
|
|
255
|
+
if (do_axpby) {
|
|
256
|
+
mma_op.apply_epilogue_safe(
|
|
257
|
+
C,
|
|
258
|
+
addmm_params->ldc,
|
|
259
|
+
addmm_params->fdc,
|
|
260
|
+
short2(tgp_bn, tgp_bm),
|
|
261
|
+
epilogue_op_axpby);
|
|
262
|
+
} else {
|
|
263
|
+
mma_op.apply_epilogue_safe(
|
|
264
|
+
C,
|
|
265
|
+
addmm_params->ldc,
|
|
266
|
+
addmm_params->fdc,
|
|
267
|
+
short2(tgp_bn, tgp_bm),
|
|
268
|
+
epilogue_op_add);
|
|
269
|
+
}
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
// Store results to device memory
|
|
273
|
+
return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
|
274
|
+
|
|
275
|
+
} else if (align_M || tgp_bm == BM) {
|
|
276
|
+
gemm_kernel::gemm_loop(
|
|
277
|
+
As,
|
|
278
|
+
Bs,
|
|
279
|
+
gemm_k_iterations,
|
|
280
|
+
loader_a,
|
|
281
|
+
loader_b,
|
|
282
|
+
mma_op,
|
|
283
|
+
tgp_bm,
|
|
284
|
+
tgp_bn,
|
|
285
|
+
leftover_bk,
|
|
286
|
+
LoopAlignment<true, false, true>{});
|
|
287
|
+
|
|
288
|
+
// Do epilogue
|
|
289
|
+
if (use_out_source) {
|
|
290
|
+
if (do_axpby) {
|
|
291
|
+
mma_op.apply_epilogue_safe(
|
|
292
|
+
C,
|
|
293
|
+
addmm_params->ldc,
|
|
294
|
+
addmm_params->fdc,
|
|
295
|
+
short2(tgp_bn, tgp_bm),
|
|
296
|
+
epilogue_op_axpby);
|
|
297
|
+
} else {
|
|
298
|
+
mma_op.apply_epilogue_safe(
|
|
299
|
+
C,
|
|
300
|
+
addmm_params->ldc,
|
|
301
|
+
addmm_params->fdc,
|
|
302
|
+
short2(tgp_bn, tgp_bm),
|
|
303
|
+
epilogue_op_add);
|
|
304
|
+
}
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
// Store results to device memory
|
|
308
|
+
return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
|
309
|
+
|
|
310
|
+
} else {
|
|
311
|
+
gemm_kernel::gemm_loop(
|
|
312
|
+
As,
|
|
313
|
+
Bs,
|
|
314
|
+
gemm_k_iterations,
|
|
315
|
+
loader_a,
|
|
316
|
+
loader_b,
|
|
317
|
+
mma_op,
|
|
318
|
+
tgp_bm,
|
|
319
|
+
tgp_bn,
|
|
320
|
+
leftover_bk,
|
|
321
|
+
LoopAlignment<false, false, true>{});
|
|
322
|
+
|
|
323
|
+
// Do epilogue
|
|
324
|
+
if (use_out_source) {
|
|
325
|
+
if (do_axpby) {
|
|
326
|
+
mma_op.apply_epilogue_safe(
|
|
327
|
+
C,
|
|
328
|
+
addmm_params->ldc,
|
|
329
|
+
addmm_params->fdc,
|
|
330
|
+
short2(tgp_bn, tgp_bm),
|
|
331
|
+
epilogue_op_axpby);
|
|
332
|
+
} else {
|
|
333
|
+
mma_op.apply_epilogue_safe(
|
|
334
|
+
C,
|
|
335
|
+
addmm_params->ldc,
|
|
336
|
+
addmm_params->fdc,
|
|
337
|
+
short2(tgp_bn, tgp_bm),
|
|
338
|
+
epilogue_op_add);
|
|
339
|
+
}
|
|
340
|
+
}
|
|
341
|
+
|
|
342
|
+
// Store results to device memory
|
|
343
|
+
return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
|
344
|
+
}
|
|
345
|
+
}
|
|
346
|
+
}
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
// Copyright © 2025 Apple Inc.
|
|
2
|
+
|
|
3
|
+
using namespace mlx::steel;
|
|
4
|
+
|
|
5
|
+
constant bool has_batch [[function_constant(10)]];
|
|
6
|
+
|
|
7
|
+
constant bool use_out_source [[function_constant(100)]];
|
|
8
|
+
constant bool do_axpby [[function_constant(110)]];
|
|
9
|
+
|
|
10
|
+
constant bool align_M [[function_constant(200)]];
|
|
11
|
+
constant bool align_N [[function_constant(201)]];
|
|
12
|
+
constant bool align_K [[function_constant(202)]];
|
|
13
|
+
|
|
14
|
+
// clang-format off
|
|
15
|
+
template <
|
|
16
|
+
bool kAlignedM,
|
|
17
|
+
bool kAlignedN,
|
|
18
|
+
typename NAXTile_t,
|
|
19
|
+
typename T>
|
|
20
|
+
void gemm_epilogue(
|
|
21
|
+
thread NAXTile_t& Dtile,
|
|
22
|
+
const device T* C,
|
|
23
|
+
const constant GEMMParams* params,
|
|
24
|
+
const constant GEMMAddMMParams* addmm_params,
|
|
25
|
+
const short sgp_sm,
|
|
26
|
+
const short sgp_sn) { // clang-format on
|
|
27
|
+
|
|
28
|
+
(void)params;
|
|
29
|
+
|
|
30
|
+
constexpr short UM = NAXTile_t::kSubTileRows;
|
|
31
|
+
constexpr short UN = NAXTile_t::kSubTileCols;
|
|
32
|
+
using CSubTile = NAXSubTile<T, UM, UN>;
|
|
33
|
+
|
|
34
|
+
using V = typename NAXTile_t::elem_type;
|
|
35
|
+
|
|
36
|
+
constexpr short TM = NAXTile_t::kTileRows;
|
|
37
|
+
constexpr short TN = NAXTile_t::kTileCols;
|
|
38
|
+
constexpr short kElemsPerSubTile = NAXTile_t::kElemsPerSubTile;
|
|
39
|
+
|
|
40
|
+
STEEL_PRAGMA_UNROLL
|
|
41
|
+
for (short mm = 0; mm < TM; mm++) {
|
|
42
|
+
STEEL_PRAGMA_UNROLL
|
|
43
|
+
for (short nn = 0; nn < TN; nn++) {
|
|
44
|
+
const short m = mm * UM;
|
|
45
|
+
const short n = nn * UN;
|
|
46
|
+
|
|
47
|
+
CSubTile CTile;
|
|
48
|
+
|
|
49
|
+
if constexpr (kAlignedM && kAlignedN) {
|
|
50
|
+
CTile.load(C, addmm_params->ldc, addmm_params->fdc, m, n);
|
|
51
|
+
} else {
|
|
52
|
+
CTile.load_safe(
|
|
53
|
+
C, addmm_params->ldc, addmm_params->fdc, sgp_sm, sgp_sn, m, n);
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
auto delems = Dtile.subtile_at(mm, nn).elems();
|
|
57
|
+
auto celems = CTile.elems();
|
|
58
|
+
|
|
59
|
+
STEEL_PRAGMA_UNROLL
|
|
60
|
+
for (short i = 0; i < kElemsPerSubTile; i++) {
|
|
61
|
+
if (do_axpby) {
|
|
62
|
+
delems[i] = addmm_params->alpha * delems[i] +
|
|
63
|
+
addmm_params->beta * static_cast<V>(celems[i]);
|
|
64
|
+
} else {
|
|
65
|
+
delems[i] += static_cast<V>(celems[i]);
|
|
66
|
+
}
|
|
67
|
+
}
|
|
68
|
+
}
|
|
69
|
+
}
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
// clang-format off
|
|
73
|
+
template <
|
|
74
|
+
typename T,
|
|
75
|
+
int BM,
|
|
76
|
+
int BN,
|
|
77
|
+
int BK,
|
|
78
|
+
int WM,
|
|
79
|
+
int WN,
|
|
80
|
+
bool transpose_a,
|
|
81
|
+
bool transpose_b,
|
|
82
|
+
typename AccumType = float>
|
|
83
|
+
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm(
|
|
84
|
+
const device T* A [[buffer(0)]],
|
|
85
|
+
const device T* B [[buffer(1)]],
|
|
86
|
+
const device T* C [[buffer(2), function_constant(use_out_source)]],
|
|
87
|
+
device T* D [[buffer(3)]],
|
|
88
|
+
const constant GEMMParams* params [[buffer(4)]],
|
|
89
|
+
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
|
90
|
+
const constant int* batch_shape [[buffer(6), function_constant(has_batch)]],
|
|
91
|
+
const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]],
|
|
92
|
+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
93
|
+
uint3 tid [[threadgroup_position_in_grid]]) { // clang-format on
|
|
94
|
+
// Find block
|
|
95
|
+
const int tid_y = ((tid.y) << params->swizzle_log) +
|
|
96
|
+
((tid.x) & ((1 << params->swizzle_log) - 1));
|
|
97
|
+
const int tid_x = (tid.x) >> params->swizzle_log;
|
|
98
|
+
|
|
99
|
+
// Exit early if out of bounds
|
|
100
|
+
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
|
101
|
+
return;
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
// Adjust for batch
|
|
105
|
+
if (has_batch) {
|
|
106
|
+
const constant auto* A_bstrides = batch_strides;
|
|
107
|
+
const constant auto* B_bstrides = batch_strides + params->batch_ndim;
|
|
108
|
+
|
|
109
|
+
ulong2 batch_offsets = elem_to_loc_broadcast(
|
|
110
|
+
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
|
111
|
+
|
|
112
|
+
A += batch_offsets.x;
|
|
113
|
+
B += batch_offsets.y;
|
|
114
|
+
|
|
115
|
+
if (use_out_source) {
|
|
116
|
+
const constant auto* C_bstrides = B_bstrides + params->batch_ndim;
|
|
117
|
+
C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
|
|
118
|
+
}
|
|
119
|
+
} else {
|
|
120
|
+
A += params->batch_stride_a * tid.z;
|
|
121
|
+
B += params->batch_stride_b * tid.z;
|
|
122
|
+
|
|
123
|
+
if (use_out_source) {
|
|
124
|
+
C += addmm_params->batch_stride_c * tid.z;
|
|
125
|
+
}
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
D += params->batch_stride_d * tid.z;
|
|
129
|
+
|
|
130
|
+
// Prepare threadgroup memory
|
|
131
|
+
threadgroup_barrier(mem_flags::mem_none);
|
|
132
|
+
|
|
133
|
+
// Find block in A, B, C
|
|
134
|
+
const int c_row = tid_y * BM;
|
|
135
|
+
const int c_col = tid_x * BN;
|
|
136
|
+
const size_t c_row_long = size_t(c_row);
|
|
137
|
+
const size_t c_col_long = size_t(c_col);
|
|
138
|
+
|
|
139
|
+
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
|
140
|
+
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
|
141
|
+
D += c_row_long * params->ldd + c_col_long;
|
|
142
|
+
|
|
143
|
+
if (use_out_source) {
|
|
144
|
+
C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc;
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
constexpr short UM = 16;
|
|
148
|
+
constexpr short UN = 32;
|
|
149
|
+
constexpr short UK = 16;
|
|
150
|
+
constexpr short SM = BM / WM;
|
|
151
|
+
constexpr short SN = BN / WN;
|
|
152
|
+
constexpr short SK = 32;
|
|
153
|
+
|
|
154
|
+
constexpr short TM = SM / UM;
|
|
155
|
+
constexpr short TN = SN / UN;
|
|
156
|
+
|
|
157
|
+
const short tm = SM * (simd_group_id / WN);
|
|
158
|
+
const short tn = SN * (simd_group_id % WN);
|
|
159
|
+
|
|
160
|
+
const short sgp_sm = align_M ? SM : min(SM, short(params->M - (c_row + tm)));
|
|
161
|
+
const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM);
|
|
162
|
+
|
|
163
|
+
const short sgp_sn = align_N ? SN : min(SN, short(params->N - (c_col + tn)));
|
|
164
|
+
const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN);
|
|
165
|
+
|
|
166
|
+
A += transpose_a ? tm : (tm * params->lda);
|
|
167
|
+
B += transpose_b ? (tn * params->ldb) : tn;
|
|
168
|
+
D += tm * params->ldd + tn;
|
|
169
|
+
|
|
170
|
+
if (use_out_source) {
|
|
171
|
+
C += tm * addmm_params->ldc + tn * addmm_params->fdc;
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
using DSubTile = NAXSubTile<AccumType, UM, UN>;
|
|
175
|
+
NAXTile<AccumType, TM, TN, DSubTile> Dtile;
|
|
176
|
+
|
|
177
|
+
dispatch_bool(align_K, [&](auto kAlignedK) {
|
|
178
|
+
dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) {
|
|
179
|
+
dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) {
|
|
180
|
+
Dtile = gemm_loop<
|
|
181
|
+
T,
|
|
182
|
+
SM,
|
|
183
|
+
SN,
|
|
184
|
+
SK,
|
|
185
|
+
BK,
|
|
186
|
+
transpose_a,
|
|
187
|
+
transpose_b,
|
|
188
|
+
kAlignedM.value,
|
|
189
|
+
kAlignedN.value,
|
|
190
|
+
kAlignedK.value,
|
|
191
|
+
UM,
|
|
192
|
+
UN,
|
|
193
|
+
UK,
|
|
194
|
+
AccumType>(A, B, params, sgp_sm, sgp_sn);
|
|
195
|
+
if (use_out_source) {
|
|
196
|
+
gemm_epilogue<kAlignedM.value, kAlignedN.value>(
|
|
197
|
+
Dtile, C, params, addmm_params, sgp_sm, sgp_sn);
|
|
198
|
+
}
|
|
199
|
+
if constexpr (kAlignedM && kAlignedN) {
|
|
200
|
+
Dtile.store(D, int(params->ldd));
|
|
201
|
+
} else {
|
|
202
|
+
Dtile.store_safe(D, int(params->ldd), short2(sgp_sn, sgp_sm));
|
|
203
|
+
}
|
|
204
|
+
});
|
|
205
|
+
});
|
|
206
|
+
});
|
|
207
|
+
}
|