mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,476 @@
|
|
|
1
|
+
// Copyright © 2024-25 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#include "mlx/backend/metal/kernels/steel/attn/attn.h"
|
|
4
|
+
|
|
5
|
+
using namespace mlx::steel;
|
|
6
|
+
|
|
7
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
8
|
+
// GEMM kernels
|
|
9
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
10
|
+
|
|
11
|
+
constant bool align_Q [[function_constant(200)]];
|
|
12
|
+
constant bool align_K [[function_constant(201)]];
|
|
13
|
+
|
|
14
|
+
constant bool has_mask [[function_constant(300)]];
|
|
15
|
+
constant bool do_causal [[function_constant(301)]];
|
|
16
|
+
constant bool has_sinks [[function_constant(302)]];
|
|
17
|
+
|
|
18
|
+
template <typename T>
|
|
19
|
+
struct TransformScale {
|
|
20
|
+
T scale;
|
|
21
|
+
METAL_FUNC TransformScale(T scale_) : scale(scale_) {}
|
|
22
|
+
|
|
23
|
+
METAL_FUNC T apply(T x) const {
|
|
24
|
+
return scale * x;
|
|
25
|
+
}
|
|
26
|
+
};
|
|
27
|
+
|
|
28
|
+
struct MaxOp {
|
|
29
|
+
template <typename T>
|
|
30
|
+
METAL_FUNC static constexpr T apply(T x, T y) {
|
|
31
|
+
return metal::max(x, y);
|
|
32
|
+
}
|
|
33
|
+
};
|
|
34
|
+
|
|
35
|
+
struct SumOp {
|
|
36
|
+
template <typename T>
|
|
37
|
+
METAL_FUNC static constexpr T apply(T x, T y) {
|
|
38
|
+
return x + y;
|
|
39
|
+
}
|
|
40
|
+
};
|
|
41
|
+
|
|
42
|
+
struct MulOp {
|
|
43
|
+
template <typename T>
|
|
44
|
+
METAL_FUNC static constexpr T apply(T x, T y) {
|
|
45
|
+
return x * y;
|
|
46
|
+
}
|
|
47
|
+
};
|
|
48
|
+
|
|
49
|
+
struct SubOp {
|
|
50
|
+
template <typename T>
|
|
51
|
+
METAL_FUNC static constexpr T apply(T x, T y) {
|
|
52
|
+
return x - y;
|
|
53
|
+
}
|
|
54
|
+
};
|
|
55
|
+
|
|
56
|
+
struct ExpSubOp {
|
|
57
|
+
template <typename T>
|
|
58
|
+
METAL_FUNC static constexpr T apply(T x, T y) {
|
|
59
|
+
return fast::exp2(x - y);
|
|
60
|
+
}
|
|
61
|
+
};
|
|
62
|
+
|
|
63
|
+
struct DivOp {
|
|
64
|
+
template <typename T>
|
|
65
|
+
METAL_FUNC static constexpr T apply(T x, T y) {
|
|
66
|
+
return x / y;
|
|
67
|
+
}
|
|
68
|
+
};
|
|
69
|
+
|
|
70
|
+
// clang-format off
|
|
71
|
+
template <
|
|
72
|
+
typename T,
|
|
73
|
+
int BQ,
|
|
74
|
+
int BK,
|
|
75
|
+
int BD,
|
|
76
|
+
int WM,
|
|
77
|
+
int WN,
|
|
78
|
+
typename MaskType = float,
|
|
79
|
+
typename AccumType = float>
|
|
80
|
+
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention(
|
|
81
|
+
const device T* Q [[buffer(0)]],
|
|
82
|
+
const device T* K [[buffer(1)]],
|
|
83
|
+
const device T* V [[buffer(2)]],
|
|
84
|
+
device T* O [[buffer(3)]],
|
|
85
|
+
const constant AttnParams* params [[buffer(4)]],
|
|
86
|
+
const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]],
|
|
87
|
+
const device MaskType* mask [[buffer(6), function_constant(has_mask)]],
|
|
88
|
+
const device T* sinks [[buffer(7), function_constant(has_sinks)]],
|
|
89
|
+
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
90
|
+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
91
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
92
|
+
uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
|
|
93
|
+
|
|
94
|
+
// Pacifying compiler
|
|
95
|
+
(void)lid;
|
|
96
|
+
|
|
97
|
+
// Move to correct block
|
|
98
|
+
ulong3 tidl{tid.x, tid.y, tid.z};
|
|
99
|
+
|
|
100
|
+
Q += tidl.z * params->Q_strides[0] + // Batch
|
|
101
|
+
tidl.y * params->Q_strides[1] + // Head
|
|
102
|
+
tidl.x * BQ * params->Q_strides[2]; // Sequence
|
|
103
|
+
|
|
104
|
+
ulong kv_head_idx = int(tid.y) / params->gqa_factor;
|
|
105
|
+
K += tidl.z * params->K_strides[0] + // Batch
|
|
106
|
+
kv_head_idx * params->K_strides[1]; // Head
|
|
107
|
+
|
|
108
|
+
V += tidl.z * params->V_strides[0] + // Batch
|
|
109
|
+
kv_head_idx * params->V_strides[1]; // Head
|
|
110
|
+
|
|
111
|
+
O += tidl.z * params->O_strides[0] + // Batch
|
|
112
|
+
tidl.y * params->O_strides[1] + // Head
|
|
113
|
+
tidl.x * BQ * params->O_strides[2]; // Sequence
|
|
114
|
+
|
|
115
|
+
if (has_mask) {
|
|
116
|
+
mask += tidl.z * mask_params->M_strides[0] + // Batch
|
|
117
|
+
tidl.y * mask_params->M_strides[1]; // Head
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
// Prepare threadgroup memory
|
|
121
|
+
constexpr short padQ = 16 / sizeof(T);
|
|
122
|
+
constexpr short padK = 16 / sizeof(T);
|
|
123
|
+
constexpr short padV = 16 / sizeof(T);
|
|
124
|
+
|
|
125
|
+
constexpr short LDQ_tgp = BD + padQ;
|
|
126
|
+
constexpr short LDK_tgp = BK + padK;
|
|
127
|
+
constexpr short LDV_tgp = BD + padV;
|
|
128
|
+
|
|
129
|
+
constexpr short tgp_mem_0 = (BK + padK) * (BD);
|
|
130
|
+
constexpr short tgp_mem_1 = BK * (BD + padV);
|
|
131
|
+
constexpr short tgp_mem_s = tgp_mem_0 > tgp_mem_1 ? tgp_mem_0 : tgp_mem_1;
|
|
132
|
+
|
|
133
|
+
threadgroup T Q_smem[BQ * (BD + padQ)];
|
|
134
|
+
threadgroup T KV_smem[tgp_mem_s];
|
|
135
|
+
|
|
136
|
+
threadgroup T* Qs = Q_smem;
|
|
137
|
+
threadgroup T* Ks = KV_smem;
|
|
138
|
+
threadgroup T* Vs = KV_smem;
|
|
139
|
+
|
|
140
|
+
// Prepare block loaders
|
|
141
|
+
using QBlockLoader = BlockLoaderT<
|
|
142
|
+
/* typename T = */ T,
|
|
143
|
+
/* short BROWS = */ BQ,
|
|
144
|
+
/* short BCOLS = */ BD,
|
|
145
|
+
/* short kDstStrRow = */ LDQ_tgp,
|
|
146
|
+
/* short kDstStrCol = */ 1,
|
|
147
|
+
/* short reduction_dim = */ 1,
|
|
148
|
+
/* short tgp_size = */ WM * WN * 32>;
|
|
149
|
+
|
|
150
|
+
// K is loaded in transposed
|
|
151
|
+
using KBlockLoader = BlockLoaderT<
|
|
152
|
+
/* typename T = */ T,
|
|
153
|
+
/* short BROWS = */ BK,
|
|
154
|
+
/* short BCOLS = */ BD,
|
|
155
|
+
/* short kDstStrRow = */ 1,
|
|
156
|
+
/* short kDstStrCol = */ LDK_tgp,
|
|
157
|
+
/* short reduction_dim = */ 0,
|
|
158
|
+
/* short tgp_size = */ WM * WN * 32>;
|
|
159
|
+
|
|
160
|
+
using VBlockLoader = BlockLoaderT<
|
|
161
|
+
/* typename T = */ T,
|
|
162
|
+
/* short BROWS = */ BK,
|
|
163
|
+
/* short BCOLS = */ BD,
|
|
164
|
+
/* short kDstStrRow = */ LDV_tgp,
|
|
165
|
+
/* short kDstStrCol = */ 1,
|
|
166
|
+
/* short reduction_dim = */ 0,
|
|
167
|
+
/* short tgp_size = */ WM * WN * 32>;
|
|
168
|
+
|
|
169
|
+
QBlockLoader loader_q(
|
|
170
|
+
Q, params->Q_strides[2], Qs, simd_group_id, simd_lane_id);
|
|
171
|
+
KBlockLoader loader_k(
|
|
172
|
+
K, params->K_strides[2], Ks, simd_group_id, simd_lane_id);
|
|
173
|
+
VBlockLoader loader_v(
|
|
174
|
+
V, params->V_strides[2], Vs, simd_group_id, simd_lane_id);
|
|
175
|
+
|
|
176
|
+
TransformScale<T> ts(static_cast<T>(params->scale * M_LOG2E_F));
|
|
177
|
+
|
|
178
|
+
// Prepare MMA tiles
|
|
179
|
+
constexpr short kFragSize = 8; // MMAFrag size
|
|
180
|
+
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
|
|
181
|
+
|
|
182
|
+
constexpr int kNWarps = WM * WN;
|
|
183
|
+
static_assert(
|
|
184
|
+
BQ >= (kNWarps * kFragSize) && BQ % (kNWarps * kFragSize) == 0,
|
|
185
|
+
"Each simdgroup must host atleast 1 simdgroup matrix along Q sequence.");
|
|
186
|
+
|
|
187
|
+
// Q seq frags per warp
|
|
188
|
+
constexpr int TQ = BQ / (kNWarps * kFragSize);
|
|
189
|
+
// KV sequence frags (all warps load the same frags)
|
|
190
|
+
constexpr int TK = BK / kFragSize;
|
|
191
|
+
// HeadDim frags (all warps load the same frags)
|
|
192
|
+
constexpr int TD = BD / kFragSize;
|
|
193
|
+
|
|
194
|
+
static_assert(TQ == 1, "Check TQ");
|
|
195
|
+
|
|
196
|
+
MMATile<AccumType, TQ, 1, MMAFrag_acc_t> Qtile;
|
|
197
|
+
MMATile<AccumType, 1, TK, MMAFrag_acc_t> Ktile;
|
|
198
|
+
MMATile<AccumType, TQ, TK, MMAFrag_acc_t> Stile;
|
|
199
|
+
MMATile<AccumType, 1, 1, MMAFrag_acc_t> Vtile;
|
|
200
|
+
MMATile<AccumType, TQ, TD, MMAFrag_acc_t> Otile;
|
|
201
|
+
|
|
202
|
+
Otile.clear();
|
|
203
|
+
|
|
204
|
+
// Prepare mma tile offsets
|
|
205
|
+
const short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
|
|
206
|
+
const short sm = simd_coord.y;
|
|
207
|
+
const short sn = simd_coord.x;
|
|
208
|
+
const short tm = kFragSize * TQ * simd_group_id;
|
|
209
|
+
|
|
210
|
+
const short Qs_offset = (tm + sm) * LDQ_tgp + sn;
|
|
211
|
+
const short Ks_offset = sm * LDK_tgp + sn;
|
|
212
|
+
const short Vs_offset = sm * LDV_tgp + sn;
|
|
213
|
+
|
|
214
|
+
constexpr short Qs_tile_stride = kFragSize;
|
|
215
|
+
constexpr short Ks_tile_stride = kFragSize * LDK_tgp;
|
|
216
|
+
|
|
217
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
218
|
+
|
|
219
|
+
// Load Q blocks apply scale
|
|
220
|
+
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
|
|
221
|
+
loader_q.load_safe(short2(BD, params->qL_rem));
|
|
222
|
+
} else {
|
|
223
|
+
loader_q.load_unsafe();
|
|
224
|
+
}
|
|
225
|
+
loader_q.apply_inplace_op(ts);
|
|
226
|
+
|
|
227
|
+
// Init row reduction variables
|
|
228
|
+
constexpr short kRowsPT = decltype(Stile)::kRowsPerThread;
|
|
229
|
+
|
|
230
|
+
AccumType max_score[kRowsPT];
|
|
231
|
+
AccumType sum_score[kRowsPT] = {0};
|
|
232
|
+
|
|
233
|
+
// Init to -Inf
|
|
234
|
+
STEEL_PRAGMA_UNROLL
|
|
235
|
+
for (short i = 0; i < kRowsPT; ++i) {
|
|
236
|
+
max_score[i] = Limits<AccumType>::finite_min;
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
if (has_sinks) {
|
|
240
|
+
STEEL_PRAGMA_UNROLL
|
|
241
|
+
for (short i = 0; i < kRowsPT; ++i) {
|
|
242
|
+
max_score[i] = M_LOG2E_F * static_cast<AccumType>(sinks[tidl.y]);
|
|
243
|
+
sum_score[i] = 1;
|
|
244
|
+
}
|
|
245
|
+
}
|
|
246
|
+
|
|
247
|
+
int kb_lim = params->NK;
|
|
248
|
+
|
|
249
|
+
if (do_causal) {
|
|
250
|
+
int q_max = (tid.x + 1) * BQ + params->qL_off;
|
|
251
|
+
kb_lim = (q_max + BK - 1) / BK;
|
|
252
|
+
kb_lim = min(params->NK, kb_lim);
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
// Loop over KV seq length
|
|
256
|
+
for (int kb = 0; kb < kb_lim; kb++) {
|
|
257
|
+
// Load K block and apply scale
|
|
258
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
259
|
+
if (!align_K && kb == (params->NK_aligned)) {
|
|
260
|
+
loader_k.load_safe(short2(BD, params->kL_rem));
|
|
261
|
+
} else {
|
|
262
|
+
loader_k.load_unsafe();
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
// Do S = Q @ K.T
|
|
266
|
+
Stile.clear();
|
|
267
|
+
|
|
268
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
269
|
+
|
|
270
|
+
STEEL_PRAGMA_UNROLL
|
|
271
|
+
for (short dd = 0; dd < TD; dd++) {
|
|
272
|
+
simdgroup_barrier(mem_flags::mem_none);
|
|
273
|
+
|
|
274
|
+
Qtile.template load<T, 1, 1, LDQ_tgp, 1>(
|
|
275
|
+
&Qs[Qs_offset + dd * Qs_tile_stride]);
|
|
276
|
+
Ktile.template load<T, 1, 1, LDK_tgp, 1>(
|
|
277
|
+
&Ks[Ks_offset + dd * Ks_tile_stride]);
|
|
278
|
+
|
|
279
|
+
simdgroup_barrier(mem_flags::mem_none);
|
|
280
|
+
|
|
281
|
+
tile_matmad(Stile, Qtile, Ktile, Stile);
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
// Mask out length sequence
|
|
285
|
+
if (!align_K && kb == (params->NK_aligned)) {
|
|
286
|
+
using stile_t = decltype(Stile);
|
|
287
|
+
using selem_t = typename stile_t::elem_type;
|
|
288
|
+
constexpr auto neg_inf = Limits<selem_t>::finite_min;
|
|
289
|
+
|
|
290
|
+
STEEL_PRAGMA_UNROLL
|
|
291
|
+
for (short i = 0; i < stile_t::kTileRows; i++) {
|
|
292
|
+
STEEL_PRAGMA_UNROLL
|
|
293
|
+
for (short j = 0; j < stile_t::kTileCols; j++) {
|
|
294
|
+
short col_pos = sn + (j * stile_t::kFragCols);
|
|
295
|
+
STEEL_PRAGMA_UNROLL
|
|
296
|
+
for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
|
|
297
|
+
if ((col_pos + jj) >= params->kL_rem) {
|
|
298
|
+
Stile.frag_at(i, j)[jj] = neg_inf;
|
|
299
|
+
}
|
|
300
|
+
}
|
|
301
|
+
}
|
|
302
|
+
}
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
// Mask out if causal
|
|
306
|
+
if (do_causal && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) {
|
|
307
|
+
using stile_t = decltype(Stile);
|
|
308
|
+
using selem_t = typename stile_t::elem_type;
|
|
309
|
+
constexpr auto neg_inf = Limits<selem_t>::finite_min;
|
|
310
|
+
|
|
311
|
+
STEEL_PRAGMA_UNROLL
|
|
312
|
+
for (short i = 0; i < stile_t::kTileRows; i++) {
|
|
313
|
+
const int row_pos =
|
|
314
|
+
tid.x * BQ + params->qL_off + tm + sm + (i * stile_t::kFragRows);
|
|
315
|
+
STEEL_PRAGMA_UNROLL
|
|
316
|
+
for (short j = 0; j < stile_t::kTileCols; j++) {
|
|
317
|
+
const int col_pos = kb * BK + sn + (j * stile_t::kFragCols);
|
|
318
|
+
STEEL_PRAGMA_UNROLL
|
|
319
|
+
for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
|
|
320
|
+
if (row_pos < (col_pos + jj)) {
|
|
321
|
+
Stile.frag_at(i, j)[jj] = neg_inf;
|
|
322
|
+
}
|
|
323
|
+
}
|
|
324
|
+
}
|
|
325
|
+
}
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
// Other masking as needed
|
|
329
|
+
if (has_mask) {
|
|
330
|
+
using stile_t = decltype(Stile);
|
|
331
|
+
using selem_t = typename stile_t::elem_type;
|
|
332
|
+
constexpr auto neg_inf = Limits<selem_t>::finite_min;
|
|
333
|
+
|
|
334
|
+
constexpr bool is_bool = is_same_v<MaskType, bool>;
|
|
335
|
+
using melem_t = typename metal::conditional_t<is_bool, bool, selem_t>;
|
|
336
|
+
|
|
337
|
+
using MMAFrag_mask_t = BaseMMAFrag<melem_t, kFragSize, kFragSize>;
|
|
338
|
+
using frag_t = typename MMAFrag_mask_t::frag_type;
|
|
339
|
+
|
|
340
|
+
STEEL_PRAGMA_UNROLL
|
|
341
|
+
for (short i = 0; i < stile_t::kTileRows; i++) {
|
|
342
|
+
const int row_pos = tid.x * BQ + tm + sm + (i * stile_t::kFragRows);
|
|
343
|
+
STEEL_PRAGMA_UNROLL
|
|
344
|
+
for (short j = 0; j < stile_t::kTileCols; j++) {
|
|
345
|
+
const int col_pos = kb * BK + sn + (j * stile_t::kFragCols);
|
|
346
|
+
|
|
347
|
+
frag_t mfrag;
|
|
348
|
+
|
|
349
|
+
MMAFrag_mask_t::load_safe(
|
|
350
|
+
mfrag,
|
|
351
|
+
mask,
|
|
352
|
+
int64_t(mask_params->M_strides[2]),
|
|
353
|
+
Int<1>{},
|
|
354
|
+
params->qL,
|
|
355
|
+
params->kL,
|
|
356
|
+
row_pos,
|
|
357
|
+
col_pos);
|
|
358
|
+
|
|
359
|
+
STEEL_PRAGMA_UNROLL
|
|
360
|
+
for (short jj = 0; jj < stile_t::MMAFrag_t::kElemsPerFrag; jj++) {
|
|
361
|
+
if constexpr (is_bool) {
|
|
362
|
+
Stile.frag_at(i, j)[jj] =
|
|
363
|
+
mfrag[jj] ? Stile.frag_at(i, j)[jj] : neg_inf;
|
|
364
|
+
} else {
|
|
365
|
+
Stile.frag_at(i, j)[jj] += M_LOG2E_F * selem_t(mfrag[jj]);
|
|
366
|
+
}
|
|
367
|
+
}
|
|
368
|
+
}
|
|
369
|
+
}
|
|
370
|
+
}
|
|
371
|
+
|
|
372
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
373
|
+
|
|
374
|
+
// Load V blocks
|
|
375
|
+
if (!align_K && kb == (params->NK_aligned)) {
|
|
376
|
+
loader_v.load_safe(short2(BD, params->kL_rem));
|
|
377
|
+
} else {
|
|
378
|
+
loader_v.load_unsafe();
|
|
379
|
+
}
|
|
380
|
+
|
|
381
|
+
// Do softmax
|
|
382
|
+
|
|
383
|
+
// Temp variables
|
|
384
|
+
AccumType new_max[kRowsPT];
|
|
385
|
+
AccumType factor[kRowsPT];
|
|
386
|
+
STEEL_PRAGMA_UNROLL
|
|
387
|
+
for (short i = 0; i < kRowsPT; ++i) {
|
|
388
|
+
new_max[i] = max_score[i];
|
|
389
|
+
}
|
|
390
|
+
|
|
391
|
+
// Row max
|
|
392
|
+
Stile.template row_reduce<MaxOp>(new_max);
|
|
393
|
+
|
|
394
|
+
// exp(Si - rowmax(Si))
|
|
395
|
+
Stile.template row_bin_op<ExpSubOp>(new_max);
|
|
396
|
+
|
|
397
|
+
// Factor exp(rowmax(Si) - rowmax(Si-1))
|
|
398
|
+
STEEL_PRAGMA_UNROLL
|
|
399
|
+
for (short i = 0; i < kRowsPT; ++i) {
|
|
400
|
+
factor[i] = fast::exp2(max_score[i] - new_max[i]);
|
|
401
|
+
}
|
|
402
|
+
|
|
403
|
+
// Save max for next iteration
|
|
404
|
+
STEEL_PRAGMA_UNROLL
|
|
405
|
+
for (short i = 0; i < kRowsPT; ++i) {
|
|
406
|
+
max_score[i] = new_max[i];
|
|
407
|
+
}
|
|
408
|
+
|
|
409
|
+
// Row Sum
|
|
410
|
+
AccumType sum_score_tmp[kRowsPT] = {0};
|
|
411
|
+
Stile.template row_reduce<SumOp>(sum_score_tmp);
|
|
412
|
+
|
|
413
|
+
// Update norm
|
|
414
|
+
STEEL_PRAGMA_UNROLL
|
|
415
|
+
for (short i = 0; i < kRowsPT; ++i) {
|
|
416
|
+
sum_score[i] = sum_score[i] * factor[i] + sum_score_tmp[i];
|
|
417
|
+
}
|
|
418
|
+
|
|
419
|
+
// Update O
|
|
420
|
+
Otile.template row_bin_op<MulOp>(factor);
|
|
421
|
+
|
|
422
|
+
// Load V into registers
|
|
423
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
424
|
+
|
|
425
|
+
STEEL_PRAGMA_UNROLL
|
|
426
|
+
for (short iq = 0; iq < TQ; iq++) {
|
|
427
|
+
STEEL_PRAGMA_UNROLL
|
|
428
|
+
for (short id = 0; id < TD; id++) {
|
|
429
|
+
STEEL_PRAGMA_UNROLL
|
|
430
|
+
for (short ik = 0; ik < TK; ik++) {
|
|
431
|
+
if constexpr (BD == 128) {
|
|
432
|
+
simdgroup_barrier(mem_flags::mem_none);
|
|
433
|
+
}
|
|
434
|
+
|
|
435
|
+
const short kk = ik * kFragSize;
|
|
436
|
+
const short dd = id * kFragSize;
|
|
437
|
+
|
|
438
|
+
Vtile.template load<T, 1, 1, LDV_tgp, 1>(
|
|
439
|
+
&Vs[Vs_offset + kk * LDV_tgp + dd]);
|
|
440
|
+
|
|
441
|
+
if constexpr (BD == 128) {
|
|
442
|
+
simdgroup_barrier(mem_flags::mem_none);
|
|
443
|
+
}
|
|
444
|
+
|
|
445
|
+
MMAFrag_acc_t::mma(
|
|
446
|
+
Otile.frag_at(iq, id),
|
|
447
|
+
Stile.frag_at(iq, ik),
|
|
448
|
+
Vtile.frag_at(0, 0),
|
|
449
|
+
Otile.frag_at(iq, id));
|
|
450
|
+
}
|
|
451
|
+
}
|
|
452
|
+
}
|
|
453
|
+
|
|
454
|
+
// Prepare for next iteration
|
|
455
|
+
loader_k.next();
|
|
456
|
+
loader_v.next();
|
|
457
|
+
}
|
|
458
|
+
|
|
459
|
+
// Normalize output
|
|
460
|
+
Otile.template row_bin_op<DivOp>(sum_score);
|
|
461
|
+
threadgroup_barrier(mem_flags::mem_none);
|
|
462
|
+
|
|
463
|
+
// Store results
|
|
464
|
+
O += (tm + sm) * params->O_strides[2] + sn;
|
|
465
|
+
|
|
466
|
+
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
|
|
467
|
+
auto dst_tile_dims = short2(BD - sn, params->qL_rem - (tm + sm));
|
|
468
|
+
|
|
469
|
+
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
|
470
|
+
return;
|
|
471
|
+
|
|
472
|
+
Otile.template store_safe<T, 1, 1>(O, params->O_strides[2], dst_tile_dims);
|
|
473
|
+
} else {
|
|
474
|
+
Otile.template store<T, 1, 1>(O, params->O_strides[2]);
|
|
475
|
+
}
|
|
476
|
+
}
|